Merge branch 'evals' of https://github.com/buster-so/buster into evals

This commit is contained in:
Nate Kelley 2025-04-16 15:45:24 -06:00
commit db60c686e6
No known key found for this signature in database
GPG Key ID: FD90372AB8D98B4F
7 changed files with 61 additions and 27 deletions

View File

@ -1,5 +1,6 @@
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use database::enums::MessageFeedback;
use diesel::prelude::Queryable; use diesel::prelude::Queryable;
use diesel::{ExpressionMethods, JoinOnDsl, NullableExpressionMethods, QueryDsl}; use diesel::{ExpressionMethods, JoinOnDsl, NullableExpressionMethods, QueryDsl};
use diesel_async::RunQueryDsl; use diesel_async::RunQueryDsl;
@ -41,6 +42,7 @@ pub struct MessageWithUser {
pub user_id: Uuid, pub user_id: Uuid,
pub user_name: Option<String>, pub user_name: Option<String>,
pub user_attributes: Value, pub user_attributes: Value,
pub feedback: Option<String>,
} }
#[derive(Queryable)] #[derive(Queryable)]
@ -122,6 +124,7 @@ pub async fn get_chat_handler(
users::id, users::id,
users::name.nullable(), users::name.nullable(),
users::attributes, users::attributes,
messages::feedback.nullable(),
)) ))
.load::<MessageWithUser>(&mut conn) .load::<MessageWithUser>(&mut conn)
.await .await
@ -275,6 +278,7 @@ pub async fn get_chat_handler(
reasoning, reasoning,
msg.final_reasoning_message, msg.final_reasoning_message,
msg.created_at, msg.created_at,
msg.feedback,
) )
}) })
.collect(); .collect();

View File

@ -298,6 +298,7 @@ pub async fn post_chat_handler(
vec![], vec![],
None, None,
message.created_at, message.created_at,
None
); );
chat_with_messages.add_message(chat_message); chat_with_messages.add_message(chat_message);
@ -817,6 +818,7 @@ pub async fn post_chat_handler(
reasoning_messages.clone(), reasoning_messages.clone(),
Some(formatted_final_reasoning_duration.clone()), // Use formatted reasoning duration Some(formatted_final_reasoning_duration.clone()), // Use formatted reasoning duration
Utc::now(), Utc::now(),
None,
); );
chat_with_messages.update_message(final_message); chat_with_messages.update_message(final_message);
@ -2442,6 +2444,7 @@ async fn initialize_chat(
Vec::new(), Vec::new(),
None, None,
Utc::now(), Utc::now(),
None,
); );
// Add message to existing chat // Add message to existing chat
@ -2481,6 +2484,7 @@ async fn initialize_chat(
Vec::new(), Vec::new(),
None, None,
Utc::now(), Utc::now(),
None,
); );
let mut chat_with_messages = ChatWithMessages::new( let mut chat_with_messages = ChatWithMessages::new(

View File

@ -4,11 +4,8 @@ use database::{pool::get_pg_pool, schema::messages};
use diesel::prelude::*; use diesel::prelude::*;
use diesel_async::RunQueryDsl; use diesel_async::RunQueryDsl;
use middleware::AuthenticatedUser; use middleware::AuthenticatedUser;
use std::str::FromStr;
use uuid::Uuid; use uuid::Uuid;
use crate::messages::types::MessageFeedback;
/// Update a message with new properties /// Update a message with new properties
/// ///
/// # Arguments /// # Arguments
@ -47,15 +44,11 @@ pub async fn update_message_handler(
// Add feedback if provided // Add feedback if provided
if let Some(fb_str) = feedback { if let Some(fb_str) = feedback {
// Validate feedback value
let feedback = MessageFeedback::from_str(&fb_str)
.map_err(|e| anyhow!(e))?;
// Update the feedback column directly // Update the feedback column directly
update_statement update_statement
.set(( .set((
messages::updated_at.eq(Utc::now()), messages::updated_at.eq(Utc::now()),
messages::feedback.eq(feedback.to_string()) messages::feedback.eq(fb_str)
)) ))
.execute(&mut conn) .execute(&mut conn)
.await?; .await?;

View File

@ -6,7 +6,6 @@ use serde_json::Value;
use uuid::Uuid; use uuid::Uuid;
pub mod message_feedback; pub mod message_feedback;
pub use message_feedback::MessageFeedback;
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ChatMessage { pub struct ChatMessage {
@ -20,6 +19,7 @@ pub struct ChatMessage {
pub reasoning_messages: HashMap<String, Value>, pub reasoning_messages: HashMap<String, Value>,
pub created_at: chrono::DateTime<chrono::Utc>, pub created_at: chrono::DateTime<chrono::Utc>,
pub final_reasoning_message: Option<String>, pub final_reasoning_message: Option<String>,
pub feedback: Option<String>,
} }
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
@ -51,6 +51,7 @@ impl ChatMessage {
reasoning_messages: HashMap::new(), reasoning_messages: HashMap::new(),
created_at: Utc::now(), created_at: Utc::now(),
final_reasoning_message: None, final_reasoning_message: None,
feedback: None,
} }
} }
@ -61,6 +62,7 @@ impl ChatMessage {
reasoning_messages: Vec<Value>, reasoning_messages: Vec<Value>,
final_reasoning_message: Option<String>, final_reasoning_message: Option<String>,
created_at: chrono::DateTime<chrono::Utc>, created_at: chrono::DateTime<chrono::Utc>,
feedback: Option<String>,
) -> Self { ) -> Self {
let response_message_ids: Vec<String> = response_messages let response_message_ids: Vec<String> = response_messages
.iter() .iter()
@ -97,6 +99,7 @@ impl ChatMessage {
reasoning_messages: reasoning_messages_map, reasoning_messages: reasoning_messages_map,
created_at, created_at,
final_reasoning_message, final_reasoning_message,
feedback,
} }
} }
} }

View File

@ -2,11 +2,11 @@ use anyhow::{Context, Result};
use chrono::Utc; use chrono::Utc;
use database::{ use database::{
self, self,
enums::{UserOrganizationRole, UserOrganizationStatus, SharingSetting}, enums::{SharingSetting, UserOrganizationRole, UserOrganizationStatus},
models::{User, UserToOrganization}, models::{User, UserToOrganization},
pool::get_pg_pool,
schema::{users, users_to_organizations}, schema::{users, users_to_organizations},
}; };
use diesel::prelude::*;
use diesel_async::{AsyncPgConnection, RunQueryDsl}; use diesel_async::{AsyncPgConnection, RunQueryDsl};
use middleware::AuthenticatedUser; use middleware::AuthenticatedUser;
use serde_json::json; use serde_json::json;
@ -17,7 +17,6 @@ use uuid::Uuid;
pub async fn invite_user_handler( pub async fn invite_user_handler(
inviting_user: &AuthenticatedUser, inviting_user: &AuthenticatedUser,
emails: Vec<String>, emails: Vec<String>,
conn: &mut AsyncPgConnection, // Accept the connection directly
) -> Result<()> { ) -> Result<()> {
let organization_id = inviting_user let organization_id = inviting_user
.organizations .organizations
@ -52,12 +51,24 @@ pub async fn invite_user_handler(
avatar_url: None, avatar_url: None,
}; };
let mut conn = match get_pg_pool().get().await {
Ok(mut conn) => conn,
Err(e) => {
return Err(e.into());
}
};
// 3. Insert user // 3. Insert user
diesel::insert_into(users::table) match diesel::insert_into(users::table)
.values(&user_to_insert) .values(&user_to_insert)
.execute(conn) .execute(&mut conn)
.await .await
.context("Failed to insert new user")?; {
Ok(_) => (),
Err(e) => {
return Err(e.into());
}
};
// 4. Create UserToOrganization struct instance // 4. Create UserToOrganization struct instance
let user_org_to_insert = UserToOrganization { let user_org_to_insert = UserToOrganization {
@ -79,11 +90,16 @@ pub async fn invite_user_handler(
}; };
// 5. Insert user organization mapping // 5. Insert user organization mapping
diesel::insert_into(users_to_organizations::table) match diesel::insert_into(users_to_organizations::table)
.values(&user_org_to_insert) .values(&user_org_to_insert)
.execute(conn) .execute(&mut conn)
.await .await
.context("Failed to map user to organization")?; {
Ok(_) => (),
Err(e) => {
return Err(e.into());
}
};
} }
Ok(()) Ok(())

View File

@ -1 +1,3 @@
pub mod invite_user_handler; pub mod invite_user_handler;
pub use invite_user_handler::*;

View File

@ -1,11 +1,12 @@
use anyhow::Result; use anyhow::Result;
use axum::{Extension, Json}; use axum::{Extension, Json};
use handlers::users::invite_user_handler;
use crate::routes::rest::ApiResponse; use crate::routes::rest::ApiResponse;
use axum::http::StatusCode; use axum::http::StatusCode;
use database::enums::UserOrganizationRole;
use middleware::AuthenticatedUser; use middleware::AuthenticatedUser;
use serde::Deserialize; use serde::Deserialize;
use tracing::error;
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct InviteUsersRequest { pub struct InviteUsersRequest {
@ -16,5 +17,16 @@ pub async fn invite_users(
Extension(user): Extension<AuthenticatedUser>, Extension(user): Extension<AuthenticatedUser>,
Json(body): Json<InviteUsersRequest>, Json(body): Json<InviteUsersRequest>,
) -> Result<ApiResponse<()>, (StatusCode, &'static str)> { ) -> Result<ApiResponse<()>, (StatusCode, &'static str)> {
Ok(ApiResponse::NoContent) let result = invite_user_handler(&user, body.emails).await;
match result {
Ok(_) => Ok(ApiResponse::NoContent),
Err(e) => {
error!("Failed to invite users: {}", e);
Err((
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to process invitation request",
))
}
}
} }