refactor(agent): Implement recursive stream processing with improved tool execution

- Refactor agent stream processing to use a recursive approach for handling tool calls
- Enhance tool execution with more robust error handling and result tracking
- Improve stream chunk processing with detailed state management
- Add support for recursive thread generation based on tool call results
- Implement cloning for LiteLLMClient to support stream processing tasks
This commit is contained in:
dal 2025-02-07 11:22:43 -07:00
parent bb4e4ca9d8
commit 8b51618afd
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
2 changed files with 125 additions and 127 deletions

View File

@ -10,6 +10,7 @@ use serde::Serialize;
use super::types::AgentThread; use super::types::AgentThread;
#[derive(Clone)]
/// The Agent struct is responsible for managing conversations with the LLM /// The Agent struct is responsible for managing conversations with the LLM
/// and coordinating tool executions. It maintains a registry of available tools /// and coordinating tool executions. It maintains a registry of available tools
/// and handles the recursive nature of tool calls. /// and handles the recursive nature of tool calls.
@ -151,9 +152,24 @@ impl Agent {
&self, &self,
thread: &AgentThread, thread: &AgentThread,
) -> Result<mpsc::Receiver<Result<Message>>> { ) -> Result<mpsc::Receiver<Result<Message>>> {
let (tx, rx) = mpsc::channel(100);
let tools_ref = self.tools.clone();
let model = self.model.clone();
let llm_client = self.llm_client.clone();
// Clone thread for task ownership
let thread = thread.clone();
tokio::spawn(async move {
async fn process_stream_recursive(
llm_client: &LiteLLMClient,
model: &str,
tools_ref: &Arc<HashMap<String, Box<dyn ToolExecutor<Output = Value>>>>,
thread: &AgentThread,
tx: &mpsc::Sender<Result<Message>>,
) -> Result<()> {
// Collect all registered tools and their schemas // Collect all registered tools and their schemas
let tools: Vec<Tool> = self let tools: Vec<Tool> = tools_ref
.tools
.iter() .iter()
.map(|(name, tool)| Tool { .map(|(name, tool)| Tool {
tool_type: "function".to_string(), tool_type: "function".to_string(),
@ -161,126 +177,67 @@ impl Agent {
}) })
.collect(); .collect();
println!("DEBUG: Starting stream_process_thread with {} tools", tools.len()); println!("DEBUG: Starting recursive stream with {} tools", tools.len());
println!("DEBUG: Tools registered: {:?}", tools);
// Create the request to send to the LLM // Create the request
let request = ChatCompletionRequest { let request = ChatCompletionRequest {
model: self.model.clone(), model: model.to_string(),
messages: thread.messages.clone(), messages: thread.messages.clone(),
tools: Some(tools), tools: Some(tools),
stream: Some(true), stream: Some(true),
..Default::default() ..Default::default()
}; };
println!("DEBUG: Created chat completion request with model: {}", self.model); // Get streaming response
let mut stream = llm_client.stream_chat_completion(request).await?;
// Get the streaming response from the LLM
let mut stream = self.llm_client.stream_chat_completion(request).await?;
let (tx, rx) = mpsc::channel(100);
let mut pending_tool_calls = HashMap::new(); let mut pending_tool_calls = HashMap::new();
// Clone the Arc for use in the spawned task
let tools_ref = self.tools.clone();
println!("DEBUG: Stream initialized, starting processing task");
// Process the stream in a separate task
tokio::spawn(async move {
let mut current_message = Message::assistant(None, None); let mut current_message = Message::assistant(None, None);
let mut has_tool_calls = false;
// Process stream chunks
while let Some(chunk_result) = stream.recv().await { while let Some(chunk_result) = stream.recv().await {
println!("DEBUG: Received new stream chunk");
match chunk_result { match chunk_result {
Ok(chunk) => { Ok(chunk) => {
let delta = &chunk.choices[0].delta; let delta = &chunk.choices[0].delta;
println!("DEBUG: Processing delta: {:?}", delta);
// Handle role changes // Handle role changes
if let Some(role) = &delta.role { if let Some(role) = &delta.role {
println!("DEBUG: Role change detected: {}", role);
match role.as_str() { match role.as_str() {
"assistant" => { "assistant" => {
current_message = Message::assistant(None, None); current_message = Message::assistant(None, None);
println!("DEBUG: Reset current_message for assistant"); let _ = tx.send(Ok(current_message.clone())).await;
}
"tool" => {
println!("DEBUG: Tool role detected, waiting for content");
continue;
} }
_ => continue, _ => continue,
} }
} }
// Handle tool calls (tool execution start) // Handle tool calls
if let Some(tool_calls) = &delta.tool_calls { if let Some(tool_calls) = &delta.tool_calls {
println!("DEBUG: Processing {} tool calls", tool_calls.len()); has_tool_calls = true;
for tool_call in tool_calls { for tool_call in tool_calls {
println!("DEBUG: Tool call detected - ID: {}, Name: {}", println!("DEBUG: Tool call detected - ID: {}", tool_call.id);
tool_call.id,
tool_call.function.name);
// Store or update the tool call
pending_tool_calls.insert(tool_call.id.clone(), tool_call.clone()); pending_tool_calls.insert(tool_call.id.clone(), tool_call.clone());
// Check if this tool call is complete and ready for execution // Update current message with tool calls
if let Some(complete_tool_call) = pending_tool_calls.get(&tool_call.id) { if let Message::Assistant { tool_calls: ref mut msg_tool_calls, .. } = current_message {
if let Some(tool) = tools_ref.get(&complete_tool_call.function.name) { *msg_tool_calls = Some(pending_tool_calls.values().cloned().collect());
println!("DEBUG: Executing tool: {}", complete_tool_call.function.name);
// Execute the tool
match tool.execute(complete_tool_call).await {
Ok(result) => {
let result_str = serde_json::to_string(&result)
.unwrap_or_else(|e| format!("Error serializing result: {}", e));
println!("DEBUG: Tool execution successful: {}", result_str);
// Send tool result message
let tool_result_msg = Message::tool(
result_str,
complete_tool_call.id.clone(),
);
let _ = tx.send(Ok(tool_result_msg)).await;
}
Err(e) => {
println!("DEBUG: Tool execution failed: {:?}", e);
let error_msg = format!("Tool execution failed: {:?}", e);
let tool_error_msg = Message::tool(
error_msg,
complete_tool_call.id.clone(),
);
let _ = tx.send(Ok(tool_error_msg)).await;
}
} }
// Remove the executed tool call let _ = tx.send(Ok(current_message.clone())).await;
pending_tool_calls.remove(&tool_call.id);
}
}
// Send the tool start message
let tool_start_msg = Message::assistant(None, Some(vec![tool_call.clone()]));
let _ = tx.send(Ok(tool_start_msg)).await;
} }
} }
// Handle content updates // Handle content updates
if let Some(content) = &delta.content { if let Some(content) = &delta.content {
println!("DEBUG: Content update received: {}", content);
match &mut current_message { match &mut current_message {
Message::Assistant { content: msg_content, tool_calls, .. } => { Message::Assistant { content: msg_content, .. } => {
*msg_content = Some(if let Some(existing) = msg_content { *msg_content = Some(
let combined = format!("{}{}", existing, content); if let Some(existing) = msg_content {
println!("DEBUG: Updated assistant content: {}", combined); format!("{}{}", existing, content)
combined
} else { } else {
println!("DEBUG: New assistant content: {}", content);
content.clone() content.clone()
});
} }
Message::Tool { content: msg_content, .. } => { );
println!("DEBUG: Updating tool content: {}", content);
*msg_content = content.clone();
} }
_ => {} _ => {}
} }
@ -288,15 +245,55 @@ impl Agent {
} }
} }
Err(e) => { Err(e) => {
println!("DEBUG: Error processing stream chunk: {:?}", e);
let _ = tx.send(Err(anyhow::Error::from(e))).await; let _ = tx.send(Err(anyhow::Error::from(e))).await;
return Ok(());
} }
} }
} }
println!("DEBUG: Stream processing completed");
// If we had tool calls, execute them and recurse
if has_tool_calls {
println!("DEBUG: Processing {} tool calls recursively", pending_tool_calls.len());
let mut tool_results = Vec::new();
// Execute all tools
for tool_call in pending_tool_calls.values() {
if let Some(tool) = tools_ref.get(&tool_call.function.name) {
match tool.execute(tool_call).await {
Ok(result) => {
let result_str = serde_json::to_string(&result)?;
let tool_result = Message::tool(result_str, tool_call.id.clone());
tool_results.push(tool_result.clone());
let _ = tx.send(Ok(tool_result)).await;
}
Err(e) => {
let error_msg = format!("Tool execution failed: {:?}", e);
let tool_error = Message::tool(error_msg, tool_call.id.clone());
tool_results.push(tool_error.clone());
let _ = tx.send(Ok(tool_error)).await;
}
}
}
}
// Create new thread with tool results
let mut new_thread = thread.clone();
new_thread.messages.push(current_message);
new_thread.messages.extend(tool_results);
// Recurse with new thread
Box::pin(process_stream_recursive(llm_client, model, tools_ref, &new_thread, tx)).await?;
}
Ok(())
}
// Start recursive processing
if let Err(e) = process_stream_recursive(&llm_client, &model, &tools_ref, &thread, &tx).await {
let _ = tx.send(Err(e)).await;
}
}); });
println!("DEBUG: Returning stream receiver");
Ok(rx) Ok(rx)
} }
} }

View File

@ -6,6 +6,7 @@ use tokio::sync::mpsc;
use super::types::*; use super::types::*;
#[derive(Clone)]
pub struct LiteLLMClient { pub struct LiteLLMClient {
client: Client, client: Client,
pub(crate) api_key: String, pub(crate) api_key: String,