From 443257408612509cb8ef511a056f35328f7e2f6e Mon Sep 17 00:00:00 2001 From: dal Date: Wed, 12 Mar 2025 08:27:59 -0600 Subject: [PATCH] better use of user throughout agents and tools --- api/libs/agents/Cargo.toml | 1 + api/libs/agents/src/agent.rs | 95 ++++++++----- .../agents/src/agents/buster_super_agent.rs | 5 +- .../tools/categories/agents_as_tools/mod.rs | 1 - .../file_tools/create_dashboard_files.rs | 3 +- .../file_tools/create_metric_files.rs | 3 +- .../file_tools/modify_dashboard_files.rs | 129 +----------------- .../file_tools/modify_metric_files.rs | 122 +---------------- .../file_tools/search_data_catalog.rs | 3 +- api/libs/agents/src/tools/categories/mod.rs | 1 - .../categories/planning_tools/create_plan.rs | 3 +- api/libs/agents/src/tools/executor.rs | 12 +- api/libs/agents/src/tools/mod.rs | 3 +- .../handlers/src/chats/post_chat_handler.rs | 32 +++-- 14 files changed, 109 insertions(+), 304 deletions(-) delete mode 100644 api/libs/agents/src/tools/categories/agents_as_tools/mod.rs diff --git a/api/libs/agents/Cargo.toml b/api/libs/agents/Cargo.toml index b6e11114b..c987ff38f 100644 --- a/api/libs/agents/Cargo.toml +++ b/api/libs/agents/Cargo.toml @@ -15,6 +15,7 @@ uuid = { workspace = true } litellm = { path = "../litellm" } database = { path = "../database" } query_engine = { path = "../query_engine" } +middleware = { path = "../middleware" } serde_json = { workspace = true } futures = { workspace = true } futures-util = { workspace = true } diff --git a/api/libs/agents/src/agent.rs b/api/libs/agents/src/agent.rs index 96b8fb7de..985869095 100644 --- a/api/libs/agents/src/agent.rs +++ b/api/libs/agents/src/agent.rs @@ -4,11 +4,12 @@ use litellm::{ AgentMessage, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient, MessageProgress, Metadata, Tool, ToolCall, ToolChoice, }; +use middleware::AuthenticatedUser; use serde_json::Value; +use std::time::{Duration, Instant}; use std::{collections::HashMap, env, sync::Arc}; use tokio::sync::{broadcast, RwLock}; use uuid::Uuid; -use std::time::{Duration, Instant}; use crate::models::AgentThread; @@ -34,7 +35,6 @@ struct MessageBuffer { first_message_sent: bool, } - impl MessageBuffer { fn new() -> Self { Self { @@ -80,7 +80,11 @@ impl MessageBuffer { // Create and send the message let message = AgentMessage::assistant( self.message_id.clone(), - if self.content.is_empty() { None } else { Some(self.content.clone()) }, + if self.content.is_empty() { + None + } else { + Some(self.content.clone()) + }, tool_calls, MessageProgress::InProgress, Some(!self.first_message_sent), @@ -88,7 +92,7 @@ impl MessageBuffer { ); agent.get_stream_sender().await.send(Ok(message))?; - + // Update state self.first_message_sent = true; self.last_flush = Instant::now(); @@ -98,8 +102,6 @@ impl MessageBuffer { } } - - #[derive(Clone)] /// The Agent struct is responsible for managing conversations with the LLM /// and coordinating tool executions. It maintains a registry of available tools @@ -122,7 +124,7 @@ pub struct Agent { /// Sender for streaming messages from this agent and sub-agents stream_tx: Arc>>>, /// The user ID for the current thread - user_id: Uuid, + user: AuthenticatedUser, /// The session ID for the current thread session_id: Uuid, /// Agent name @@ -136,7 +138,7 @@ impl Agent { pub fn new( model: String, tools: HashMap + Send + Sync>>, - user_id: Uuid, + user: AuthenticatedUser, session_id: Uuid, name: String, ) -> Self { @@ -155,7 +157,7 @@ impl Agent { state: Arc::new(RwLock::new(HashMap::new())), current_thread: Arc::new(RwLock::new(None)), stream_tx: Arc::new(RwLock::new(Some(tx))), - user_id, + user, session_id, shutdown_tx: Arc::new(RwLock::new(shutdown_tx)), name, @@ -176,7 +178,7 @@ impl Agent { state: Arc::clone(&existing_agent.state), current_thread: Arc::clone(&existing_agent.current_thread), stream_tx: Arc::clone(&existing_agent.stream_tx), - user_id: existing_agent.user_id, + user: existing_agent.user.clone(), session_id: existing_agent.session_id, shutdown_tx: Arc::clone(&existing_agent.shutdown_tx), name, @@ -241,7 +243,11 @@ impl Agent { } pub fn get_user_id(&self) -> Uuid { - self.user_id + self.user.id + } + + pub fn get_user(&self) -> AuthenticatedUser { + self.user.clone() } pub fn get_session_id(&self) -> Uuid { @@ -450,7 +456,8 @@ impl Agent { if let Some(tool_calls) = &delta.tool_calls { for tool_call in tool_calls { let id = tool_call.id.clone().unwrap_or_else(|| { - buffer.tool_calls + buffer + .tool_calls .keys() .next() .map(|s| s.clone()) @@ -458,7 +465,8 @@ impl Agent { }); // Get or create the pending tool call - let pending_call = buffer.tool_calls + let pending_call = buffer + .tool_calls .entry(id.clone()) .or_insert_with(PendingToolCall::new); @@ -484,7 +492,8 @@ impl Agent { // Create and send the final message let final_tool_calls: Option> = if !buffer.tool_calls.is_empty() { Some( - buffer.tool_calls + buffer + .tool_calls .values() .map(|p| p.clone().into_tool_call()) .collect(), @@ -495,7 +504,11 @@ impl Agent { let final_message = AgentMessage::assistant( buffer.message_id, - if buffer.content.is_empty() { None } else { Some(buffer.content) }, + if buffer.content.is_empty() { + None + } else { + Some(buffer.content) + }, final_tool_calls.clone(), MessageProgress::Complete, Some(false), @@ -527,7 +540,7 @@ impl Agent { for tool_call in tool_calls { if let Some(tool) = self.tools.read().await.get(&tool_call.function.name) { let params: Value = serde_json::from_str(&tool_call.function.arguments)?; - let result = tool.execute(params, tool_call.id.clone()).await?; + let result = tool.execute(params, tool_call.id.clone(), self.get_user()).await?; let result_str = serde_json::to_string(&result)?; let tool_message = AgentMessage::tool( None, @@ -671,12 +684,32 @@ mod tests { use super::*; use crate::tools::ToolExecutor; use async_trait::async_trait; + use chrono::{Utc}; use litellm::MessageProgress; use serde_json::{json, Value}; use uuid::Uuid; + use middleware::types::AuthenticatedUser; fn setup() { dotenv::dotenv().ok(); + std::env::set_var("LLM_API_KEY", "test_key"); + std::env::set_var("LLM_BASE_URL", "http://localhost:8000"); + } + + // Create a mock AuthenticatedUser for testing + fn create_test_user() -> AuthenticatedUser { + AuthenticatedUser { + id: Uuid::new_v4(), + email: "test@example.com".to_string(), + name: Some("Test User".to_string()), + config: json!({}), + created_at: Utc::now(), + updated_at: Utc::now(), + attributes: json!({}), + avatar_url: None, + organizations: vec![], + teams: vec![], + } } struct WeatherTool { @@ -696,13 +729,8 @@ mod tests { tool_id: String, progress: MessageProgress, ) -> Result<()> { - let message = AgentMessage::tool( - None, - content, - tool_id, - Some(self.get_name()), - progress, - ); + let message = + AgentMessage::tool(None, content, tool_id, Some(self.get_name()), progress); self.agent.get_stream_sender().await.send(Ok(message))?; Ok(()) } @@ -713,7 +741,12 @@ mod tests { type Output = Value; type Params = Value; - async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result { + async fn execute( + &self, + params: Self::Params, + tool_call_id: String, + user: AuthenticatedUser, + ) -> Result { self.send_progress( "Fetching weather data...".to_string(), "123".to_string(), @@ -778,14 +811,14 @@ mod tests { let agent = Agent::new( "o1".to_string(), HashMap::new(), - Uuid::new_v4(), + create_test_user(), Uuid::new_v4(), "test_agent".to_string(), ); let thread = AgentThread::new( None, - Uuid::new_v4(), + create_test_user().id, vec![AgentMessage::user("Hello, world!".to_string())], ); @@ -803,7 +836,7 @@ mod tests { let mut agent = Agent::new( "o1".to_string(), HashMap::new(), - Uuid::new_v4(), + create_test_user(), Uuid::new_v4(), "test_agent".to_string(), ); @@ -816,7 +849,7 @@ mod tests { let thread = AgentThread::new( None, - Uuid::new_v4(), + create_test_user().id, vec![AgentMessage::user( "What is the weather in vineyard ut?".to_string(), )], @@ -836,7 +869,7 @@ mod tests { let mut agent = Agent::new( "o1".to_string(), HashMap::new(), - Uuid::new_v4(), + create_test_user(), Uuid::new_v4(), "test_agent".to_string(), ); @@ -847,7 +880,7 @@ mod tests { let thread = AgentThread::new( None, - Uuid::new_v4(), + create_test_user().id, vec![AgentMessage::user( "What is the weather in vineyard ut and san francisco?".to_string(), )], @@ -867,7 +900,7 @@ mod tests { let agent = Agent::new( "o1".to_string(), HashMap::new(), - Uuid::new_v4(), + create_test_user(), Uuid::new_v4(), "test_agent".to_string(), ); diff --git a/api/libs/agents/src/agents/buster_super_agent.rs b/api/libs/agents/src/agents/buster_super_agent.rs index 25059268a..3a5cfe3d7 100644 --- a/api/libs/agents/src/agents/buster_super_agent.rs +++ b/api/libs/agents/src/agents/buster_super_agent.rs @@ -1,4 +1,5 @@ use anyhow::Result; +use middleware::AuthenticatedUser; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; @@ -97,12 +98,12 @@ impl BusterSuperAgent { Ok(()) } - pub async fn new(user_id: Uuid, session_id: Uuid) -> Result { + pub async fn new(user: AuthenticatedUser, session_id: Uuid) -> Result { // Create agent with empty tools map let agent = Arc::new(Agent::new( "o3-mini".to_string(), HashMap::new(), - user_id, + user, session_id, "buster_super_agent".to_string(), )); diff --git a/api/libs/agents/src/tools/categories/agents_as_tools/mod.rs b/api/libs/agents/src/tools/categories/agents_as_tools/mod.rs deleted file mode 100644 index 8b1378917..000000000 --- a/api/libs/agents/src/tools/categories/agents_as_tools/mod.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/api/libs/agents/src/tools/categories/file_tools/create_dashboard_files.rs b/api/libs/agents/src/tools/categories/file_tools/create_dashboard_files.rs index 96a81717e..19af894db 100644 --- a/api/libs/agents/src/tools/categories/file_tools/create_dashboard_files.rs +++ b/api/libs/agents/src/tools/categories/file_tools/create_dashboard_files.rs @@ -11,6 +11,7 @@ use serde::{Deserialize, Serialize}; use serde_json::{self, json, Value}; use tracing::debug; use uuid::Uuid; +use middleware::AuthenticatedUser; use crate::{ agent::Agent, @@ -131,7 +132,7 @@ impl ToolExecutor for CreateDashboardFilesTool { } } - async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result { + async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result { let start_time = Instant::now(); let files = params.files; diff --git a/api/libs/agents/src/tools/categories/file_tools/create_metric_files.rs b/api/libs/agents/src/tools/categories/file_tools/create_metric_files.rs index a7ca0fec2..f9f6a1951 100644 --- a/api/libs/agents/src/tools/categories/file_tools/create_metric_files.rs +++ b/api/libs/agents/src/tools/categories/file_tools/create_metric_files.rs @@ -12,6 +12,7 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use tracing::debug; use uuid::Uuid; +use middleware::AuthenticatedUser; use crate::{ agent::Agent, @@ -81,7 +82,7 @@ impl ToolExecutor for CreateMetricFilesTool { } } - async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result { + async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result { let start_time = Instant::now(); let files = params.files; diff --git a/api/libs/agents/src/tools/categories/file_tools/modify_dashboard_files.rs b/api/libs/agents/src/tools/categories/file_tools/modify_dashboard_files.rs index ce45811de..1a989bc2e 100644 --- a/api/libs/agents/src/tools/categories/file_tools/modify_dashboard_files.rs +++ b/api/libs/agents/src/tools/categories/file_tools/modify_dashboard_files.rs @@ -12,6 +12,7 @@ use indexmap::IndexMap; use query_engine::data_types::DataType; use serde_json::Value; use tracing::{debug, error, info}; +use middleware::AuthenticatedUser; use super::{ common::{ @@ -67,7 +68,7 @@ impl ToolExecutor for ModifyDashboardFilesTool { } } - async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result { + async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result { let start_time = Instant::now(); debug!("Starting file modification execution"); @@ -215,128 +216,4 @@ impl ToolExecutor for ModifyDashboardFilesTool { "description": "Makes content-based modifications to one or more existing dashboard YAML files in a single call. Each modification specifies the exact content to replace and its replacement. If you need to update chart config or other sections within a file, use this. Guard Rail: Do not execute any file creation or modifications until a thorough data catalog search has been completed and reviewed." }) } -} - -#[cfg(test)] -mod tests { - use std::collections::HashMap; - - use super::*; - use crate::tools::categories::file_tools::common::{ - apply_modifications_to_content, Modification, ModificationResult, - }; - use chrono::Utc; - use serde_json::json; - use uuid::Uuid; - - #[test] - fn test_apply_modifications_to_content() { - let original_content = - "name: test_dashboard\ntype: dashboard\ndescription: A test dashboard"; - - // Test single modification - let mods1 = vec![Modification { - content_to_replace: "type: dashboard".to_string(), - new_content: "type: custom_dashboard".to_string(), - }]; - let result1 = apply_modifications_to_content(original_content, &mods1, "test.yml").unwrap(); - assert_eq!( - result1, - "name: test_dashboard\ntype: custom_dashboard\ndescription: A test dashboard" - ); - - // Test multiple non-overlapping modifications - let mods2 = vec![ - Modification { - content_to_replace: "test_dashboard".to_string(), - new_content: "new_dashboard".to_string(), - }, - Modification { - content_to_replace: "A test dashboard".to_string(), - new_content: "An updated dashboard".to_string(), - }, - ]; - let result2 = apply_modifications_to_content(original_content, &mods2, "test.yml").unwrap(); - assert_eq!( - result2, - "name: new_dashboard\ntype: dashboard\ndescription: An updated dashboard" - ); - - // Test content not found - let mods3 = vec![Modification { - content_to_replace: "nonexistent content".to_string(), - new_content: "new content".to_string(), - }]; - let result3 = apply_modifications_to_content(original_content, &mods3, "test.yml"); - assert!(result3.is_err()); - assert!(result3 - .unwrap_err() - .to_string() - .contains("Content to replace not found")); - } - - #[test] - fn test_modification_result_tracking() { - let result = ModificationResult { - file_id: Uuid::new_v4(), - file_name: "test.yml".to_string(), - success: true, - error: None, - modification_type: "content".to_string(), - timestamp: Utc::now(), - duration: 0, - }; - - assert!(result.success); - assert!(result.error.is_none()); - - let error_result = ModificationResult { - success: false, - error: Some("Failed to parse YAML".to_string()), - ..result - }; - assert!(!error_result.success); - assert!(error_result.error.is_some()); - assert_eq!(error_result.error.unwrap(), "Failed to parse YAML"); - } - - #[test] - fn test_tool_parameter_validation() { - let tool = ModifyDashboardFilesTool { - agent: Arc::new(Agent::new( - "o3-mini".to_string(), - HashMap::new(), - Uuid::new_v4(), - Uuid::new_v4(), - "test_agent".to_string(), - )), - }; - - // Test valid parameters - let valid_params = json!({ - "files": [{ - "id": Uuid::new_v4().to_string(), - "file_name": "test.yml", - "modifications": [{ - "content_to_replace": "old content", - "new_content": "new content" - }] - }] - }); - let valid_args = serde_json::to_string(&valid_params).unwrap(); - let result = serde_json::from_str::(&valid_args); - assert!(result.is_ok()); - - // Test missing required fields - let missing_fields_params = json!({ - "files": [{ - "id": Uuid::new_v4().to_string(), - "file_name": "test.yml" - // missing modifications - }] - }); - let missing_args = serde_json::to_string(&missing_fields_params).unwrap(); - let result = serde_json::from_str::(&missing_args); - assert!(result.is_err()); - } -} +} \ No newline at end of file diff --git a/api/libs/agents/src/tools/categories/file_tools/modify_metric_files.rs b/api/libs/agents/src/tools/categories/file_tools/modify_metric_files.rs index 801a3997f..20a832c4f 100644 --- a/api/libs/agents/src/tools/categories/file_tools/modify_metric_files.rs +++ b/api/libs/agents/src/tools/categories/file_tools/modify_metric_files.rs @@ -8,6 +8,7 @@ use database::{enums::Verification, models::MetricFile, pool::get_pg_pool, schem use diesel::{upsert::excluded, ExpressionMethods, QueryDsl}; use diesel_async::RunQueryDsl; use indexmap::IndexMap; +use middleware::AuthenticatedUser; use query_engine::data_types::DataType; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -65,7 +66,7 @@ impl ToolExecutor for ModifyMetricFilesTool { } } - async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result { + async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result { let start_time = Instant::now(); debug!("Starting file modification execution"); @@ -274,122 +275,3 @@ impl ToolExecutor for ModifyMetricFilesTool { }) } } - -#[cfg(test)] -mod tests { - use std::collections::HashMap; - - use super::*; - use chrono::Utc; - use serde_json::json; - - #[test] - fn test_apply_modifications_to_content() { - let original_content = "name: test_metric\ntype: counter\ndescription: A test metric"; - - // Test single modification - let mods1 = vec![Modification { - content_to_replace: "type: counter".to_string(), - new_content: "type: gauge".to_string(), - }]; - let result1 = apply_modifications_to_content(original_content, &mods1, "test.yml").unwrap(); - assert_eq!( - result1, - "name: test_metric\ntype: gauge\ndescription: A test metric" - ); - - // Test multiple non-overlapping modifications - let mods2 = vec![ - Modification { - content_to_replace: "test_metric".to_string(), - new_content: "new_metric".to_string(), - }, - Modification { - content_to_replace: "A test metric".to_string(), - new_content: "An updated metric".to_string(), - }, - ]; - let result2 = apply_modifications_to_content(original_content, &mods2, "test.yml").unwrap(); - assert_eq!( - result2, - "name: new_metric\ntype: counter\ndescription: An updated metric" - ); - - // Test content not found - let mods3 = vec![Modification { - content_to_replace: "nonexistent content".to_string(), - new_content: "new content".to_string(), - }]; - let result3 = apply_modifications_to_content(original_content, &mods3, "test.yml"); - assert!(result3.is_err()); - assert!(result3 - .unwrap_err() - .to_string() - .contains("Content to replace not found")); - } - - #[test] - fn test_modification_result_tracking() { - let result = ModificationResult { - file_id: Uuid::new_v4(), - file_name: "test.yml".to_string(), - success: true, - error: None, - modification_type: "content".to_string(), - timestamp: Utc::now(), - duration: 0, - }; - - assert!(result.success); - assert!(result.error.is_none()); - - let error_result = ModificationResult { - success: false, - error: Some("Failed to parse YAML".to_string()), - ..result - }; - assert!(!error_result.success); - assert!(error_result.error.is_some()); - assert_eq!(error_result.error.unwrap(), "Failed to parse YAML"); - } - - #[test] - fn test_tool_parameter_validation() { - let tool = ModifyMetricFilesTool { - agent: Arc::new(Agent::new( - "o3-mini".to_string(), - HashMap::new(), - Uuid::new_v4(), - Uuid::new_v4(), - "test_agent".to_string(), - )), - }; - - // Test valid parameters - let valid_params = json!({ - "files": [{ - "id": Uuid::new_v4().to_string(), - "file_name": "test.yml", - "modifications": [{ - "content_to_replace": "old content", - "new_content": "new content" - }] - }] - }); - let valid_args = serde_json::to_string(&valid_params).unwrap(); - let result = serde_json::from_str::(&valid_args); - assert!(result.is_ok()); - - // Test missing required fields - let missing_fields_params = json!({ - "files": [{ - "id": Uuid::new_v4().to_string(), - "file_name": "test.yml" - // missing modifications - }] - }); - let missing_args = serde_json::to_string(&missing_fields_params).unwrap(); - let result = serde_json::from_str::(&missing_args); - assert!(result.is_err()); - } -} diff --git a/api/libs/agents/src/tools/categories/file_tools/search_data_catalog.rs b/api/libs/agents/src/tools/categories/file_tools/search_data_catalog.rs index 93c590aad..5d1477c5c 100644 --- a/api/libs/agents/src/tools/categories/file_tools/search_data_catalog.rs +++ b/api/libs/agents/src/tools/categories/file_tools/search_data_catalog.rs @@ -7,6 +7,7 @@ use chrono::{DateTime, Utc}; use database::{pool::get_pg_pool, schema::datasets}; use diesel::prelude::*; use diesel_async::RunQueryDsl; +use middleware::AuthenticatedUser; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use tracing::{debug, error, warn}; @@ -260,7 +261,7 @@ impl ToolExecutor for SearchDataCatalogTool { type Output = SearchDataCatalogOutput; type Params = SearchDataCatalogParams; - async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result { + async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result { let start_time = Instant::now(); // Fetch all non-deleted datasets diff --git a/api/libs/agents/src/tools/categories/mod.rs b/api/libs/agents/src/tools/categories/mod.rs index 0ec899fe9..4cd1a6a92 100644 --- a/api/libs/agents/src/tools/categories/mod.rs +++ b/api/libs/agents/src/tools/categories/mod.rs @@ -7,6 +7,5 @@ //! - interaction_tools: Tools for user interaction and UI manipulation //! - planning_tools: Tools for planning and scheduling -pub mod agents_as_tools; pub mod file_tools; pub mod planning_tools; \ No newline at end of file diff --git a/api/libs/agents/src/tools/categories/planning_tools/create_plan.rs b/api/libs/agents/src/tools/categories/planning_tools/create_plan.rs index c701b9934..325ebe8ef 100644 --- a/api/libs/agents/src/tools/categories/planning_tools/create_plan.rs +++ b/api/libs/agents/src/tools/categories/planning_tools/create_plan.rs @@ -1,5 +1,6 @@ use anyhow::Result; use async_trait::async_trait; +use middleware::AuthenticatedUser; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::sync::Arc; @@ -36,7 +37,7 @@ impl ToolExecutor for CreatePlan { "create_plan".to_string() } - async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result { + async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result { self.agent .set_state_value(String::from("plan_available"), Value::Bool(true)) .await; diff --git a/api/libs/agents/src/tools/executor.rs b/api/libs/agents/src/tools/executor.rs index b031546c8..8e7687771 100644 --- a/api/libs/agents/src/tools/executor.rs +++ b/api/libs/agents/src/tools/executor.rs @@ -1,6 +1,8 @@ use anyhow::Result; +use middleware::AuthenticatedUser; use serde::{de::DeserializeOwned, Serialize}; use serde_json::Value; +use uuid::Uuid; /// A trait that defines how tools should be implemented. /// Any struct that wants to be used as a tool must implement this trait. @@ -13,7 +15,7 @@ pub trait ToolExecutor: Send + Sync { type Params: DeserializeOwned + Send; /// Execute the tool with the given parameters and tool call ID. - async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result; + async fn execute(&self, params: Self::Params, tool_call_id: String, user_id: AuthenticatedUser) -> Result; /// Get the JSON schema for this tool fn get_schema(&self) -> Value; @@ -53,9 +55,9 @@ where type Output = Value; type Params = Value; - async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result { + async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result { let params = serde_json::from_value(params)?; - let result = self.inner.execute(params, tool_call_id).await?; + let result = self.inner.execute(params, tool_call_id, user).await?; Ok(serde_json::to_value(result)?) } @@ -78,8 +80,8 @@ impl + Send + Sync> ToolExecutor type Output = Value; type Params = Value; - async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result { - (**self).execute(params, tool_call_id).await + async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result { + (**self).execute(params, tool_call_id, user).await } fn get_schema(&self) -> Value { diff --git a/api/libs/agents/src/tools/mod.rs b/api/libs/agents/src/tools/mod.rs index 0875640cc..a572a70bd 100644 --- a/api/libs/agents/src/tools/mod.rs +++ b/api/libs/agents/src/tools/mod.rs @@ -10,5 +10,4 @@ pub use executor::{ToolExecutor, ToolCallExecutor, IntoToolCallExecutor}; // Re-export commonly used tool categories pub use categories::file_tools; -pub use categories::planning_tools; -pub use categories::agents_as_tools; \ No newline at end of file +pub use categories::planning_tools; \ No newline at end of file diff --git a/api/libs/handlers/src/chats/post_chat_handler.rs b/api/libs/handlers/src/chats/post_chat_handler.rs index ae5bdc0fe..70dec9b1f 100644 --- a/api/libs/handlers/src/chats/post_chat_handler.rs +++ b/api/libs/handlers/src/chats/post_chat_handler.rs @@ -3,10 +3,14 @@ use once_cell::sync::OnceCell; use std::{collections::HashMap, sync::Mutex, time::Instant}; use agents::{ - tools::{file_tools::{ - common::ModifyFilesOutput, create_dashboard_files::CreateDashboardFilesOutput, - create_metric_files::CreateMetricFilesOutput, search_data_catalog::SearchDataCatalogOutput, - }, planning_tools::CreatePlanOutput}, + tools::{ + file_tools::{ + common::ModifyFilesOutput, create_dashboard_files::CreateDashboardFilesOutput, + create_metric_files::CreateMetricFilesOutput, + search_data_catalog::SearchDataCatalogOutput, + }, + planning_tools::CreatePlanOutput, + }, AgentExt, AgentMessage, AgentThread, BusterSuperAgent, }; @@ -175,7 +179,7 @@ pub async fn post_chat_handler( let mut initial_messages = vec![]; // Initialize agent to add context - let agent = BusterSuperAgent::new(user.id, chat_id).await?; + let agent = BusterSuperAgent::new(user.clone(), chat_id).await?; // Load context if provided if let Some(existing_chat_id) = request.chat_id { @@ -416,7 +420,8 @@ pub async fn post_chat_handler( fn prepare_final_message_state(containers: &[BusterContainer]) -> Result<(Vec, Vec)> { let mut response_messages = Vec::new(); // Use a Vec to maintain order, with a HashMap to track latest version of each message - let mut reasoning_map: std::collections::HashMap = std::collections::HashMap::new(); + let mut reasoning_map: std::collections::HashMap = + std::collections::HashMap::new(); let mut reasoning_order = Vec::new(); for container in containers { @@ -466,7 +471,7 @@ fn prepare_final_message_state(containers: &[BusterContainer]) -> Result<(Vec