mirror of https://github.com/buster-so/buster.git
associate files correclty.
This commit is contained in:
parent
dec98bacfe
commit
26a9397228
|
@ -1,12 +1,15 @@
|
|||
use dashmap::DashMap;
|
||||
use middleware::AuthenticatedUser;
|
||||
use std::{collections::HashMap, time::Instant};
|
||||
use dashmap::DashMap;
|
||||
|
||||
use agents::{
|
||||
tools::{file_tools::{
|
||||
common::ModifyFilesOutput, create_dashboards::CreateDashboardFilesOutput,
|
||||
create_metrics::CreateMetricFilesOutput, search_data_catalog::SearchDataCatalogOutput,
|
||||
}, planning_tools::CreatePlanOutput},
|
||||
tools::{
|
||||
file_tools::{
|
||||
common::ModifyFilesOutput, create_dashboards::CreateDashboardFilesOutput,
|
||||
create_metrics::CreateMetricFilesOutput, search_data_catalog::SearchDataCatalogOutput,
|
||||
},
|
||||
planning_tools::CreatePlanOutput,
|
||||
},
|
||||
AgentExt, AgentMessage, AgentThread, BusterSuperAgent,
|
||||
};
|
||||
|
||||
|
@ -83,10 +86,12 @@ impl ChunkTracker {
|
|||
let mut delta_to_return = String::new();
|
||||
|
||||
{
|
||||
self.chunks.entry(chunk_id.clone()).or_insert_with(|| ChunkState {
|
||||
complete_text: String::new(),
|
||||
last_seen_content: String::new(),
|
||||
});
|
||||
self.chunks
|
||||
.entry(chunk_id.clone())
|
||||
.or_insert_with(|| ChunkState {
|
||||
complete_text: String::new(),
|
||||
last_seen_content: String::new(),
|
||||
});
|
||||
|
||||
// Now that we've initialized the entry if needed, get mutable access to update it
|
||||
if let Some(mut entry) = self.chunks.get_mut(&chunk_id) {
|
||||
|
@ -122,7 +127,9 @@ impl ChunkTracker {
|
|||
}
|
||||
|
||||
pub fn get_complete_text(&self, chunk_id: String) -> Option<String> {
|
||||
self.chunks.get(&chunk_id).map(|state| state.complete_text.clone())
|
||||
self.chunks
|
||||
.get(&chunk_id)
|
||||
.map(|state| state.complete_text.clone())
|
||||
}
|
||||
|
||||
pub fn clear_chunk(&self, chunk_id: String) {
|
||||
|
@ -235,12 +242,17 @@ pub async fn post_chat_handler(
|
|||
// Only store completed messages in raw_llm_messages
|
||||
match &msg {
|
||||
AgentMessage::Assistant {
|
||||
progress, content, id, ..
|
||||
progress,
|
||||
content,
|
||||
id,
|
||||
..
|
||||
} => {
|
||||
// Store chunks in the tracker to ensure deduplication
|
||||
if let Some(content_str) = content {
|
||||
// Use message ID as chunk ID, or generate a consistent one if missing
|
||||
let chunk_id = id.clone().unwrap_or_else(|| "assistant_message".to_string());
|
||||
let chunk_id = id
|
||||
.clone()
|
||||
.unwrap_or_else(|| "assistant_message".to_string());
|
||||
// Add to chunk tracker to handle deduplication
|
||||
chunk_tracker.add_chunk(chunk_id.clone(), content_str.clone());
|
||||
}
|
||||
|
@ -248,10 +260,13 @@ pub async fn post_chat_handler(
|
|||
if matches!(progress, MessageProgress::Complete) {
|
||||
if let Some(content_str) = content {
|
||||
// Use message ID as chunk ID, or generate a consistent one if missing
|
||||
let chunk_id = id.clone().unwrap_or_else(|| "assistant_message".to_string());
|
||||
let chunk_id = id
|
||||
.clone()
|
||||
.unwrap_or_else(|| "assistant_message".to_string());
|
||||
|
||||
// Get the complete deduplicated text from the chunk tracker
|
||||
let complete_text = chunk_tracker.get_complete_text(chunk_id.clone())
|
||||
let complete_text = chunk_tracker
|
||||
.get_complete_text(chunk_id.clone())
|
||||
.unwrap_or_else(|| content_str.clone());
|
||||
|
||||
// Create a new message with the deduplicated content
|
||||
|
@ -285,7 +300,9 @@ pub async fn post_chat_handler(
|
|||
}
|
||||
|
||||
// Always transform the message
|
||||
match transform_message(&chat_id, &message_id, msg, tx.as_ref(), &chunk_tracker).await {
|
||||
match transform_message(&chat_id, &message_id, msg, tx.as_ref(), &chunk_tracker)
|
||||
.await
|
||||
{
|
||||
Ok(containers) => {
|
||||
// Store all transformed containers
|
||||
for (container, _) in containers.clone() {
|
||||
|
@ -369,7 +386,7 @@ pub async fn post_chat_handler(
|
|||
final_reasoning_message: format!("Reasoned for {} seconds", reasoning_duration),
|
||||
title: title.title.clone().unwrap_or_default(),
|
||||
raw_llm_messages: serde_json::to_value(&raw_llm_messages)?,
|
||||
feedback: None
|
||||
feedback: None,
|
||||
};
|
||||
|
||||
let mut conn = get_pg_pool().get().await?;
|
||||
|
@ -435,7 +452,8 @@ pub async fn post_chat_handler(
|
|||
fn prepare_final_message_state(containers: &[BusterContainer]) -> Result<(Vec<Value>, Vec<Value>)> {
|
||||
let mut response_messages = Vec::new();
|
||||
// Use a Vec to maintain order, with a HashMap to track latest version of each message
|
||||
let mut reasoning_map: std::collections::HashMap<String, (usize, Value)> = std::collections::HashMap::new();
|
||||
let mut reasoning_map: std::collections::HashMap<String, (usize, Value)> =
|
||||
std::collections::HashMap::new();
|
||||
let mut reasoning_order = Vec::new();
|
||||
|
||||
for container in containers {
|
||||
|
@ -522,8 +540,14 @@ async fn process_completed_files(
|
|||
let mut processed_file_ids = std::collections::HashSet::new();
|
||||
|
||||
for msg in messages {
|
||||
if let Ok(containers) =
|
||||
transform_message(&message.chat_id, &message.id, msg.clone(), None, chunk_tracker).await
|
||||
if let Ok(containers) = transform_message(
|
||||
&message.chat_id,
|
||||
&message.id,
|
||||
msg.clone(),
|
||||
None,
|
||||
chunk_tracker,
|
||||
)
|
||||
.await
|
||||
{
|
||||
transformed_messages.extend(containers);
|
||||
}
|
||||
|
@ -540,24 +564,22 @@ async fn process_completed_files(
|
|||
continue;
|
||||
}
|
||||
|
||||
if let Some(file_content) = file.files.get(file_id) {
|
||||
if let Some(_) = file.files.get(file_id) {
|
||||
// Only process files that have completed reasoning
|
||||
if file.status == "completed" {
|
||||
// Create message-to-file association
|
||||
let message_to_file = MessageToFile {
|
||||
id: Uuid::new_v4(),
|
||||
message_id: message.id,
|
||||
file_id: Uuid::parse_str(&file_id)?,
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
deleted_at: None,
|
||||
};
|
||||
// Create message-to-file association
|
||||
let message_to_file = MessageToFile {
|
||||
id: Uuid::new_v4(),
|
||||
message_id: message.id,
|
||||
file_id: Uuid::parse_str(&file_id)?,
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
deleted_at: None,
|
||||
};
|
||||
|
||||
diesel::insert_into(messages_to_files::table)
|
||||
.values(&message_to_file)
|
||||
.execute(conn)
|
||||
.await?;
|
||||
}
|
||||
diesel::insert_into(messages_to_files::table)
|
||||
.values(&message_to_file)
|
||||
.execute(conn)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -725,7 +747,6 @@ pub async fn transform_message(
|
|||
tx: Option<&mpsc::Sender<Result<(BusterContainer, ThreadEvent)>>>,
|
||||
tracker: &ChunkTracker,
|
||||
) -> Result<Vec<(BusterContainer, ThreadEvent)>> {
|
||||
|
||||
match message {
|
||||
AgentMessage::Assistant {
|
||||
id,
|
||||
|
@ -1022,7 +1043,6 @@ fn transform_tool_message(
|
|||
}
|
||||
|
||||
fn tool_create_plan(id: String, content: String) -> Result<Vec<BusterReasoningMessage>> {
|
||||
|
||||
let plan_markdown = match serde_json::from_str::<CreatePlanOutput>(&content) {
|
||||
Ok(result) => result.plan_markdown,
|
||||
Err(e) => {
|
||||
|
@ -1046,7 +1066,6 @@ fn tool_create_plan(id: String, content: String) -> Result<Vec<BusterReasoningMe
|
|||
|
||||
// Update tool_create_metrics to require ID
|
||||
fn tool_create_metrics(id: String, content: String) -> Result<Vec<BusterReasoningMessage>> {
|
||||
|
||||
// Parse the CreateMetricFilesOutput from content
|
||||
let create_metrics_result = match serde_json::from_str::<CreateMetricFilesOutput>(&content) {
|
||||
Ok(result) => result,
|
||||
|
@ -1102,7 +1121,6 @@ fn tool_create_metrics(id: String, content: String) -> Result<Vec<BusterReasonin
|
|||
|
||||
// Update tool_modify_metrics to require ID
|
||||
fn tool_modify_metrics(id: String, content: String) -> Result<Vec<BusterReasoningMessage>> {
|
||||
|
||||
// Parse the ModifyFilesOutput from content
|
||||
let modify_metrics_result = match serde_json::from_str::<ModifyFilesOutput>(&content) {
|
||||
Ok(result) => result,
|
||||
|
@ -1158,7 +1176,6 @@ fn tool_modify_metrics(id: String, content: String) -> Result<Vec<BusterReasonin
|
|||
|
||||
// Update tool_create_dashboards to require ID
|
||||
fn tool_create_dashboards(id: String, content: String) -> Result<Vec<BusterReasoningMessage>> {
|
||||
|
||||
// Parse the CreateDashboardFilesOutput from content
|
||||
let create_dashboards_result =
|
||||
match serde_json::from_str::<CreateDashboardFilesOutput>(&content) {
|
||||
|
@ -1444,7 +1461,8 @@ fn transform_assistant_tool_message(
|
|||
let mut updated_files = std::collections::HashMap::new();
|
||||
|
||||
for (file_id, file_content) in file.files.iter() {
|
||||
let chunk_id = format!("{}_{}", file.id, file_content.file_name);
|
||||
let chunk_id =
|
||||
format!("{}_{}", file.id, file_content.file_name);
|
||||
let complete_text = tracker
|
||||
.get_complete_text(chunk_id.clone())
|
||||
.unwrap_or_else(|| {
|
||||
|
@ -1471,10 +1489,12 @@ fn transform_assistant_tool_message(
|
|||
let mut updated_files = std::collections::HashMap::new();
|
||||
|
||||
for (file_id, file_content) in file.files.iter() {
|
||||
let chunk_id = format!("{}_{}", file.id, file_content.file_name);
|
||||
let chunk_id =
|
||||
format!("{}_{}", file.id, file_content.file_name);
|
||||
|
||||
if let Some(chunk) = &file_content.file.text_chunk {
|
||||
let delta = tracker.add_chunk(chunk_id.clone(), chunk.clone());
|
||||
let delta =
|
||||
tracker.add_chunk(chunk_id.clone(), chunk.clone());
|
||||
|
||||
if !delta.is_empty() {
|
||||
let mut updated_content = file_content.clone();
|
||||
|
|
Loading…
Reference in New Issue