From 84bada0dd1659c42e7f56d74f24374c38377a9bd Mon Sep 17 00:00:00 2001 From: dal Date: Mon, 24 Feb 2025 07:01:25 -0700 Subject: [PATCH] simplify agents pulling off shutdown function --- .../utils/agent/agents/exploratory_agent.rs | 71 +------------------ api/src/utils/agent/agents/manager_agent.rs | 42 +---------- api/src/utils/agent/agents/metric_agent.rs | 71 +------------------ 3 files changed, 6 insertions(+), 178 deletions(-) diff --git a/api/src/utils/agent/agents/exploratory_agent.rs b/api/src/utils/agent/agents/exploratory_agent.rs index 8fc2945b2..5ece9e3cf 100644 --- a/api/src/utils/agent/agents/exploratory_agent.rs +++ b/api/src/utils/agent/agents/exploratory_agent.rs @@ -51,76 +51,9 @@ impl ExploratoryAgent { thread.set_developer_message(EXPLORATORY_AGENT_PROMPT.to_string()); // Get shutdown receiver - let mut shutdown_rx = self.get_agent().get_shutdown_receiver().await; - let mut rx = self.stream_process_thread(thread).await?; + let rx = self.stream_process_thread(thread).await?; - // Clone what we need for the processing task - let agent = Arc::clone(self.get_agent()); - - let rx_return = rx.resubscribe(); - - tokio::spawn(async move { - loop { - tokio::select! { - recv_result = rx.recv() => { - match recv_result { - Ok(msg_result) => { - match msg_result { - Ok(msg) => { - // Forward message to stream sender - let sender = agent.get_stream_sender().await; - if let Err(e) = sender.send(Ok(msg.clone())) { - let err_msg = format!("Error forwarding message: {:?}", e); - let _ = sender.send(Err(AgentError(err_msg))); - continue; - } - - if let Some(content) = msg.get_content() { - if content == "AGENT_COMPLETE" { - break; - } - } - } - Err(e) => { - let err_msg = format!("Error processing message: {:?}", e); - let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg))); - continue; - } - } - } - Err(e) => { - let err_msg = format!("Error receiving message: {:?}", e); - let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg))); - continue; - } - } - } - _ = shutdown_rx.recv() => { - // Handle shutdown gracefully - let tools = agent.get_tools().await; - for (_, tool) in tools.iter() { - if let Err(e) = tool.handle_shutdown().await { - let err_msg = format!("Error shutting down tool: {:?}", e); - let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg))); - } - } - - let _ = agent.get_stream_sender().await.send( - Ok(AgentMessage::assistant( - Some("shutdown_message".to_string()), - Some("Exploratory agent shutting down gracefully".to_string()), - None, - None, - None, - )) - ); - break; - } - } - } - }); - - Ok(rx_return) + Ok(rx) } } diff --git a/api/src/utils/agent/agents/manager_agent.rs b/api/src/utils/agent/agents/manager_agent.rs index 9735bb542..b9a82b809 100644 --- a/api/src/utils/agent/agents/manager_agent.rs +++ b/api/src/utils/agent/agents/manager_agent.rs @@ -125,47 +125,9 @@ impl ManagerAgent { thread: &mut AgentThread, ) -> Result>> { thread.set_developer_message(MANAGER_AGENT_PROMPT.to_string()); - - // Use existing channel - important for sub-agents - let rx = self.get_agent().get_stream_receiver().await; - - // Get shutdown receiver - let mut shutdown_rx = self.get_agent().get_shutdown_receiver().await; - - // Clone only what we need - let agent = Arc::clone(self.get_agent()); - let thread = thread.clone(); - - tokio::spawn(async move { - tokio::select! { - result = agent.process_thread(&thread) => { - if let Err(e) = result { - let err_msg = format!("Manager agent processing failed: {:?}", e); - let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg))); - } - } - _ = shutdown_rx.recv() => { - // Shutdown all tools - let tools = agent.get_tools().await; - for (_, tool) in tools.iter() { - if let Err(e) = tool.handle_shutdown().await { - let err_msg = format!("Error shutting down tool: {:?}", e); - let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg))); - } - } - let _ = agent.get_stream_sender().await.send( - Ok(AgentMessage::assistant( - Some("shutdown_message".to_string()), - Some("Manager agent shutting down gracefully".to_string()), - None, - None, - None, - )) - ); - } - } - }); + // Get shutdown receiver + let rx = self.stream_process_thread(thread).await?; Ok(rx) } diff --git a/api/src/utils/agent/agents/metric_agent.rs b/api/src/utils/agent/agents/metric_agent.rs index 1db654582..80295c7a5 100644 --- a/api/src/utils/agent/agents/metric_agent.rs +++ b/api/src/utils/agent/agents/metric_agent.rs @@ -71,76 +71,9 @@ impl MetricAgent { thread.set_developer_message(METRIC_AGENT_PROMPT.to_string()); // Get shutdown receiver - let mut shutdown_rx = self.get_agent().get_shutdown_receiver().await; - let mut rx = self.stream_process_thread(thread).await?; + let rx = self.stream_process_thread(thread).await?; - // Clone what we need for the processing task - let agent = Arc::clone(self.get_agent()); - - let rx_return = rx.resubscribe(); - - tokio::spawn(async move { - loop { - tokio::select! { - recv_result = rx.recv() => { - match recv_result { - Ok(msg_result) => { - match msg_result { - Ok(msg) => { - // Forward message to stream sender - let sender = agent.get_stream_sender().await; - if let Err(e) = sender.send(Ok(msg.clone())) { - let err_msg = format!("Error forwarding message: {:?}", e); - let _ = sender.send(Err(AgentError(err_msg))); - continue; - } - - if let Some(content) = msg.get_content() { - if content == "AGENT_COMPLETE" { - break; - } - } - } - Err(e) => { - let err_msg = format!("Error processing message: {:?}", e); - let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg))); - continue; - } - } - } - Err(e) => { - let err_msg = format!("Error receiving message: {:?}", e); - let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg))); - continue; - } - } - } - _ = shutdown_rx.recv() => { - // Handle shutdown gracefully - let tools = agent.get_tools().await; - for (_, tool) in tools.iter() { - if let Err(e) = tool.handle_shutdown().await { - let err_msg = format!("Error shutting down tool: {:?}", e); - let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg))); - } - } - - let _ = agent.get_stream_sender().await.send( - Ok(AgentMessage::assistant( - Some("shutdown_message".to_string()), - Some("Metric agent shutting down gracefully".to_string()), - None, - None, - None, - )) - ); - break; - } - } - } - }); - - Ok(rx_return) + Ok(rx) } }