channel upgrade

This commit is contained in:
dal 2025-04-16 22:57:38 -06:00
parent 15e74d4575
commit 32dd1891a4
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
1 changed files with 102 additions and 55 deletions

View File

@ -116,10 +116,15 @@ impl MessageBuffer {
); );
// Continue on error with broadcast::error::SendError // Continue on error with broadcast::error::SendError
if let Err(e) = agent.get_stream_sender().await.send(Ok(message)) { // Ensure we handle the Result from get_stream_sender first
if let Ok(sender) = agent.get_stream_sender().await {
if let Err(e) = sender.send(Ok(message)) {
// Log warning but don't fail the operation // Log warning but don't fail the operation
tracing::warn!("Channel send error, message may be dropped: {}", e); tracing::warn!("Channel send error, message may be dropped: {}", e);
} }
} else {
tracing::warn!("Stream sender not available, message dropped.");
}
// Update state // Update state
self.first_message_sent = true; self.first_message_sent = true;
@ -195,7 +200,7 @@ impl Agent {
let llm_client = LiteLLMClient::new(api_key, base_url); let llm_client = LiteLLMClient::new(api_key, base_url);
// When creating a new agent, initialize broadcast channel with higher capacity for better concurrency // When creating a new agent, initialize broadcast channel with higher capacity for better concurrency
let (tx, _rx) = broadcast::channel(5000); let (tx, _rx) = broadcast::channel(10000);
// Increase shutdown channel capacity to avoid blocking // Increase shutdown channel capacity to avoid blocking
let (shutdown_tx, _) = broadcast::channel(100); let (shutdown_tx, _) = broadcast::channel(100);
@ -266,14 +271,22 @@ impl Agent {
enabled_tools enabled_tools
} }
/// Get a new receiver for the broadcast channel /// Get a new receiver for the broadcast channel.
pub async fn get_stream_receiver(&self) -> broadcast::Receiver<MessageResult> { /// Returns an error if the stream channel has been closed or was not initialized.
self.stream_tx.read().await.as_ref().unwrap().subscribe() pub async fn get_stream_receiver(&self) -> Result<broadcast::Receiver<MessageResult>, AgentError> {
match self.stream_tx.read().await.as_ref() {
Some(tx) => Ok(tx.subscribe()),
None => Err(AgentError("Stream channel is closed or not initialized.".to_string()))
}
} }
/// Get a clone of the current stream sender /// Get a clone of the current stream sender.
pub async fn get_stream_sender(&self) -> broadcast::Sender<MessageResult> { /// Returns an error if the stream channel has been closed or was not initialized.
self.stream_tx.read().await.as_ref().unwrap().clone() pub async fn get_stream_sender(&self) -> Result<broadcast::Sender<MessageResult>, AgentError> {
match self.stream_tx.read().await.as_ref() {
Some(tx) => Ok(tx.clone()),
None => Err(AgentError("Stream channel is closed or not initialized.".to_string()))
}
} }
/// Get a value from the agent's state by key /// Get a value from the agent's state by key
@ -467,14 +480,24 @@ impl Agent {
let err_msg = format!("Error processing thread: {:?}", e); let err_msg = format!("Error processing thread: {:?}", e);
error!("{}", err_msg); // Log the error error!("{}", err_msg); // Log the error
// Use the clone created before select! // Use the clone created before select!
if let Err(send_err) = agent_clone_for_post_process.get_stream_sender().await.send(Err(AgentError(err_msg.clone()))) { // Handle the Result from get_stream_sender
if let Ok(sender) = agent_clone_for_post_process.get_stream_sender().await {
if let Err(send_err) = sender.send(Err(AgentError(err_msg.clone()))) {
tracing::warn!("Failed to send error message to stream: {}", send_err); tracing::warn!("Failed to send error message to stream: {}", send_err);
} }
} else {
tracing::warn!("Stream sender not available when trying to send error message.");
}
} }
// Use the clone created before select! // Use the clone created before select!
if let Err(e) = agent_clone_for_post_process.get_stream_sender().await.send(Ok(AgentMessage::Done)) { // Handle the Result from get_stream_sender
if let Ok(sender) = agent_clone_for_post_process.get_stream_sender().await {
if let Err(e) = sender.send(Ok(AgentMessage::Done)) {
tracing::debug!("Failed to send Done message, receiver likely dropped: {}", e); tracing::debug!("Failed to send Done message, receiver likely dropped: {}", e);
} }
} else {
tracing::debug!("Stream sender not available when trying to send Done message.");
}
}, },
_ = shutdown_rx.recv() => { _ = shutdown_rx.recv() => {
// Use the clone created before select! // Use the clone created before select!
@ -487,18 +510,29 @@ impl Agent {
None, None,
Some(agent_clone_shutdown.name.clone()), Some(agent_clone_shutdown.name.clone()),
); );
if let Err(e) = agent_clone_shutdown.get_stream_sender().await.send(Ok(shutdown_msg)) { // Handle the Result from get_stream_sender
if let Ok(sender) = agent_clone_shutdown.get_stream_sender().await {
if let Err(e) = sender.send(Ok(shutdown_msg)) {
tracing::warn!("Failed to send shutdown notification: {}", e); tracing::warn!("Failed to send shutdown notification: {}", e);
} }
} else {
tracing::warn!("Stream sender not available when trying to send shutdown notification.");
}
if let Err(e) = agent_clone_for_post_process.clone().get_stream_sender().await.send(Ok(AgentMessage::Done)) { // Handle the Result from get_stream_sender
if let Ok(sender) = agent_clone_for_post_process.clone().get_stream_sender().await {
if let Err(e) = sender.send(Ok(AgentMessage::Done)) {
tracing::debug!("Failed to send Done message after shutdown, receiver likely dropped: {}", e); tracing::debug!("Failed to send Done message after shutdown, receiver likely dropped: {}", e);
} }
} else {
tracing::debug!("Stream sender not available when trying to send Done message after shutdown.");
}
} }
} }
}); });
Ok(agent_for_ok.get_stream_receiver().await) // Handle the Result from get_stream_receiver
agent_for_ok.get_stream_receiver().await.map_err(|e| e.into())
} }
async fn process_thread_with_depth( async fn process_thread_with_depth(
@ -577,12 +611,17 @@ impl Agent {
None, None,
Some(agent.name.clone()), Some(agent.name.clone()),
); );
if let Err(e) = agent.get_stream_sender().await.send(Ok(message)) { // Handle the Result from get_stream_sender
if let Ok(sender) = agent.get_stream_sender().await {
if let Err(e) = sender.send(Ok(message)) {
tracing::warn!( tracing::warn!(
"Channel send error when sending recursion limit message: {}", "Channel send error when sending recursion limit message: {}",
e e
); );
} }
} else {
tracing::warn!("Stream sender not available when sending recursion limit message.");
}
agent.close().await; // Ensure stream is closed agent.close().await; // Ensure stream is closed
return Ok(()); // Don't return error, just stop processing return Ok(()); // Don't return error, just stop processing
} }
@ -782,16 +821,17 @@ impl Agent {
// Broadcast the final assistant message // Broadcast the final assistant message
// Ensure we don't block if the receiver dropped // Ensure we don't block if the receiver dropped
if let Err(e) = agent // Handle the Result from get_stream_sender
.get_stream_sender() if let Ok(sender) = agent.get_stream_sender().await {
.await if let Err(e) = sender.send(Ok(final_message.clone())) {
.send(Ok(final_message.clone()))
{
tracing::debug!( tracing::debug!(
"Failed to send final assistant message (receiver likely dropped): {}", "Failed to send final assistant message (receiver likely dropped): {}",
e e
); );
} }
} else {
tracing::debug!("Stream sender not available when sending final assistant message.");
}
// Update thread with assistant message // Update thread with assistant message
agent.update_current_thread(final_message.clone()).await?; agent.update_current_thread(final_message.clone()).await?;
@ -949,16 +989,17 @@ impl Agent {
} }
// Broadcast the tool message as soon as we receive it - use try_send to avoid blocking // Broadcast the tool message as soon as we receive it - use try_send to avoid blocking
if let Err(e) = agent // Handle the Result from get_stream_sender
.get_stream_sender() if let Ok(sender) = agent.get_stream_sender().await {
.await if let Err(e) = sender.send(Ok(tool_message.clone())) {
.send(Ok(tool_message.clone()))
{
tracing::debug!( tracing::debug!(
"Failed to send tool message (receiver likely dropped): {}", "Failed to send tool message (receiver likely dropped): {}",
e e
); );
} }
} else {
tracing::debug!("Stream sender not available when sending tool message.");
}
// Update thread with tool response BEFORE checking termination // Update thread with tool response BEFORE checking termination
agent.update_current_thread(tool_message.clone()).await?; agent.update_current_thread(tool_message.clone()).await?;
@ -989,16 +1030,17 @@ impl Agent {
MessageProgress::Complete, MessageProgress::Complete,
); );
// Broadcast the error message // Broadcast the error message
if let Err(e) = agent // Handle the Result from get_stream_sender
.get_stream_sender() if let Ok(sender) = agent.get_stream_sender().await {
.await if let Err(e) = sender.send(Ok(error_result.clone())) {
.send(Ok(error_result.clone()))
{
tracing::debug!( tracing::debug!(
"Failed to send tool error message (receiver likely dropped): {}", "Failed to send tool error message (receiver likely dropped): {}",
e e
); );
} }
} else {
tracing::debug!("Stream sender not available when sending tool error message.");
}
// Update thread and push the error result for the next LLM call // Update thread and push the error result for the next LLM call
agent.update_current_thread(error_result.clone()).await?; agent.update_current_thread(error_result.clone()).await?;
// Continue processing other tool calls if any // Continue processing other tool calls if any
@ -1010,9 +1052,14 @@ impl Agent {
// Finish the trace without consuming it // Finish the trace without consuming it
agent.finish_trace(&trace_builder).await?; agent.finish_trace(&trace_builder).await?;
// Send Done message // Send Done message
if let Err(e) = agent.get_stream_sender().await.send(Ok(AgentMessage::Done)) { // Handle the Result from get_stream_sender
if let Ok(sender) = agent.get_stream_sender().await {
if let Err(e) = sender.send(Ok(AgentMessage::Done)) {
tracing::debug!("Failed to send Done message after tool termination (receiver likely dropped): {}", e); tracing::debug!("Failed to send Done message after tool termination (receiver likely dropped): {}", e);
} }
} else {
tracing::debug!("Stream sender not available when sending Done message after tool termination.");
}
return Ok(()); // Exit the function, preventing recursion return Ok(()); // Exit the function, preventing recursion
} }
@ -1261,7 +1308,7 @@ mod tests {
) -> Result<()> { ) -> Result<()> {
let message = let message =
AgentMessage::tool(None, content, tool_id, Some(self.get_name()), progress); AgentMessage::tool(None, content, tool_id, Some(self.get_name()), progress);
self.agent.get_stream_sender().await.send(Ok(message))?; self.agent.get_stream_sender().await?.send(Ok(message))?;
Ok(()) Ok(())
} }
} }