diff --git a/api/src/utils/agent/agent.rs b/api/src/utils/agent/agent.rs index 4ca18e704..abe00c33b 100644 --- a/api/src/utils/agent/agent.rs +++ b/api/src/utils/agent/agent.rs @@ -19,12 +19,14 @@ struct ToolCallExecutor { impl ToolCallExecutor { fn new(inner: T) -> Self { - Self { inner: Box::new(inner) } + Self { + inner: Box::new(inner), + } } } #[async_trait::async_trait] -impl ToolExecutor for ToolCallExecutor +impl ToolExecutor for ToolCallExecutor where T::Params: serde::de::DeserializeOwned, T::Output: serde::Serialize, @@ -45,6 +47,10 @@ where fn get_name(&self) -> String { self.inner.get_name() } + + async fn is_enabled(&self) -> bool { + self.inner.is_enabled().await + } } // Add this near the top of the file, with other trait implementations @@ -64,6 +70,10 @@ impl + Send + Sync> ToolExecutor fn get_name(&self) -> String { (**self).get_name() } + + async fn is_enabled(&self) -> bool { + (**self).is_enabled().await + } } #[derive(Clone)] @@ -74,7 +84,11 @@ pub struct Agent { /// Client for communicating with the LLM provider llm_client: LiteLLMClient, /// Registry of available tools, mapped by their names - tools: Arc + Send + Sync>>>>, + tools: Arc< + RwLock< + HashMap + Send + Sync>>, + >, + >, /// The model identifier to use (e.g., "gpt-4") model: String, /// Flexible state storage for maintaining memory across interactions @@ -136,6 +150,24 @@ impl Agent { } } + pub async fn get_enabled_tools(&self) -> Vec { + // Collect all registered tools and their schemas + let tools = self.tools.read().await; + + let mut enabled_tools = Vec::new(); + + for (_, tool) in tools.iter() { + if tool.is_enabled().await { + enabled_tools.push(Tool { + tool_type: "function".to_string(), + function: tool.get_schema(), + }); + } + } + + enabled_tools + } + /// Update the stream sender for this agent pub async fn set_stream_sender(&self, tx: mpsc::Sender>) { *self.stream_tx.write().await = tx; @@ -308,16 +340,7 @@ impl Agent { } // Collect all registered tools and their schemas - let tools: Vec = self - .tools - .read() - .await - .iter() - .map(|(name, tool)| Tool { - tool_type: "function".to_string(), - function: tool.get_schema(), - }) - .collect(); + let tools = self.get_enabled_tools().await; // Create the tool-enabled request let request = ChatCompletionRequest { @@ -382,6 +405,7 @@ impl Agent { result_str, tool_call.id.clone(), Some(tool_call.function.name.clone()), + // TODO: need the progress for streaming None, ); @@ -461,7 +485,10 @@ impl PendingToolCall { pub trait AgentExt { fn get_agent(&self) -> &Arc; - async fn stream_process_thread(&self, thread: &AgentThread) -> Result>> { + async fn stream_process_thread( + &self, + thread: &AgentThread, + ) -> Result>> { (*self.get_agent()).process_thread_streaming(thread).await } @@ -540,6 +567,10 @@ mod tests { Ok(result) } + async fn is_enabled(&self) -> bool { + true + } + fn get_schema(&self) -> Value { json!({ "name": "get_weather", diff --git a/api/src/utils/agent/agents/manager_agent.rs b/api/src/utils/agent/agents/manager_agent.rs index 2a6f267dc..513141055 100644 --- a/api/src/utils/agent/agents/manager_agent.rs +++ b/api/src/utils/agent/agents/manager_agent.rs @@ -1,27 +1,22 @@ -use anyhow::{anyhow, Result}; +use anyhow::Result; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::{Arc, RwLock}; use tokio::sync::mpsc::Receiver; -use tracing::{debug, info}; use uuid::Uuid; use crate::utils::tools::agents_as_tools::{DashboardAgentTool, MetricAgentTool}; -use crate::utils::tools::file_tools::{send_assets_to_user, SendAssetsToUserTool}; +use crate::utils::tools::file_tools::SendAssetsToUserTool; use crate::utils::{ agent::{Agent, AgentExt, AgentThread}, tools::{ agents_as_tools::ExploratoryAgentTool, - file_tools::{ - CreateFilesTool, ModifyFilesTool, OpenFilesTool, SearchDataCatalogTool, SearchFilesTool, - }, + file_tools::{SearchDataCatalogTool, SearchFilesTool}, IntoValueTool, ToolExecutor, }, }; -use litellm::{Message as AgentMessage, ToolCall}; - -use super::MetricAgent; +use litellm::Message as AgentMessage; #[derive(Debug, Serialize, Deserialize)] pub struct ManagerAgentOutput { @@ -159,7 +154,21 @@ impl ManagerAgent { ) -> Result>> { thread.set_developer_message(MANAGER_AGENT_PROMPT.to_string()); - self.stream_process_thread(thread).await + let mut rx = self.stream_process_thread(thread).await?; + + while let Some(message) = rx.recv().await { + let message = message?; + if let AgentMessage::Tool { + id, + content, + tool_call_id, + name, + progress, + } = message + {} + } + + Ok(rx) } } diff --git a/api/src/utils/agent/types.rs b/api/src/utils/agent/types.rs index 97e090a17..45e3051fe 100644 --- a/api/src/utils/agent/types.rs +++ b/api/src/utils/agent/types.rs @@ -26,7 +26,11 @@ impl AgentThread { /// Set the developer message in the thread pub fn set_developer_message(&mut self, message: String) { // Look for an existing developer message - if let Some(pos) = self.messages.iter().position(|msg| matches!(msg, Message::Developer { .. })) { + if let Some(pos) = self + .messages + .iter() + .position(|msg| matches!(msg, Message::Developer { .. })) + { // Update existing developer message self.messages[pos] = Message::developer(message); } else { @@ -37,8 +41,17 @@ impl AgentThread { /// Remove the most recent assistant message from the thread pub fn remove_last_assistant_message(&mut self) { - if let Some(pos) = self.messages.iter().rposition(|msg| matches!(msg, Message::Assistant { .. })) { + if let Some(pos) = self + .messages + .iter() + .rposition(|msg| matches!(msg, Message::Assistant { .. })) + { self.messages.remove(pos); } } + + /// Add a user message to the thread + pub fn add_user_message(&mut self, content: String) { + self.messages.push(Message::user(content)); + } } diff --git a/api/src/utils/tools/agents_as_tools/dashboard_agent_tool.rs b/api/src/utils/tools/agents_as_tools/dashboard_agent_tool.rs index 95c4c07c0..b60fd67cd 100644 --- a/api/src/utils/tools/agents_as_tools/dashboard_agent_tool.rs +++ b/api/src/utils/tools/agents_as_tools/dashboard_agent_tool.rs @@ -42,7 +42,14 @@ impl ToolExecutor for DashboardAgentTool { type Params = DashboardAgentParams; fn get_name(&self) -> String { - "create_or_modify_dashboard".to_string() + "create_or_modify_dashboards".to_string() + } + + async fn is_enabled(&self) -> bool { + match self.agent.get_state_value("data_context").await { + Some(_) => true, + None => false, + } } async fn execute(&self, params: Self::Params) -> Result { @@ -63,6 +70,8 @@ impl ToolExecutor for DashboardAgentTool { current_thread.remove_last_assistant_message(); println!("DashboardAgentTool: Last assistant message removed"); + current_thread.add_user_message(params.ticket_description); + println!("DashboardAgentTool: Starting dashboard agent run"); // Run the dashboard agent and get the output let output = dashboard_agent.run(&mut current_thread).await?; diff --git a/api/src/utils/tools/agents_as_tools/exploratory_agent_tool.rs b/api/src/utils/tools/agents_as_tools/exploratory_agent_tool.rs index 9030c74ab..ca60dce0a 100644 --- a/api/src/utils/tools/agents_as_tools/exploratory_agent_tool.rs +++ b/api/src/utils/tools/agents_as_tools/exploratory_agent_tool.rs @@ -1,6 +1,5 @@ use anyhow::Result; use async_trait::async_trait; -use litellm::Message as AgentMessage; use serde::Deserialize; use serde_json::Value; use std::sync::Arc; @@ -33,6 +32,13 @@ impl ToolExecutor for ExploratoryAgentTool { "explore_data".to_string() } + async fn is_enabled(&self) -> bool { + match self.agent.get_state_value("data_context").await { + Some(_) => true, + None => false, + } + } + async fn execute(&self, params: Self::Params) -> Result { // Create and initialize the agent let exploratory_agent = ExploratoryAgent::from_existing(&self.agent).await?; @@ -45,6 +51,8 @@ impl ToolExecutor for ExploratoryAgentTool { current_thread.remove_last_assistant_message(); + current_thread.add_user_message(params.ticket_description); + // Run the exploratory agent and get the receiver let _rx = exploratory_agent.run(&mut current_thread).await?; diff --git a/api/src/utils/tools/agents_as_tools/metric_agent_tool.rs b/api/src/utils/tools/agents_as_tools/metric_agent_tool.rs index 402933aaf..ca8bee224 100644 --- a/api/src/utils/tools/agents_as_tools/metric_agent_tool.rs +++ b/api/src/utils/tools/agents_as_tools/metric_agent_tool.rs @@ -35,6 +35,13 @@ impl ToolExecutor for MetricAgentTool { "create_or_modify_metrics".to_string() } + async fn is_enabled(&self) -> bool { + match self.agent.get_state_value("data_context").await { + Some(_) => true, + None => false, + } + } + async fn execute(&self, params: Self::Params) -> Result { // Create and initialize the agent let metric_agent = MetricAgent::from_existing(&self.agent).await?; @@ -46,13 +53,10 @@ impl ToolExecutor for MetricAgentTool { .await .ok_or_else(|| anyhow::anyhow!("No current thread"))?; - // Parse input parameters - let agent_input = MetricAgentInput { - ticket_description: params.ticket_description, - }; - current_thread.remove_last_assistant_message(); + current_thread.add_user_message(params.ticket_description); + // Run the metric agent and get the receiver let _rx = metric_agent.run(&mut current_thread).await?; diff --git a/api/src/utils/tools/data_tools/create_plan.rs b/api/src/utils/tools/data_tools/create_plan.rs index c554c8c2b..03586136a 100644 --- a/api/src/utils/tools/data_tools/create_plan.rs +++ b/api/src/utils/tools/data_tools/create_plan.rs @@ -2,13 +2,10 @@ use anyhow::Result; use async_trait::async_trait; use serde::{Deserialize, Serialize}; use serde_json::Value; -use uuid::Uuid; use std::sync::Arc; +use uuid::Uuid; -use crate::utils::{ - tools::ToolExecutor, - agent::Agent, -}; +use crate::utils::{agent::Agent, tools::ToolExecutor}; use litellm::ToolCall; #[derive(Debug, Serialize, Deserialize)] @@ -25,7 +22,7 @@ pub struct PlanInput { } pub struct CreatePlan { - agent: Arc + agent: Arc, } impl CreatePlan { @@ -45,7 +42,7 @@ impl ToolExecutor for CreatePlan { async fn execute(&self, params: Self::Params) -> Result { let input = params; - + // TODO: Implement actual plan creation logic here // This would typically involve: // 1. Validating the markdown content @@ -59,6 +56,10 @@ impl ToolExecutor for CreatePlan { }) } + async fn is_enabled(&self) -> bool { + true + } + fn get_schema(&self) -> Value { serde_json::json!({ "name": "create_plan", @@ -79,4 +80,4 @@ impl ToolExecutor for CreatePlan { } }) } -} \ No newline at end of file +} diff --git a/api/src/utils/tools/data_tools/review_plan.rs b/api/src/utils/tools/data_tools/review_plan.rs index 1fa85d200..c0097a2bd 100644 --- a/api/src/utils/tools/data_tools/review_plan.rs +++ b/api/src/utils/tools/data_tools/review_plan.rs @@ -47,6 +47,10 @@ impl ToolExecutor for ReviewPlan { "review_plan".to_string() } + async fn is_enabled(&self) -> bool { + true + } + async fn execute(&self, params: Self::Params) -> Result { let input = params; diff --git a/api/src/utils/tools/data_tools/run_sql.rs b/api/src/utils/tools/data_tools/run_sql.rs index 72730989c..7592265cf 100644 --- a/api/src/utils/tools/data_tools/run_sql.rs +++ b/api/src/utils/tools/data_tools/run_sql.rs @@ -57,6 +57,10 @@ impl ToolExecutor for SqlQuery { "run_sql".to_string() } + async fn is_enabled(&self) -> bool { + true + } + async fn execute(&self, params: Self::Params) -> Result { let input = params; let mut results = Vec::new(); diff --git a/api/src/utils/tools/file_tools/create_dashboard_files.rs b/api/src/utils/tools/file_tools/create_dashboard_files.rs index 7985ad182..65ee480fe 100644 --- a/api/src/utils/tools/file_tools/create_dashboard_files.rs +++ b/api/src/utils/tools/file_tools/create_dashboard_files.rs @@ -118,6 +118,10 @@ impl ToolExecutor for CreateDashboardFilesTool { "create_dashboard_files".to_string() } + async fn is_enabled(&self) -> bool { + true + } + async fn execute(&self, params: Self::Params) -> Result { let start_time = Instant::now(); diff --git a/api/src/utils/tools/file_tools/create_files.rs b/api/src/utils/tools/file_tools/create_files.rs index e07dce551..c9fe94ed1 100644 --- a/api/src/utils/tools/file_tools/create_files.rs +++ b/api/src/utils/tools/file_tools/create_files.rs @@ -167,6 +167,14 @@ impl ToolExecutor for CreateFilesTool { type Output = CreateFilesOutput; type Params = CreateFilesParams; + fn get_name(&self) -> String { + "create_files".to_string() + } + + async fn is_enabled(&self) -> bool { + true + } + async fn execute(&self, params: Self::Params) -> Result { let start_time = Instant::now(); @@ -314,10 +322,6 @@ impl ToolExecutor for CreateFilesTool { }) } - fn get_name(&self) -> String { - "create_files".to_string() - } - fn get_schema(&self) -> Value { serde_json::json!({ "name": "create_files", diff --git a/api/src/utils/tools/file_tools/create_metric_files.rs b/api/src/utils/tools/file_tools/create_metric_files.rs index 472129000..6d970aa9f 100644 --- a/api/src/utils/tools/file_tools/create_metric_files.rs +++ b/api/src/utils/tools/file_tools/create_metric_files.rs @@ -113,6 +113,10 @@ impl ToolExecutor for CreateMetricFilesTool { "create_metric_files".to_string() } + async fn is_enabled(&self) -> bool { + true + } + async fn execute(&self, params: Self::Params) -> Result { let start_time = Instant::now(); diff --git a/api/src/utils/tools/file_tools/modify_dashboard_files.rs b/api/src/utils/tools/file_tools/modify_dashboard_files.rs index 6dd2d63ce..4fa62c8b9 100644 --- a/api/src/utils/tools/file_tools/modify_dashboard_files.rs +++ b/api/src/utils/tools/file_tools/modify_dashboard_files.rs @@ -236,6 +236,10 @@ impl ToolExecutor for ModifyDashboardFilesTool { "modify_dashboard_files".to_string() } + async fn is_enabled(&self) -> bool { + true + } + async fn execute(&self, params: Self::Params) -> Result { let start_time = Instant::now(); diff --git a/api/src/utils/tools/file_tools/modify_files.rs b/api/src/utils/tools/file_tools/modify_files.rs index 30d71dd82..3d0f72fba 100644 --- a/api/src/utils/tools/file_tools/modify_files.rs +++ b/api/src/utils/tools/file_tools/modify_files.rs @@ -249,6 +249,14 @@ impl ToolExecutor for ModifyFilesTool { type Output = ModifyFilesOutput; type Params = ModifyFilesParams; + fn get_name(&self) -> String { + "modify_files".to_string() + } + + async fn is_enabled(&self) -> bool { + true + } + async fn execute(&self, params: Self::Params) -> Result { let start_time = Instant::now(); @@ -467,10 +475,6 @@ impl ToolExecutor for ModifyFilesTool { Ok(output) } - fn get_name(&self) -> String { - "modify_files".to_string() - } - fn get_schema(&self) -> Value { serde_json::json!({ "name": "modify_files", diff --git a/api/src/utils/tools/file_tools/modify_metric_files.rs b/api/src/utils/tools/file_tools/modify_metric_files.rs index 9a46144b2..41dfdde1a 100644 --- a/api/src/utils/tools/file_tools/modify_metric_files.rs +++ b/api/src/utils/tools/file_tools/modify_metric_files.rs @@ -249,6 +249,10 @@ impl ToolExecutor for ModifyMetricFilesTool { "modify_metric_files".to_string() } + async fn is_enabled(&self) -> bool { + true + } + async fn execute(&self, params: Self::Params) -> Result { let start_time = Instant::now(); diff --git a/api/src/utils/tools/file_tools/open_files.rs b/api/src/utils/tools/file_tools/open_files.rs index a5fb695b0..31dd476b3 100644 --- a/api/src/utils/tools/file_tools/open_files.rs +++ b/api/src/utils/tools/file_tools/open_files.rs @@ -67,6 +67,14 @@ impl ToolExecutor for OpenFilesTool { type Output = OpenFilesOutput; type Params = OpenFilesParams; + fn get_name(&self) -> String { + "open_files".to_string() + } + + async fn is_enabled(&self) -> bool { + true + } + async fn execute(&self, params: Self::Params) -> Result { let start_time = Instant::now(); @@ -219,10 +227,6 @@ impl ToolExecutor for OpenFilesTool { }) } - fn get_name(&self) -> String { - "open_files".to_string() - } - fn get_schema(&self) -> Value { serde_json::json!({ "name": "open_files", diff --git a/api/src/utils/tools/file_tools/search_data_catalog.rs b/api/src/utils/tools/file_tools/search_data_catalog.rs index 0e62aeb8f..c1a0631b6 100644 --- a/api/src/utils/tools/file_tools/search_data_catalog.rs +++ b/api/src/utils/tools/file_tools/search_data_catalog.rs @@ -85,6 +85,10 @@ impl SearchDataCatalogTool { Self { agent } } + async fn is_enabled(&self) -> bool { + true + } + fn format_search_prompt(query_params: &[String], datasets: &[DatasetRecord]) -> Result { let datasets_json = datasets .iter() @@ -125,6 +129,7 @@ impl SearchDataCatalogTool { user_id: user_id.to_string(), session_id: session_id.to_string(), }), + reasoning_effort: Some("low".to_string()), ..Default::default() }; @@ -259,6 +264,10 @@ impl ToolExecutor for SearchDataCatalogTool { }) } + async fn is_enabled(&self) -> bool { + true + } + fn get_name(&self) -> String { "search_data_catalog".to_string() } diff --git a/api/src/utils/tools/file_tools/search_files.rs b/api/src/utils/tools/file_tools/search_files.rs index 8215ab1ab..9797c3cdc 100644 --- a/api/src/utils/tools/file_tools/search_files.rs +++ b/api/src/utils/tools/file_tools/search_files.rs @@ -90,6 +90,10 @@ impl SearchFilesTool { Self { agent } } + async fn is_enabled(&self) -> bool { + true + } + fn format_search_prompt(query_params: &[String], files_array: &[Value]) -> Result { let queries_joined = query_params.join("\n"); let files_json = serde_json::to_string_pretty(&files_array)?; @@ -166,6 +170,10 @@ impl ToolExecutor for SearchFilesTool { "search_files".to_string() } + async fn is_enabled(&self) -> bool { + true + } + async fn execute(&self, params: Self::Params) -> Result { let start_time = Instant::now(); diff --git a/api/src/utils/tools/file_tools/send_assets_to_user.rs b/api/src/utils/tools/file_tools/send_assets_to_user.rs index acbd23b79..154d44bb8 100644 --- a/api/src/utils/tools/file_tools/send_assets_to_user.rs +++ b/api/src/utils/tools/file_tools/send_assets_to_user.rs @@ -43,47 +43,54 @@ impl ToolExecutor for SendAssetsToUserTool { "decide_assets_to_return".to_string() } + async fn is_enabled(&self) -> bool { + match self.agent.get_state_value("files_created").await { + Some(_) => true, + None => false, + } + } + fn get_schema(&self) -> Value { serde_json::json!({ - "name": "decide_assets_to_return", - "description": "Use after you have created or modified any assets (metrics or dashboards) to specify exactly which assets to present in the final response. If you have not created or modified any assets, do not call this action.", - "strict": true, - "parameters": { - "type": "object", - "required": [ - "assets_to_return", - "ticket_description" - ], - "properties": { - "assets_to_return": { - "type": "array", - "description": "List of assets to present in the final response, each with an ID and a name", - "items": { - "type": "object", - "properties": { - "id": { - "type": "string", - "description": "Unique identifier for the asset" - }, - "name": { - "type": "string", - "description": "Name of the asset" - } + "name": "decide_assets_to_return", + "description": "Use after you have created or modified any assets (metrics or dashboards) to specify exactly which assets to present in the final response. If you have not created or modified any assets, do not call this action.", + "strict": true, + "parameters": { + "type": "object", + "required": [ + "assets_to_return", + "ticket_description" + ], + "properties": { + "assets_to_return": { + "type": "array", + "description": "List of assets to present in the final response, each with an ID and a name", + "items": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Unique identifier for the asset" }, - "required": [ - "id", - "name" - ], - "additionalProperties": false - } - }, - "ticket_description": { - "type": "string", - "description": "Description of the ticket related to the assets" + "name": { + "type": "string", + "description": "Name of the asset" + } + }, + "required": [ + "id", + "name" + ], + "additionalProperties": false } }, - "additionalProperties": false - } - }) + "ticket_description": { + "type": "string", + "description": "Description of the ticket related to the assets" + } + }, + "additionalProperties": false + } + }) } } diff --git a/api/src/utils/tools/interaction_tools/send_message_to_user.rs b/api/src/utils/tools/interaction_tools/send_message_to_user.rs index 2a7dfa10d..3cf451cb1 100644 --- a/api/src/utils/tools/interaction_tools/send_message_to_user.rs +++ b/api/src/utils/tools/interaction_tools/send_message_to_user.rs @@ -55,4 +55,8 @@ impl ToolExecutor for SendMessageToUser { fn get_name(&self) -> String { "send_message_to_user".to_string() } + + async fn is_enabled(&self) -> bool { + true + } } diff --git a/api/src/utils/tools/mod.rs b/api/src/utils/tools/mod.rs index 2b9dd1ab5..f61114c42 100644 --- a/api/src/utils/tools/mod.rs +++ b/api/src/utils/tools/mod.rs @@ -31,6 +31,9 @@ pub trait ToolExecutor: Send + Sync { /// Get the name of this tool fn get_name(&self) -> String; + + /// Check if this tool is currently enabled + async fn is_enabled(&self) -> bool; } /// A wrapper type that converts any ToolExecutor to one that outputs Value @@ -61,6 +64,10 @@ impl ToolExecutor for ValueToolExecutor { fn get_name(&self) -> String { self.inner.get_name() } + + async fn is_enabled(&self) -> bool { + self.inner.is_enabled().await + } } /// Extension trait to add value conversion methods to ToolExecutor