From 5af9a8e4eb4f7666bdfa50f676ad0f2c3ee27de7 Mon Sep 17 00:00:00 2001 From: dal Date: Wed, 5 Mar 2025 11:21:11 -0700 Subject: [PATCH] buffering content on stream --- api/libs/agents/src/agent.rs | 153 ++++++++++++++++++++--------------- 1 file changed, 90 insertions(+), 63 deletions(-) diff --git a/api/libs/agents/src/agent.rs b/api/libs/agents/src/agent.rs index 651ca2ac9..b23c05f1b 100644 --- a/api/libs/agents/src/agent.rs +++ b/api/libs/agents/src/agent.rs @@ -8,6 +8,7 @@ use serde_json::Value; use std::{collections::HashMap, env, sync::Arc}; use tokio::sync::{broadcast, RwLock}; use uuid::Uuid; +use std::time::{Duration, Instant}; use crate::models::AgentThread; @@ -24,6 +25,81 @@ impl std::fmt::Display for AgentError { type MessageResult = Result; +#[derive(Debug)] +struct MessageBuffer { + content: String, + tool_calls: HashMap, + last_flush: Instant, + message_id: Option, + first_message_sent: bool, +} + + +impl MessageBuffer { + fn new() -> Self { + Self { + content: String::new(), + tool_calls: HashMap::new(), + last_flush: Instant::now(), + message_id: None, + first_message_sent: false, + } + } + + fn should_flush(&self) -> bool { + self.last_flush.elapsed() >= Duration::from_millis(50) + } + + fn has_changes(&self) -> bool { + !self.content.is_empty() || !self.tool_calls.is_empty() + } + + async fn flush(&mut self, agent: &Agent) -> Result<()> { + if !self.has_changes() { + return Ok(()); + } + + // Create tool calls vector if we have any + let tool_calls: Option> = if !self.tool_calls.is_empty() { + Some( + self.tool_calls + .values() + .filter_map(|p| { + if p.function_name.is_some() { + Some(p.clone().into_tool_call()) + } else { + None + } + }) + .collect(), + ) + } else { + None + }; + + // Create and send the message + let message = AgentMessage::assistant( + self.message_id.clone(), + if self.content.is_empty() { None } else { Some(self.content.clone()) }, + tool_calls, + MessageProgress::InProgress, + Some(!self.first_message_sent), + Some(agent.name.clone()), + ); + + agent.get_stream_sender().await.send(Ok(message))?; + + // Update state + self.first_message_sent = true; + self.last_flush = Instant::now(); + self.content.clear(); // Clear content but keep tool calls as they may still be accumulating + + Ok(()) + } +} + + + #[derive(Clone)] /// The Agent struct is responsible for managing conversations with the LLM /// and coordinating tool executions. It maintains a registry of available tools @@ -347,12 +423,8 @@ impl Agent { }; // Process the streaming chunks - let mut pending_tool_calls: HashMap = HashMap::new(); - let mut content_buffer = String::new(); + let mut buffer = MessageBuffer::new(); let mut is_complete = false; - let mut message_id: Option = None; - // Flag to track if we've sent the first message - let mut first_message_sent = false; while let Some(chunk_result) = stream_rx.recv().await { match chunk_result { @@ -361,34 +433,19 @@ impl Agent { continue; } - message_id = Some(chunk.id.clone()); - + buffer.message_id = Some(chunk.id.clone()); let delta = &chunk.choices[0].delta; // Accumulate content if present if let Some(content) = &delta.content { - content_buffer.push_str(content); - - // Stream the content update using the ID directly from this chunk's delta - let partial_message = AgentMessage::assistant( - message_id.clone(), - Some(content_buffer.clone()), - None, - MessageProgress::InProgress, - Some(!first_message_sent), - Some(self.name.clone()), - ); - - self.get_stream_sender().await.send(Ok(partial_message))?; - first_message_sent = true; + buffer.content.push_str(content); } // Process tool calls if present if let Some(tool_calls) = &delta.tool_calls { for tool_call in tool_calls { let id = tool_call.id.clone().unwrap_or_else(|| { - // If no ID is provided, use existing IDs or generate a new one - pending_tool_calls + buffer.tool_calls .keys() .next() .map(|s| s.clone()) @@ -396,44 +453,18 @@ impl Agent { }); // Get or create the pending tool call - let pending_call = pending_tool_calls + let pending_call = buffer.tool_calls .entry(id.clone()) .or_insert_with(PendingToolCall::new); // Update the pending call with the delta pending_call.update_from_delta(tool_call); } + } - // Stream the updated tool calls - let tool_calls_vec: Vec = pending_tool_calls - .values() - .filter_map(|p| { - if p.function_name.is_some() { - Some(p.clone().into_tool_call()) - } else { - None - } - }) - .collect(); - - if !tool_calls_vec.is_empty() { - let partial_message = AgentMessage::assistant( - message_id.clone(), - if content_buffer.is_empty() { - None - } else { - Some(content_buffer.clone()) - }, - Some(tool_calls_vec), - MessageProgress::InProgress, - Some(!first_message_sent), // Set initial=true only for the first message - Some(self.name.clone()), - ); - - self.get_stream_sender().await.send(Ok(partial_message))?; - // Mark that we've sent the first message - first_message_sent = true; - } + // Check if we should flush the buffer + if buffer.should_flush() { + buffer.flush(self).await?; } // Check if this is the final chunk @@ -445,10 +476,10 @@ impl Agent { } } - // Create the final assistant message - let final_tool_calls: Option> = if !pending_tool_calls.is_empty() { + // Create and send the final message + let final_tool_calls: Option> = if !buffer.tool_calls.is_empty() { Some( - pending_tool_calls + buffer.tool_calls .values() .map(|p| p.clone().into_tool_call()) .collect(), @@ -458,12 +489,8 @@ impl Agent { }; let final_message = AgentMessage::assistant( - message_id.clone(), - if content_buffer.is_empty() { - None - } else { - Some(content_buffer) - }, + buffer.message_id, + if buffer.content.is_empty() { None } else { Some(buffer.content) }, final_tool_calls.clone(), MessageProgress::Complete, Some(false),