context loaders.

This commit is contained in:
dal 2025-03-04 09:40:27 -07:00
parent fee69dbab7
commit 99d9399f51
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
9 changed files with 462 additions and 10 deletions

View File

@ -36,6 +36,9 @@ let mut conn = get_pg_pool().get().await?;
- Avoid unnecessary `.clone()` calls
- Use `&str` instead of `String` for function parameters when the string doesn't need to be owned
### Importing packages/crates
- Please make the dependency as short as possible in the actual logic by importing the crate/package.
### Database Operations
- Use Diesel for database migrations and query building
- Migrations are stored in the `migrations/` directory

View File

@ -115,7 +115,7 @@ Reference these example PRDs for guidance:
- [ ] All template sections completed
- [ ] Technical design is detailed and complete
- [ ] File changes are documented
- [ ] Implementation phases are clear
- [ ] Implementation phases are clear (can be as many as you need.)
- [ ] Testing strategy is defined
- [ ] Security considerations addressed
- [ ] Dependencies and Files listed

View File

@ -19,6 +19,7 @@ futures = { workspace = true }
redis = { workspace = true }
regex = { workspace = true }
indexmap = { workspace = true }
async-trait = { workspace = true }
# Local dependencies
database = { path = "../database" }

View File

@ -0,0 +1,61 @@
use std::sync::Arc;
use anyhow::Result;
use async_trait::async_trait;
use database::{
models::User,
pool::get_pg_pool,
schema::{chats, messages},
};
use diesel::prelude::*;
use diesel_async::RunQueryDsl;
use agents::{Agent, AgentMessage};
use uuid::Uuid;
use super::ContextLoader;
pub struct ChatContextLoader {
pub chat_id: Uuid,
}
impl ChatContextLoader {
pub fn new(chat_id: Uuid) -> Self {
Self { chat_id }
}
}
#[async_trait]
impl ContextLoader for ChatContextLoader {
async fn load_context(&self, user: &User, agent: &Arc<Agent>) -> Result<Vec<AgentMessage>> {
let mut conn = get_pg_pool().get().await?;
// First verify the chat exists and user has access
let chat = chats::table
.filter(chats::id.eq(self.chat_id))
.filter(chats::created_by.eq(&user.id))
.first::<database::models::Chat>(&mut conn)
.await?;
// Get all messages for the chat
let messages = messages::table
.filter(messages::chat_id.eq(chat.id))
.order_by(messages::created_at.asc())
.load::<database::models::Message>(&mut conn)
.await?;
// Convert messages to AgentMessages
let mut agent_messages = Vec::new();
for message in messages {
// Add user message
agent_messages.push(AgentMessage::user(message.request));
// Add assistant messages from response
if let Ok(response_messages) = serde_json::from_value::<Vec<AgentMessage>>(message.response)
{
agent_messages.extend(response_messages);
}
}
Ok(agent_messages)
}
}

View File

@ -0,0 +1,181 @@
use anyhow::{Result, anyhow};
use async_trait::async_trait;
use database::{
models::{User, Dataset, MetricFile},
pool::get_pg_pool,
schema::{dashboard_files, metric_files, datasets},
};
use diesel::prelude::*;
use diesel_async::RunQueryDsl;
use agents::{AgentMessage, Agent};
use litellm::MessageProgress;
use serde_json::Value;
use std::sync::Arc;
use uuid::Uuid;
use std::collections::HashSet;
use super::ContextLoader;
pub struct DashboardContextLoader {
pub dashboard_id: Uuid,
}
impl DashboardContextLoader {
pub fn new(dashboard_id: Uuid) -> Self {
Self { dashboard_id }
}
}
#[async_trait]
impl ContextLoader for DashboardContextLoader {
async fn load_context(&self, user: &User, agent: &Arc<Agent>) -> Result<Vec<AgentMessage>> {
let mut conn = get_pg_pool().get().await.map_err(|e| {
anyhow!("Failed to get database connection for dashboard context loading: {}", e)
})?;
// First verify the dashboard exists and user has access
let dashboard = dashboard_files::table
.filter(dashboard_files::id.eq(self.dashboard_id))
// .filter(dashboard_files::created_by.eq(&user.id))
.first::<database::models::DashboardFile>(&mut conn)
.await
.map_err(|e| {
anyhow!("Failed to load dashboard (id: {}). Either it doesn't exist or user {} doesn't have access: {}",
self.dashboard_id, user.id, e)
})?;
// Parse dashboard content to DashboardYml
let dashboard_yml: agents::tools::categories::file_tools::file_types::dashboard_yml::DashboardYml =
serde_json::from_value(dashboard.content.clone())
.map_err(|e| anyhow!("Failed to parse dashboard content as YAML for dashboard {}: {}", dashboard.name, e))?;
// Collect all metric IDs from the dashboard
let mut metric_ids = HashSet::new();
for row in &dashboard_yml.rows {
for item in &row.items {
metric_ids.insert(item.id);
}
}
// Load all referenced metrics
let mut metrics_vec = Vec::new();
let mut all_dataset_ids = HashSet::new();
let mut failed_metric_loads = Vec::new();
for metric_id in metric_ids {
match metric_files::table
.filter(metric_files::id.eq(metric_id))
.first::<MetricFile>(&mut conn)
.await
{
Ok(metric) => {
// Parse metric content
match serde_json::from_value::<agents::tools::categories::file_tools::file_types::metric_yml::MetricYml>(metric.content.clone()) {
Ok(metric_yml) => {
all_dataset_ids.extend(metric_yml.dataset_ids);
metrics_vec.push(metric);
}
Err(e) => {
failed_metric_loads.push((metric_id, format!("Failed to parse metric content: {}", e)));
}
}
}
Err(e) => {
failed_metric_loads.push((metric_id, format!("Failed to load metric: {}", e)));
}
}
}
if !failed_metric_loads.is_empty() {
tracing::warn!(
"Failed to load some metrics for dashboard {}: {:?}",
dashboard.name,
failed_metric_loads
);
}
// Load all unique datasets
let mut datasets_vec = Vec::new();
let mut failed_dataset_loads = Vec::new();
for dataset_id in all_dataset_ids {
match datasets::table
.filter(datasets::id.eq(dataset_id))
.first::<Dataset>(&mut conn)
.await
{
Ok(dataset) => datasets_vec.push(dataset),
Err(e) => failed_dataset_loads.push((dataset_id, e.to_string())),
}
}
if !failed_dataset_loads.is_empty() {
tracing::warn!(
"Failed to load some datasets for dashboard {}: {:?}",
dashboard.name,
failed_dataset_loads
);
}
// Set agent state based on loaded assets
agent.set_state_value(String::from("dashboards_available"), Value::Bool(true))
.await;
agent.set_state_value(String::from("files_available"), Value::Bool(true))
.await;
if !metrics_vec.is_empty() {
agent.set_state_value(String::from("metrics_available"), Value::Bool(true))
.await;
};
if !datasets_vec.is_empty() {
agent.set_state_value(String::from("data_context"), Value::Bool(true))
.await;
};
// Format the context message with dashboard, metrics, and dataset information
let dashboard_yaml = serde_yaml::to_string(&dashboard_yml)
.map_err(|e| anyhow!("Failed to serialize dashboard {} to YAML: {}", dashboard.name, e))?;
let mut context_message = format!(
"This conversation is continuing with context from the dashboard. Here is the relevant information:\n\nDashboard Definition:\n{}\n\n",
dashboard_yaml
);
if !metrics_vec.is_empty() {
context_message.push_str("Referenced Metrics:\n");
for metric in metrics_vec {
match serde_json::from_value::<agents::tools::categories::file_tools::file_types::metric_yml::MetricYml>(metric.content) {
Ok(metric_yml) => {
match serde_yaml::to_string(&metric_yml) {
Ok(yaml) => context_message.push_str(&format!("\n{}\n", yaml)),
Err(e) => tracing::warn!("Failed to serialize metric {} to YAML: {}", metric.id, e),
}
}
Err(e) => tracing::warn!("Failed to parse metric {} content: {}", metric.id, e),
}
}
}
if !datasets_vec.is_empty() {
context_message.push_str("\nReferenced Datasets:\n");
for dataset in datasets_vec {
if let Some(yml_content) = dataset.yml_file {
context_message.push_str(&format!("\n{}\n", yml_content));
} else {
tracing::warn!("Dataset {} has no YML content", dataset.id);
}
}
}
Ok(vec![AgentMessage::Assistant {
id: None,
content: Some(context_message),
name: None,
tool_calls: None,
progress: MessageProgress::Complete,
initial: true,
}])
}
}

View File

@ -0,0 +1,126 @@
use agents::{Agent, AgentMessage};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use database::{
models::{Dataset, User},
pool::get_pg_pool,
schema::{datasets, metric_files},
};
use diesel::prelude::*;
use diesel_async::RunQueryDsl;
use serde_json::Value;
use std::sync::Arc;
use uuid::Uuid;
use super::ContextLoader;
pub struct MetricContextLoader {
pub metric_id: Uuid,
}
impl MetricContextLoader {
pub fn new(metric_id: Uuid) -> Self {
Self { metric_id }
}
}
#[async_trait]
impl ContextLoader for MetricContextLoader {
async fn load_context(&self, user: &User, agent: &Arc<Agent>) -> Result<Vec<AgentMessage>> {
let mut conn = get_pg_pool().get().await.map_err(|e| {
anyhow!(
"Failed to get database connection for metric context loading: {}",
e
)
})?;
// First verify the metric exists and user has access
let metric = metric_files::table
.filter(metric_files::id.eq(self.metric_id))
// .filter(metric_files::created_by.eq(&user.id))
.first::<database::models::MetricFile>(&mut conn)
.await
.map_err(|e| {
anyhow!("Failed to load metric (id: {}). Either it doesn't exist or user {} doesn't have access: {}",
self.metric_id, user.id, e)
})?;
// Get the metric content as MetricYml
let metric_yml: agents::tools::categories::file_tools::file_types::metric_yml::MetricYml =
serde_json::from_value(metric.content.clone()).map_err(|e| {
anyhow!(
"Failed to parse metric content as YAML for metric {}: {}",
metric.name,
e
)
})?;
// Load all referenced datasets
let dataset_ids = &metric_yml.dataset_ids;
let mut datasets_vec = Vec::new();
let mut failed_dataset_loads = Vec::new();
for dataset_id in dataset_ids {
match datasets::table
.filter(datasets::id.eq(dataset_id))
.first::<Dataset>(&mut conn)
.await
{
Ok(dataset) => datasets_vec.push(dataset),
Err(e) => failed_dataset_loads.push((dataset_id, e.to_string())),
}
}
if !failed_dataset_loads.is_empty() {
tracing::warn!(
"Failed to load some datasets for metric {}: {:?}",
metric.name,
failed_dataset_loads
);
}
// Set agent state based on loaded assets
agent
.set_state_value(String::from("metrics_available"), Value::Bool(true))
.await;
agent
.set_state_value(String::from("files_available"), Value::Bool(true))
.await;
if !datasets_vec.is_empty() {
agent
.set_state_value(String::from("data_context"), Value::Bool(true))
.await;
};
// Format the context message with metric and dataset information
let metric_yaml = serde_yaml::to_string(&metric_yml)
.map_err(|e| anyhow!("Failed to serialize metric {} to YAML: {}", metric.name, e))?;
let mut context_message = format!(
"This conversation is continuing with context from the metric. Here is the relevant information:\n\nMetric Definition:\n{}\n\n",
metric_yaml
);
if !datasets_vec.is_empty() {
context_message.push_str("Referenced Datasets:\n");
for dataset in datasets_vec {
if let Some(yml_content) = dataset.yml_file {
context_message.push_str(&format!("\n{}\n", yml_content));
} else {
tracing::warn!("Dataset {} has no YML content", dataset.id);
}
}
}
Ok(vec![AgentMessage::Assistant {
id: None,
content: Some(context_message),
name: None,
tool_calls: None,
progress: litellm::MessageProgress::Complete,
initial: true,
}])
}
}

View File

@ -0,0 +1,43 @@
use anyhow::Result;
use async_trait::async_trait;
use database::models::User;
use agents::AgentMessage;
use std::sync::Arc;
use agents::Agent;
pub mod chat_context;
pub mod metric_context;
pub mod dashboard_context;
pub use chat_context::ChatContextLoader;
pub use metric_context::MetricContextLoader;
pub use dashboard_context::DashboardContextLoader;
#[async_trait]
pub trait ContextLoader {
async fn load_context(&self, user: &User, agent: &Arc<Agent>) -> Result<Vec<AgentMessage>>;
}
// Validate that only one context type is provided
pub fn validate_context_request(
chat_id: Option<uuid::Uuid>,
metric_id: Option<uuid::Uuid>,
dashboard_id: Option<uuid::Uuid>,
) -> Result<()> {
let context_count = [
chat_id.is_some(),
metric_id.is_some(),
dashboard_id.is_some(),
]
.iter()
.filter(|&&b| b)
.count();
if context_count > 1 {
return Err(anyhow::anyhow!(
"Only one context type (chat, metric, or dashboard) can be provided"
));
}
Ok(())
}

View File

@ -4,6 +4,7 @@ pub mod update_chats_handler;
pub mod delete_chats_handler;
pub mod types;
pub mod streaming_parser;
pub mod context_loaders;
pub use get_chat_handler::get_chat_handler;
pub use post_chat_handler::post_chat_handler;

View File

@ -1,8 +1,8 @@
use std::collections::HashMap;
use agents::{
tools::file_tools::search_data_catalog::SearchDataCatalogOutput, AgentMessage, AgentThread,
BusterSuperAgent,
tools::file_tools::search_data_catalog::SearchDataCatalogOutput, AgentExt, AgentMessage,
AgentThread, BusterSuperAgent,
};
use anyhow::{anyhow, Result};
@ -23,7 +23,13 @@ use serde::{Deserialize, Serialize};
use serde_json::Value;
use uuid::Uuid;
use crate::chats::streaming_parser::StreamingParser;
use crate::chats::{
context_loaders::{
chat_context::ChatContextLoader, dashboard_context::DashboardContextLoader,
metric_context::MetricContextLoader, validate_context_request, ContextLoader,
},
streaming_parser::StreamingParser,
};
use crate::messages::types::{ChatMessage, ChatUserMessage};
use super::types::ChatWithMessages;
@ -44,6 +50,8 @@ pub struct ChatCreateNewChat {
pub prompt: String,
pub chat_id: Option<Uuid>,
pub message_id: Option<Uuid>,
pub metric_id: Option<Uuid>,
pub dashboard_id: Option<Uuid>,
}
pub async fn post_chat_handler(
@ -51,6 +59,9 @@ pub async fn post_chat_handler(
user: User,
tx: Option<mpsc::Sender<Result<(BusterContainer, ThreadEvent)>>>,
) -> Result<ChatWithMessages> {
// Validate context request
validate_context_request(request.chat_id, request.metric_id, request.dashboard_id)?;
let chat_id = request.chat_id.unwrap_or_else(Uuid::new_v4);
let message_id = request.message_id.unwrap_or_else(Uuid::new_v4);
@ -120,13 +131,38 @@ pub async fn post_chat_handler(
.execute(&mut conn)
.await?;
// Initialize agent and process request
// Initialize agent with context if provided
let mut initial_messages = vec![];
// Initialize agent to add context
let agent = BusterSuperAgent::new(user.id, chat_id).await?;
let mut chat = AgentThread::new(
Some(chat_id),
user.id,
vec![AgentMessage::user(request.prompt.clone())],
);
// Load context if provided
if let Some(existing_chat_id) = request.chat_id {
let context_loader = ChatContextLoader::new(existing_chat_id);
let context_messages = context_loader
.load_context(&user, agent.get_agent())
.await?;
initial_messages.extend(context_messages);
} else if let Some(metric_id) = request.metric_id {
let context_loader = MetricContextLoader::new(metric_id);
let context_messages = context_loader
.load_context(&user, agent.get_agent())
.await?;
initial_messages.extend(context_messages);
} else if let Some(dashboard_id) = request.dashboard_id {
let context_loader = DashboardContextLoader::new(dashboard_id);
let context_messages = context_loader
.load_context(&user, agent.get_agent())
.await?;
initial_messages.extend(context_messages);
}
// Add the new user message
initial_messages.push(AgentMessage::user(request.prompt.clone()));
// Initialize the agent thread
let mut chat = AgentThread::new(Some(chat_id), user.id, initial_messages);
let title_handle = {
let tx = tx.clone();