diff --git a/api/libs/agents/src/agent.rs b/api/libs/agents/src/agent.rs index 5a50c7970..1f155b0ba 100644 --- a/api/libs/agents/src/agent.rs +++ b/api/libs/agents/src/agent.rs @@ -329,7 +329,7 @@ impl Agent { model: self.model.clone(), messages: thread.messages.clone(), tools: if tools.is_empty() { None } else { Some(tools) }, - tool_choice: Some(ToolChoice::Required), + tool_choice: Some(ToolChoice::Auto), stream: Some(true), // Enable streaming metadata: Some(Metadata { generation_name: "agent".to_string(), @@ -337,7 +337,6 @@ impl Agent { session_id: thread.id.to_string(), trace_id: thread.id.to_string(), }), - store: Some(true), ..Default::default() }; diff --git a/api/libs/agents/src/agents/buster_super_agent.rs b/api/libs/agents/src/agents/buster_super_agent.rs index fe110c9e6..25059268a 100644 --- a/api/libs/agents/src/agents/buster_super_agent.rs +++ b/api/libs/agents/src/agents/buster_super_agent.rs @@ -104,7 +104,7 @@ impl BusterSuperAgent { HashMap::new(), user_id, session_id, - "manager_agent".to_string(), + "buster_super_agent".to_string(), )); let manager = Self { agent }; @@ -116,7 +116,7 @@ impl BusterSuperAgent { // Create a new agent with the same core properties and shared state/stream let agent = Arc::new(Agent::from_existing( existing_agent, - "manager_agent".to_string(), + "buster_super_agent".to_string(), )); let manager = Self { agent }; manager.load_tools().await?; @@ -127,7 +127,7 @@ impl BusterSuperAgent { &self, thread: &mut AgentThread, ) -> Result>> { - thread.set_developer_message(MANAGER_AGENT_PROMPT.to_string()); + thread.set_developer_message(BUSTER_SUPER_AGENT_PROMPT.to_string()); // Get shutdown receiver let rx = self.stream_process_thread(thread).await?; @@ -141,7 +141,7 @@ impl BusterSuperAgent { } } -const MANAGER_AGENT_PROMPT: &str = r##"### Role & Task +const BUSTER_SUPER_AGENT_PROMPT: &str = r##"### Role & Task You are Buster, an expert analytics and data engineer. Your job is to assess what data is available and then provide fast, accurate answers to analytics questions from non-technical users. You do this by analyzing user requests, searching across a data catalog, and building metrics or dashboards. --- ### Actions Available (Tools) diff --git a/api/libs/handlers/src/chats/context_loaders/chat_context.rs b/api/libs/handlers/src/chats/context_loaders/chat_context.rs index 6d44d3fbc..b7cbe5dcc 100644 --- a/api/libs/handlers/src/chats/context_loaders/chat_context.rs +++ b/api/libs/handlers/src/chats/context_loaders/chat_context.rs @@ -1,4 +1,5 @@ use std::sync::Arc; +use std::collections::HashSet; use anyhow::Result; use async_trait::async_trait; @@ -10,6 +11,7 @@ use database::{ use diesel::prelude::*; use diesel_async::RunQueryDsl; use agents::{Agent, AgentMessage}; +use serde_json::Value; use uuid::Uuid; use super::ContextLoader; @@ -22,6 +24,33 @@ impl ChatContextLoader { pub fn new(chat_id: Uuid) -> Self { Self { chat_id } } + + // Helper function to check for tool usage and set appropriate context + async fn update_context_from_tool_calls(agent: &Arc, message: &AgentMessage) { + if let AgentMessage::Assistant { tool_calls: Some(tool_calls), .. } = message { + for tool_call in tool_calls { + match tool_call.function.name.as_str() { + "search_data_catalog" => { + agent.set_state_value(String::from("data_context"), Value::Bool(true)) + .await; + }, + "create_metrics" | "update_metrics" => { + agent.set_state_value(String::from("metrics_available"), Value::Bool(true)) + .await; + }, + "create_dashboard" | "update_dashboard" => { + agent.set_state_value(String::from("dashboards_available"), Value::Bool(true)) + .await; + }, + name if name.contains("file") || name.contains("read") || name.contains("write") || name.contains("edit") => { + agent.set_state_value(String::from("files_available"), Value::Bool(true)) + .await; + }, + _ => {} + } + } + } + } } #[async_trait] @@ -36,23 +65,33 @@ impl ContextLoader for ChatContextLoader { .first::(&mut conn) .await?; - // Get all messages for the chat - let messages = messages::table + // Get only the most recent message for the chat + let message = messages::table .filter(messages::chat_id.eq(chat.id)) - .order_by(messages::created_at.asc()) - .load::(&mut conn) + .order_by(messages::created_at.desc()) + .first::(&mut conn) .await?; + // Track seen message IDs + let mut seen_ids = HashSet::new(); // Convert messages to AgentMessages let mut agent_messages = Vec::new(); - for message in messages { - // Add user message - agent_messages.push(AgentMessage::user(message.request_message)); - - // Add assistant messages from response - if let Ok(response_messages) = serde_json::from_value::>(message.response_messages) - { - agent_messages.extend(response_messages); + + // Process only the most recent message's raw LLM messages + if let Ok(raw_messages) = serde_json::from_value::>(message.raw_llm_messages) { + // Check each message for tool calls and update context + for agent_message in &raw_messages { + Self::update_context_from_tool_calls(agent, agent_message).await; + + // Only add messages with new IDs + if let Some(id) = agent_message.get_id() { + if seen_ids.insert(id.to_string()) { + agent_messages.push(agent_message.clone()); + } + } else { + // Messages without IDs are always included + agent_messages.push(agent_message.clone()); + } } } diff --git a/api/libs/handlers/src/chats/post_chat_handler.rs b/api/libs/handlers/src/chats/post_chat_handler.rs index b2b0eef02..1a69ecb00 100644 --- a/api/libs/handlers/src/chats/post_chat_handler.rs +++ b/api/libs/handlers/src/chats/post_chat_handler.rs @@ -28,6 +28,7 @@ use crate::chats::{ chat_context::ChatContextLoader, dashboard_context::DashboardContextLoader, metric_context::MetricContextLoader, validate_context_request, ContextLoader, }, + get_chat_handler, streaming_parser::StreamingParser, }; use crate::messages::types::{ChatMessage, ChatUserMessage}; @@ -42,6 +43,7 @@ pub enum ThreadEvent { GeneratingReasoningMessage, GeneratingTitle, InitializeChat, + Completed, } #[derive(Debug, Deserialize, Clone)] @@ -59,12 +61,8 @@ pub async fn post_chat_handler( tx: Option>>, ) -> Result { let reasoning_duration = Instant::now(); - // Validate context request validate_context_request(request.chat_id, request.metric_id, request.dashboard_id)?; - let chat_id = request.chat_id.unwrap_or_else(Uuid::new_v4); - let message_id = request.message_id.unwrap_or_else(Uuid::new_v4); - let user_org_id = match user.attributes.get("organization_id") { Some(Value::String(org_id)) => Uuid::parse_str(&org_id).unwrap_or_default(), _ => { @@ -73,47 +71,15 @@ pub async fn post_chat_handler( } }; + // Initialize chat - either get existing or create new + let (chat_id, message_id, mut chat_with_messages) = + initialize_chat(&request, &user, user_org_id).await?; + tracing::info!( "Starting post_chat_handler for chat_id: {}, message_id: {}, organization_id: {}, user_id: {}", chat_id, message_id, user_org_id, user.id ); - // Create chat - let chat = Chat { - id: chat_id, - title: request.prompt.clone(), - organization_id: user_org_id, - created_by: user.id.clone(), - created_at: Utc::now(), - updated_at: Utc::now(), - deleted_at: None, - updated_by: user.id.clone(), - }; - - let mut chat_with_messages = ChatWithMessages { - id: chat_id, - title: request.prompt.clone(), - is_favorited: false, - messages: vec![ChatMessage { - id: message_id, - request_message: ChatUserMessage { - request: request.prompt.clone(), - sender_id: user.id.clone(), - sender_name: user.name.clone().unwrap_or_default(), - sender_avatar: None, - }, - response_messages: vec![], - reasoning: vec![], - created_at: Utc::now().to_string(), - }], - created_at: Utc::now().to_string(), - updated_at: Utc::now().to_string(), - created_by: user.id.to_string(), - created_by_id: user.id.to_string(), - created_by_name: user.name.clone().unwrap_or_default(), - created_by_avatar: None, - }; - // Send initial chat state to client if let Some(tx) = tx.clone() { tx.send(Ok(( @@ -126,12 +92,6 @@ pub async fn post_chat_handler( // Create database connection let mut conn = get_pg_pool().get().await?; - // Create chat in database - insert_into(chats::table) - .values(&chat) - .execute(&mut conn) - .await?; - // Initialize agent with context if provided let mut initial_messages = vec![]; @@ -162,6 +122,9 @@ pub async fn post_chat_handler( // Add the new user message initial_messages.push(AgentMessage::user(request.prompt.clone())); + // Initialize raw_llm_messages with initial_messages + let mut raw_llm_messages = initial_messages.clone(); + // Initialize the agent thread let mut chat = AgentThread::new(Some(chat_id), user.id, initial_messages); @@ -187,11 +150,30 @@ pub async fn post_chat_handler( while let Ok(message_result) = rx.recv().await { match message_result { Ok(msg) => { - // Store the original message + // Store the original message for file processing all_messages.push(msg.clone()); + // Only store completed messages in raw_llm_messages + match &msg { + AgentMessage::Assistant { progress, .. } => { + if matches!(progress, MessageProgress::Complete) { + raw_llm_messages.push(msg.clone()); + } + } + AgentMessage::Tool { progress, .. } => { + if matches!(progress, MessageProgress::Complete) { + raw_llm_messages.push(msg.clone()); + } + } + // User messages and other types don't have progress, so we store them all + AgentMessage::User { .. } => { + raw_llm_messages.push(msg.clone()); + } + _ => {} // Ignore other message types + } + // Always transform the message - match transform_message(&chat_id, &message_id, msg) { + match transform_message(&chat_id, &message_id, msg, tx.as_ref()).await { Ok((containers, event)) => { // Store all transformed containers for container in containers.clone() { @@ -261,7 +243,7 @@ pub async fn post_chat_handler( reasoning: serde_json::to_value(&reasoning_messages)?, final_reasoning_message, title: title.title.clone().unwrap_or_default(), - raw_llm_messages: Value::Array(vec![]), + raw_llm_messages: serde_json::to_value(&raw_llm_messages)?, }; // Insert message into database @@ -283,6 +265,15 @@ pub async fn post_chat_handler( chat_with_messages.title = title; } + // Send final completed state + if let Some(tx) = &tx { + tx.send(Ok(( + BusterContainer::Chat(chat_with_messages.clone()), + ThreadEvent::Completed, + ))) + .await?; + } + tracing::info!("Completed post_chat_handler for chat_id: {}", chat_id); Ok(chat_with_messages) } @@ -346,15 +337,14 @@ async fn process_completed_files( user_id: &Uuid, ) -> Result<()> { // Transform messages to BusterContainer format - let transformed_messages: Vec = messages - .iter() - .filter_map(|msg| { - transform_message(&message.chat_id, &message.id, msg.clone()) - .ok() - .map(|(containers, _)| containers) - }) - .flatten() - .collect(); + let mut transformed_messages = Vec::new(); + for msg in messages { + if let Ok((containers, _)) = + transform_message(&message.chat_id, &message.id, msg.clone(), None).await + { + transformed_messages.extend(containers); + } + } // Process any completed metric or dashboard files for container in transformed_messages { @@ -575,10 +565,11 @@ pub enum BusterContainer { GeneratingTitle(BusterGeneratingTitle), } -pub fn transform_message( +pub async fn transform_message( chat_id: &Uuid, message_id: &Uuid, message: AgentMessage, + tx: Option<&mpsc::Sender>>, ) -> Result<(Vec, ThreadEvent)> { println!("MESSAGE_STREAM: Transforming message: {:?}", message); @@ -631,13 +622,25 @@ pub fn transform_message( status: Some("completed".to_string()), }); - containers.push(BusterContainer::ReasoningMessage( - BusterReasoningMessageContainer { + let reasoning_container = + BusterContainer::ReasoningMessage(BusterReasoningMessageContainer { reasoning: reasoning_message, chat_id: *chat_id, message_id: *message_id, - }, - )); + }); + + // Send the finished reasoning message separately + if let Some(tx) = tx { + if let Err(e) = tx + .send(Ok(( + reasoning_container, + ThreadEvent::GeneratingReasoningMessage, + ))) + .await + { + tracing::warn!("Failed to send finished reasoning message: {:?}", e); + } + } } return Ok((containers, ThreadEvent::GeneratingResponseMessage)); @@ -1393,3 +1396,78 @@ pub async fn generate_conversation_title( Ok(title) } + +async fn initialize_chat( + request: &ChatCreateNewChat, + user: &User, + user_org_id: Uuid, +) -> Result<(Uuid, Uuid, ChatWithMessages)> { + let message_id = request.message_id.unwrap_or_else(Uuid::new_v4); + + if let Some(existing_chat_id) = request.chat_id { + // Get existing chat - no need to create new chat in DB + let mut existing_chat = get_chat_handler(&existing_chat_id, &user.id).await?; + + // Add new message to existing chat + existing_chat.messages.push(ChatMessage { + id: message_id, + request_message: ChatUserMessage { + request: request.prompt.clone(), + sender_id: user.id.clone(), + sender_name: user.name.clone().unwrap_or_default(), + sender_avatar: None, + }, + response_messages: vec![], + reasoning: vec![], + created_at: Utc::now().to_string(), + }); + + Ok((existing_chat_id, message_id, existing_chat)) + } else { + // Create new chat since we don't have an existing one + let chat_id = Uuid::new_v4(); + let chat = Chat { + id: chat_id, + title: request.prompt.clone(), + organization_id: user_org_id, + created_by: user.id.clone(), + created_at: Utc::now(), + updated_at: Utc::now(), + deleted_at: None, + updated_by: user.id.clone(), + }; + + let chat_with_messages = ChatWithMessages { + id: chat_id, + title: request.prompt.clone(), + is_favorited: false, + messages: vec![ChatMessage { + id: message_id, + request_message: ChatUserMessage { + request: request.prompt.clone(), + sender_id: user.id.clone(), + sender_name: user.name.clone().unwrap_or_default(), + sender_avatar: None, + }, + response_messages: vec![], + reasoning: vec![], + created_at: Utc::now().to_string(), + }], + created_at: Utc::now().to_string(), + updated_at: Utc::now().to_string(), + created_by: user.id.to_string(), + created_by_id: user.id.to_string(), + created_by_name: user.name.clone().unwrap_or_default(), + created_by_avatar: None, + }; + + // Only create new chat in DB if this is a new chat + let mut conn = get_pg_pool().get().await?; + insert_into(chats::table) + .values(&chat) + .execute(&mut conn) + .await?; + + Ok((chat_id, message_id, chat_with_messages)) + } +} diff --git a/api/libs/litellm/src/types.rs b/api/libs/litellm/src/types.rs index 221737177..50769cd47 100644 --- a/api/libs/litellm/src/types.rs +++ b/api/libs/litellm/src/types.rs @@ -250,6 +250,15 @@ impl AgentMessage { Self::User { id, .. } => *id = Some(new_id), } } + + pub fn get_id(&self) -> Option { + match self { + Self::Assistant { id, .. } => id.clone(), + Self::Tool { id, .. } => id.clone(), + Self::Developer { id, .. } => id.clone(), + Self::User { id, .. } => id.clone(), + } + } } #[derive(Debug, Serialize, Deserialize, Clone)] @@ -584,7 +593,9 @@ mod tests { async fn test_chat_completion_request_with_tools() { let request = ChatCompletionRequest { model: "o1".to_string(), - messages: vec![AgentMessage::user("Hello whats the weather in vineyard ut!")], + messages: vec![AgentMessage::user( + "Hello whats the weather in vineyard ut!", + )], max_completion_tokens: Some(100), tools: Some(vec![Tool { tool_type: "function".to_string(), @@ -877,7 +888,9 @@ mod tests { // Test request with function tool let request = ChatCompletionRequest { model: "gpt-4o".to_string(), - messages: vec![AgentMessage::user("What's the weather like in Boston today?")], + messages: vec![AgentMessage::user( + "What's the weather like in Boston today?", + )], tools: Some(vec![Tool { tool_type: "function".to_string(), function: json!({ diff --git a/api/src/routes/ws/threads_and_messages/post_thread.rs b/api/src/routes/ws/threads_and_messages/post_thread.rs index 92d9015fe..2759e5762 100644 --- a/api/src/routes/ws/threads_and_messages/post_thread.rs +++ b/api/src/routes/ws/threads_and_messages/post_thread.rs @@ -46,6 +46,9 @@ pub async fn post_thread( ThreadEvent::InitializeChat => { WsEvent::Threads(WSThreadEvent::InitializeChat) } + ThreadEvent::Completed => { + WsEvent::Threads(WSThreadEvent::Complete) + } }; let response = WsResponseMessage::new_no_user(