Merge branch 'evals' of https://github.com/buster-so/buster into evals

This commit is contained in:
Nate Kelley 2025-04-09 14:00:50 -06:00
commit bcbe5d6f97
No known key found for this signature in database
GPG Key ID: FD90372AB8D98B4F
1 changed files with 460 additions and 88 deletions

View File

@ -1,5 +1,6 @@
use dashmap::DashMap;
use middleware::AuthenticatedUser;
use std::collections::HashSet;
use std::{collections::HashMap, time::Instant};
use agents::{
@ -48,6 +49,15 @@ use crate::messages::types::{ChatMessage, ChatUserMessage};
use super::types::ChatWithMessages;
use tokio::sync::mpsc;
use database::types::dashboard_yml::DashboardYml;
use database::pool::PgPool;
use diesel::OptionalExtension;
use diesel_async::AsyncPgConnection; // Import PgPool
// Add imports for version history types
use database::types::version_history::{VersionContent, VersionHistory};
// Define the helper struct at the module level
#[derive(Debug, Clone)]
struct CompletedFileInfo {
@ -55,6 +65,7 @@ struct CompletedFileInfo {
file_type: String, // "metric" or "dashboard"
file_name: String,
version_number: i32,
content: String, // Added to store file content for parsing
}
// Define ThreadEvent
@ -215,45 +226,48 @@ pub async fn post_chat_handler(
.values(&message)
.execute(&mut conn)
.await?;
// After message is inserted, create file association if needed
if message.response_messages.is_array() {
let response_arr = message.response_messages.as_array().unwrap();
// Find a file response in the array
for response in response_arr {
if response.get("type").map_or(false, |t| t == "file") {
// Extract version_number from response, default to 1 if not found
let asset_version_number = response.get("version_number")
let asset_version_number = response
.get("version_number")
.and_then(|v| v.as_i64())
.map(|v| v as i32)
.unwrap_or(1);
// Ensure the response id matches the asset_id
let response_id = response.get("id")
let response_id = response
.get("id")
.and_then(|id| id.as_str())
.and_then(|id_str| Uuid::parse_str(id_str).ok())
.unwrap_or(asset_id_value);
// Verify the response ID matches the asset ID
if response_id == asset_id_value {
// Create association in database - now the message exists in DB
if let Err(e) = create_message_file_association(
message.id,
asset_id_value,
asset_version_number,
asset_type_value,
)
.await {
tracing::warn!("Failed to create message file association: {}", e);
}
// Verify the response ID matches the asset ID
if response_id == asset_id_value {
// Create association in database - now the message exists in DB
if let Err(e) = create_message_file_association(
message.id,
asset_id_value,
asset_version_number,
asset_type_value,
)
.await
{
tracing::warn!("Failed to create message file association: {}", e);
}
// We only need to process one file association
break;
}
// We only need to process one file association
break;
}
}
}
// Add to updated messages for the response
updated_messages.push(message);
@ -277,7 +291,7 @@ pub async fn post_chat_handler(
);
chat_with_messages.add_message(chat_message);
// We don't need to process the raw_llm_messages here
// The ChatContextLoader.update_context_from_tool_calls function will handle the asset state
// when the agent is initialized and loads the context
@ -290,7 +304,7 @@ pub async fn post_chat_handler(
AssetType::DashboardFile => Some("dashboard".to_string()),
_ => None,
};
if let Some(file_type) = asset_type_string {
// Update the chat directly to ensure it has the most_recent_file information
let mut conn = get_pg_pool().get().await?;
@ -305,10 +319,12 @@ pub async fn post_chat_handler(
tracing::info!(
"Updated chat {} with most_recent_file_id: {}, most_recent_file_type: {}",
chat_id, asset_id_value, file_type
chat_id,
asset_id_value,
file_type
);
}
// Return early with auto-generated messages - no need for agent processing
return Ok(chat_with_messages);
}
@ -378,6 +394,42 @@ pub async fn post_chat_handler(
let mut sent_initial_files = false; // Flag to track if initial files have been sent
let mut early_sent_file_messages: Vec<Value> = Vec::new(); // Store file messages sent early
// --- START: Load History and Find Context Dashboard ---
let mut context_dashboard_id: Option<Uuid> = None;
if let Some(existing_chat_id) = request.chat_id {
// Fetch the most recent message for this chat to find the last dashboard shown
let pool = get_pg_pool();
let mut conn = pool.get().await?;
let last_message_result = messages::table
.filter(messages::chat_id.eq(existing_chat_id))
.order(messages::created_at.desc())
.first::<Message>(&mut conn)
.await
.optional()?; // Use optional() to handle chats with no previous messages gracefully
if let Some(last_message) = last_message_result {
if let Ok(last_response_values) =
serde_json::from_value::<Vec<Value>>(last_message.response_messages)
{
for value in last_response_values {
if let Ok(response_msg) = serde_json::from_value::<BusterChatMessage>(value) {
if let BusterChatMessage::File { id, file_type, .. } = response_msg {
if file_type == "dashboard" {
if let Ok(uuid) = Uuid::parse_str(&id) {
context_dashboard_id = Some(uuid);
tracing::debug!("Found context dashboard ID: {}", uuid);
break; // Found the most recent dashboard
}
}
}
}
}
}
}
}
// --- END: Load History and Find Context Dashboard ---
// Process all messages from the agent
while let Ok(message_result) = rx.recv().await {
match message_result {
@ -451,42 +503,91 @@ pub async fn post_chat_handler(
// Store transformed containers BEFORE potential early file sending
// This ensures the files are based on the most up-to-date reasoning
let transformed_results = transform_message(&chat_id, &message_id, msg.clone(), tx.as_ref(), &chunk_tracker).await;
let transformed_results = transform_message(
&chat_id,
&message_id,
msg.clone(),
tx.as_ref(),
&chunk_tracker,
)
.await;
match transformed_results {
Ok(containers) => {
// Store all transformed containers first
all_transformed_containers.extend(containers.iter().map(|(c, _)| c.clone()));
all_transformed_containers
.extend(containers.iter().map(|(c, _)| c.clone()));
// --- START: Early File Sending Logic ---
// Check if this is the first text chunk and we haven't sent files yet
if !sent_initial_files {
// Look for an incoming text chunk within the *current* message `msg`
if let AgentMessage::Assistant { content: Some(_), progress: MessageProgress::InProgress, .. } = &msg {
if let AgentMessage::Assistant {
content: Some(_),
progress: MessageProgress::InProgress,
..
} = &msg
{
if let Some(tx_channel) = &tx {
// Set flag immediately to prevent re-entry
sent_initial_files = true;
// Perform filtering based on containers received SO FAR
let current_completed_files = collect_completed_files(&all_transformed_containers);
let filtered_files = apply_file_filtering_rules(&current_completed_files);
early_sent_file_messages = generate_file_response_values(&filtered_files);
let current_completed_files =
collect_completed_files(&all_transformed_containers);
// Pass context_dashboard_id and pool, await the async function
match apply_file_filtering_rules(
&current_completed_files,
context_dashboard_id,
&get_pg_pool(),
)
.await
{
Ok(filtered_files_info) => {
early_sent_file_messages =
generate_file_response_values(&filtered_files_info);
// Send the filtered file messages FIRST
for file_value in &early_sent_file_messages {
if let Ok(buster_chat_message) = serde_json::from_value::<BusterChatMessage>(file_value.clone()) {
let file_container = BusterContainer::ChatMessage(BusterChatMessageContainer {
response_message: buster_chat_message,
chat_id,
message_id,
});
if tx_channel.send(Ok((file_container, ThreadEvent::GeneratingResponseMessage))).await.is_err() {
tracing::warn!("Client disconnected while sending early file messages");
// Setting the flag ensures we don't retry, but allows loop to continue processing other messages if needed
// Potentially break here if sending is critical: break;
// Send the filtered file messages FIRST
for file_value in &early_sent_file_messages {
if let Ok(buster_chat_message) =
serde_json::from_value::<BusterChatMessage>(
file_value.clone(),
)
{
let file_container =
BusterContainer::ChatMessage(
BusterChatMessageContainer {
response_message:
buster_chat_message,
chat_id,
message_id,
},
);
if tx_channel
.send(Ok((
file_container,
ThreadEvent::GeneratingResponseMessage,
)))
.await
.is_err()
{
tracing::warn!("Client disconnected while sending early file messages");
// Setting the flag ensures we don't retry, but allows loop to continue processing other messages if needed
// Potentially break here if sending is critical: break;
}
} else {
tracing::error!("Failed to deserialize early file message value: {:?}", file_value);
}
}
} else {
tracing::error!("Failed to deserialize early file message value: {:?}", file_value);
}
Err(e) => {
tracing::error!(
"Error applying file filtering rules early: {}",
e
);
// Optionally send an error over tx_channel or handle otherwise
// For now, proceed without sending early files if filtering fails
early_sent_file_messages = vec![]; // Ensure list is empty
}
}
}
@ -496,13 +597,19 @@ pub async fn post_chat_handler(
// Now send the transformed containers for the current message
if let Some(tx_channel) = &tx {
for (container, thread_event) in containers {
if tx_channel.send(Ok((container, thread_event))).await.is_err() {
tracing::warn!("Client disconnected, but continuing to process messages");
// Don't break immediately, allow storing final state
}
}
}
for (container, thread_event) in containers {
if tx_channel
.send(Ok((container, thread_event)))
.await
.is_err()
{
tracing::warn!(
"Client disconnected, but continuing to process messages"
);
// Don't break immediately, allow storing final state
}
}
}
}
Err(e) => {
tracing::error!("Error transforming message: {}", e);
@ -747,28 +854,41 @@ async fn process_completed_files(
// Transform messages again specifically for DB processing if needed,
// or directly use reasoning messages if they contain enough info.
let mut transformed_messages_for_db = Vec::new();
for msg in messages {
// Use a temporary tracker instance if needed, or reuse the main one
if let Ok(containers) = transform_message(
&message.chat_id, &message.id, msg.clone(), None, chunk_tracker
).await {
transformed_messages_for_db.extend(containers.into_iter().map(|(c, _)| c));
}
}
for msg in messages {
// Use a temporary tracker instance if needed, or reuse the main one
if let Ok(containers) = transform_message(
&message.chat_id,
&message.id,
msg.clone(),
None,
chunk_tracker,
)
.await
{
transformed_messages_for_db.extend(containers.into_iter().map(|(c, _)| c));
}
}
let mut processed_file_ids = std::collections::HashSet::new();
for container in transformed_messages_for_db { // Use the re-transformed messages
for container in transformed_messages_for_db {
// Use the re-transformed messages
if let BusterContainer::ReasoningMessage(msg) = container {
match &msg.reasoning {
BusterReasoningMessage::File(file) if file.message_type == "files" && file.status == "completed" => {
BusterReasoningMessage::File(file)
if file.message_type == "files" && file.status == "completed" =>
{
for (file_id_key, file_content) in &file.files {
if file_content.status == "completed" { // Ensure inner file is also complete
if file_content.status == "completed" {
// Ensure inner file is also complete
let file_uuid = match Uuid::parse_str(file_id_key) {
Ok(uuid) => uuid,
Err(_) => {
tracing::warn!("Invalid UUID format for file ID in reasoning: {}", file_id_key);
continue; // Skip this file
tracing::warn!(
"Invalid UUID format for file ID in reasoning: {}",
file_id_key
);
continue; // Skip this file
}
};
@ -793,21 +913,25 @@ async fn process_completed_files(
if let Err(e) = diesel::insert_into(messages_to_files::table)
.values(&message_to_file)
.execute(conn)
.await {
tracing::error!("Failed to insert message_to_file link for file {}: {}", file_uuid, e);
continue; // Skip chat update if DB link fails
}
.await
{
tracing::error!(
"Failed to insert message_to_file link for file {}: {}",
file_uuid,
e
);
continue; // Skip chat update if DB link fails
}
// Determine file type for chat update
let file_type_for_chat = match file_content.file_type.as_str() {
"dashboard" => Some("dashboard".to_string()),
"metric" => Some("metric".to_string()),
_ => None,
};
let file_type_for_chat = match file_content.file_type.as_str() {
"dashboard" => Some("dashboard".to_string()),
"metric" => Some("metric".to_string()),
_ => None,
};
// Update the chat with the most recent file info
if let Err(e) = diesel::update(chats::table.find(message.chat_id))
if let Err(e) = diesel::update(chats::table.find(message.chat_id))
.set((
chats::most_recent_file_id.eq(Some(file_uuid)),
chats::most_recent_file_type.eq(file_type_for_chat),
@ -815,10 +939,11 @@ async fn process_completed_files(
chats::updated_at.eq(Utc::now()),
))
.execute(conn)
.await {
tracing::error!("Failed to update chat {} with most recent file info for {}: {}", message.chat_id, file_uuid, e);
}
}
.await
{
tracing::error!("Failed to update chat {} with most recent file info for {}: {}", message.chat_id, file_uuid, e);
}
}
}
}
_ => (),
@ -2157,11 +2282,15 @@ fn collect_completed_files(containers: &[BusterContainer]) -> Vec<CompletedFileI
if file_reasoning.message_type == "files" && file_reasoning.status == "completed" {
for (_file_id_key, file_detail) in &file_reasoning.files {
if file_detail.status == "completed" {
// Extract content, default to empty string if None
let content = file_detail.file.text.clone().unwrap_or_default();
completed_files.push(CompletedFileInfo {
id: file_detail.id.clone(),
file_type: file_detail.file_type.clone(),
file_name: file_detail.file_name.clone(),
version_number: file_detail.version_number,
content, // Populate the content field
});
}
}
@ -2172,17 +2301,260 @@ fn collect_completed_files(containers: &[BusterContainer]) -> Vec<CompletedFileI
completed_files
}
// Helper function to encapsulate filtering rules
fn apply_file_filtering_rules(completed_files: &[CompletedFileInfo]) -> Vec<CompletedFileInfo> {
let contains_metrics = completed_files.iter().any(|f| f.file_type == "metric");
let contains_dashboards = completed_files.iter().any(|f| f.file_type == "dashboard");
// --- START: New Helper Function ---
// Fetches dashboard details from the database
async fn fetch_dashboard_details(
id: Uuid,
conn: &mut AsyncPgConnection,
) -> Result<Option<CompletedFileInfo>> {
match dashboard_files::table
.filter(dashboard_files::id.eq(id))
// Select id, name, and version_history
.select((dashboard_files::id, dashboard_files::name, dashboard_files::version_history))
// Adjust expected tuple type
.first::<(Uuid, String, VersionHistory)>(conn)
.await
.optional()? // Handle case where dashboard might not be found
{
Some((db_id, name, version_history)) => {
if let Some(latest_version) = version_history.get_latest_version() {
// Extract dashboard_yml content and serialize it back to string
if let VersionContent::DashboardYml(dashboard_yml) = &latest_version.content {
match serde_json::to_string(dashboard_yml) { // Serialize to JSON string
Ok(content_string) => {
Ok(Some(CompletedFileInfo {
id: db_id.to_string(),
file_type: "dashboard".to_string(),
file_name: name,
version_number: latest_version.version_number,
content: content_string, // Use serialized string
}))
}
Err(e) => {
tracing::error!("Failed to serialize DashboardYml content for {}: {}", db_id, e);
Ok(None) // Treat serialization error as unable to fetch details
}
}
} else {
// Content was not DashboardYml, unexpected
tracing::warn!("Latest version content for dashboard {} is not DashboardYml type.", db_id);
Ok(None)
}
} else {
// Version history exists but has no versions
tracing::warn!("Version history for dashboard {} has no versions.", db_id);
Ok(None)
}
}
None => Ok(None), // Dashboard not found in the table
}
}
// --- END: New Helper Function ---
if contains_dashboards {
completed_files.iter().filter(|f| f.file_type == "dashboard").cloned().collect()
// Helper function to encapsulate filtering rules - NOW ASYNC
async fn apply_file_filtering_rules(
completed_files_this_turn: &[CompletedFileInfo],
context_dashboard_id: Option<Uuid>,
pool: &PgPool, // Pass pool
) -> Result<Vec<CompletedFileInfo>> {
// Return Result
let metrics_this_turn: Vec<_> = completed_files_this_turn
.iter()
.filter(|f| f.file_type == "metric")
.cloned()
.collect();
let dashboards_this_turn: Vec<_> = completed_files_this_turn
.iter()
.filter(|f| f.file_type == "dashboard")
.cloned()
.collect();
match context_dashboard_id {
// --- Context Exists ---
Some(ctx_id) => {
// Fetch context dashboard details once upfront
let mut conn = pool.get().await?;
let context_dashboard_info_opt = fetch_dashboard_details(ctx_id, &mut conn).await?;
// If context dashboard couldn't be fetched (e.g., deleted), treat as no context
if context_dashboard_info_opt.is_none() {
tracing::warn!(
"Context dashboard ID {} not found, falling back to current turn logic.",
ctx_id
);
return process_current_turn_files(&metrics_this_turn, &dashboards_this_turn);
}
let context_dashboard_info = context_dashboard_info_opt.unwrap(); // Safe unwrap due to check above
// Case 2 Check: Only context metrics modified, no new dashboards?
if dashboards_this_turn.is_empty() && !metrics_this_turn.is_empty() {
// Parse context dashboard to see if *all* modified metrics belong to it
let mut all_metrics_belong_to_context = true;
let context_metric_ids = match DashboardYml::new(
context_dashboard_info.content.clone(),
) {
Ok(yml) => yml
.rows
.iter()
.flat_map(|r| r.items.iter().map(|i| i.id))
.collect::<HashSet<Uuid>>(),
Err(e) => {
tracing::warn!("Failed to parse context dashboard {} for Case 2 check: {}. Assuming metrics might not belong.", ctx_id, e);
all_metrics_belong_to_context = false; // Cannot confirm, assume they don't all belong
HashSet::new()
}
};
if all_metrics_belong_to_context {
// Only proceed if context dashboard parsed
for metric in &metrics_this_turn {
if let Ok(metric_uuid) = Uuid::parse_str(&metric.id) {
if !context_metric_ids.contains(&metric_uuid) {
all_metrics_belong_to_context = false;
break;
}
}
}
}
// If all modified metrics seem to belong to the context dashboard, return only it.
if all_metrics_belong_to_context {
tracing::debug!(
"Context dashboard {} exists, only metrics belonging to it were modified. Returning context dashboard.",
ctx_id
);
return Ok(vec![context_dashboard_info]);
}
// If not all metrics belong, fall through to Case 3 logic below
}
// Case 1 & 3: New dashboard created OR complex state (metrics modified don't all belong to context).
tracing::debug!(
"New dashboard created or metrics modified don't all belong to context dashboard {}. Checking for modified context metrics.",
ctx_id
);
// Check if any metric modified *this turn* belongs to the context dashboard.
let mut modified_context_metric_this_turn = false;
let context_metric_ids = match DashboardYml::new(context_dashboard_info.content.clone())
{
Ok(yml) => yml
.rows
.iter()
.flat_map(|r| r.items.iter().map(|i| i.id))
.collect::<HashSet<Uuid>>(),
Err(e) => {
tracing::warn!("Failed to parse context dashboard {} for Case 3 check: {}. Assuming no context metrics modified.", ctx_id, e);
HashSet::new() // Assume no overlap if parsing fails
}
};
for metric in &metrics_this_turn {
if let Ok(metric_uuid) = Uuid::parse_str(&metric.id) {
if context_metric_ids.contains(&metric_uuid) {
modified_context_metric_this_turn = true;
break;
}
}
}
// If a context metric was modified AND other assets exist this turn...
if modified_context_metric_this_turn {
tracing::debug!("Context metric modified alongside other assets. Combining context dashboard with current turn processing.");
// Process current turn files
let new_filtered_assets =
process_current_turn_files(&metrics_this_turn, &dashboards_this_turn)?;
// Return context dashboard first, then the processed new assets
Ok(vec![context_dashboard_info]
.into_iter()
.chain(new_filtered_assets.into_iter())
.collect())
} else {
// No context metric modified, or context parsing failed. Process current turn only.
tracing::debug!("No context metric modified (or context parse failed). Processing current turn files only.");
process_current_turn_files(&metrics_this_turn, &dashboards_this_turn)
}
}
// --- No Context ---
None => {
// Case 1 (No context): Process current turn's files normally.
tracing::debug!("No context dashboard ID found. Processing current turn files only.");
process_current_turn_files(&metrics_this_turn, &dashboards_this_turn)
}
}
}
// Helper for the previous filtering logic (refactored, kept synchronous as it doesn't do IO)
fn process_current_turn_files(
metrics: &[CompletedFileInfo],
dashboards: &[CompletedFileInfo],
) -> Result<Vec<CompletedFileInfo>> {
// Return Result for consistency
let contains_metrics = !metrics.is_empty();
let contains_dashboards = !dashboards.is_empty();
if contains_metrics && contains_dashboards {
// Parse dashboards, find referenced metrics, filter unreferenced, combine
let mut metric_uuids = HashSet::new();
for metric in metrics {
if let Ok(uuid) = Uuid::parse_str(&metric.id) {
metric_uuids.insert(uuid);
}
}
let mut referenced_metric_uuids = HashSet::new();
for dashboard_info in dashboards {
match DashboardYml::new(dashboard_info.content.clone()) {
Ok(dashboard_yml) => {
for row in dashboard_yml.rows {
for item in row.items {
referenced_metric_uuids.insert(item.id);
}
}
}
Err(e) => {
tracing::warn!(
"Failed to parse dashboard YML content for ID {} during current turn processing: {}. Skipping for metric reference check.",
dashboard_info.id,
e
);
}
}
}
let unreferenced_metric_uuids: HashSet<_> = metric_uuids
.difference(&referenced_metric_uuids)
.copied()
.collect();
if unreferenced_metric_uuids.is_empty() {
// All metrics referenced, return only dashboards
Ok(dashboards.to_vec())
} else {
let unreferenced_metrics: Vec<_> = metrics
.iter()
.filter(|m| {
Uuid::parse_str(&m.id)
.map_or(false, |uuid| unreferenced_metric_uuids.contains(&uuid))
})
.cloned()
.collect();
// Return unreferenced metrics first, then dashboards
let mut combined = unreferenced_metrics;
combined.extend(dashboards.iter().cloned());
Ok(combined)
}
} else if contains_dashboards {
// Only dashboards
Ok(dashboards.to_vec())
} else if contains_metrics {
completed_files.iter().filter(|f| f.file_type == "metric").cloned().collect()
// Only metrics
Ok(metrics.to_vec())
} else {
vec![]
// Neither
Ok(vec![])
}
}