Merge branch 'evals' into big-nate/bus-939-create-new-structure-for-chats

This commit is contained in:
Nate Kelley 2025-03-05 12:21:34 -07:00
commit d9c62c7043
No known key found for this signature in database
GPG Key ID: FD90372AB8D98B4F
1 changed files with 90 additions and 63 deletions

View File

@ -8,6 +8,7 @@ use serde_json::Value;
use std::{collections::HashMap, env, sync::Arc};
use tokio::sync::{broadcast, RwLock};
use uuid::Uuid;
use std::time::{Duration, Instant};
use crate::models::AgentThread;
@ -24,6 +25,81 @@ impl std::fmt::Display for AgentError {
type MessageResult = Result<AgentMessage, AgentError>;
#[derive(Debug)]
struct MessageBuffer {
content: String,
tool_calls: HashMap<String, PendingToolCall>,
last_flush: Instant,
message_id: Option<String>,
first_message_sent: bool,
}
impl MessageBuffer {
fn new() -> Self {
Self {
content: String::new(),
tool_calls: HashMap::new(),
last_flush: Instant::now(),
message_id: None,
first_message_sent: false,
}
}
fn should_flush(&self) -> bool {
self.last_flush.elapsed() >= Duration::from_millis(50)
}
fn has_changes(&self) -> bool {
!self.content.is_empty() || !self.tool_calls.is_empty()
}
async fn flush(&mut self, agent: &Agent) -> Result<()> {
if !self.has_changes() {
return Ok(());
}
// Create tool calls vector if we have any
let tool_calls: Option<Vec<ToolCall>> = if !self.tool_calls.is_empty() {
Some(
self.tool_calls
.values()
.filter_map(|p| {
if p.function_name.is_some() {
Some(p.clone().into_tool_call())
} else {
None
}
})
.collect(),
)
} else {
None
};
// Create and send the message
let message = AgentMessage::assistant(
self.message_id.clone(),
if self.content.is_empty() { None } else { Some(self.content.clone()) },
tool_calls,
MessageProgress::InProgress,
Some(!self.first_message_sent),
Some(agent.name.clone()),
);
agent.get_stream_sender().await.send(Ok(message))?;
// Update state
self.first_message_sent = true;
self.last_flush = Instant::now();
self.content.clear(); // Clear content but keep tool calls as they may still be accumulating
Ok(())
}
}
#[derive(Clone)]
/// The Agent struct is responsible for managing conversations with the LLM
/// and coordinating tool executions. It maintains a registry of available tools
@ -347,12 +423,8 @@ impl Agent {
};
// Process the streaming chunks
let mut pending_tool_calls: HashMap<String, PendingToolCall> = HashMap::new();
let mut content_buffer = String::new();
let mut buffer = MessageBuffer::new();
let mut is_complete = false;
let mut message_id: Option<String> = None;
// Flag to track if we've sent the first message
let mut first_message_sent = false;
while let Some(chunk_result) = stream_rx.recv().await {
match chunk_result {
@ -361,34 +433,19 @@ impl Agent {
continue;
}
message_id = Some(chunk.id.clone());
buffer.message_id = Some(chunk.id.clone());
let delta = &chunk.choices[0].delta;
// Accumulate content if present
if let Some(content) = &delta.content {
content_buffer.push_str(content);
// Stream the content update using the ID directly from this chunk's delta
let partial_message = AgentMessage::assistant(
message_id.clone(),
Some(content_buffer.clone()),
None,
MessageProgress::InProgress,
Some(!first_message_sent),
Some(self.name.clone()),
);
self.get_stream_sender().await.send(Ok(partial_message))?;
first_message_sent = true;
buffer.content.push_str(content);
}
// Process tool calls if present
if let Some(tool_calls) = &delta.tool_calls {
for tool_call in tool_calls {
let id = tool_call.id.clone().unwrap_or_else(|| {
// If no ID is provided, use existing IDs or generate a new one
pending_tool_calls
buffer.tool_calls
.keys()
.next()
.map(|s| s.clone())
@ -396,44 +453,18 @@ impl Agent {
});
// Get or create the pending tool call
let pending_call = pending_tool_calls
let pending_call = buffer.tool_calls
.entry(id.clone())
.or_insert_with(PendingToolCall::new);
// Update the pending call with the delta
pending_call.update_from_delta(tool_call);
}
}
// Stream the updated tool calls
let tool_calls_vec: Vec<ToolCall> = pending_tool_calls
.values()
.filter_map(|p| {
if p.function_name.is_some() {
Some(p.clone().into_tool_call())
} else {
None
}
})
.collect();
if !tool_calls_vec.is_empty() {
let partial_message = AgentMessage::assistant(
message_id.clone(),
if content_buffer.is_empty() {
None
} else {
Some(content_buffer.clone())
},
Some(tool_calls_vec),
MessageProgress::InProgress,
Some(!first_message_sent), // Set initial=true only for the first message
Some(self.name.clone()),
);
self.get_stream_sender().await.send(Ok(partial_message))?;
// Mark that we've sent the first message
first_message_sent = true;
}
// Check if we should flush the buffer
if buffer.should_flush() {
buffer.flush(self).await?;
}
// Check if this is the final chunk
@ -445,10 +476,10 @@ impl Agent {
}
}
// Create the final assistant message
let final_tool_calls: Option<Vec<ToolCall>> = if !pending_tool_calls.is_empty() {
// Create and send the final message
let final_tool_calls: Option<Vec<ToolCall>> = if !buffer.tool_calls.is_empty() {
Some(
pending_tool_calls
buffer.tool_calls
.values()
.map(|p| p.clone().into_tool_call())
.collect(),
@ -458,12 +489,8 @@ impl Agent {
};
let final_message = AgentMessage::assistant(
message_id.clone(),
if content_buffer.is_empty() {
None
} else {
Some(content_buffer)
},
buffer.message_id,
if buffer.content.is_empty() { None } else { Some(buffer.content) },
final_tool_calls.clone(),
MessageProgress::Complete,
Some(false),