From 99d9399f5108d3c0bb4d0e6e523273080115558a Mon Sep 17 00:00:00 2001 From: dal Date: Tue, 4 Mar 2025 09:40:27 -0700 Subject: [PATCH] context loaders. --- api/.cursor/rules/global.mdc | 3 + api/.cursor/rules/prds.mdc | 2 +- api/libs/handlers/Cargo.toml | 1 + .../src/chats/context_loaders/chat_context.rs | 61 ++++++ .../context_loaders/dashboard_context.rs | 181 ++++++++++++++++++ .../chats/context_loaders/metric_context.rs | 126 ++++++++++++ .../handlers/src/chats/context_loaders/mod.rs | 43 +++++ api/libs/handlers/src/chats/mod.rs | 1 + .../handlers/src/chats/post_chat_handler.rs | 54 +++++- 9 files changed, 462 insertions(+), 10 deletions(-) create mode 100644 api/libs/handlers/src/chats/context_loaders/chat_context.rs create mode 100644 api/libs/handlers/src/chats/context_loaders/dashboard_context.rs create mode 100644 api/libs/handlers/src/chats/context_loaders/metric_context.rs create mode 100644 api/libs/handlers/src/chats/context_loaders/mod.rs diff --git a/api/.cursor/rules/global.mdc b/api/.cursor/rules/global.mdc index 6a99128ab..5e5c60abb 100644 --- a/api/.cursor/rules/global.mdc +++ b/api/.cursor/rules/global.mdc @@ -36,6 +36,9 @@ let mut conn = get_pg_pool().get().await?; - Avoid unnecessary `.clone()` calls - Use `&str` instead of `String` for function parameters when the string doesn't need to be owned +### Importing packages/crates +- Please make the dependency as short as possible in the actual logic by importing the crate/package. + ### Database Operations - Use Diesel for database migrations and query building - Migrations are stored in the `migrations/` directory diff --git a/api/.cursor/rules/prds.mdc b/api/.cursor/rules/prds.mdc index 971e489a2..2e2ff80f5 100644 --- a/api/.cursor/rules/prds.mdc +++ b/api/.cursor/rules/prds.mdc @@ -115,7 +115,7 @@ Reference these example PRDs for guidance: - [ ] All template sections completed - [ ] Technical design is detailed and complete - [ ] File changes are documented -- [ ] Implementation phases are clear +- [ ] Implementation phases are clear (can be as many as you need.) - [ ] Testing strategy is defined - [ ] Security considerations addressed - [ ] Dependencies and Files listed diff --git a/api/libs/handlers/Cargo.toml b/api/libs/handlers/Cargo.toml index 1f974be75..69151cd59 100644 --- a/api/libs/handlers/Cargo.toml +++ b/api/libs/handlers/Cargo.toml @@ -19,6 +19,7 @@ futures = { workspace = true } redis = { workspace = true } regex = { workspace = true } indexmap = { workspace = true } +async-trait = { workspace = true } # Local dependencies database = { path = "../database" } diff --git a/api/libs/handlers/src/chats/context_loaders/chat_context.rs b/api/libs/handlers/src/chats/context_loaders/chat_context.rs new file mode 100644 index 000000000..f239b8134 --- /dev/null +++ b/api/libs/handlers/src/chats/context_loaders/chat_context.rs @@ -0,0 +1,61 @@ +use std::sync::Arc; + +use anyhow::Result; +use async_trait::async_trait; +use database::{ + models::User, + pool::get_pg_pool, + schema::{chats, messages}, +}; +use diesel::prelude::*; +use diesel_async::RunQueryDsl; +use agents::{Agent, AgentMessage}; +use uuid::Uuid; + +use super::ContextLoader; + +pub struct ChatContextLoader { + pub chat_id: Uuid, +} + +impl ChatContextLoader { + pub fn new(chat_id: Uuid) -> Self { + Self { chat_id } + } +} + +#[async_trait] +impl ContextLoader for ChatContextLoader { + async fn load_context(&self, user: &User, agent: &Arc) -> Result> { + let mut conn = get_pg_pool().get().await?; + + // First verify the chat exists and user has access + let chat = chats::table + .filter(chats::id.eq(self.chat_id)) + .filter(chats::created_by.eq(&user.id)) + .first::(&mut conn) + .await?; + + // Get all messages for the chat + let messages = messages::table + .filter(messages::chat_id.eq(chat.id)) + .order_by(messages::created_at.asc()) + .load::(&mut conn) + .await?; + + // Convert messages to AgentMessages + let mut agent_messages = Vec::new(); + for message in messages { + // Add user message + agent_messages.push(AgentMessage::user(message.request)); + + // Add assistant messages from response + if let Ok(response_messages) = serde_json::from_value::>(message.response) + { + agent_messages.extend(response_messages); + } + } + + Ok(agent_messages) + } +} \ No newline at end of file diff --git a/api/libs/handlers/src/chats/context_loaders/dashboard_context.rs b/api/libs/handlers/src/chats/context_loaders/dashboard_context.rs new file mode 100644 index 000000000..598d358a2 --- /dev/null +++ b/api/libs/handlers/src/chats/context_loaders/dashboard_context.rs @@ -0,0 +1,181 @@ +use anyhow::{Result, anyhow}; +use async_trait::async_trait; +use database::{ + models::{User, Dataset, MetricFile}, + pool::get_pg_pool, + schema::{dashboard_files, metric_files, datasets}, +}; +use diesel::prelude::*; +use diesel_async::RunQueryDsl; +use agents::{AgentMessage, Agent}; +use litellm::MessageProgress; +use serde_json::Value; +use std::sync::Arc; +use uuid::Uuid; +use std::collections::HashSet; + +use super::ContextLoader; + +pub struct DashboardContextLoader { + pub dashboard_id: Uuid, +} + +impl DashboardContextLoader { + pub fn new(dashboard_id: Uuid) -> Self { + Self { dashboard_id } + } +} + +#[async_trait] +impl ContextLoader for DashboardContextLoader { + async fn load_context(&self, user: &User, agent: &Arc) -> Result> { + let mut conn = get_pg_pool().get().await.map_err(|e| { + anyhow!("Failed to get database connection for dashboard context loading: {}", e) + })?; + + // First verify the dashboard exists and user has access + let dashboard = dashboard_files::table + .filter(dashboard_files::id.eq(self.dashboard_id)) + // .filter(dashboard_files::created_by.eq(&user.id)) + .first::(&mut conn) + .await + .map_err(|e| { + anyhow!("Failed to load dashboard (id: {}). Either it doesn't exist or user {} doesn't have access: {}", + self.dashboard_id, user.id, e) + })?; + + // Parse dashboard content to DashboardYml + let dashboard_yml: agents::tools::categories::file_tools::file_types::dashboard_yml::DashboardYml = + serde_json::from_value(dashboard.content.clone()) + .map_err(|e| anyhow!("Failed to parse dashboard content as YAML for dashboard {}: {}", dashboard.name, e))?; + + // Collect all metric IDs from the dashboard + let mut metric_ids = HashSet::new(); + for row in &dashboard_yml.rows { + for item in &row.items { + metric_ids.insert(item.id); + } + } + + // Load all referenced metrics + let mut metrics_vec = Vec::new(); + let mut all_dataset_ids = HashSet::new(); + let mut failed_metric_loads = Vec::new(); + + for metric_id in metric_ids { + match metric_files::table + .filter(metric_files::id.eq(metric_id)) + .first::(&mut conn) + .await + { + Ok(metric) => { + // Parse metric content + match serde_json::from_value::(metric.content.clone()) { + Ok(metric_yml) => { + all_dataset_ids.extend(metric_yml.dataset_ids); + metrics_vec.push(metric); + } + Err(e) => { + failed_metric_loads.push((metric_id, format!("Failed to parse metric content: {}", e))); + } + } + } + Err(e) => { + failed_metric_loads.push((metric_id, format!("Failed to load metric: {}", e))); + } + } + } + + if !failed_metric_loads.is_empty() { + tracing::warn!( + "Failed to load some metrics for dashboard {}: {:?}", + dashboard.name, + failed_metric_loads + ); + } + + // Load all unique datasets + let mut datasets_vec = Vec::new(); + let mut failed_dataset_loads = Vec::new(); + + for dataset_id in all_dataset_ids { + match datasets::table + .filter(datasets::id.eq(dataset_id)) + .first::(&mut conn) + .await + { + Ok(dataset) => datasets_vec.push(dataset), + Err(e) => failed_dataset_loads.push((dataset_id, e.to_string())), + } + } + + if !failed_dataset_loads.is_empty() { + tracing::warn!( + "Failed to load some datasets for dashboard {}: {:?}", + dashboard.name, + failed_dataset_loads + ); + } + + // Set agent state based on loaded assets + agent.set_state_value(String::from("dashboards_available"), Value::Bool(true)) + .await; + + agent.set_state_value(String::from("files_available"), Value::Bool(true)) + .await; + + if !metrics_vec.is_empty() { + agent.set_state_value(String::from("metrics_available"), Value::Bool(true)) + .await; + }; + + if !datasets_vec.is_empty() { + agent.set_state_value(String::from("data_context"), Value::Bool(true)) + .await; + }; + + // Format the context message with dashboard, metrics, and dataset information + let dashboard_yaml = serde_yaml::to_string(&dashboard_yml) + .map_err(|e| anyhow!("Failed to serialize dashboard {} to YAML: {}", dashboard.name, e))?; + + let mut context_message = format!( + "This conversation is continuing with context from the dashboard. Here is the relevant information:\n\nDashboard Definition:\n{}\n\n", + dashboard_yaml + ); + + if !metrics_vec.is_empty() { + context_message.push_str("Referenced Metrics:\n"); + for metric in metrics_vec { + match serde_json::from_value::(metric.content) { + Ok(metric_yml) => { + match serde_yaml::to_string(&metric_yml) { + Ok(yaml) => context_message.push_str(&format!("\n{}\n", yaml)), + Err(e) => tracing::warn!("Failed to serialize metric {} to YAML: {}", metric.id, e), + } + } + Err(e) => tracing::warn!("Failed to parse metric {} content: {}", metric.id, e), + } + } + } + + if !datasets_vec.is_empty() { + context_message.push_str("\nReferenced Datasets:\n"); + for dataset in datasets_vec { + if let Some(yml_content) = dataset.yml_file { + context_message.push_str(&format!("\n{}\n", yml_content)); + } else { + tracing::warn!("Dataset {} has no YML content", dataset.id); + } + } + } + + Ok(vec![AgentMessage::Assistant { + id: None, + content: Some(context_message), + name: None, + tool_calls: None, + progress: MessageProgress::Complete, + initial: true, + }]) + } +} \ No newline at end of file diff --git a/api/libs/handlers/src/chats/context_loaders/metric_context.rs b/api/libs/handlers/src/chats/context_loaders/metric_context.rs new file mode 100644 index 000000000..4ea050e93 --- /dev/null +++ b/api/libs/handlers/src/chats/context_loaders/metric_context.rs @@ -0,0 +1,126 @@ +use agents::{Agent, AgentMessage}; +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use database::{ + models::{Dataset, User}, + pool::get_pg_pool, + schema::{datasets, metric_files}, +}; +use diesel::prelude::*; +use diesel_async::RunQueryDsl; +use serde_json::Value; +use std::sync::Arc; +use uuid::Uuid; + +use super::ContextLoader; + +pub struct MetricContextLoader { + pub metric_id: Uuid, +} + +impl MetricContextLoader { + pub fn new(metric_id: Uuid) -> Self { + Self { metric_id } + } +} + +#[async_trait] +impl ContextLoader for MetricContextLoader { + async fn load_context(&self, user: &User, agent: &Arc) -> Result> { + let mut conn = get_pg_pool().get().await.map_err(|e| { + anyhow!( + "Failed to get database connection for metric context loading: {}", + e + ) + })?; + + // First verify the metric exists and user has access + let metric = metric_files::table + .filter(metric_files::id.eq(self.metric_id)) + // .filter(metric_files::created_by.eq(&user.id)) + .first::(&mut conn) + .await + .map_err(|e| { + anyhow!("Failed to load metric (id: {}). Either it doesn't exist or user {} doesn't have access: {}", + self.metric_id, user.id, e) + })?; + + // Get the metric content as MetricYml + let metric_yml: agents::tools::categories::file_tools::file_types::metric_yml::MetricYml = + serde_json::from_value(metric.content.clone()).map_err(|e| { + anyhow!( + "Failed to parse metric content as YAML for metric {}: {}", + metric.name, + e + ) + })?; + + // Load all referenced datasets + let dataset_ids = &metric_yml.dataset_ids; + let mut datasets_vec = Vec::new(); + let mut failed_dataset_loads = Vec::new(); + + for dataset_id in dataset_ids { + match datasets::table + .filter(datasets::id.eq(dataset_id)) + .first::(&mut conn) + .await + { + Ok(dataset) => datasets_vec.push(dataset), + Err(e) => failed_dataset_loads.push((dataset_id, e.to_string())), + } + } + + if !failed_dataset_loads.is_empty() { + tracing::warn!( + "Failed to load some datasets for metric {}: {:?}", + metric.name, + failed_dataset_loads + ); + } + + // Set agent state based on loaded assets + agent + .set_state_value(String::from("metrics_available"), Value::Bool(true)) + .await; + + agent + .set_state_value(String::from("files_available"), Value::Bool(true)) + .await; + + if !datasets_vec.is_empty() { + agent + .set_state_value(String::from("data_context"), Value::Bool(true)) + .await; + }; + + // Format the context message with metric and dataset information + let metric_yaml = serde_yaml::to_string(&metric_yml) + .map_err(|e| anyhow!("Failed to serialize metric {} to YAML: {}", metric.name, e))?; + + let mut context_message = format!( + "This conversation is continuing with context from the metric. Here is the relevant information:\n\nMetric Definition:\n{}\n\n", + metric_yaml + ); + + if !datasets_vec.is_empty() { + context_message.push_str("Referenced Datasets:\n"); + for dataset in datasets_vec { + if let Some(yml_content) = dataset.yml_file { + context_message.push_str(&format!("\n{}\n", yml_content)); + } else { + tracing::warn!("Dataset {} has no YML content", dataset.id); + } + } + } + + Ok(vec![AgentMessage::Assistant { + id: None, + content: Some(context_message), + name: None, + tool_calls: None, + progress: litellm::MessageProgress::Complete, + initial: true, + }]) + } +} diff --git a/api/libs/handlers/src/chats/context_loaders/mod.rs b/api/libs/handlers/src/chats/context_loaders/mod.rs new file mode 100644 index 000000000..40acbf184 --- /dev/null +++ b/api/libs/handlers/src/chats/context_loaders/mod.rs @@ -0,0 +1,43 @@ +use anyhow::Result; +use async_trait::async_trait; +use database::models::User; +use agents::AgentMessage; +use std::sync::Arc; +use agents::Agent; + +pub mod chat_context; +pub mod metric_context; +pub mod dashboard_context; + +pub use chat_context::ChatContextLoader; +pub use metric_context::MetricContextLoader; +pub use dashboard_context::DashboardContextLoader; + +#[async_trait] +pub trait ContextLoader { + async fn load_context(&self, user: &User, agent: &Arc) -> Result>; +} + +// Validate that only one context type is provided +pub fn validate_context_request( + chat_id: Option, + metric_id: Option, + dashboard_id: Option, +) -> Result<()> { + let context_count = [ + chat_id.is_some(), + metric_id.is_some(), + dashboard_id.is_some(), + ] + .iter() + .filter(|&&b| b) + .count(); + + if context_count > 1 { + return Err(anyhow::anyhow!( + "Only one context type (chat, metric, or dashboard) can be provided" + )); + } + + Ok(()) +} \ No newline at end of file diff --git a/api/libs/handlers/src/chats/mod.rs b/api/libs/handlers/src/chats/mod.rs index 3563b5b12..bf86d8f27 100644 --- a/api/libs/handlers/src/chats/mod.rs +++ b/api/libs/handlers/src/chats/mod.rs @@ -4,6 +4,7 @@ pub mod update_chats_handler; pub mod delete_chats_handler; pub mod types; pub mod streaming_parser; +pub mod context_loaders; pub use get_chat_handler::get_chat_handler; pub use post_chat_handler::post_chat_handler; diff --git a/api/libs/handlers/src/chats/post_chat_handler.rs b/api/libs/handlers/src/chats/post_chat_handler.rs index 4e27ae52d..42ea056bc 100644 --- a/api/libs/handlers/src/chats/post_chat_handler.rs +++ b/api/libs/handlers/src/chats/post_chat_handler.rs @@ -1,8 +1,8 @@ use std::collections::HashMap; use agents::{ - tools::file_tools::search_data_catalog::SearchDataCatalogOutput, AgentMessage, AgentThread, - BusterSuperAgent, + tools::file_tools::search_data_catalog::SearchDataCatalogOutput, AgentExt, AgentMessage, + AgentThread, BusterSuperAgent, }; use anyhow::{anyhow, Result}; @@ -23,7 +23,13 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use uuid::Uuid; -use crate::chats::streaming_parser::StreamingParser; +use crate::chats::{ + context_loaders::{ + chat_context::ChatContextLoader, dashboard_context::DashboardContextLoader, + metric_context::MetricContextLoader, validate_context_request, ContextLoader, + }, + streaming_parser::StreamingParser, +}; use crate::messages::types::{ChatMessage, ChatUserMessage}; use super::types::ChatWithMessages; @@ -44,6 +50,8 @@ pub struct ChatCreateNewChat { pub prompt: String, pub chat_id: Option, pub message_id: Option, + pub metric_id: Option, + pub dashboard_id: Option, } pub async fn post_chat_handler( @@ -51,6 +59,9 @@ pub async fn post_chat_handler( user: User, tx: Option>>, ) -> Result { + // Validate context request + validate_context_request(request.chat_id, request.metric_id, request.dashboard_id)?; + let chat_id = request.chat_id.unwrap_or_else(Uuid::new_v4); let message_id = request.message_id.unwrap_or_else(Uuid::new_v4); @@ -120,13 +131,38 @@ pub async fn post_chat_handler( .execute(&mut conn) .await?; - // Initialize agent and process request + // Initialize agent with context if provided + let mut initial_messages = vec![]; + + // Initialize agent to add context let agent = BusterSuperAgent::new(user.id, chat_id).await?; - let mut chat = AgentThread::new( - Some(chat_id), - user.id, - vec![AgentMessage::user(request.prompt.clone())], - ); + + // Load context if provided + if let Some(existing_chat_id) = request.chat_id { + let context_loader = ChatContextLoader::new(existing_chat_id); + let context_messages = context_loader + .load_context(&user, agent.get_agent()) + .await?; + initial_messages.extend(context_messages); + } else if let Some(metric_id) = request.metric_id { + let context_loader = MetricContextLoader::new(metric_id); + let context_messages = context_loader + .load_context(&user, agent.get_agent()) + .await?; + initial_messages.extend(context_messages); + } else if let Some(dashboard_id) = request.dashboard_id { + let context_loader = DashboardContextLoader::new(dashboard_id); + let context_messages = context_loader + .load_context(&user, agent.get_agent()) + .await?; + initial_messages.extend(context_messages); + } + + // Add the new user message + initial_messages.push(AgentMessage::user(request.prompt.clone())); + + // Initialize the agent thread + let mut chat = AgentThread::new(Some(chat_id), user.id, initial_messages); let title_handle = { let tx = tx.clone();