channel upgrade

This commit is contained in:
dal 2025-04-16 22:57:38 -06:00
parent 15e74d4575
commit 32dd1891a4
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
1 changed files with 102 additions and 55 deletions

View File

@ -116,9 +116,14 @@ impl MessageBuffer {
);
// Continue on error with broadcast::error::SendError
if let Err(e) = agent.get_stream_sender().await.send(Ok(message)) {
// Log warning but don't fail the operation
tracing::warn!("Channel send error, message may be dropped: {}", e);
// Ensure we handle the Result from get_stream_sender first
if let Ok(sender) = agent.get_stream_sender().await {
if let Err(e) = sender.send(Ok(message)) {
// Log warning but don't fail the operation
tracing::warn!("Channel send error, message may be dropped: {}", e);
}
} else {
tracing::warn!("Stream sender not available, message dropped.");
}
// Update state
@ -195,7 +200,7 @@ impl Agent {
let llm_client = LiteLLMClient::new(api_key, base_url);
// When creating a new agent, initialize broadcast channel with higher capacity for better concurrency
let (tx, _rx) = broadcast::channel(5000);
let (tx, _rx) = broadcast::channel(10000);
// Increase shutdown channel capacity to avoid blocking
let (shutdown_tx, _) = broadcast::channel(100);
@ -266,14 +271,22 @@ impl Agent {
enabled_tools
}
/// Get a new receiver for the broadcast channel
pub async fn get_stream_receiver(&self) -> broadcast::Receiver<MessageResult> {
self.stream_tx.read().await.as_ref().unwrap().subscribe()
/// Get a new receiver for the broadcast channel.
/// Returns an error if the stream channel has been closed or was not initialized.
pub async fn get_stream_receiver(&self) -> Result<broadcast::Receiver<MessageResult>, AgentError> {
match self.stream_tx.read().await.as_ref() {
Some(tx) => Ok(tx.subscribe()),
None => Err(AgentError("Stream channel is closed or not initialized.".to_string()))
}
}
/// Get a clone of the current stream sender
pub async fn get_stream_sender(&self) -> broadcast::Sender<MessageResult> {
self.stream_tx.read().await.as_ref().unwrap().clone()
/// Get a clone of the current stream sender.
/// Returns an error if the stream channel has been closed or was not initialized.
pub async fn get_stream_sender(&self) -> Result<broadcast::Sender<MessageResult>, AgentError> {
match self.stream_tx.read().await.as_ref() {
Some(tx) => Ok(tx.clone()),
None => Err(AgentError("Stream channel is closed or not initialized.".to_string()))
}
}
/// Get a value from the agent's state by key
@ -467,13 +480,23 @@ impl Agent {
let err_msg = format!("Error processing thread: {:?}", e);
error!("{}", err_msg); // Log the error
// Use the clone created before select!
if let Err(send_err) = agent_clone_for_post_process.get_stream_sender().await.send(Err(AgentError(err_msg.clone()))) {
tracing::warn!("Failed to send error message to stream: {}", send_err);
// Handle the Result from get_stream_sender
if let Ok(sender) = agent_clone_for_post_process.get_stream_sender().await {
if let Err(send_err) = sender.send(Err(AgentError(err_msg.clone()))) {
tracing::warn!("Failed to send error message to stream: {}", send_err);
}
} else {
tracing::warn!("Stream sender not available when trying to send error message.");
}
}
// Use the clone created before select!
if let Err(e) = agent_clone_for_post_process.get_stream_sender().await.send(Ok(AgentMessage::Done)) {
tracing::debug!("Failed to send Done message, receiver likely dropped: {}", e);
// Use the clone created before select!
// Handle the Result from get_stream_sender
if let Ok(sender) = agent_clone_for_post_process.get_stream_sender().await {
if let Err(e) = sender.send(Ok(AgentMessage::Done)) {
tracing::debug!("Failed to send Done message, receiver likely dropped: {}", e);
}
} else {
tracing::debug!("Stream sender not available when trying to send Done message.");
}
},
_ = shutdown_rx.recv() => {
@ -487,18 +510,29 @@ impl Agent {
None,
Some(agent_clone_shutdown.name.clone()),
);
if let Err(e) = agent_clone_shutdown.get_stream_sender().await.send(Ok(shutdown_msg)) {
tracing::warn!("Failed to send shutdown notification: {}", e);
// Handle the Result from get_stream_sender
if let Ok(sender) = agent_clone_shutdown.get_stream_sender().await {
if let Err(e) = sender.send(Ok(shutdown_msg)) {
tracing::warn!("Failed to send shutdown notification: {}", e);
}
} else {
tracing::warn!("Stream sender not available when trying to send shutdown notification.");
}
if let Err(e) = agent_clone_for_post_process.clone().get_stream_sender().await.send(Ok(AgentMessage::Done)) {
tracing::debug!("Failed to send Done message after shutdown, receiver likely dropped: {}", e);
// Handle the Result from get_stream_sender
if let Ok(sender) = agent_clone_for_post_process.clone().get_stream_sender().await {
if let Err(e) = sender.send(Ok(AgentMessage::Done)) {
tracing::debug!("Failed to send Done message after shutdown, receiver likely dropped: {}", e);
}
} else {
tracing::debug!("Stream sender not available when trying to send Done message after shutdown.");
}
}
}
});
Ok(agent_for_ok.get_stream_receiver().await)
// Handle the Result from get_stream_receiver
agent_for_ok.get_stream_receiver().await.map_err(|e| e.into())
}
async fn process_thread_with_depth(
@ -577,11 +611,16 @@ impl Agent {
None,
Some(agent.name.clone()),
);
if let Err(e) = agent.get_stream_sender().await.send(Ok(message)) {
tracing::warn!(
"Channel send error when sending recursion limit message: {}",
e
);
// Handle the Result from get_stream_sender
if let Ok(sender) = agent.get_stream_sender().await {
if let Err(e) = sender.send(Ok(message)) {
tracing::warn!(
"Channel send error when sending recursion limit message: {}",
e
);
}
} else {
tracing::warn!("Stream sender not available when sending recursion limit message.");
}
agent.close().await; // Ensure stream is closed
return Ok(()); // Don't return error, just stop processing
@ -782,15 +821,16 @@ impl Agent {
// Broadcast the final assistant message
// Ensure we don't block if the receiver dropped
if let Err(e) = agent
.get_stream_sender()
.await
.send(Ok(final_message.clone()))
{
tracing::debug!(
"Failed to send final assistant message (receiver likely dropped): {}",
e
);
// Handle the Result from get_stream_sender
if let Ok(sender) = agent.get_stream_sender().await {
if let Err(e) = sender.send(Ok(final_message.clone())) {
tracing::debug!(
"Failed to send final assistant message (receiver likely dropped): {}",
e
);
}
} else {
tracing::debug!("Stream sender not available when sending final assistant message.");
}
// Update thread with assistant message
@ -949,15 +989,16 @@ impl Agent {
}
// Broadcast the tool message as soon as we receive it - use try_send to avoid blocking
if let Err(e) = agent
.get_stream_sender()
.await
.send(Ok(tool_message.clone()))
{
tracing::debug!(
"Failed to send tool message (receiver likely dropped): {}",
e
);
// Handle the Result from get_stream_sender
if let Ok(sender) = agent.get_stream_sender().await {
if let Err(e) = sender.send(Ok(tool_message.clone())) {
tracing::debug!(
"Failed to send tool message (receiver likely dropped): {}",
e
);
}
} else {
tracing::debug!("Stream sender not available when sending tool message.");
}
// Update thread with tool response BEFORE checking termination
@ -989,15 +1030,16 @@ impl Agent {
MessageProgress::Complete,
);
// Broadcast the error message
if let Err(e) = agent
.get_stream_sender()
.await
.send(Ok(error_result.clone()))
{
tracing::debug!(
"Failed to send tool error message (receiver likely dropped): {}",
e
);
// Handle the Result from get_stream_sender
if let Ok(sender) = agent.get_stream_sender().await {
if let Err(e) = sender.send(Ok(error_result.clone())) {
tracing::debug!(
"Failed to send tool error message (receiver likely dropped): {}",
e
);
}
} else {
tracing::debug!("Stream sender not available when sending tool error message.");
}
// Update thread and push the error result for the next LLM call
agent.update_current_thread(error_result.clone()).await?;
@ -1010,8 +1052,13 @@ impl Agent {
// Finish the trace without consuming it
agent.finish_trace(&trace_builder).await?;
// Send Done message
if let Err(e) = agent.get_stream_sender().await.send(Ok(AgentMessage::Done)) {
tracing::debug!("Failed to send Done message after tool termination (receiver likely dropped): {}", e);
// Handle the Result from get_stream_sender
if let Ok(sender) = agent.get_stream_sender().await {
if let Err(e) = sender.send(Ok(AgentMessage::Done)) {
tracing::debug!("Failed to send Done message after tool termination (receiver likely dropped): {}", e);
}
} else {
tracing::debug!("Stream sender not available when sending Done message after tool termination.");
}
return Ok(()); // Exit the function, preventing recursion
}
@ -1261,7 +1308,7 @@ mod tests {
) -> Result<()> {
let message =
AgentMessage::tool(None, content, tool_id, Some(self.get_name()), progress);
self.agent.get_stream_sender().await.send(Ok(message))?;
self.agent.get_stream_sender().await?.send(Ok(message))?;
Ok(())
}
}