mirror of https://github.com/buster-so/buster.git
context loaders.
This commit is contained in:
parent
fee69dbab7
commit
99d9399f51
|
@ -36,6 +36,9 @@ let mut conn = get_pg_pool().get().await?;
|
||||||
- Avoid unnecessary `.clone()` calls
|
- Avoid unnecessary `.clone()` calls
|
||||||
- Use `&str` instead of `String` for function parameters when the string doesn't need to be owned
|
- Use `&str` instead of `String` for function parameters when the string doesn't need to be owned
|
||||||
|
|
||||||
|
### Importing packages/crates
|
||||||
|
- Please make the dependency as short as possible in the actual logic by importing the crate/package.
|
||||||
|
|
||||||
### Database Operations
|
### Database Operations
|
||||||
- Use Diesel for database migrations and query building
|
- Use Diesel for database migrations and query building
|
||||||
- Migrations are stored in the `migrations/` directory
|
- Migrations are stored in the `migrations/` directory
|
||||||
|
|
|
@ -115,7 +115,7 @@ Reference these example PRDs for guidance:
|
||||||
- [ ] All template sections completed
|
- [ ] All template sections completed
|
||||||
- [ ] Technical design is detailed and complete
|
- [ ] Technical design is detailed and complete
|
||||||
- [ ] File changes are documented
|
- [ ] File changes are documented
|
||||||
- [ ] Implementation phases are clear
|
- [ ] Implementation phases are clear (can be as many as you need.)
|
||||||
- [ ] Testing strategy is defined
|
- [ ] Testing strategy is defined
|
||||||
- [ ] Security considerations addressed
|
- [ ] Security considerations addressed
|
||||||
- [ ] Dependencies and Files listed
|
- [ ] Dependencies and Files listed
|
||||||
|
|
|
@ -19,6 +19,7 @@ futures = { workspace = true }
|
||||||
redis = { workspace = true }
|
redis = { workspace = true }
|
||||||
regex = { workspace = true }
|
regex = { workspace = true }
|
||||||
indexmap = { workspace = true }
|
indexmap = { workspace = true }
|
||||||
|
async-trait = { workspace = true }
|
||||||
|
|
||||||
# Local dependencies
|
# Local dependencies
|
||||||
database = { path = "../database" }
|
database = { path = "../database" }
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use database::{
|
||||||
|
models::User,
|
||||||
|
pool::get_pg_pool,
|
||||||
|
schema::{chats, messages},
|
||||||
|
};
|
||||||
|
use diesel::prelude::*;
|
||||||
|
use diesel_async::RunQueryDsl;
|
||||||
|
use agents::{Agent, AgentMessage};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use super::ContextLoader;
|
||||||
|
|
||||||
|
pub struct ChatContextLoader {
|
||||||
|
pub chat_id: Uuid,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChatContextLoader {
|
||||||
|
pub fn new(chat_id: Uuid) -> Self {
|
||||||
|
Self { chat_id }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ContextLoader for ChatContextLoader {
|
||||||
|
async fn load_context(&self, user: &User, agent: &Arc<Agent>) -> Result<Vec<AgentMessage>> {
|
||||||
|
let mut conn = get_pg_pool().get().await?;
|
||||||
|
|
||||||
|
// First verify the chat exists and user has access
|
||||||
|
let chat = chats::table
|
||||||
|
.filter(chats::id.eq(self.chat_id))
|
||||||
|
.filter(chats::created_by.eq(&user.id))
|
||||||
|
.first::<database::models::Chat>(&mut conn)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Get all messages for the chat
|
||||||
|
let messages = messages::table
|
||||||
|
.filter(messages::chat_id.eq(chat.id))
|
||||||
|
.order_by(messages::created_at.asc())
|
||||||
|
.load::<database::models::Message>(&mut conn)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Convert messages to AgentMessages
|
||||||
|
let mut agent_messages = Vec::new();
|
||||||
|
for message in messages {
|
||||||
|
// Add user message
|
||||||
|
agent_messages.push(AgentMessage::user(message.request));
|
||||||
|
|
||||||
|
// Add assistant messages from response
|
||||||
|
if let Ok(response_messages) = serde_json::from_value::<Vec<AgentMessage>>(message.response)
|
||||||
|
{
|
||||||
|
agent_messages.extend(response_messages);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(agent_messages)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,181 @@
|
||||||
|
use anyhow::{Result, anyhow};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use database::{
|
||||||
|
models::{User, Dataset, MetricFile},
|
||||||
|
pool::get_pg_pool,
|
||||||
|
schema::{dashboard_files, metric_files, datasets},
|
||||||
|
};
|
||||||
|
use diesel::prelude::*;
|
||||||
|
use diesel_async::RunQueryDsl;
|
||||||
|
use agents::{AgentMessage, Agent};
|
||||||
|
use litellm::MessageProgress;
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use uuid::Uuid;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
|
use super::ContextLoader;
|
||||||
|
|
||||||
|
pub struct DashboardContextLoader {
|
||||||
|
pub dashboard_id: Uuid,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DashboardContextLoader {
|
||||||
|
pub fn new(dashboard_id: Uuid) -> Self {
|
||||||
|
Self { dashboard_id }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ContextLoader for DashboardContextLoader {
|
||||||
|
async fn load_context(&self, user: &User, agent: &Arc<Agent>) -> Result<Vec<AgentMessage>> {
|
||||||
|
let mut conn = get_pg_pool().get().await.map_err(|e| {
|
||||||
|
anyhow!("Failed to get database connection for dashboard context loading: {}", e)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// First verify the dashboard exists and user has access
|
||||||
|
let dashboard = dashboard_files::table
|
||||||
|
.filter(dashboard_files::id.eq(self.dashboard_id))
|
||||||
|
// .filter(dashboard_files::created_by.eq(&user.id))
|
||||||
|
.first::<database::models::DashboardFile>(&mut conn)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
anyhow!("Failed to load dashboard (id: {}). Either it doesn't exist or user {} doesn't have access: {}",
|
||||||
|
self.dashboard_id, user.id, e)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Parse dashboard content to DashboardYml
|
||||||
|
let dashboard_yml: agents::tools::categories::file_tools::file_types::dashboard_yml::DashboardYml =
|
||||||
|
serde_json::from_value(dashboard.content.clone())
|
||||||
|
.map_err(|e| anyhow!("Failed to parse dashboard content as YAML for dashboard {}: {}", dashboard.name, e))?;
|
||||||
|
|
||||||
|
// Collect all metric IDs from the dashboard
|
||||||
|
let mut metric_ids = HashSet::new();
|
||||||
|
for row in &dashboard_yml.rows {
|
||||||
|
for item in &row.items {
|
||||||
|
metric_ids.insert(item.id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load all referenced metrics
|
||||||
|
let mut metrics_vec = Vec::new();
|
||||||
|
let mut all_dataset_ids = HashSet::new();
|
||||||
|
let mut failed_metric_loads = Vec::new();
|
||||||
|
|
||||||
|
for metric_id in metric_ids {
|
||||||
|
match metric_files::table
|
||||||
|
.filter(metric_files::id.eq(metric_id))
|
||||||
|
.first::<MetricFile>(&mut conn)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(metric) => {
|
||||||
|
// Parse metric content
|
||||||
|
match serde_json::from_value::<agents::tools::categories::file_tools::file_types::metric_yml::MetricYml>(metric.content.clone()) {
|
||||||
|
Ok(metric_yml) => {
|
||||||
|
all_dataset_ids.extend(metric_yml.dataset_ids);
|
||||||
|
metrics_vec.push(metric);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
failed_metric_loads.push((metric_id, format!("Failed to parse metric content: {}", e)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
failed_metric_loads.push((metric_id, format!("Failed to load metric: {}", e)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !failed_metric_loads.is_empty() {
|
||||||
|
tracing::warn!(
|
||||||
|
"Failed to load some metrics for dashboard {}: {:?}",
|
||||||
|
dashboard.name,
|
||||||
|
failed_metric_loads
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load all unique datasets
|
||||||
|
let mut datasets_vec = Vec::new();
|
||||||
|
let mut failed_dataset_loads = Vec::new();
|
||||||
|
|
||||||
|
for dataset_id in all_dataset_ids {
|
||||||
|
match datasets::table
|
||||||
|
.filter(datasets::id.eq(dataset_id))
|
||||||
|
.first::<Dataset>(&mut conn)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(dataset) => datasets_vec.push(dataset),
|
||||||
|
Err(e) => failed_dataset_loads.push((dataset_id, e.to_string())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !failed_dataset_loads.is_empty() {
|
||||||
|
tracing::warn!(
|
||||||
|
"Failed to load some datasets for dashboard {}: {:?}",
|
||||||
|
dashboard.name,
|
||||||
|
failed_dataset_loads
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set agent state based on loaded assets
|
||||||
|
agent.set_state_value(String::from("dashboards_available"), Value::Bool(true))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
agent.set_state_value(String::from("files_available"), Value::Bool(true))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
if !metrics_vec.is_empty() {
|
||||||
|
agent.set_state_value(String::from("metrics_available"), Value::Bool(true))
|
||||||
|
.await;
|
||||||
|
};
|
||||||
|
|
||||||
|
if !datasets_vec.is_empty() {
|
||||||
|
agent.set_state_value(String::from("data_context"), Value::Bool(true))
|
||||||
|
.await;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Format the context message with dashboard, metrics, and dataset information
|
||||||
|
let dashboard_yaml = serde_yaml::to_string(&dashboard_yml)
|
||||||
|
.map_err(|e| anyhow!("Failed to serialize dashboard {} to YAML: {}", dashboard.name, e))?;
|
||||||
|
|
||||||
|
let mut context_message = format!(
|
||||||
|
"This conversation is continuing with context from the dashboard. Here is the relevant information:\n\nDashboard Definition:\n{}\n\n",
|
||||||
|
dashboard_yaml
|
||||||
|
);
|
||||||
|
|
||||||
|
if !metrics_vec.is_empty() {
|
||||||
|
context_message.push_str("Referenced Metrics:\n");
|
||||||
|
for metric in metrics_vec {
|
||||||
|
match serde_json::from_value::<agents::tools::categories::file_tools::file_types::metric_yml::MetricYml>(metric.content) {
|
||||||
|
Ok(metric_yml) => {
|
||||||
|
match serde_yaml::to_string(&metric_yml) {
|
||||||
|
Ok(yaml) => context_message.push_str(&format!("\n{}\n", yaml)),
|
||||||
|
Err(e) => tracing::warn!("Failed to serialize metric {} to YAML: {}", metric.id, e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => tracing::warn!("Failed to parse metric {} content: {}", metric.id, e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !datasets_vec.is_empty() {
|
||||||
|
context_message.push_str("\nReferenced Datasets:\n");
|
||||||
|
for dataset in datasets_vec {
|
||||||
|
if let Some(yml_content) = dataset.yml_file {
|
||||||
|
context_message.push_str(&format!("\n{}\n", yml_content));
|
||||||
|
} else {
|
||||||
|
tracing::warn!("Dataset {} has no YML content", dataset.id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(vec![AgentMessage::Assistant {
|
||||||
|
id: None,
|
||||||
|
content: Some(context_message),
|
||||||
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
progress: MessageProgress::Complete,
|
||||||
|
initial: true,
|
||||||
|
}])
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,126 @@
|
||||||
|
use agents::{Agent, AgentMessage};
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use database::{
|
||||||
|
models::{Dataset, User},
|
||||||
|
pool::get_pg_pool,
|
||||||
|
schema::{datasets, metric_files},
|
||||||
|
};
|
||||||
|
use diesel::prelude::*;
|
||||||
|
use diesel_async::RunQueryDsl;
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use super::ContextLoader;
|
||||||
|
|
||||||
|
pub struct MetricContextLoader {
|
||||||
|
pub metric_id: Uuid,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MetricContextLoader {
|
||||||
|
pub fn new(metric_id: Uuid) -> Self {
|
||||||
|
Self { metric_id }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ContextLoader for MetricContextLoader {
|
||||||
|
async fn load_context(&self, user: &User, agent: &Arc<Agent>) -> Result<Vec<AgentMessage>> {
|
||||||
|
let mut conn = get_pg_pool().get().await.map_err(|e| {
|
||||||
|
anyhow!(
|
||||||
|
"Failed to get database connection for metric context loading: {}",
|
||||||
|
e
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// First verify the metric exists and user has access
|
||||||
|
let metric = metric_files::table
|
||||||
|
.filter(metric_files::id.eq(self.metric_id))
|
||||||
|
// .filter(metric_files::created_by.eq(&user.id))
|
||||||
|
.first::<database::models::MetricFile>(&mut conn)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
anyhow!("Failed to load metric (id: {}). Either it doesn't exist or user {} doesn't have access: {}",
|
||||||
|
self.metric_id, user.id, e)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Get the metric content as MetricYml
|
||||||
|
let metric_yml: agents::tools::categories::file_tools::file_types::metric_yml::MetricYml =
|
||||||
|
serde_json::from_value(metric.content.clone()).map_err(|e| {
|
||||||
|
anyhow!(
|
||||||
|
"Failed to parse metric content as YAML for metric {}: {}",
|
||||||
|
metric.name,
|
||||||
|
e
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Load all referenced datasets
|
||||||
|
let dataset_ids = &metric_yml.dataset_ids;
|
||||||
|
let mut datasets_vec = Vec::new();
|
||||||
|
let mut failed_dataset_loads = Vec::new();
|
||||||
|
|
||||||
|
for dataset_id in dataset_ids {
|
||||||
|
match datasets::table
|
||||||
|
.filter(datasets::id.eq(dataset_id))
|
||||||
|
.first::<Dataset>(&mut conn)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(dataset) => datasets_vec.push(dataset),
|
||||||
|
Err(e) => failed_dataset_loads.push((dataset_id, e.to_string())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !failed_dataset_loads.is_empty() {
|
||||||
|
tracing::warn!(
|
||||||
|
"Failed to load some datasets for metric {}: {:?}",
|
||||||
|
metric.name,
|
||||||
|
failed_dataset_loads
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set agent state based on loaded assets
|
||||||
|
agent
|
||||||
|
.set_state_value(String::from("metrics_available"), Value::Bool(true))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
agent
|
||||||
|
.set_state_value(String::from("files_available"), Value::Bool(true))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
if !datasets_vec.is_empty() {
|
||||||
|
agent
|
||||||
|
.set_state_value(String::from("data_context"), Value::Bool(true))
|
||||||
|
.await;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Format the context message with metric and dataset information
|
||||||
|
let metric_yaml = serde_yaml::to_string(&metric_yml)
|
||||||
|
.map_err(|e| anyhow!("Failed to serialize metric {} to YAML: {}", metric.name, e))?;
|
||||||
|
|
||||||
|
let mut context_message = format!(
|
||||||
|
"This conversation is continuing with context from the metric. Here is the relevant information:\n\nMetric Definition:\n{}\n\n",
|
||||||
|
metric_yaml
|
||||||
|
);
|
||||||
|
|
||||||
|
if !datasets_vec.is_empty() {
|
||||||
|
context_message.push_str("Referenced Datasets:\n");
|
||||||
|
for dataset in datasets_vec {
|
||||||
|
if let Some(yml_content) = dataset.yml_file {
|
||||||
|
context_message.push_str(&format!("\n{}\n", yml_content));
|
||||||
|
} else {
|
||||||
|
tracing::warn!("Dataset {} has no YML content", dataset.id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(vec![AgentMessage::Assistant {
|
||||||
|
id: None,
|
||||||
|
content: Some(context_message),
|
||||||
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
progress: litellm::MessageProgress::Complete,
|
||||||
|
initial: true,
|
||||||
|
}])
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,43 @@
|
||||||
|
use anyhow::Result;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use database::models::User;
|
||||||
|
use agents::AgentMessage;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use agents::Agent;
|
||||||
|
|
||||||
|
pub mod chat_context;
|
||||||
|
pub mod metric_context;
|
||||||
|
pub mod dashboard_context;
|
||||||
|
|
||||||
|
pub use chat_context::ChatContextLoader;
|
||||||
|
pub use metric_context::MetricContextLoader;
|
||||||
|
pub use dashboard_context::DashboardContextLoader;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait ContextLoader {
|
||||||
|
async fn load_context(&self, user: &User, agent: &Arc<Agent>) -> Result<Vec<AgentMessage>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that only one context type is provided
|
||||||
|
pub fn validate_context_request(
|
||||||
|
chat_id: Option<uuid::Uuid>,
|
||||||
|
metric_id: Option<uuid::Uuid>,
|
||||||
|
dashboard_id: Option<uuid::Uuid>,
|
||||||
|
) -> Result<()> {
|
||||||
|
let context_count = [
|
||||||
|
chat_id.is_some(),
|
||||||
|
metric_id.is_some(),
|
||||||
|
dashboard_id.is_some(),
|
||||||
|
]
|
||||||
|
.iter()
|
||||||
|
.filter(|&&b| b)
|
||||||
|
.count();
|
||||||
|
|
||||||
|
if context_count > 1 {
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"Only one context type (chat, metric, or dashboard) can be provided"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
|
@ -4,6 +4,7 @@ pub mod update_chats_handler;
|
||||||
pub mod delete_chats_handler;
|
pub mod delete_chats_handler;
|
||||||
pub mod types;
|
pub mod types;
|
||||||
pub mod streaming_parser;
|
pub mod streaming_parser;
|
||||||
|
pub mod context_loaders;
|
||||||
|
|
||||||
pub use get_chat_handler::get_chat_handler;
|
pub use get_chat_handler::get_chat_handler;
|
||||||
pub use post_chat_handler::post_chat_handler;
|
pub use post_chat_handler::post_chat_handler;
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use agents::{
|
use agents::{
|
||||||
tools::file_tools::search_data_catalog::SearchDataCatalogOutput, AgentMessage, AgentThread,
|
tools::file_tools::search_data_catalog::SearchDataCatalogOutput, AgentExt, AgentMessage,
|
||||||
BusterSuperAgent,
|
AgentThread, BusterSuperAgent,
|
||||||
};
|
};
|
||||||
|
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
|
@ -23,7 +23,13 @@ use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::chats::streaming_parser::StreamingParser;
|
use crate::chats::{
|
||||||
|
context_loaders::{
|
||||||
|
chat_context::ChatContextLoader, dashboard_context::DashboardContextLoader,
|
||||||
|
metric_context::MetricContextLoader, validate_context_request, ContextLoader,
|
||||||
|
},
|
||||||
|
streaming_parser::StreamingParser,
|
||||||
|
};
|
||||||
use crate::messages::types::{ChatMessage, ChatUserMessage};
|
use crate::messages::types::{ChatMessage, ChatUserMessage};
|
||||||
|
|
||||||
use super::types::ChatWithMessages;
|
use super::types::ChatWithMessages;
|
||||||
|
@ -44,6 +50,8 @@ pub struct ChatCreateNewChat {
|
||||||
pub prompt: String,
|
pub prompt: String,
|
||||||
pub chat_id: Option<Uuid>,
|
pub chat_id: Option<Uuid>,
|
||||||
pub message_id: Option<Uuid>,
|
pub message_id: Option<Uuid>,
|
||||||
|
pub metric_id: Option<Uuid>,
|
||||||
|
pub dashboard_id: Option<Uuid>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn post_chat_handler(
|
pub async fn post_chat_handler(
|
||||||
|
@ -51,6 +59,9 @@ pub async fn post_chat_handler(
|
||||||
user: User,
|
user: User,
|
||||||
tx: Option<mpsc::Sender<Result<(BusterContainer, ThreadEvent)>>>,
|
tx: Option<mpsc::Sender<Result<(BusterContainer, ThreadEvent)>>>,
|
||||||
) -> Result<ChatWithMessages> {
|
) -> Result<ChatWithMessages> {
|
||||||
|
// Validate context request
|
||||||
|
validate_context_request(request.chat_id, request.metric_id, request.dashboard_id)?;
|
||||||
|
|
||||||
let chat_id = request.chat_id.unwrap_or_else(Uuid::new_v4);
|
let chat_id = request.chat_id.unwrap_or_else(Uuid::new_v4);
|
||||||
let message_id = request.message_id.unwrap_or_else(Uuid::new_v4);
|
let message_id = request.message_id.unwrap_or_else(Uuid::new_v4);
|
||||||
|
|
||||||
|
@ -120,13 +131,38 @@ pub async fn post_chat_handler(
|
||||||
.execute(&mut conn)
|
.execute(&mut conn)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Initialize agent and process request
|
// Initialize agent with context if provided
|
||||||
|
let mut initial_messages = vec![];
|
||||||
|
|
||||||
|
// Initialize agent to add context
|
||||||
let agent = BusterSuperAgent::new(user.id, chat_id).await?;
|
let agent = BusterSuperAgent::new(user.id, chat_id).await?;
|
||||||
let mut chat = AgentThread::new(
|
|
||||||
Some(chat_id),
|
// Load context if provided
|
||||||
user.id,
|
if let Some(existing_chat_id) = request.chat_id {
|
||||||
vec![AgentMessage::user(request.prompt.clone())],
|
let context_loader = ChatContextLoader::new(existing_chat_id);
|
||||||
);
|
let context_messages = context_loader
|
||||||
|
.load_context(&user, agent.get_agent())
|
||||||
|
.await?;
|
||||||
|
initial_messages.extend(context_messages);
|
||||||
|
} else if let Some(metric_id) = request.metric_id {
|
||||||
|
let context_loader = MetricContextLoader::new(metric_id);
|
||||||
|
let context_messages = context_loader
|
||||||
|
.load_context(&user, agent.get_agent())
|
||||||
|
.await?;
|
||||||
|
initial_messages.extend(context_messages);
|
||||||
|
} else if let Some(dashboard_id) = request.dashboard_id {
|
||||||
|
let context_loader = DashboardContextLoader::new(dashboard_id);
|
||||||
|
let context_messages = context_loader
|
||||||
|
.load_context(&user, agent.get_agent())
|
||||||
|
.await?;
|
||||||
|
initial_messages.extend(context_messages);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the new user message
|
||||||
|
initial_messages.push(AgentMessage::user(request.prompt.clone()));
|
||||||
|
|
||||||
|
// Initialize the agent thread
|
||||||
|
let mut chat = AgentThread::new(Some(chat_id), user.id, initial_messages);
|
||||||
|
|
||||||
let title_handle = {
|
let title_handle = {
|
||||||
let tx = tx.clone();
|
let tx = tx.clone();
|
||||||
|
|
Loading…
Reference in New Issue