mirror of https://github.com/buster-so/buster.git
Merge branch 'evals' into big-nate/bus-939-create-new-structure-for-chats
This commit is contained in:
commit
11bd8eb6f3
|
@ -44,7 +44,7 @@ pub struct Agent {
|
|||
/// The current thread being processed, if any
|
||||
current_thread: Arc<RwLock<Option<AgentThread>>>,
|
||||
/// Sender for streaming messages from this agent and sub-agents
|
||||
stream_tx: Arc<RwLock<broadcast::Sender<MessageResult>>>,
|
||||
stream_tx: Arc<RwLock<Option<broadcast::Sender<MessageResult>>>>,
|
||||
/// The user ID for the current thread
|
||||
user_id: Uuid,
|
||||
/// The session ID for the current thread
|
||||
|
@ -69,9 +69,7 @@ impl Agent {
|
|||
|
||||
let llm_client = LiteLLMClient::new(Some(llm_api_key), Some(llm_base_url));
|
||||
|
||||
// Create a broadcast channel with buffer size 1000
|
||||
let (tx, _rx) = broadcast::channel(1000);
|
||||
// Create shutdown channel with buffer size 1
|
||||
let (shutdown_tx, _) = broadcast::channel(1);
|
||||
|
||||
Self {
|
||||
|
@ -80,7 +78,7 @@ impl Agent {
|
|||
model,
|
||||
state: Arc::new(RwLock::new(HashMap::new())),
|
||||
current_thread: Arc::new(RwLock::new(None)),
|
||||
stream_tx: Arc::new(RwLock::new(tx)),
|
||||
stream_tx: Arc::new(RwLock::new(Some(tx))),
|
||||
user_id,
|
||||
session_id,
|
||||
shutdown_tx: Arc::new(RwLock::new(shutdown_tx)),
|
||||
|
@ -129,12 +127,12 @@ impl Agent {
|
|||
|
||||
/// Get a new receiver for the broadcast channel
|
||||
pub async fn get_stream_receiver(&self) -> broadcast::Receiver<MessageResult> {
|
||||
self.stream_tx.read().await.subscribe()
|
||||
self.stream_tx.read().await.as_ref().unwrap().subscribe()
|
||||
}
|
||||
|
||||
/// Get a clone of the current stream sender
|
||||
pub async fn get_stream_sender(&self) -> broadcast::Sender<MessageResult> {
|
||||
self.stream_tx.read().await.clone()
|
||||
self.stream_tx.read().await.as_ref().unwrap().clone()
|
||||
}
|
||||
|
||||
/// Get a value from the agent's state by key
|
||||
|
@ -319,6 +317,7 @@ impl Agent {
|
|||
Some(self.name.clone()),
|
||||
);
|
||||
self.get_stream_sender().await.send(Ok(message))?;
|
||||
self.close().await;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
|
@ -481,6 +480,7 @@ impl Agent {
|
|||
|
||||
// If this is an auto response without tool calls, it means we're done
|
||||
if final_tool_calls.is_none() {
|
||||
self.close().await;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
|
@ -545,6 +545,12 @@ impl Agent {
|
|||
> {
|
||||
self.tools.read().await
|
||||
}
|
||||
|
||||
// Add this new method alongside other channel-related methods
|
||||
pub async fn close(&self) {
|
||||
let mut tx = self.stream_tx.write().await;
|
||||
*tx = None;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
|
|
|
@ -232,10 +232,13 @@ pub async fn post_chat_handler(
|
|||
|
||||
tracing::error!("Error receiving message from agent: {}", e);
|
||||
// Don't return early, continue processing remaining messages
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("Finishing up the agent and moving on to store the final message");
|
||||
|
||||
let title = title_handle.await??;
|
||||
let reasoning_duration = reasoning_duration.elapsed().as_secs();
|
||||
|
||||
|
|
Loading…
Reference in New Issue