From 4207d8f714d8ff7678c979a73f99253623a48a02 Mon Sep 17 00:00:00 2001 From: dal Date: Tue, 4 Mar 2025 12:25:26 -0700 Subject: [PATCH] kill channel when agent finishes. --- api/libs/agents/src/agent.rs | 18 ++++++++++++------ .../handlers/src/chats/post_chat_handler.rs | 3 +++ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/api/libs/agents/src/agent.rs b/api/libs/agents/src/agent.rs index 6fa7af58b..5a50c7970 100644 --- a/api/libs/agents/src/agent.rs +++ b/api/libs/agents/src/agent.rs @@ -44,7 +44,7 @@ pub struct Agent { /// The current thread being processed, if any current_thread: Arc>>, /// Sender for streaming messages from this agent and sub-agents - stream_tx: Arc>>, + stream_tx: Arc>>>, /// The user ID for the current thread user_id: Uuid, /// The session ID for the current thread @@ -69,9 +69,7 @@ impl Agent { let llm_client = LiteLLMClient::new(Some(llm_api_key), Some(llm_base_url)); - // Create a broadcast channel with buffer size 1000 let (tx, _rx) = broadcast::channel(1000); - // Create shutdown channel with buffer size 1 let (shutdown_tx, _) = broadcast::channel(1); Self { @@ -80,7 +78,7 @@ impl Agent { model, state: Arc::new(RwLock::new(HashMap::new())), current_thread: Arc::new(RwLock::new(None)), - stream_tx: Arc::new(RwLock::new(tx)), + stream_tx: Arc::new(RwLock::new(Some(tx))), user_id, session_id, shutdown_tx: Arc::new(RwLock::new(shutdown_tx)), @@ -129,12 +127,12 @@ impl Agent { /// Get a new receiver for the broadcast channel pub async fn get_stream_receiver(&self) -> broadcast::Receiver { - self.stream_tx.read().await.subscribe() + self.stream_tx.read().await.as_ref().unwrap().subscribe() } /// Get a clone of the current stream sender pub async fn get_stream_sender(&self) -> broadcast::Sender { - self.stream_tx.read().await.clone() + self.stream_tx.read().await.as_ref().unwrap().clone() } /// Get a value from the agent's state by key @@ -319,6 +317,7 @@ impl Agent { Some(self.name.clone()), ); self.get_stream_sender().await.send(Ok(message))?; + self.close().await; return Ok(()); } @@ -481,6 +480,7 @@ impl Agent { // If this is an auto response without tool calls, it means we're done if final_tool_calls.is_none() { + self.close().await; return Ok(()); } @@ -545,6 +545,12 @@ impl Agent { > { self.tools.read().await } + + // Add this new method alongside other channel-related methods + pub async fn close(&self) { + let mut tx = self.stream_tx.write().await; + *tx = None; + } } #[derive(Debug, Default, Clone)] diff --git a/api/libs/handlers/src/chats/post_chat_handler.rs b/api/libs/handlers/src/chats/post_chat_handler.rs index e3def36be..17bacee3d 100644 --- a/api/libs/handlers/src/chats/post_chat_handler.rs +++ b/api/libs/handlers/src/chats/post_chat_handler.rs @@ -232,10 +232,13 @@ pub async fn post_chat_handler( tracing::error!("Error receiving message from agent: {}", e); // Don't return early, continue processing remaining messages + break; } } } + println!("Finishing up the agent and moving on to store the final message"); + let title = title_handle.await??; let reasoning_duration = reasoning_duration.elapsed().as_secs();