mirror of https://github.com/buster-so/buster.git
buffering content on stream
This commit is contained in:
parent
6ec195be52
commit
5af9a8e4eb
|
@ -8,6 +8,7 @@ use serde_json::Value;
|
|||
use std::{collections::HashMap, env, sync::Arc};
|
||||
use tokio::sync::{broadcast, RwLock};
|
||||
use uuid::Uuid;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use crate::models::AgentThread;
|
||||
|
||||
|
@ -24,6 +25,81 @@ impl std::fmt::Display for AgentError {
|
|||
|
||||
type MessageResult = Result<AgentMessage, AgentError>;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MessageBuffer {
|
||||
content: String,
|
||||
tool_calls: HashMap<String, PendingToolCall>,
|
||||
last_flush: Instant,
|
||||
message_id: Option<String>,
|
||||
first_message_sent: bool,
|
||||
}
|
||||
|
||||
|
||||
impl MessageBuffer {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
content: String::new(),
|
||||
tool_calls: HashMap::new(),
|
||||
last_flush: Instant::now(),
|
||||
message_id: None,
|
||||
first_message_sent: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn should_flush(&self) -> bool {
|
||||
self.last_flush.elapsed() >= Duration::from_millis(50)
|
||||
}
|
||||
|
||||
fn has_changes(&self) -> bool {
|
||||
!self.content.is_empty() || !self.tool_calls.is_empty()
|
||||
}
|
||||
|
||||
async fn flush(&mut self, agent: &Agent) -> Result<()> {
|
||||
if !self.has_changes() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Create tool calls vector if we have any
|
||||
let tool_calls: Option<Vec<ToolCall>> = if !self.tool_calls.is_empty() {
|
||||
Some(
|
||||
self.tool_calls
|
||||
.values()
|
||||
.filter_map(|p| {
|
||||
if p.function_name.is_some() {
|
||||
Some(p.clone().into_tool_call())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Create and send the message
|
||||
let message = AgentMessage::assistant(
|
||||
self.message_id.clone(),
|
||||
if self.content.is_empty() { None } else { Some(self.content.clone()) },
|
||||
tool_calls,
|
||||
MessageProgress::InProgress,
|
||||
Some(!self.first_message_sent),
|
||||
Some(agent.name.clone()),
|
||||
);
|
||||
|
||||
agent.get_stream_sender().await.send(Ok(message))?;
|
||||
|
||||
// Update state
|
||||
self.first_message_sent = true;
|
||||
self.last_flush = Instant::now();
|
||||
self.content.clear(); // Clear content but keep tool calls as they may still be accumulating
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[derive(Clone)]
|
||||
/// The Agent struct is responsible for managing conversations with the LLM
|
||||
/// and coordinating tool executions. It maintains a registry of available tools
|
||||
|
@ -347,12 +423,8 @@ impl Agent {
|
|||
};
|
||||
|
||||
// Process the streaming chunks
|
||||
let mut pending_tool_calls: HashMap<String, PendingToolCall> = HashMap::new();
|
||||
let mut content_buffer = String::new();
|
||||
let mut buffer = MessageBuffer::new();
|
||||
let mut is_complete = false;
|
||||
let mut message_id: Option<String> = None;
|
||||
// Flag to track if we've sent the first message
|
||||
let mut first_message_sent = false;
|
||||
|
||||
while let Some(chunk_result) = stream_rx.recv().await {
|
||||
match chunk_result {
|
||||
|
@ -361,34 +433,19 @@ impl Agent {
|
|||
continue;
|
||||
}
|
||||
|
||||
message_id = Some(chunk.id.clone());
|
||||
|
||||
buffer.message_id = Some(chunk.id.clone());
|
||||
let delta = &chunk.choices[0].delta;
|
||||
|
||||
// Accumulate content if present
|
||||
if let Some(content) = &delta.content {
|
||||
content_buffer.push_str(content);
|
||||
|
||||
// Stream the content update using the ID directly from this chunk's delta
|
||||
let partial_message = AgentMessage::assistant(
|
||||
message_id.clone(),
|
||||
Some(content_buffer.clone()),
|
||||
None,
|
||||
MessageProgress::InProgress,
|
||||
Some(!first_message_sent),
|
||||
Some(self.name.clone()),
|
||||
);
|
||||
|
||||
self.get_stream_sender().await.send(Ok(partial_message))?;
|
||||
first_message_sent = true;
|
||||
buffer.content.push_str(content);
|
||||
}
|
||||
|
||||
// Process tool calls if present
|
||||
if let Some(tool_calls) = &delta.tool_calls {
|
||||
for tool_call in tool_calls {
|
||||
let id = tool_call.id.clone().unwrap_or_else(|| {
|
||||
// If no ID is provided, use existing IDs or generate a new one
|
||||
pending_tool_calls
|
||||
buffer.tool_calls
|
||||
.keys()
|
||||
.next()
|
||||
.map(|s| s.clone())
|
||||
|
@ -396,44 +453,18 @@ impl Agent {
|
|||
});
|
||||
|
||||
// Get or create the pending tool call
|
||||
let pending_call = pending_tool_calls
|
||||
let pending_call = buffer.tool_calls
|
||||
.entry(id.clone())
|
||||
.or_insert_with(PendingToolCall::new);
|
||||
|
||||
// Update the pending call with the delta
|
||||
pending_call.update_from_delta(tool_call);
|
||||
}
|
||||
|
||||
// Stream the updated tool calls
|
||||
let tool_calls_vec: Vec<ToolCall> = pending_tool_calls
|
||||
.values()
|
||||
.filter_map(|p| {
|
||||
if p.function_name.is_some() {
|
||||
Some(p.clone().into_tool_call())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if !tool_calls_vec.is_empty() {
|
||||
let partial_message = AgentMessage::assistant(
|
||||
message_id.clone(),
|
||||
if content_buffer.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(content_buffer.clone())
|
||||
},
|
||||
Some(tool_calls_vec),
|
||||
MessageProgress::InProgress,
|
||||
Some(!first_message_sent), // Set initial=true only for the first message
|
||||
Some(self.name.clone()),
|
||||
);
|
||||
|
||||
self.get_stream_sender().await.send(Ok(partial_message))?;
|
||||
// Mark that we've sent the first message
|
||||
first_message_sent = true;
|
||||
}
|
||||
// Check if we should flush the buffer
|
||||
if buffer.should_flush() {
|
||||
buffer.flush(self).await?;
|
||||
}
|
||||
|
||||
// Check if this is the final chunk
|
||||
|
@ -445,10 +476,10 @@ impl Agent {
|
|||
}
|
||||
}
|
||||
|
||||
// Create the final assistant message
|
||||
let final_tool_calls: Option<Vec<ToolCall>> = if !pending_tool_calls.is_empty() {
|
||||
// Create and send the final message
|
||||
let final_tool_calls: Option<Vec<ToolCall>> = if !buffer.tool_calls.is_empty() {
|
||||
Some(
|
||||
pending_tool_calls
|
||||
buffer.tool_calls
|
||||
.values()
|
||||
.map(|p| p.clone().into_tool_call())
|
||||
.collect(),
|
||||
|
@ -458,12 +489,8 @@ impl Agent {
|
|||
};
|
||||
|
||||
let final_message = AgentMessage::assistant(
|
||||
message_id.clone(),
|
||||
if content_buffer.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(content_buffer)
|
||||
},
|
||||
buffer.message_id,
|
||||
if buffer.content.is_empty() { None } else { Some(buffer.content) },
|
||||
final_tool_calls.clone(),
|
||||
MessageProgress::Complete,
|
||||
Some(false),
|
||||
|
|
Loading…
Reference in New Issue