Merge branch 'evals' into big-nate/bus-939-create-new-structure-for-chats

This commit is contained in:
Nate Kelley 2025-03-04 12:29:04 -07:00
commit 11bd8eb6f3
No known key found for this signature in database
GPG Key ID: FD90372AB8D98B4F
2 changed files with 15 additions and 6 deletions

View File

@ -44,7 +44,7 @@ pub struct Agent {
/// The current thread being processed, if any
current_thread: Arc<RwLock<Option<AgentThread>>>,
/// Sender for streaming messages from this agent and sub-agents
stream_tx: Arc<RwLock<broadcast::Sender<MessageResult>>>,
stream_tx: Arc<RwLock<Option<broadcast::Sender<MessageResult>>>>,
/// The user ID for the current thread
user_id: Uuid,
/// The session ID for the current thread
@ -69,9 +69,7 @@ impl Agent {
let llm_client = LiteLLMClient::new(Some(llm_api_key), Some(llm_base_url));
// Create a broadcast channel with buffer size 1000
let (tx, _rx) = broadcast::channel(1000);
// Create shutdown channel with buffer size 1
let (shutdown_tx, _) = broadcast::channel(1);
Self {
@ -80,7 +78,7 @@ impl Agent {
model,
state: Arc::new(RwLock::new(HashMap::new())),
current_thread: Arc::new(RwLock::new(None)),
stream_tx: Arc::new(RwLock::new(tx)),
stream_tx: Arc::new(RwLock::new(Some(tx))),
user_id,
session_id,
shutdown_tx: Arc::new(RwLock::new(shutdown_tx)),
@ -129,12 +127,12 @@ impl Agent {
/// Get a new receiver for the broadcast channel
pub async fn get_stream_receiver(&self) -> broadcast::Receiver<MessageResult> {
self.stream_tx.read().await.subscribe()
self.stream_tx.read().await.as_ref().unwrap().subscribe()
}
/// Get a clone of the current stream sender
pub async fn get_stream_sender(&self) -> broadcast::Sender<MessageResult> {
self.stream_tx.read().await.clone()
self.stream_tx.read().await.as_ref().unwrap().clone()
}
/// Get a value from the agent's state by key
@ -319,6 +317,7 @@ impl Agent {
Some(self.name.clone()),
);
self.get_stream_sender().await.send(Ok(message))?;
self.close().await;
return Ok(());
}
@ -481,6 +480,7 @@ impl Agent {
// If this is an auto response without tool calls, it means we're done
if final_tool_calls.is_none() {
self.close().await;
return Ok(());
}
@ -545,6 +545,12 @@ impl Agent {
> {
self.tools.read().await
}
// Add this new method alongside other channel-related methods
pub async fn close(&self) {
let mut tx = self.stream_tx.write().await;
*tx = None;
}
}
#[derive(Debug, Default, Clone)]

View File

@ -232,10 +232,13 @@ pub async fn post_chat_handler(
tracing::error!("Error receiving message from agent: {}", e);
// Don't return early, continue processing remaining messages
break;
}
}
}
println!("Finishing up the agent and moving on to store the final message");
let title = title_handle.await??;
let reasoning_duration = reasoning_duration.elapsed().as_secs();