From d8ee830c6a4009c6005075384c1c039bda7d026c Mon Sep 17 00:00:00 2001 From: dal Date: Mon, 24 Feb 2025 06:20:16 -0700 Subject: [PATCH] ok just need to tie up the last few things --- api/src/routes/rest/routes/chats/post_chat.rs | 11 +- api/src/utils/agent/agent.rs | 135 +++++++++++------- api/src/utils/agent/agents/dashboard_agent.rs | 115 ++++++++------- .../utils/agent/agents/exploratory_agent.rs | 86 +++++++++-- api/src/utils/agent/agents/manager_agent.rs | 52 ++++++- api/src/utils/agent/agents/metric_agent.rs | 80 ++++++++++- .../agents_as_tools/dashboard_agent_tool.rs | 10 +- api/src/utils/tools/mod.rs | 16 ++- 8 files changed, 371 insertions(+), 134 deletions(-) diff --git a/api/src/routes/rest/routes/chats/post_chat.rs b/api/src/routes/rest/routes/chats/post_chat.rs index 57507f11a..9e4837326 100644 --- a/api/src/routes/rest/routes/chats/post_chat.rs +++ b/api/src/routes/rest/routes/chats/post_chat.rs @@ -10,6 +10,7 @@ use handlers::threads::types::ThreadWithMessages; use litellm::Message as AgentMessage; use serde::{Deserialize, Serialize}; use serde_json::Value; +use tokio::sync::broadcast; use uuid::Uuid; use crate::routes::rest::ApiResponse; @@ -98,10 +99,12 @@ async fn process_chat(request: ChatCreateNewChat, user: User) -> Result messages.push(msg), - Err(e) => return Err(e.into()), + loop { + match rx.recv().await { + Ok(Ok(msg)) => messages.push(msg), + Ok(Err(e)) => return Err(e.into()), + Err(broadcast::error::RecvError::Closed) => break, + Err(e) => return Err(anyhow!(e)), } } diff --git a/api/src/utils/agent/agent.rs b/api/src/utils/agent/agent.rs index abe00c33b..c6e3c2398 100644 --- a/api/src/utils/agent/agent.rs +++ b/api/src/utils/agent/agent.rs @@ -5,13 +5,26 @@ use litellm::{ }; use serde_json::Value; use std::{collections::HashMap, env, sync::Arc}; -use tokio::sync::{mpsc, RwLock}; +use tokio::sync::{broadcast, RwLock}; use uuid::Uuid; use crate::utils::tools::ToolExecutor; use super::types::AgentThread; +#[derive(Debug, Clone)] +pub struct AgentError(pub String); + +impl std::error::Error for AgentError {} + +impl std::fmt::Display for AgentError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +type MessageResult = Result; + /// A wrapper type that converts ToolCall parameters to Value before executing struct ToolCallExecutor { inner: Box, @@ -96,11 +109,13 @@ pub struct Agent { /// The current thread being processed, if any current_thread: Arc>>, /// Sender for streaming messages from this agent and sub-agents - stream_tx: Arc>>>, + stream_tx: Arc>>, /// The user ID for the current thread user_id: Uuid, /// The session ID for the current thread session_id: Uuid, + /// Shutdown signal sender + shutdown_tx: Arc>>, } impl Agent { @@ -116,8 +131,10 @@ impl Agent { let llm_client = LiteLLMClient::new(Some(llm_api_key), Some(llm_base_url)); - // Create a default channel that just drops messages - let (tx, _rx) = mpsc::channel(1); + // Create a broadcast channel with buffer size 1000 + let (tx, _rx) = broadcast::channel(1000); + // Create shutdown channel with buffer size 1 + let (shutdown_tx, _) = broadcast::channel(1); Self { llm_client, @@ -128,6 +145,7 @@ impl Agent { stream_tx: Arc::new(RwLock::new(tx)), user_id, session_id, + shutdown_tx: Arc::new(RwLock::new(shutdown_tx)), } } @@ -140,13 +158,14 @@ impl Agent { Self { llm_client, - tools: Arc::new(RwLock::new(HashMap::new())), // Start with empty tools + tools: Arc::new(RwLock::new(HashMap::new())), model: existing_agent.model.clone(), state: Arc::clone(&existing_agent.state), current_thread: Arc::clone(&existing_agent.current_thread), stream_tx: Arc::clone(&existing_agent.stream_tx), user_id: existing_agent.user_id, session_id: existing_agent.session_id, + shutdown_tx: Arc::clone(&existing_agent.shutdown_tx), } } @@ -168,13 +187,13 @@ impl Agent { enabled_tools } - /// Update the stream sender for this agent - pub async fn set_stream_sender(&self, tx: mpsc::Sender>) { - *self.stream_tx.write().await = tx; + /// Get a new receiver for the broadcast channel + pub async fn get_stream_receiver(&self) -> broadcast::Receiver { + self.stream_tx.read().await.subscribe() } /// Get a clone of the current stream sender - pub async fn get_stream_sender(&self) -> mpsc::Sender> { + pub async fn get_stream_sender(&self) -> broadcast::Sender { self.stream_tx.read().await.clone() } @@ -276,7 +295,7 @@ impl Agent { let mut rx = self.process_thread_streaming(thread).await?; let mut final_message = None; - while let Some(msg) = rx.recv().await { + while let Ok(msg) = rx.recv().await { final_message = Some(msg?); } @@ -294,26 +313,37 @@ impl Agent { pub async fn process_thread_streaming( &self, thread: &AgentThread, - ) -> Result>> { - // Create new channel for this processing session - let (tx, rx) = mpsc::channel(100); - self.set_stream_sender(tx).await; - + ) -> Result> { // Spawn the processing task let agent_clone = self.clone(); let thread_clone = thread.clone(); + // Get shutdown receiver + let mut shutdown_rx = self.get_shutdown_receiver().await; + tokio::spawn(async move { - if let Err(e) = agent_clone - .process_thread_with_depth(&thread_clone, 0) - .await - { - let err_msg = format!("Error processing thread: {:?}", e); - let _ = agent_clone.get_stream_sender().await.send(Err(e)).await; + tokio::select! { + result = agent_clone.process_thread_with_depth(&thread_clone, 0) => { + if let Err(e) = result { + let err_msg = format!("Error processing thread: {:?}", e); + let _ = agent_clone.get_stream_sender().await.send(Err(AgentError(err_msg))); + } + } + _ = shutdown_rx.recv() => { + let _ = agent_clone.get_stream_sender().await.send( + Ok(Message::assistant( + Some("shutdown_message".to_string()), + Some("Processing interrupted due to shutdown signal".to_string()), + None, + None, + None, + )) + ); + } } }); - Ok(rx) + Ok(self.get_stream_receiver().await) } async fn process_thread_with_depth( @@ -335,7 +365,7 @@ impl Agent { None, None, ); - self.get_stream_sender().await.send(Ok(message)).await?; + self.get_stream_sender().await.send(Ok(message))?; return Ok(()); } @@ -425,6 +455,23 @@ impl Agent { Ok(()) } } + + /// Get a receiver for the shutdown signal + pub async fn get_shutdown_receiver(&self) -> broadcast::Receiver<()> { + self.shutdown_tx.read().await.subscribe() + } + + /// Signal shutdown to all receivers + pub async fn shutdown(&self) -> Result<()> { + // Send shutdown signal + self.shutdown_tx.read().await.send(())?; + Ok(()) + } + + /// Get a reference to the tools map + pub async fn get_tools(&self) -> tokio::sync::RwLockReadGuard<'_, HashMap + Send + Sync>>> { + self.tools.read().await + } } #[derive(Debug, Default)] @@ -488,7 +535,7 @@ pub trait AgentExt { async fn stream_process_thread( &self, thread: &AgentThread, - ) -> Result>> { + ) -> Result> { (*self.get_agent()).process_thread_streaming(thread).await } @@ -524,24 +571,27 @@ mod tests { } } + impl WeatherTool { + async fn send_progress(&self, content: String, tool_id: String, progress: MessageProgress) -> Result<()> { + let message = Message::tool( + None, + content, + tool_id, + Some(self.get_name()), + Some(progress), + ); + self.agent.get_stream_sender().await.send(Ok(message))?; + Ok(()) + } + } + #[async_trait] impl ToolExecutor for WeatherTool { type Output = Value; type Params = Value; async fn execute(&self, params: Self::Params) -> Result { - // Send progress using agent's stream sender - self.agent - .get_stream_sender() - .await - .send(Ok(Message::tool( - None, - "Fetching weather data...".to_string(), - "123".to_string(), - Some(self.get_name()), - Some(MessageProgress::InProgress), - ))) - .await?; + self.send_progress("Fetching weather data...".to_string(), "123".to_string(), MessageProgress::InProgress).await?; // Simulate a delay tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; @@ -551,18 +601,7 @@ mod tests { "unit": "fahrenheit" }); - // Send completion message using agent's stream sender - self.agent - .get_stream_sender() - .await - .send(Ok(Message::tool( - None, - serde_json::to_string(&result)?, - "123".to_string(), - Some(self.get_name()), - Some(MessageProgress::Complete), - ))) - .await?; + self.send_progress(serde_json::to_string(&result)?, "123".to_string(), MessageProgress::Complete).await?; Ok(result) } diff --git a/api/src/utils/agent/agents/dashboard_agent.rs b/api/src/utils/agent/agents/dashboard_agent.rs index 62ced9e44..fbbf87334 100644 --- a/api/src/utils/agent/agents/dashboard_agent.rs +++ b/api/src/utils/agent/agents/dashboard_agent.rs @@ -5,7 +5,7 @@ use std::collections::HashMap; use uuid::Uuid; use crate::utils::{ - agent::{Agent, AgentExt, AgentThread}, + agent::{agent::AgentError, Agent, AgentExt, AgentThread}, tools::{ agents_as_tools::dashboard_agent_tool::DashboardAgentOutput, file_tools::{ CreateDashboardFilesTool, CreateMetricFilesTool, ModifyDashboardFilesTool, @@ -15,6 +15,7 @@ use crate::utils::{ }; use litellm::Message as AgentMessage; +use tokio::sync::broadcast; pub struct DashboardAgent { agent: Arc, @@ -84,71 +85,75 @@ impl DashboardAgent { if content == "AGENT_COMPLETE") } - pub async fn run(&self, thread: &mut AgentThread) -> Result { - println!("Running dashboard agent"); - println!("Setting developer message"); + pub async fn run(&self, thread: &mut AgentThread) -> Result>> { thread.set_developer_message(DASHBOARD_AGENT_PROMPT.to_string()); - println!("Starting stream_process_thread"); + // Get shutdown receiver + let mut shutdown_rx = self.get_agent().get_shutdown_receiver().await; let mut rx = self.stream_process_thread(thread).await?; - println!("Got receiver from stream_process_thread"); - println!("Starting message processing loop"); + let rx_return = rx.resubscribe(); + // Process messages internally until we determine we're done - while let Some(msg_result) = rx.recv().await { - println!("Received message from channel"); - match msg_result { - Ok(msg) => { - println!("Message content: {:?}", msg.get_content()); - println!("Message has tool calls: {:?}", msg.get_tool_calls()); - - println!("Forwarding message to stream sender"); - if let Err(e) = self.get_agent().get_stream_sender().await.send(Ok(msg.clone())).await { - println!("Error forwarding message: {:?}", e); - // Continue processing even if we fail to forward - continue; - } - - if let Some(content) = msg.get_content() { - println!("Message has content: {}", content); - if content == "AGENT_COMPLETE" { - println!("Found completion signal, breaking loop"); - break; + loop { + tokio::select! { + recv_result = rx.recv() => { + match recv_result { + Ok(msg_result) => { + match msg_result { + Ok(msg) => { + // Forward message to stream sender + let sender = self.get_agent().get_stream_sender().await; + if let Err(e) = sender.send(Ok(msg.clone())) { + let err_msg = format!("Error forwarding message: {:?}", e); + let _ = sender.send(Err(AgentError(err_msg))); + continue; + } + + if let Some(content) = msg.get_content() { + if content == "AGENT_COMPLETE" { + return Ok(rx_return); + } + } + } + Err(e) => { + let err_msg = format!("Error processing message: {:?}", e); + let _ = self.get_agent().get_stream_sender().await.send(Err(AgentError(err_msg))); + continue; + } + } + } + Err(e) => { + let err_msg = format!("Error receiving message: {:?}", e); + let _ = self.get_agent().get_stream_sender().await.send(Err(AgentError(err_msg))); + continue; } } } - Err(e) => { - println!("Error receiving message: {:?}", e); - println!("Error details: {:?}", e.to_string()); - // Log error but continue processing instead of returning error - continue; + _ = shutdown_rx.recv() => { + // Handle shutdown gracefully + let tools = self.get_agent().get_tools().await; + for (_, tool) in tools.iter() { + if let Err(e) = tool.handle_shutdown().await { + let err_msg = format!("Error shutting down tool: {:?}", e); + let _ = self.get_agent().get_stream_sender().await.send(Err(AgentError(err_msg))); + } + } + + let _ = self.get_agent().get_stream_sender().await.send( + Ok(AgentMessage::assistant( + Some("shutdown_message".to_string()), + Some("Dashboard agent shutting down gracefully".to_string()), + None, + None, + None, + )) + ); + + return Ok(rx_return); } } } - println!("Exited message processing loop"); - - println!("Creating completion signal"); - let completion_msg = AgentMessage::assistant( - None, - Some("AGENT_COMPLETE".to_string()), - None, - None, - None, - ); - - println!("Sending completion signal"); - self.get_agent() - .get_stream_sender() - .await - .send(Ok(completion_msg)) - .await?; - - println!("Sent completion signal, returning output"); - Ok(DashboardAgentOutput { - message: "Dashboard processing complete".to_string(), - duration: 0, - files: vec![], - }) } } diff --git a/api/src/utils/agent/agents/exploratory_agent.rs b/api/src/utils/agent/agents/exploratory_agent.rs index 2bd320fa4..8fc2945b2 100644 --- a/api/src/utils/agent/agents/exploratory_agent.rs +++ b/api/src/utils/agent/agents/exploratory_agent.rs @@ -2,17 +2,15 @@ use std::sync::Arc; use anyhow::Result; use std::collections::HashMap; -use tokio::sync::mpsc::Receiver; use uuid::Uuid; use crate::utils::{ - agent::{Agent, AgentExt, AgentThread}, - tools::{ - IntoValueTool, ToolExecutor, - }, + agent::{agent::AgentError, Agent, AgentExt, AgentThread}, + tools::{IntoValueTool, ToolExecutor}, }; use litellm::Message as AgentMessage; +use tokio::sync::broadcast; pub struct ExploratoryAgent { agent: Arc, @@ -46,11 +44,83 @@ impl ExploratoryAgent { Ok(exploratory) } - pub async fn run(&self, thread: &mut AgentThread) -> Result>> { - // Process using agent's streaming functionality + pub async fn run( + &self, + thread: &mut AgentThread, + ) -> Result>> { thread.set_developer_message(EXPLORATORY_AGENT_PROMPT.to_string()); - self.stream_process_thread(thread).await + // Get shutdown receiver + let mut shutdown_rx = self.get_agent().get_shutdown_receiver().await; + let mut rx = self.stream_process_thread(thread).await?; + + // Clone what we need for the processing task + let agent = Arc::clone(self.get_agent()); + + let rx_return = rx.resubscribe(); + + tokio::spawn(async move { + loop { + tokio::select! { + recv_result = rx.recv() => { + match recv_result { + Ok(msg_result) => { + match msg_result { + Ok(msg) => { + // Forward message to stream sender + let sender = agent.get_stream_sender().await; + if let Err(e) = sender.send(Ok(msg.clone())) { + let err_msg = format!("Error forwarding message: {:?}", e); + let _ = sender.send(Err(AgentError(err_msg))); + continue; + } + + if let Some(content) = msg.get_content() { + if content == "AGENT_COMPLETE" { + break; + } + } + } + Err(e) => { + let err_msg = format!("Error processing message: {:?}", e); + let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg))); + continue; + } + } + } + Err(e) => { + let err_msg = format!("Error receiving message: {:?}", e); + let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg))); + continue; + } + } + } + _ = shutdown_rx.recv() => { + // Handle shutdown gracefully + let tools = agent.get_tools().await; + for (_, tool) in tools.iter() { + if let Err(e) = tool.handle_shutdown().await { + let err_msg = format!("Error shutting down tool: {:?}", e); + let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg))); + } + } + + let _ = agent.get_stream_sender().await.send( + Ok(AgentMessage::assistant( + Some("shutdown_message".to_string()), + Some("Exploratory agent shutting down gracefully".to_string()), + None, + None, + None, + )) + ); + break; + } + } + } + }); + + Ok(rx_return) } } diff --git a/api/src/utils/agent/agents/manager_agent.rs b/api/src/utils/agent/agents/manager_agent.rs index a9f01c284..9735bb542 100644 --- a/api/src/utils/agent/agents/manager_agent.rs +++ b/api/src/utils/agent/agents/manager_agent.rs @@ -3,13 +3,13 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use std::collections::HashMap; use std::sync::Arc; -use tokio::sync::mpsc::Receiver; +use tokio::sync::broadcast; use uuid::Uuid; use crate::utils::tools::agents_as_tools::{DashboardAgentTool, MetricAgentTool}; use crate::utils::tools::file_tools::SendAssetsToUserTool; use crate::utils::{ - agent::{Agent, AgentExt, AgentThread}, + agent::{agent::AgentError, Agent, AgentExt, AgentThread}, tools::{ agents_as_tools::ExploratoryAgentTool, file_tools::{SearchDataCatalogTool, SearchFilesTool}, @@ -123,13 +123,57 @@ impl ManagerAgent { pub async fn run( &self, thread: &mut AgentThread, - ) -> Result>> { + ) -> Result>> { thread.set_developer_message(MANAGER_AGENT_PROMPT.to_string()); + + // Use existing channel - important for sub-agents + let rx = self.get_agent().get_stream_receiver().await; + + // Get shutdown receiver + let mut shutdown_rx = self.get_agent().get_shutdown_receiver().await; + + // Clone only what we need + let agent = Arc::clone(self.get_agent()); + let thread = thread.clone(); + + tokio::spawn(async move { + tokio::select! { + result = agent.process_thread(&thread) => { + if let Err(e) = result { + let err_msg = format!("Manager agent processing failed: {:?}", e); + let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg))); + } + } + _ = shutdown_rx.recv() => { + // Shutdown all tools + let tools = agent.get_tools().await; + for (_, tool) in tools.iter() { + if let Err(e) = tool.handle_shutdown().await { + let err_msg = format!("Error shutting down tool: {:?}", e); + let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg))); + } + } - let mut rx = self.stream_process_thread(thread).await?; + let _ = agent.get_stream_sender().await.send( + Ok(AgentMessage::assistant( + Some("shutdown_message".to_string()), + Some("Manager agent shutting down gracefully".to_string()), + None, + None, + None, + )) + ); + } + } + }); Ok(rx) } + + /// Shutdown the manager agent and all its tools + pub async fn shutdown(&self) -> Result<()> { + self.get_agent().shutdown().await + } } const MANAGER_AGENT_PROMPT: &str = r##" diff --git a/api/src/utils/agent/agents/metric_agent.rs b/api/src/utils/agent/agents/metric_agent.rs index d1231922a..1db654582 100644 --- a/api/src/utils/agent/agents/metric_agent.rs +++ b/api/src/utils/agent/agents/metric_agent.rs @@ -2,11 +2,10 @@ use std::sync::Arc; use anyhow::Result; use std::collections::HashMap; -use tokio::sync::mpsc::Receiver; use uuid::Uuid; use crate::utils::{ - agent::{Agent, AgentExt, AgentThread}, + agent::{agent::AgentError, Agent, AgentExt, AgentThread}, tools::{ file_tools::{CreateMetricFilesTool, ModifyMetricFilesTool}, IntoValueTool, ToolExecutor, @@ -14,6 +13,7 @@ use crate::utils::{ }; use litellm::Message as AgentMessage; +use tokio::sync::broadcast; pub struct MetricAgent { agent: Arc, @@ -67,10 +67,80 @@ impl MetricAgent { pub async fn run( &self, thread: &mut AgentThread, - ) -> Result>> { + ) -> Result>> { thread.set_developer_message(METRIC_AGENT_PROMPT.to_string()); - // Process using agent's streaming functionality - self.stream_process_thread(thread).await + + // Get shutdown receiver + let mut shutdown_rx = self.get_agent().get_shutdown_receiver().await; + let mut rx = self.stream_process_thread(thread).await?; + + // Clone what we need for the processing task + let agent = Arc::clone(self.get_agent()); + + let rx_return = rx.resubscribe(); + + tokio::spawn(async move { + loop { + tokio::select! { + recv_result = rx.recv() => { + match recv_result { + Ok(msg_result) => { + match msg_result { + Ok(msg) => { + // Forward message to stream sender + let sender = agent.get_stream_sender().await; + if let Err(e) = sender.send(Ok(msg.clone())) { + let err_msg = format!("Error forwarding message: {:?}", e); + let _ = sender.send(Err(AgentError(err_msg))); + continue; + } + + if let Some(content) = msg.get_content() { + if content == "AGENT_COMPLETE" { + break; + } + } + } + Err(e) => { + let err_msg = format!("Error processing message: {:?}", e); + let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg))); + continue; + } + } + } + Err(e) => { + let err_msg = format!("Error receiving message: {:?}", e); + let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg))); + continue; + } + } + } + _ = shutdown_rx.recv() => { + // Handle shutdown gracefully + let tools = agent.get_tools().await; + for (_, tool) in tools.iter() { + if let Err(e) = tool.handle_shutdown().await { + let err_msg = format!("Error shutting down tool: {:?}", e); + let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg))); + } + } + + let _ = agent.get_stream_sender().await.send( + Ok(AgentMessage::assistant( + Some("shutdown_message".to_string()), + Some("Metric agent shutting down gracefully".to_string()), + None, + None, + None, + )) + ); + break; + } + } + } + }); + + Ok(rx_return) } } diff --git a/api/src/utils/tools/agents_as_tools/dashboard_agent_tool.rs b/api/src/utils/tools/agents_as_tools/dashboard_agent_tool.rs index b4b35aa41..6ed6f11bc 100644 --- a/api/src/utils/tools/agents_as_tools/dashboard_agent_tool.rs +++ b/api/src/utils/tools/agents_as_tools/dashboard_agent_tool.rs @@ -74,7 +74,7 @@ impl ToolExecutor for DashboardAgentTool { println!("DashboardAgentTool: Starting dashboard agent run"); // Run the dashboard agent and get the output - let output = dashboard_agent.run(&mut current_thread).await?; + let _receiver = dashboard_agent.run(&mut current_thread).await?; println!("DashboardAgentTool: Dashboard agent run completed"); println!("DashboardAgentTool: Preparing success response"); @@ -83,12 +83,12 @@ impl ToolExecutor for DashboardAgentTool { .set_state_value(String::from("files_available"), Value::Bool(false)) .await; - // Return success with the output + // Return dummy data for testing Ok(serde_json::json!({ "status": "success", - "message": output.message, - "duration": output.duration, - "files": output.files + "message": "Test dashboard creation", + "duration": 0, + "files": [] })) } diff --git a/api/src/utils/tools/mod.rs b/api/src/utils/tools/mod.rs index f61114c42..c7f014155 100644 --- a/api/src/utils/tools/mod.rs +++ b/api/src/utils/tools/mod.rs @@ -15,7 +15,7 @@ pub mod interaction_tools; /// A trait that defines how tools should be implemented. /// Any struct that wants to be used as a tool must implement this trait. /// Tools are constructed with a reference to their agent and can access its capabilities. -#[async_trait] +#[async_trait::async_trait] pub trait ToolExecutor: Send + Sync { /// The type of the output of the tool type Output: Serialize + Send; @@ -34,21 +34,27 @@ pub trait ToolExecutor: Send + Sync { /// Check if this tool is currently enabled async fn is_enabled(&self) -> bool; + + /// Handle shutdown signal. Default implementation does nothing. + /// Tools should override this if they need to perform cleanup on shutdown. + async fn handle_shutdown(&self) -> Result<()> { + Ok(()) + } } /// A wrapper type that converts any ToolExecutor to one that outputs Value -pub struct ValueToolExecutor { +pub struct ValueToolExecutor { inner: T, } -impl ValueToolExecutor { +impl ValueToolExecutor { pub fn new(inner: T) -> Self { Self { inner } } } -#[async_trait] -impl ToolExecutor for ValueToolExecutor { +#[async_trait::async_trait] +impl ToolExecutor for ValueToolExecutor { type Output = Value; type Params = T::Params;