mirror of https://github.com/buster-so/buster.git
context loaders.
This commit is contained in:
parent
fee69dbab7
commit
99d9399f51
|
@ -36,6 +36,9 @@ let mut conn = get_pg_pool().get().await?;
|
|||
- Avoid unnecessary `.clone()` calls
|
||||
- Use `&str` instead of `String` for function parameters when the string doesn't need to be owned
|
||||
|
||||
### Importing packages/crates
|
||||
- Please make the dependency as short as possible in the actual logic by importing the crate/package.
|
||||
|
||||
### Database Operations
|
||||
- Use Diesel for database migrations and query building
|
||||
- Migrations are stored in the `migrations/` directory
|
||||
|
|
|
@ -115,7 +115,7 @@ Reference these example PRDs for guidance:
|
|||
- [ ] All template sections completed
|
||||
- [ ] Technical design is detailed and complete
|
||||
- [ ] File changes are documented
|
||||
- [ ] Implementation phases are clear
|
||||
- [ ] Implementation phases are clear (can be as many as you need.)
|
||||
- [ ] Testing strategy is defined
|
||||
- [ ] Security considerations addressed
|
||||
- [ ] Dependencies and Files listed
|
||||
|
|
|
@ -19,6 +19,7 @@ futures = { workspace = true }
|
|||
redis = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
indexmap = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
|
||||
# Local dependencies
|
||||
database = { path = "../database" }
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use database::{
|
||||
models::User,
|
||||
pool::get_pg_pool,
|
||||
schema::{chats, messages},
|
||||
};
|
||||
use diesel::prelude::*;
|
||||
use diesel_async::RunQueryDsl;
|
||||
use agents::{Agent, AgentMessage};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::ContextLoader;
|
||||
|
||||
pub struct ChatContextLoader {
|
||||
pub chat_id: Uuid,
|
||||
}
|
||||
|
||||
impl ChatContextLoader {
|
||||
pub fn new(chat_id: Uuid) -> Self {
|
||||
Self { chat_id }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ContextLoader for ChatContextLoader {
|
||||
async fn load_context(&self, user: &User, agent: &Arc<Agent>) -> Result<Vec<AgentMessage>> {
|
||||
let mut conn = get_pg_pool().get().await?;
|
||||
|
||||
// First verify the chat exists and user has access
|
||||
let chat = chats::table
|
||||
.filter(chats::id.eq(self.chat_id))
|
||||
.filter(chats::created_by.eq(&user.id))
|
||||
.first::<database::models::Chat>(&mut conn)
|
||||
.await?;
|
||||
|
||||
// Get all messages for the chat
|
||||
let messages = messages::table
|
||||
.filter(messages::chat_id.eq(chat.id))
|
||||
.order_by(messages::created_at.asc())
|
||||
.load::<database::models::Message>(&mut conn)
|
||||
.await?;
|
||||
|
||||
// Convert messages to AgentMessages
|
||||
let mut agent_messages = Vec::new();
|
||||
for message in messages {
|
||||
// Add user message
|
||||
agent_messages.push(AgentMessage::user(message.request));
|
||||
|
||||
// Add assistant messages from response
|
||||
if let Ok(response_messages) = serde_json::from_value::<Vec<AgentMessage>>(message.response)
|
||||
{
|
||||
agent_messages.extend(response_messages);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(agent_messages)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,181 @@
|
|||
use anyhow::{Result, anyhow};
|
||||
use async_trait::async_trait;
|
||||
use database::{
|
||||
models::{User, Dataset, MetricFile},
|
||||
pool::get_pg_pool,
|
||||
schema::{dashboard_files, metric_files, datasets},
|
||||
};
|
||||
use diesel::prelude::*;
|
||||
use diesel_async::RunQueryDsl;
|
||||
use agents::{AgentMessage, Agent};
|
||||
use litellm::MessageProgress;
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
use std::collections::HashSet;
|
||||
|
||||
use super::ContextLoader;
|
||||
|
||||
pub struct DashboardContextLoader {
|
||||
pub dashboard_id: Uuid,
|
||||
}
|
||||
|
||||
impl DashboardContextLoader {
|
||||
pub fn new(dashboard_id: Uuid) -> Self {
|
||||
Self { dashboard_id }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ContextLoader for DashboardContextLoader {
|
||||
async fn load_context(&self, user: &User, agent: &Arc<Agent>) -> Result<Vec<AgentMessage>> {
|
||||
let mut conn = get_pg_pool().get().await.map_err(|e| {
|
||||
anyhow!("Failed to get database connection for dashboard context loading: {}", e)
|
||||
})?;
|
||||
|
||||
// First verify the dashboard exists and user has access
|
||||
let dashboard = dashboard_files::table
|
||||
.filter(dashboard_files::id.eq(self.dashboard_id))
|
||||
// .filter(dashboard_files::created_by.eq(&user.id))
|
||||
.first::<database::models::DashboardFile>(&mut conn)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
anyhow!("Failed to load dashboard (id: {}). Either it doesn't exist or user {} doesn't have access: {}",
|
||||
self.dashboard_id, user.id, e)
|
||||
})?;
|
||||
|
||||
// Parse dashboard content to DashboardYml
|
||||
let dashboard_yml: agents::tools::categories::file_tools::file_types::dashboard_yml::DashboardYml =
|
||||
serde_json::from_value(dashboard.content.clone())
|
||||
.map_err(|e| anyhow!("Failed to parse dashboard content as YAML for dashboard {}: {}", dashboard.name, e))?;
|
||||
|
||||
// Collect all metric IDs from the dashboard
|
||||
let mut metric_ids = HashSet::new();
|
||||
for row in &dashboard_yml.rows {
|
||||
for item in &row.items {
|
||||
metric_ids.insert(item.id);
|
||||
}
|
||||
}
|
||||
|
||||
// Load all referenced metrics
|
||||
let mut metrics_vec = Vec::new();
|
||||
let mut all_dataset_ids = HashSet::new();
|
||||
let mut failed_metric_loads = Vec::new();
|
||||
|
||||
for metric_id in metric_ids {
|
||||
match metric_files::table
|
||||
.filter(metric_files::id.eq(metric_id))
|
||||
.first::<MetricFile>(&mut conn)
|
||||
.await
|
||||
{
|
||||
Ok(metric) => {
|
||||
// Parse metric content
|
||||
match serde_json::from_value::<agents::tools::categories::file_tools::file_types::metric_yml::MetricYml>(metric.content.clone()) {
|
||||
Ok(metric_yml) => {
|
||||
all_dataset_ids.extend(metric_yml.dataset_ids);
|
||||
metrics_vec.push(metric);
|
||||
}
|
||||
Err(e) => {
|
||||
failed_metric_loads.push((metric_id, format!("Failed to parse metric content: {}", e)));
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
failed_metric_loads.push((metric_id, format!("Failed to load metric: {}", e)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !failed_metric_loads.is_empty() {
|
||||
tracing::warn!(
|
||||
"Failed to load some metrics for dashboard {}: {:?}",
|
||||
dashboard.name,
|
||||
failed_metric_loads
|
||||
);
|
||||
}
|
||||
|
||||
// Load all unique datasets
|
||||
let mut datasets_vec = Vec::new();
|
||||
let mut failed_dataset_loads = Vec::new();
|
||||
|
||||
for dataset_id in all_dataset_ids {
|
||||
match datasets::table
|
||||
.filter(datasets::id.eq(dataset_id))
|
||||
.first::<Dataset>(&mut conn)
|
||||
.await
|
||||
{
|
||||
Ok(dataset) => datasets_vec.push(dataset),
|
||||
Err(e) => failed_dataset_loads.push((dataset_id, e.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
if !failed_dataset_loads.is_empty() {
|
||||
tracing::warn!(
|
||||
"Failed to load some datasets for dashboard {}: {:?}",
|
||||
dashboard.name,
|
||||
failed_dataset_loads
|
||||
);
|
||||
}
|
||||
|
||||
// Set agent state based on loaded assets
|
||||
agent.set_state_value(String::from("dashboards_available"), Value::Bool(true))
|
||||
.await;
|
||||
|
||||
agent.set_state_value(String::from("files_available"), Value::Bool(true))
|
||||
.await;
|
||||
|
||||
if !metrics_vec.is_empty() {
|
||||
agent.set_state_value(String::from("metrics_available"), Value::Bool(true))
|
||||
.await;
|
||||
};
|
||||
|
||||
if !datasets_vec.is_empty() {
|
||||
agent.set_state_value(String::from("data_context"), Value::Bool(true))
|
||||
.await;
|
||||
};
|
||||
|
||||
// Format the context message with dashboard, metrics, and dataset information
|
||||
let dashboard_yaml = serde_yaml::to_string(&dashboard_yml)
|
||||
.map_err(|e| anyhow!("Failed to serialize dashboard {} to YAML: {}", dashboard.name, e))?;
|
||||
|
||||
let mut context_message = format!(
|
||||
"This conversation is continuing with context from the dashboard. Here is the relevant information:\n\nDashboard Definition:\n{}\n\n",
|
||||
dashboard_yaml
|
||||
);
|
||||
|
||||
if !metrics_vec.is_empty() {
|
||||
context_message.push_str("Referenced Metrics:\n");
|
||||
for metric in metrics_vec {
|
||||
match serde_json::from_value::<agents::tools::categories::file_tools::file_types::metric_yml::MetricYml>(metric.content) {
|
||||
Ok(metric_yml) => {
|
||||
match serde_yaml::to_string(&metric_yml) {
|
||||
Ok(yaml) => context_message.push_str(&format!("\n{}\n", yaml)),
|
||||
Err(e) => tracing::warn!("Failed to serialize metric {} to YAML: {}", metric.id, e),
|
||||
}
|
||||
}
|
||||
Err(e) => tracing::warn!("Failed to parse metric {} content: {}", metric.id, e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !datasets_vec.is_empty() {
|
||||
context_message.push_str("\nReferenced Datasets:\n");
|
||||
for dataset in datasets_vec {
|
||||
if let Some(yml_content) = dataset.yml_file {
|
||||
context_message.push_str(&format!("\n{}\n", yml_content));
|
||||
} else {
|
||||
tracing::warn!("Dataset {} has no YML content", dataset.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(vec![AgentMessage::Assistant {
|
||||
id: None,
|
||||
content: Some(context_message),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
progress: MessageProgress::Complete,
|
||||
initial: true,
|
||||
}])
|
||||
}
|
||||
}
|
|
@ -0,0 +1,126 @@
|
|||
use agents::{Agent, AgentMessage};
|
||||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
use database::{
|
||||
models::{Dataset, User},
|
||||
pool::get_pg_pool,
|
||||
schema::{datasets, metric_files},
|
||||
};
|
||||
use diesel::prelude::*;
|
||||
use diesel_async::RunQueryDsl;
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::ContextLoader;
|
||||
|
||||
pub struct MetricContextLoader {
|
||||
pub metric_id: Uuid,
|
||||
}
|
||||
|
||||
impl MetricContextLoader {
|
||||
pub fn new(metric_id: Uuid) -> Self {
|
||||
Self { metric_id }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ContextLoader for MetricContextLoader {
|
||||
async fn load_context(&self, user: &User, agent: &Arc<Agent>) -> Result<Vec<AgentMessage>> {
|
||||
let mut conn = get_pg_pool().get().await.map_err(|e| {
|
||||
anyhow!(
|
||||
"Failed to get database connection for metric context loading: {}",
|
||||
e
|
||||
)
|
||||
})?;
|
||||
|
||||
// First verify the metric exists and user has access
|
||||
let metric = metric_files::table
|
||||
.filter(metric_files::id.eq(self.metric_id))
|
||||
// .filter(metric_files::created_by.eq(&user.id))
|
||||
.first::<database::models::MetricFile>(&mut conn)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
anyhow!("Failed to load metric (id: {}). Either it doesn't exist or user {} doesn't have access: {}",
|
||||
self.metric_id, user.id, e)
|
||||
})?;
|
||||
|
||||
// Get the metric content as MetricYml
|
||||
let metric_yml: agents::tools::categories::file_tools::file_types::metric_yml::MetricYml =
|
||||
serde_json::from_value(metric.content.clone()).map_err(|e| {
|
||||
anyhow!(
|
||||
"Failed to parse metric content as YAML for metric {}: {}",
|
||||
metric.name,
|
||||
e
|
||||
)
|
||||
})?;
|
||||
|
||||
// Load all referenced datasets
|
||||
let dataset_ids = &metric_yml.dataset_ids;
|
||||
let mut datasets_vec = Vec::new();
|
||||
let mut failed_dataset_loads = Vec::new();
|
||||
|
||||
for dataset_id in dataset_ids {
|
||||
match datasets::table
|
||||
.filter(datasets::id.eq(dataset_id))
|
||||
.first::<Dataset>(&mut conn)
|
||||
.await
|
||||
{
|
||||
Ok(dataset) => datasets_vec.push(dataset),
|
||||
Err(e) => failed_dataset_loads.push((dataset_id, e.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
if !failed_dataset_loads.is_empty() {
|
||||
tracing::warn!(
|
||||
"Failed to load some datasets for metric {}: {:?}",
|
||||
metric.name,
|
||||
failed_dataset_loads
|
||||
);
|
||||
}
|
||||
|
||||
// Set agent state based on loaded assets
|
||||
agent
|
||||
.set_state_value(String::from("metrics_available"), Value::Bool(true))
|
||||
.await;
|
||||
|
||||
agent
|
||||
.set_state_value(String::from("files_available"), Value::Bool(true))
|
||||
.await;
|
||||
|
||||
if !datasets_vec.is_empty() {
|
||||
agent
|
||||
.set_state_value(String::from("data_context"), Value::Bool(true))
|
||||
.await;
|
||||
};
|
||||
|
||||
// Format the context message with metric and dataset information
|
||||
let metric_yaml = serde_yaml::to_string(&metric_yml)
|
||||
.map_err(|e| anyhow!("Failed to serialize metric {} to YAML: {}", metric.name, e))?;
|
||||
|
||||
let mut context_message = format!(
|
||||
"This conversation is continuing with context from the metric. Here is the relevant information:\n\nMetric Definition:\n{}\n\n",
|
||||
metric_yaml
|
||||
);
|
||||
|
||||
if !datasets_vec.is_empty() {
|
||||
context_message.push_str("Referenced Datasets:\n");
|
||||
for dataset in datasets_vec {
|
||||
if let Some(yml_content) = dataset.yml_file {
|
||||
context_message.push_str(&format!("\n{}\n", yml_content));
|
||||
} else {
|
||||
tracing::warn!("Dataset {} has no YML content", dataset.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(vec![AgentMessage::Assistant {
|
||||
id: None,
|
||||
content: Some(context_message),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
progress: litellm::MessageProgress::Complete,
|
||||
initial: true,
|
||||
}])
|
||||
}
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use database::models::User;
|
||||
use agents::AgentMessage;
|
||||
use std::sync::Arc;
|
||||
use agents::Agent;
|
||||
|
||||
pub mod chat_context;
|
||||
pub mod metric_context;
|
||||
pub mod dashboard_context;
|
||||
|
||||
pub use chat_context::ChatContextLoader;
|
||||
pub use metric_context::MetricContextLoader;
|
||||
pub use dashboard_context::DashboardContextLoader;
|
||||
|
||||
#[async_trait]
|
||||
pub trait ContextLoader {
|
||||
async fn load_context(&self, user: &User, agent: &Arc<Agent>) -> Result<Vec<AgentMessage>>;
|
||||
}
|
||||
|
||||
// Validate that only one context type is provided
|
||||
pub fn validate_context_request(
|
||||
chat_id: Option<uuid::Uuid>,
|
||||
metric_id: Option<uuid::Uuid>,
|
||||
dashboard_id: Option<uuid::Uuid>,
|
||||
) -> Result<()> {
|
||||
let context_count = [
|
||||
chat_id.is_some(),
|
||||
metric_id.is_some(),
|
||||
dashboard_id.is_some(),
|
||||
]
|
||||
.iter()
|
||||
.filter(|&&b| b)
|
||||
.count();
|
||||
|
||||
if context_count > 1 {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Only one context type (chat, metric, or dashboard) can be provided"
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -4,6 +4,7 @@ pub mod update_chats_handler;
|
|||
pub mod delete_chats_handler;
|
||||
pub mod types;
|
||||
pub mod streaming_parser;
|
||||
pub mod context_loaders;
|
||||
|
||||
pub use get_chat_handler::get_chat_handler;
|
||||
pub use post_chat_handler::post_chat_handler;
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use agents::{
|
||||
tools::file_tools::search_data_catalog::SearchDataCatalogOutput, AgentMessage, AgentThread,
|
||||
BusterSuperAgent,
|
||||
tools::file_tools::search_data_catalog::SearchDataCatalogOutput, AgentExt, AgentMessage,
|
||||
AgentThread, BusterSuperAgent,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
|
@ -23,7 +23,13 @@ use serde::{Deserialize, Serialize};
|
|||
use serde_json::Value;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::chats::streaming_parser::StreamingParser;
|
||||
use crate::chats::{
|
||||
context_loaders::{
|
||||
chat_context::ChatContextLoader, dashboard_context::DashboardContextLoader,
|
||||
metric_context::MetricContextLoader, validate_context_request, ContextLoader,
|
||||
},
|
||||
streaming_parser::StreamingParser,
|
||||
};
|
||||
use crate::messages::types::{ChatMessage, ChatUserMessage};
|
||||
|
||||
use super::types::ChatWithMessages;
|
||||
|
@ -44,6 +50,8 @@ pub struct ChatCreateNewChat {
|
|||
pub prompt: String,
|
||||
pub chat_id: Option<Uuid>,
|
||||
pub message_id: Option<Uuid>,
|
||||
pub metric_id: Option<Uuid>,
|
||||
pub dashboard_id: Option<Uuid>,
|
||||
}
|
||||
|
||||
pub async fn post_chat_handler(
|
||||
|
@ -51,6 +59,9 @@ pub async fn post_chat_handler(
|
|||
user: User,
|
||||
tx: Option<mpsc::Sender<Result<(BusterContainer, ThreadEvent)>>>,
|
||||
) -> Result<ChatWithMessages> {
|
||||
// Validate context request
|
||||
validate_context_request(request.chat_id, request.metric_id, request.dashboard_id)?;
|
||||
|
||||
let chat_id = request.chat_id.unwrap_or_else(Uuid::new_v4);
|
||||
let message_id = request.message_id.unwrap_or_else(Uuid::new_v4);
|
||||
|
||||
|
@ -120,13 +131,38 @@ pub async fn post_chat_handler(
|
|||
.execute(&mut conn)
|
||||
.await?;
|
||||
|
||||
// Initialize agent and process request
|
||||
// Initialize agent with context if provided
|
||||
let mut initial_messages = vec![];
|
||||
|
||||
// Initialize agent to add context
|
||||
let agent = BusterSuperAgent::new(user.id, chat_id).await?;
|
||||
let mut chat = AgentThread::new(
|
||||
Some(chat_id),
|
||||
user.id,
|
||||
vec![AgentMessage::user(request.prompt.clone())],
|
||||
);
|
||||
|
||||
// Load context if provided
|
||||
if let Some(existing_chat_id) = request.chat_id {
|
||||
let context_loader = ChatContextLoader::new(existing_chat_id);
|
||||
let context_messages = context_loader
|
||||
.load_context(&user, agent.get_agent())
|
||||
.await?;
|
||||
initial_messages.extend(context_messages);
|
||||
} else if let Some(metric_id) = request.metric_id {
|
||||
let context_loader = MetricContextLoader::new(metric_id);
|
||||
let context_messages = context_loader
|
||||
.load_context(&user, agent.get_agent())
|
||||
.await?;
|
||||
initial_messages.extend(context_messages);
|
||||
} else if let Some(dashboard_id) = request.dashboard_id {
|
||||
let context_loader = DashboardContextLoader::new(dashboard_id);
|
||||
let context_messages = context_loader
|
||||
.load_context(&user, agent.get_agent())
|
||||
.await?;
|
||||
initial_messages.extend(context_messages);
|
||||
}
|
||||
|
||||
// Add the new user message
|
||||
initial_messages.push(AgentMessage::user(request.prompt.clone()));
|
||||
|
||||
// Initialize the agent thread
|
||||
let mut chat = AgentThread::new(Some(chat_id), user.id, initial_messages);
|
||||
|
||||
let title_handle = {
|
||||
let tx = tx.clone();
|
||||
|
|
Loading…
Reference in New Issue