diff --git a/api/libs/handlers/src/chats/context_loaders/chat_context.rs b/api/libs/handlers/src/chats/context_loaders/chat_context.rs index 374851b33..e1544ef3c 100644 --- a/api/libs/handlers/src/chats/context_loaders/chat_context.rs +++ b/api/libs/handlers/src/chats/context_loaders/chat_context.rs @@ -1,6 +1,7 @@ -use std::sync::Arc; use std::collections::HashSet; +use std::sync::Arc; +use agents::{Agent, AgentMessage}; use anyhow::Result; use async_trait::async_trait; use database::{ @@ -9,7 +10,6 @@ use database::{ }; use diesel::prelude::*; use diesel_async::RunQueryDsl; -use agents::{Agent, AgentMessage}; use middleware::AuthenticatedUser; use serde_json::Value; use uuid::Uuid; @@ -28,36 +28,56 @@ impl ChatContextLoader { // Helper function to check for tool usage and set appropriate context async fn update_context_from_tool_calls(agent: &Arc, message: &AgentMessage) { // Handle tool calls from assistant messages - if let AgentMessage::Assistant { tool_calls: Some(tool_calls), .. } = message { + if let AgentMessage::Assistant { + tool_calls: Some(tool_calls), + .. + } = message + { for tool_call in tool_calls { match tool_call.function.name.as_str() { "search_data_catalog" => { - agent.set_state_value(String::from("data_context"), Value::Bool(true)) + agent + .set_state_value(String::from("data_context"), Value::Bool(true)) .await; - }, + } "create_metrics" | "update_metrics" => { - agent.set_state_value(String::from("metrics_available"), Value::Bool(true)) + agent + .set_state_value(String::from("metrics_available"), Value::Bool(true)) .await; - }, + } "create_dashboards" | "update_dashboards" => { - agent.set_state_value(String::from("dashboards_available"), Value::Bool(true)) + agent + .set_state_value( + String::from("dashboards_available"), + Value::Bool(true), + ) .await; - }, + } "import_assets" => { // When we see import_assets, we need to check the content in the corresponding tool response // This will be handled separately when processing tool messages - }, - name if name.contains("file") || name.contains("read") || name.contains("write") || name.contains("edit") => { - agent.set_state_value(String::from("files_available"), Value::Bool(true)) + } + name if name.contains("file") + || name.contains("read") + || name.contains("write") + || name.contains("edit") => + { + agent + .set_state_value(String::from("files_available"), Value::Bool(true)) .await; - }, + } _ => {} } } } - + // Handle tool responses - important for import_assets - if let AgentMessage::Tool { name: Some(tool_name), content, .. } = message { + if let AgentMessage::Tool { + name: Some(tool_name), + content, + .. + } = message + { if tool_name == "import_assets" { // Parse the tool response to see what was imported if let Ok(import_result) = serde_json::from_str::(content) { @@ -65,65 +85,94 @@ impl ChatContextLoader { if let Some(files) = import_result.get("files").and_then(|f| f.as_array()) { if !files.is_empty() { // Set files_available for any imported files - agent.set_state_value(String::from("files_available"), Value::Bool(true)) + agent + .set_state_value(String::from("files_available"), Value::Bool(true)) .await; - + // Check each file to determine its type let mut has_metrics = false; let mut has_dashboards = false; let mut has_datasets = false; - + for file in files { // Check file_type/asset_type to determine what kind of asset this is - let file_type = file.get("file_type").and_then(|ft| ft.as_str()) + let file_type = file + .get("file_type") + .and_then(|ft| ft.as_str()) .or_else(|| file.get("asset_type").and_then(|at| at.as_str())); - - tracing::debug!("Processing imported file with type: {:?}", file_type); - + + tracing::debug!( + "Processing imported file with type: {:?}", + file_type + ); + match file_type { Some("metric") => { has_metrics = true; - + // Check if the metric has dataset references - if let Some(yml_content) = file.get("yml_content").and_then(|y| y.as_str()) { - if yml_content.contains("dataset") || yml_content.contains("datasetIds") { + if let Some(yml_content) = + file.get("yml_content").and_then(|y| y.as_str()) + { + if yml_content.contains("dataset") + || yml_content.contains("datasetIds") + { has_datasets = true; } } - }, + } Some("dashboard") => { has_dashboards = true; - + // Dashboards often reference metrics too has_metrics = true; - + // Check if the dashboard has dataset references via metrics - if let Some(yml_content) = file.get("yml_content").and_then(|y| y.as_str()) { - if yml_content.contains("dataset") || yml_content.contains("datasetIds") { + if let Some(yml_content) = + file.get("yml_content").and_then(|y| y.as_str()) + { + if yml_content.contains("dataset") + || yml_content.contains("datasetIds") + { has_datasets = true; } } - }, + } _ => { - tracing::debug!("Unknown file type in import_assets: {:?}", file_type); + tracing::debug!( + "Unknown file type in import_assets: {:?}", + file_type + ); } } } - + // Set appropriate state values based on what we found if has_metrics { tracing::debug!("Setting metrics_available state to true"); - agent.set_state_value(String::from("metrics_available"), Value::Bool(true)) + agent + .set_state_value( + String::from("metrics_available"), + Value::Bool(true), + ) .await; } if has_dashboards { tracing::debug!("Setting dashboards_available state to true"); - agent.set_state_value(String::from("dashboards_available"), Value::Bool(true)) + agent + .set_state_value( + String::from("dashboards_available"), + Value::Bool(true), + ) .await; } if has_datasets { tracing::debug!("Setting data_context state to true"); - agent.set_state_value(String::from("data_context"), Value::Bool(true)) + agent + .set_state_value( + String::from("data_context"), + Value::Bool(true), + ) .await; } } @@ -136,34 +185,47 @@ impl ChatContextLoader { #[async_trait] impl ContextLoader for ChatContextLoader { - async fn load_context(&self, user: &AuthenticatedUser, agent: &Arc) -> Result> { + async fn load_context( + &self, + user: &AuthenticatedUser, + agent: &Arc, + ) -> Result> { let mut conn = get_pg_pool().get().await?; // First verify the chat exists and user has access let chat = chats::table .filter(chats::id.eq(self.chat_id)) .filter(chats::created_by.eq(&user.id)) + .filter(chats::deleted_at.is_null()) .first::(&mut conn) .await?; // Get only the most recent message for the chat - let message = messages::table + let message = match messages::table .filter(messages::chat_id.eq(chat.id)) + .filter(messages::deleted_at.is_null()) .order_by(messages::created_at.desc()) .first::(&mut conn) - .await?; + .await + { + Ok(message) => message, + Err(diesel::NotFound) => return Ok(vec![]), + Err(e) => return Err(anyhow::anyhow!("Failed to get message: {}", e)), + }; // Track seen message IDs let mut seen_ids = HashSet::new(); // Convert messages to AgentMessages let mut agent_messages = Vec::new(); - + // Process only the most recent message's raw LLM messages - if let Ok(raw_messages) = serde_json::from_value::>(message.raw_llm_messages) { + if let Ok(raw_messages) = + serde_json::from_value::>(message.raw_llm_messages) + { // Check each message for tool calls and update context for agent_message in &raw_messages { Self::update_context_from_tool_calls(agent, agent_message).await; - + // Only add messages with new IDs if let Some(id) = agent_message.get_id() { if seen_ids.insert(id.to_string()) { @@ -178,4 +240,4 @@ impl ContextLoader for ChatContextLoader { Ok(agent_messages) } -} \ No newline at end of file +} diff --git a/api/libs/handlers/src/chats/post_chat_handler.rs b/api/libs/handlers/src/chats/post_chat_handler.rs index 0762d95ea..59fcbe055 100644 --- a/api/libs/handlers/src/chats/post_chat_handler.rs +++ b/api/libs/handlers/src/chats/post_chat_handler.rs @@ -228,8 +228,6 @@ pub async fn post_chat_handler( let messages = generate_asset_messages(asset_id_value, asset_type_value, &user).await?; - println!("messages: {:?}", messages); - // Add messages to chat and associate with chat_id let mut updated_messages = Vec::new(); for mut message in messages {