mirror of https://github.com/buster-so/buster.git
refactor(agent): Implement recursive stream processing with improved tool execution
- Refactor agent stream processing to use a recursive approach for handling tool calls - Enhance tool execution with more robust error handling and result tracking - Improve stream chunk processing with detailed state management - Add support for recursive thread generation based on tool call results - Implement cloning for LiteLLMClient to support stream processing tasks
This commit is contained in:
parent
bb4e4ca9d8
commit
8b51618afd
|
@ -10,6 +10,7 @@ use serde::Serialize;
|
||||||
|
|
||||||
use super::types::AgentThread;
|
use super::types::AgentThread;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
/// The Agent struct is responsible for managing conversations with the LLM
|
/// The Agent struct is responsible for managing conversations with the LLM
|
||||||
/// and coordinating tool executions. It maintains a registry of available tools
|
/// and coordinating tool executions. It maintains a registry of available tools
|
||||||
/// and handles the recursive nature of tool calls.
|
/// and handles the recursive nature of tool calls.
|
||||||
|
@ -151,152 +152,148 @@ impl Agent {
|
||||||
&self,
|
&self,
|
||||||
thread: &AgentThread,
|
thread: &AgentThread,
|
||||||
) -> Result<mpsc::Receiver<Result<Message>>> {
|
) -> Result<mpsc::Receiver<Result<Message>>> {
|
||||||
// Collect all registered tools and their schemas
|
|
||||||
let tools: Vec<Tool> = self
|
|
||||||
.tools
|
|
||||||
.iter()
|
|
||||||
.map(|(name, tool)| Tool {
|
|
||||||
tool_type: "function".to_string(),
|
|
||||||
function: tool.get_schema(),
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
println!("DEBUG: Starting stream_process_thread with {} tools", tools.len());
|
|
||||||
println!("DEBUG: Tools registered: {:?}", tools);
|
|
||||||
|
|
||||||
// Create the request to send to the LLM
|
|
||||||
let request = ChatCompletionRequest {
|
|
||||||
model: self.model.clone(),
|
|
||||||
messages: thread.messages.clone(),
|
|
||||||
tools: Some(tools),
|
|
||||||
stream: Some(true),
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
|
|
||||||
println!("DEBUG: Created chat completion request with model: {}", self.model);
|
|
||||||
|
|
||||||
// Get the streaming response from the LLM
|
|
||||||
let mut stream = self.llm_client.stream_chat_completion(request).await?;
|
|
||||||
let (tx, rx) = mpsc::channel(100);
|
let (tx, rx) = mpsc::channel(100);
|
||||||
let mut pending_tool_calls = HashMap::new();
|
|
||||||
|
|
||||||
// Clone the Arc for use in the spawned task
|
|
||||||
let tools_ref = self.tools.clone();
|
let tools_ref = self.tools.clone();
|
||||||
|
let model = self.model.clone();
|
||||||
|
let llm_client = self.llm_client.clone();
|
||||||
|
|
||||||
println!("DEBUG: Stream initialized, starting processing task");
|
// Clone thread for task ownership
|
||||||
|
let thread = thread.clone();
|
||||||
|
|
||||||
// Process the stream in a separate task
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let mut current_message = Message::assistant(None, None);
|
async fn process_stream_recursive(
|
||||||
|
llm_client: &LiteLLMClient,
|
||||||
|
model: &str,
|
||||||
|
tools_ref: &Arc<HashMap<String, Box<dyn ToolExecutor<Output = Value>>>>,
|
||||||
|
thread: &AgentThread,
|
||||||
|
tx: &mpsc::Sender<Result<Message>>,
|
||||||
|
) -> Result<()> {
|
||||||
|
// Collect all registered tools and their schemas
|
||||||
|
let tools: Vec<Tool> = tools_ref
|
||||||
|
.iter()
|
||||||
|
.map(|(name, tool)| Tool {
|
||||||
|
tool_type: "function".to_string(),
|
||||||
|
function: tool.get_schema(),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
while let Some(chunk_result) = stream.recv().await {
|
println!("DEBUG: Starting recursive stream with {} tools", tools.len());
|
||||||
println!("DEBUG: Received new stream chunk");
|
|
||||||
match chunk_result {
|
|
||||||
Ok(chunk) => {
|
|
||||||
let delta = &chunk.choices[0].delta;
|
|
||||||
println!("DEBUG: Processing delta: {:?}", delta);
|
|
||||||
|
|
||||||
// Handle role changes
|
// Create the request
|
||||||
if let Some(role) = &delta.role {
|
let request = ChatCompletionRequest {
|
||||||
println!("DEBUG: Role change detected: {}", role);
|
model: model.to_string(),
|
||||||
match role.as_str() {
|
messages: thread.messages.clone(),
|
||||||
"assistant" => {
|
tools: Some(tools),
|
||||||
current_message = Message::assistant(None, None);
|
stream: Some(true),
|
||||||
println!("DEBUG: Reset current_message for assistant");
|
..Default::default()
|
||||||
}
|
};
|
||||||
"tool" => {
|
|
||||||
println!("DEBUG: Tool role detected, waiting for content");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
_ => continue,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle tool calls (tool execution start)
|
// Get streaming response
|
||||||
if let Some(tool_calls) = &delta.tool_calls {
|
let mut stream = llm_client.stream_chat_completion(request).await?;
|
||||||
println!("DEBUG: Processing {} tool calls", tool_calls.len());
|
let mut pending_tool_calls = HashMap::new();
|
||||||
for tool_call in tool_calls {
|
let mut current_message = Message::assistant(None, None);
|
||||||
println!("DEBUG: Tool call detected - ID: {}, Name: {}",
|
let mut has_tool_calls = false;
|
||||||
tool_call.id,
|
|
||||||
tool_call.function.name);
|
// Process stream chunks
|
||||||
|
while let Some(chunk_result) = stream.recv().await {
|
||||||
// Store or update the tool call
|
match chunk_result {
|
||||||
pending_tool_calls.insert(tool_call.id.clone(), tool_call.clone());
|
Ok(chunk) => {
|
||||||
|
let delta = &chunk.choices[0].delta;
|
||||||
// Check if this tool call is complete and ready for execution
|
|
||||||
if let Some(complete_tool_call) = pending_tool_calls.get(&tool_call.id) {
|
// Handle role changes
|
||||||
if let Some(tool) = tools_ref.get(&complete_tool_call.function.name) {
|
if let Some(role) = &delta.role {
|
||||||
println!("DEBUG: Executing tool: {}", complete_tool_call.function.name);
|
match role.as_str() {
|
||||||
|
"assistant" => {
|
||||||
// Execute the tool
|
current_message = Message::assistant(None, None);
|
||||||
match tool.execute(complete_tool_call).await {
|
let _ = tx.send(Ok(current_message.clone())).await;
|
||||||
Ok(result) => {
|
|
||||||
let result_str = serde_json::to_string(&result)
|
|
||||||
.unwrap_or_else(|e| format!("Error serializing result: {}", e));
|
|
||||||
println!("DEBUG: Tool execution successful: {}", result_str);
|
|
||||||
|
|
||||||
// Send tool result message
|
|
||||||
let tool_result_msg = Message::tool(
|
|
||||||
result_str,
|
|
||||||
complete_tool_call.id.clone(),
|
|
||||||
);
|
|
||||||
let _ = tx.send(Ok(tool_result_msg)).await;
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
println!("DEBUG: Tool execution failed: {:?}", e);
|
|
||||||
let error_msg = format!("Tool execution failed: {:?}", e);
|
|
||||||
let tool_error_msg = Message::tool(
|
|
||||||
error_msg,
|
|
||||||
complete_tool_call.id.clone(),
|
|
||||||
);
|
|
||||||
let _ = tx.send(Ok(tool_error_msg)).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove the executed tool call
|
|
||||||
pending_tool_calls.remove(&tool_call.id);
|
|
||||||
}
|
}
|
||||||
|
_ => continue,
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Send the tool start message
|
// Handle tool calls
|
||||||
let tool_start_msg = Message::assistant(None, Some(vec![tool_call.clone()]));
|
if let Some(tool_calls) = &delta.tool_calls {
|
||||||
let _ = tx.send(Ok(tool_start_msg)).await;
|
has_tool_calls = true;
|
||||||
|
for tool_call in tool_calls {
|
||||||
|
println!("DEBUG: Tool call detected - ID: {}", tool_call.id);
|
||||||
|
pending_tool_calls.insert(tool_call.id.clone(), tool_call.clone());
|
||||||
|
|
||||||
|
// Update current message with tool calls
|
||||||
|
if let Message::Assistant { tool_calls: ref mut msg_tool_calls, .. } = current_message {
|
||||||
|
*msg_tool_calls = Some(pending_tool_calls.values().cloned().collect());
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = tx.send(Ok(current_message.clone())).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle content updates
|
||||||
|
if let Some(content) = &delta.content {
|
||||||
|
match &mut current_message {
|
||||||
|
Message::Assistant { content: msg_content, .. } => {
|
||||||
|
*msg_content = Some(
|
||||||
|
if let Some(existing) = msg_content {
|
||||||
|
format!("{}{}", existing, content)
|
||||||
|
} else {
|
||||||
|
content.clone()
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
let _ = tx.send(Ok(current_message.clone())).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Err(e) => {
|
||||||
// Handle content updates
|
let _ = tx.send(Err(anyhow::Error::from(e))).await;
|
||||||
if let Some(content) = &delta.content {
|
return Ok(());
|
||||||
println!("DEBUG: Content update received: {}", content);
|
|
||||||
match &mut current_message {
|
|
||||||
Message::Assistant { content: msg_content, tool_calls, .. } => {
|
|
||||||
*msg_content = Some(if let Some(existing) = msg_content {
|
|
||||||
let combined = format!("{}{}", existing, content);
|
|
||||||
println!("DEBUG: Updated assistant content: {}", combined);
|
|
||||||
combined
|
|
||||||
} else {
|
|
||||||
println!("DEBUG: New assistant content: {}", content);
|
|
||||||
content.clone()
|
|
||||||
});
|
|
||||||
}
|
|
||||||
Message::Tool { content: msg_content, .. } => {
|
|
||||||
println!("DEBUG: Updating tool content: {}", content);
|
|
||||||
*msg_content = content.clone();
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
let _ = tx.send(Ok(current_message.clone())).await;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
|
||||||
println!("DEBUG: Error processing stream chunk: {:?}", e);
|
|
||||||
let _ = tx.send(Err(anyhow::Error::from(e))).await;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we had tool calls, execute them and recurse
|
||||||
|
if has_tool_calls {
|
||||||
|
println!("DEBUG: Processing {} tool calls recursively", pending_tool_calls.len());
|
||||||
|
let mut tool_results = Vec::new();
|
||||||
|
|
||||||
|
// Execute all tools
|
||||||
|
for tool_call in pending_tool_calls.values() {
|
||||||
|
if let Some(tool) = tools_ref.get(&tool_call.function.name) {
|
||||||
|
match tool.execute(tool_call).await {
|
||||||
|
Ok(result) => {
|
||||||
|
let result_str = serde_json::to_string(&result)?;
|
||||||
|
let tool_result = Message::tool(result_str, tool_call.id.clone());
|
||||||
|
tool_results.push(tool_result.clone());
|
||||||
|
let _ = tx.send(Ok(tool_result)).await;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let error_msg = format!("Tool execution failed: {:?}", e);
|
||||||
|
let tool_error = Message::tool(error_msg, tool_call.id.clone());
|
||||||
|
tool_results.push(tool_error.clone());
|
||||||
|
let _ = tx.send(Ok(tool_error)).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new thread with tool results
|
||||||
|
let mut new_thread = thread.clone();
|
||||||
|
new_thread.messages.push(current_message);
|
||||||
|
new_thread.messages.extend(tool_results);
|
||||||
|
|
||||||
|
// Recurse with new thread
|
||||||
|
Box::pin(process_stream_recursive(llm_client, model, tools_ref, &new_thread, tx)).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start recursive processing
|
||||||
|
if let Err(e) = process_stream_recursive(&llm_client, &model, &tools_ref, &thread, &tx).await {
|
||||||
|
let _ = tx.send(Err(e)).await;
|
||||||
}
|
}
|
||||||
println!("DEBUG: Stream processing completed");
|
|
||||||
});
|
});
|
||||||
|
|
||||||
println!("DEBUG: Returning stream receiver");
|
|
||||||
Ok(rx)
|
Ok(rx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ use tokio::sync::mpsc;
|
||||||
|
|
||||||
use super::types::*;
|
use super::types::*;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct LiteLLMClient {
|
pub struct LiteLLMClient {
|
||||||
client: Client,
|
client: Client,
|
||||||
pub(crate) api_key: String,
|
pub(crate) api_key: String,
|
||||||
|
|
Loading…
Reference in New Issue