mirror of https://github.com/buster-so/buster.git
associate files correclty.
This commit is contained in:
parent
dec98bacfe
commit
26a9397228
|
@ -1,12 +1,15 @@
|
||||||
|
use dashmap::DashMap;
|
||||||
use middleware::AuthenticatedUser;
|
use middleware::AuthenticatedUser;
|
||||||
use std::{collections::HashMap, time::Instant};
|
use std::{collections::HashMap, time::Instant};
|
||||||
use dashmap::DashMap;
|
|
||||||
|
|
||||||
use agents::{
|
use agents::{
|
||||||
tools::{file_tools::{
|
tools::{
|
||||||
|
file_tools::{
|
||||||
common::ModifyFilesOutput, create_dashboards::CreateDashboardFilesOutput,
|
common::ModifyFilesOutput, create_dashboards::CreateDashboardFilesOutput,
|
||||||
create_metrics::CreateMetricFilesOutput, search_data_catalog::SearchDataCatalogOutput,
|
create_metrics::CreateMetricFilesOutput, search_data_catalog::SearchDataCatalogOutput,
|
||||||
}, planning_tools::CreatePlanOutput},
|
},
|
||||||
|
planning_tools::CreatePlanOutput,
|
||||||
|
},
|
||||||
AgentExt, AgentMessage, AgentThread, BusterSuperAgent,
|
AgentExt, AgentMessage, AgentThread, BusterSuperAgent,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -83,7 +86,9 @@ impl ChunkTracker {
|
||||||
let mut delta_to_return = String::new();
|
let mut delta_to_return = String::new();
|
||||||
|
|
||||||
{
|
{
|
||||||
self.chunks.entry(chunk_id.clone()).or_insert_with(|| ChunkState {
|
self.chunks
|
||||||
|
.entry(chunk_id.clone())
|
||||||
|
.or_insert_with(|| ChunkState {
|
||||||
complete_text: String::new(),
|
complete_text: String::new(),
|
||||||
last_seen_content: String::new(),
|
last_seen_content: String::new(),
|
||||||
});
|
});
|
||||||
|
@ -122,7 +127,9 @@ impl ChunkTracker {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_complete_text(&self, chunk_id: String) -> Option<String> {
|
pub fn get_complete_text(&self, chunk_id: String) -> Option<String> {
|
||||||
self.chunks.get(&chunk_id).map(|state| state.complete_text.clone())
|
self.chunks
|
||||||
|
.get(&chunk_id)
|
||||||
|
.map(|state| state.complete_text.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn clear_chunk(&self, chunk_id: String) {
|
pub fn clear_chunk(&self, chunk_id: String) {
|
||||||
|
@ -235,12 +242,17 @@ pub async fn post_chat_handler(
|
||||||
// Only store completed messages in raw_llm_messages
|
// Only store completed messages in raw_llm_messages
|
||||||
match &msg {
|
match &msg {
|
||||||
AgentMessage::Assistant {
|
AgentMessage::Assistant {
|
||||||
progress, content, id, ..
|
progress,
|
||||||
|
content,
|
||||||
|
id,
|
||||||
|
..
|
||||||
} => {
|
} => {
|
||||||
// Store chunks in the tracker to ensure deduplication
|
// Store chunks in the tracker to ensure deduplication
|
||||||
if let Some(content_str) = content {
|
if let Some(content_str) = content {
|
||||||
// Use message ID as chunk ID, or generate a consistent one if missing
|
// Use message ID as chunk ID, or generate a consistent one if missing
|
||||||
let chunk_id = id.clone().unwrap_or_else(|| "assistant_message".to_string());
|
let chunk_id = id
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| "assistant_message".to_string());
|
||||||
// Add to chunk tracker to handle deduplication
|
// Add to chunk tracker to handle deduplication
|
||||||
chunk_tracker.add_chunk(chunk_id.clone(), content_str.clone());
|
chunk_tracker.add_chunk(chunk_id.clone(), content_str.clone());
|
||||||
}
|
}
|
||||||
|
@ -248,10 +260,13 @@ pub async fn post_chat_handler(
|
||||||
if matches!(progress, MessageProgress::Complete) {
|
if matches!(progress, MessageProgress::Complete) {
|
||||||
if let Some(content_str) = content {
|
if let Some(content_str) = content {
|
||||||
// Use message ID as chunk ID, or generate a consistent one if missing
|
// Use message ID as chunk ID, or generate a consistent one if missing
|
||||||
let chunk_id = id.clone().unwrap_or_else(|| "assistant_message".to_string());
|
let chunk_id = id
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| "assistant_message".to_string());
|
||||||
|
|
||||||
// Get the complete deduplicated text from the chunk tracker
|
// Get the complete deduplicated text from the chunk tracker
|
||||||
let complete_text = chunk_tracker.get_complete_text(chunk_id.clone())
|
let complete_text = chunk_tracker
|
||||||
|
.get_complete_text(chunk_id.clone())
|
||||||
.unwrap_or_else(|| content_str.clone());
|
.unwrap_or_else(|| content_str.clone());
|
||||||
|
|
||||||
// Create a new message with the deduplicated content
|
// Create a new message with the deduplicated content
|
||||||
|
@ -285,7 +300,9 @@ pub async fn post_chat_handler(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Always transform the message
|
// Always transform the message
|
||||||
match transform_message(&chat_id, &message_id, msg, tx.as_ref(), &chunk_tracker).await {
|
match transform_message(&chat_id, &message_id, msg, tx.as_ref(), &chunk_tracker)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(containers) => {
|
Ok(containers) => {
|
||||||
// Store all transformed containers
|
// Store all transformed containers
|
||||||
for (container, _) in containers.clone() {
|
for (container, _) in containers.clone() {
|
||||||
|
@ -369,7 +386,7 @@ pub async fn post_chat_handler(
|
||||||
final_reasoning_message: format!("Reasoned for {} seconds", reasoning_duration),
|
final_reasoning_message: format!("Reasoned for {} seconds", reasoning_duration),
|
||||||
title: title.title.clone().unwrap_or_default(),
|
title: title.title.clone().unwrap_or_default(),
|
||||||
raw_llm_messages: serde_json::to_value(&raw_llm_messages)?,
|
raw_llm_messages: serde_json::to_value(&raw_llm_messages)?,
|
||||||
feedback: None
|
feedback: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut conn = get_pg_pool().get().await?;
|
let mut conn = get_pg_pool().get().await?;
|
||||||
|
@ -435,7 +452,8 @@ pub async fn post_chat_handler(
|
||||||
fn prepare_final_message_state(containers: &[BusterContainer]) -> Result<(Vec<Value>, Vec<Value>)> {
|
fn prepare_final_message_state(containers: &[BusterContainer]) -> Result<(Vec<Value>, Vec<Value>)> {
|
||||||
let mut response_messages = Vec::new();
|
let mut response_messages = Vec::new();
|
||||||
// Use a Vec to maintain order, with a HashMap to track latest version of each message
|
// Use a Vec to maintain order, with a HashMap to track latest version of each message
|
||||||
let mut reasoning_map: std::collections::HashMap<String, (usize, Value)> = std::collections::HashMap::new();
|
let mut reasoning_map: std::collections::HashMap<String, (usize, Value)> =
|
||||||
|
std::collections::HashMap::new();
|
||||||
let mut reasoning_order = Vec::new();
|
let mut reasoning_order = Vec::new();
|
||||||
|
|
||||||
for container in containers {
|
for container in containers {
|
||||||
|
@ -522,8 +540,14 @@ async fn process_completed_files(
|
||||||
let mut processed_file_ids = std::collections::HashSet::new();
|
let mut processed_file_ids = std::collections::HashSet::new();
|
||||||
|
|
||||||
for msg in messages {
|
for msg in messages {
|
||||||
if let Ok(containers) =
|
if let Ok(containers) = transform_message(
|
||||||
transform_message(&message.chat_id, &message.id, msg.clone(), None, chunk_tracker).await
|
&message.chat_id,
|
||||||
|
&message.id,
|
||||||
|
msg.clone(),
|
||||||
|
None,
|
||||||
|
chunk_tracker,
|
||||||
|
)
|
||||||
|
.await
|
||||||
{
|
{
|
||||||
transformed_messages.extend(containers);
|
transformed_messages.extend(containers);
|
||||||
}
|
}
|
||||||
|
@ -540,9 +564,8 @@ async fn process_completed_files(
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(file_content) = file.files.get(file_id) {
|
if let Some(_) = file.files.get(file_id) {
|
||||||
// Only process files that have completed reasoning
|
// Only process files that have completed reasoning
|
||||||
if file.status == "completed" {
|
|
||||||
// Create message-to-file association
|
// Create message-to-file association
|
||||||
let message_to_file = MessageToFile {
|
let message_to_file = MessageToFile {
|
||||||
id: Uuid::new_v4(),
|
id: Uuid::new_v4(),
|
||||||
|
@ -560,7 +583,6 @@ async fn process_completed_files(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
_ => (),
|
_ => (),
|
||||||
},
|
},
|
||||||
_ => (),
|
_ => (),
|
||||||
|
@ -725,7 +747,6 @@ pub async fn transform_message(
|
||||||
tx: Option<&mpsc::Sender<Result<(BusterContainer, ThreadEvent)>>>,
|
tx: Option<&mpsc::Sender<Result<(BusterContainer, ThreadEvent)>>>,
|
||||||
tracker: &ChunkTracker,
|
tracker: &ChunkTracker,
|
||||||
) -> Result<Vec<(BusterContainer, ThreadEvent)>> {
|
) -> Result<Vec<(BusterContainer, ThreadEvent)>> {
|
||||||
|
|
||||||
match message {
|
match message {
|
||||||
AgentMessage::Assistant {
|
AgentMessage::Assistant {
|
||||||
id,
|
id,
|
||||||
|
@ -1022,7 +1043,6 @@ fn transform_tool_message(
|
||||||
}
|
}
|
||||||
|
|
||||||
fn tool_create_plan(id: String, content: String) -> Result<Vec<BusterReasoningMessage>> {
|
fn tool_create_plan(id: String, content: String) -> Result<Vec<BusterReasoningMessage>> {
|
||||||
|
|
||||||
let plan_markdown = match serde_json::from_str::<CreatePlanOutput>(&content) {
|
let plan_markdown = match serde_json::from_str::<CreatePlanOutput>(&content) {
|
||||||
Ok(result) => result.plan_markdown,
|
Ok(result) => result.plan_markdown,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
@ -1046,7 +1066,6 @@ fn tool_create_plan(id: String, content: String) -> Result<Vec<BusterReasoningMe
|
||||||
|
|
||||||
// Update tool_create_metrics to require ID
|
// Update tool_create_metrics to require ID
|
||||||
fn tool_create_metrics(id: String, content: String) -> Result<Vec<BusterReasoningMessage>> {
|
fn tool_create_metrics(id: String, content: String) -> Result<Vec<BusterReasoningMessage>> {
|
||||||
|
|
||||||
// Parse the CreateMetricFilesOutput from content
|
// Parse the CreateMetricFilesOutput from content
|
||||||
let create_metrics_result = match serde_json::from_str::<CreateMetricFilesOutput>(&content) {
|
let create_metrics_result = match serde_json::from_str::<CreateMetricFilesOutput>(&content) {
|
||||||
Ok(result) => result,
|
Ok(result) => result,
|
||||||
|
@ -1102,7 +1121,6 @@ fn tool_create_metrics(id: String, content: String) -> Result<Vec<BusterReasonin
|
||||||
|
|
||||||
// Update tool_modify_metrics to require ID
|
// Update tool_modify_metrics to require ID
|
||||||
fn tool_modify_metrics(id: String, content: String) -> Result<Vec<BusterReasoningMessage>> {
|
fn tool_modify_metrics(id: String, content: String) -> Result<Vec<BusterReasoningMessage>> {
|
||||||
|
|
||||||
// Parse the ModifyFilesOutput from content
|
// Parse the ModifyFilesOutput from content
|
||||||
let modify_metrics_result = match serde_json::from_str::<ModifyFilesOutput>(&content) {
|
let modify_metrics_result = match serde_json::from_str::<ModifyFilesOutput>(&content) {
|
||||||
Ok(result) => result,
|
Ok(result) => result,
|
||||||
|
@ -1158,7 +1176,6 @@ fn tool_modify_metrics(id: String, content: String) -> Result<Vec<BusterReasonin
|
||||||
|
|
||||||
// Update tool_create_dashboards to require ID
|
// Update tool_create_dashboards to require ID
|
||||||
fn tool_create_dashboards(id: String, content: String) -> Result<Vec<BusterReasoningMessage>> {
|
fn tool_create_dashboards(id: String, content: String) -> Result<Vec<BusterReasoningMessage>> {
|
||||||
|
|
||||||
// Parse the CreateDashboardFilesOutput from content
|
// Parse the CreateDashboardFilesOutput from content
|
||||||
let create_dashboards_result =
|
let create_dashboards_result =
|
||||||
match serde_json::from_str::<CreateDashboardFilesOutput>(&content) {
|
match serde_json::from_str::<CreateDashboardFilesOutput>(&content) {
|
||||||
|
@ -1444,7 +1461,8 @@ fn transform_assistant_tool_message(
|
||||||
let mut updated_files = std::collections::HashMap::new();
|
let mut updated_files = std::collections::HashMap::new();
|
||||||
|
|
||||||
for (file_id, file_content) in file.files.iter() {
|
for (file_id, file_content) in file.files.iter() {
|
||||||
let chunk_id = format!("{}_{}", file.id, file_content.file_name);
|
let chunk_id =
|
||||||
|
format!("{}_{}", file.id, file_content.file_name);
|
||||||
let complete_text = tracker
|
let complete_text = tracker
|
||||||
.get_complete_text(chunk_id.clone())
|
.get_complete_text(chunk_id.clone())
|
||||||
.unwrap_or_else(|| {
|
.unwrap_or_else(|| {
|
||||||
|
@ -1471,10 +1489,12 @@ fn transform_assistant_tool_message(
|
||||||
let mut updated_files = std::collections::HashMap::new();
|
let mut updated_files = std::collections::HashMap::new();
|
||||||
|
|
||||||
for (file_id, file_content) in file.files.iter() {
|
for (file_id, file_content) in file.files.iter() {
|
||||||
let chunk_id = format!("{}_{}", file.id, file_content.file_name);
|
let chunk_id =
|
||||||
|
format!("{}_{}", file.id, file_content.file_name);
|
||||||
|
|
||||||
if let Some(chunk) = &file_content.file.text_chunk {
|
if let Some(chunk) = &file_content.file.text_chunk {
|
||||||
let delta = tracker.add_chunk(chunk_id.clone(), chunk.clone());
|
let delta =
|
||||||
|
tracker.add_chunk(chunk_id.clone(), chunk.clone());
|
||||||
|
|
||||||
if !delta.is_empty() {
|
if !delta.is_empty() {
|
||||||
let mut updated_content = file_content.clone();
|
let mut updated_content = file_content.clone();
|
||||||
|
|
Loading…
Reference in New Issue