associate files correclty.

This commit is contained in:
dal 2025-03-21 11:36:13 -06:00
parent dec98bacfe
commit 26a9397228
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
1 changed files with 74 additions and 54 deletions

View File

@ -1,12 +1,15 @@
use dashmap::DashMap;
use middleware::AuthenticatedUser;
use std::{collections::HashMap, time::Instant};
use dashmap::DashMap;
use agents::{
tools::{file_tools::{
tools::{
file_tools::{
common::ModifyFilesOutput, create_dashboards::CreateDashboardFilesOutput,
create_metrics::CreateMetricFilesOutput, search_data_catalog::SearchDataCatalogOutput,
}, planning_tools::CreatePlanOutput},
},
planning_tools::CreatePlanOutput,
},
AgentExt, AgentMessage, AgentThread, BusterSuperAgent,
};
@ -83,7 +86,9 @@ impl ChunkTracker {
let mut delta_to_return = String::new();
{
self.chunks.entry(chunk_id.clone()).or_insert_with(|| ChunkState {
self.chunks
.entry(chunk_id.clone())
.or_insert_with(|| ChunkState {
complete_text: String::new(),
last_seen_content: String::new(),
});
@ -122,7 +127,9 @@ impl ChunkTracker {
}
pub fn get_complete_text(&self, chunk_id: String) -> Option<String> {
self.chunks.get(&chunk_id).map(|state| state.complete_text.clone())
self.chunks
.get(&chunk_id)
.map(|state| state.complete_text.clone())
}
pub fn clear_chunk(&self, chunk_id: String) {
@ -235,12 +242,17 @@ pub async fn post_chat_handler(
// Only store completed messages in raw_llm_messages
match &msg {
AgentMessage::Assistant {
progress, content, id, ..
progress,
content,
id,
..
} => {
// Store chunks in the tracker to ensure deduplication
if let Some(content_str) = content {
// Use message ID as chunk ID, or generate a consistent one if missing
let chunk_id = id.clone().unwrap_or_else(|| "assistant_message".to_string());
let chunk_id = id
.clone()
.unwrap_or_else(|| "assistant_message".to_string());
// Add to chunk tracker to handle deduplication
chunk_tracker.add_chunk(chunk_id.clone(), content_str.clone());
}
@ -248,10 +260,13 @@ pub async fn post_chat_handler(
if matches!(progress, MessageProgress::Complete) {
if let Some(content_str) = content {
// Use message ID as chunk ID, or generate a consistent one if missing
let chunk_id = id.clone().unwrap_or_else(|| "assistant_message".to_string());
let chunk_id = id
.clone()
.unwrap_or_else(|| "assistant_message".to_string());
// Get the complete deduplicated text from the chunk tracker
let complete_text = chunk_tracker.get_complete_text(chunk_id.clone())
let complete_text = chunk_tracker
.get_complete_text(chunk_id.clone())
.unwrap_or_else(|| content_str.clone());
// Create a new message with the deduplicated content
@ -285,7 +300,9 @@ pub async fn post_chat_handler(
}
// Always transform the message
match transform_message(&chat_id, &message_id, msg, tx.as_ref(), &chunk_tracker).await {
match transform_message(&chat_id, &message_id, msg, tx.as_ref(), &chunk_tracker)
.await
{
Ok(containers) => {
// Store all transformed containers
for (container, _) in containers.clone() {
@ -369,7 +386,7 @@ pub async fn post_chat_handler(
final_reasoning_message: format!("Reasoned for {} seconds", reasoning_duration),
title: title.title.clone().unwrap_or_default(),
raw_llm_messages: serde_json::to_value(&raw_llm_messages)?,
feedback: None
feedback: None,
};
let mut conn = get_pg_pool().get().await?;
@ -435,7 +452,8 @@ pub async fn post_chat_handler(
fn prepare_final_message_state(containers: &[BusterContainer]) -> Result<(Vec<Value>, Vec<Value>)> {
let mut response_messages = Vec::new();
// Use a Vec to maintain order, with a HashMap to track latest version of each message
let mut reasoning_map: std::collections::HashMap<String, (usize, Value)> = std::collections::HashMap::new();
let mut reasoning_map: std::collections::HashMap<String, (usize, Value)> =
std::collections::HashMap::new();
let mut reasoning_order = Vec::new();
for container in containers {
@ -522,8 +540,14 @@ async fn process_completed_files(
let mut processed_file_ids = std::collections::HashSet::new();
for msg in messages {
if let Ok(containers) =
transform_message(&message.chat_id, &message.id, msg.clone(), None, chunk_tracker).await
if let Ok(containers) = transform_message(
&message.chat_id,
&message.id,
msg.clone(),
None,
chunk_tracker,
)
.await
{
transformed_messages.extend(containers);
}
@ -540,9 +564,8 @@ async fn process_completed_files(
continue;
}
if let Some(file_content) = file.files.get(file_id) {
if let Some(_) = file.files.get(file_id) {
// Only process files that have completed reasoning
if file.status == "completed" {
// Create message-to-file association
let message_to_file = MessageToFile {
id: Uuid::new_v4(),
@ -560,7 +583,6 @@ async fn process_completed_files(
}
}
}
}
_ => (),
},
_ => (),
@ -725,7 +747,6 @@ pub async fn transform_message(
tx: Option<&mpsc::Sender<Result<(BusterContainer, ThreadEvent)>>>,
tracker: &ChunkTracker,
) -> Result<Vec<(BusterContainer, ThreadEvent)>> {
match message {
AgentMessage::Assistant {
id,
@ -1022,7 +1043,6 @@ fn transform_tool_message(
}
fn tool_create_plan(id: String, content: String) -> Result<Vec<BusterReasoningMessage>> {
let plan_markdown = match serde_json::from_str::<CreatePlanOutput>(&content) {
Ok(result) => result.plan_markdown,
Err(e) => {
@ -1046,7 +1066,6 @@ fn tool_create_plan(id: String, content: String) -> Result<Vec<BusterReasoningMe
// Update tool_create_metrics to require ID
fn tool_create_metrics(id: String, content: String) -> Result<Vec<BusterReasoningMessage>> {
// Parse the CreateMetricFilesOutput from content
let create_metrics_result = match serde_json::from_str::<CreateMetricFilesOutput>(&content) {
Ok(result) => result,
@ -1102,7 +1121,6 @@ fn tool_create_metrics(id: String, content: String) -> Result<Vec<BusterReasonin
// Update tool_modify_metrics to require ID
fn tool_modify_metrics(id: String, content: String) -> Result<Vec<BusterReasoningMessage>> {
// Parse the ModifyFilesOutput from content
let modify_metrics_result = match serde_json::from_str::<ModifyFilesOutput>(&content) {
Ok(result) => result,
@ -1158,7 +1176,6 @@ fn tool_modify_metrics(id: String, content: String) -> Result<Vec<BusterReasonin
// Update tool_create_dashboards to require ID
fn tool_create_dashboards(id: String, content: String) -> Result<Vec<BusterReasoningMessage>> {
// Parse the CreateDashboardFilesOutput from content
let create_dashboards_result =
match serde_json::from_str::<CreateDashboardFilesOutput>(&content) {
@ -1444,7 +1461,8 @@ fn transform_assistant_tool_message(
let mut updated_files = std::collections::HashMap::new();
for (file_id, file_content) in file.files.iter() {
let chunk_id = format!("{}_{}", file.id, file_content.file_name);
let chunk_id =
format!("{}_{}", file.id, file_content.file_name);
let complete_text = tracker
.get_complete_text(chunk_id.clone())
.unwrap_or_else(|| {
@ -1471,10 +1489,12 @@ fn transform_assistant_tool_message(
let mut updated_files = std::collections::HashMap::new();
for (file_id, file_content) in file.files.iter() {
let chunk_id = format!("{}_{}", file.id, file_content.file_name);
let chunk_id =
format!("{}_{}", file.id, file_content.file_name);
if let Some(chunk) = &file_content.file.text_chunk {
let delta = tracker.add_chunk(chunk_id.clone(), chunk.clone());
let delta =
tracker.add_chunk(chunk_id.clone(), chunk.clone());
if !delta.is_empty() {
let mut updated_content = file_content.clone();