refactor(agent): Implement recursive stream processing with improved tool execution

- Refactor agent stream processing to use a recursive approach for handling tool calls
- Enhance tool execution with more robust error handling and result tracking
- Improve stream chunk processing with detailed state management
- Add support for recursive thread generation based on tool call results
- Implement cloning for LiteLLMClient to support stream processing tasks
This commit is contained in:
dal 2025-02-07 11:22:43 -07:00
parent bb4e4ca9d8
commit 8b51618afd
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
2 changed files with 125 additions and 127 deletions

View File

@ -10,6 +10,7 @@ use serde::Serialize;
use super::types::AgentThread; use super::types::AgentThread;
#[derive(Clone)]
/// The Agent struct is responsible for managing conversations with the LLM /// The Agent struct is responsible for managing conversations with the LLM
/// and coordinating tool executions. It maintains a registry of available tools /// and coordinating tool executions. It maintains a registry of available tools
/// and handles the recursive nature of tool calls. /// and handles the recursive nature of tool calls.
@ -151,152 +152,148 @@ impl Agent {
&self, &self,
thread: &AgentThread, thread: &AgentThread,
) -> Result<mpsc::Receiver<Result<Message>>> { ) -> Result<mpsc::Receiver<Result<Message>>> {
// Collect all registered tools and their schemas
let tools: Vec<Tool> = self
.tools
.iter()
.map(|(name, tool)| Tool {
tool_type: "function".to_string(),
function: tool.get_schema(),
})
.collect();
println!("DEBUG: Starting stream_process_thread with {} tools", tools.len());
println!("DEBUG: Tools registered: {:?}", tools);
// Create the request to send to the LLM
let request = ChatCompletionRequest {
model: self.model.clone(),
messages: thread.messages.clone(),
tools: Some(tools),
stream: Some(true),
..Default::default()
};
println!("DEBUG: Created chat completion request with model: {}", self.model);
// Get the streaming response from the LLM
let mut stream = self.llm_client.stream_chat_completion(request).await?;
let (tx, rx) = mpsc::channel(100); let (tx, rx) = mpsc::channel(100);
let mut pending_tool_calls = HashMap::new();
// Clone the Arc for use in the spawned task
let tools_ref = self.tools.clone(); let tools_ref = self.tools.clone();
let model = self.model.clone();
let llm_client = self.llm_client.clone();
println!("DEBUG: Stream initialized, starting processing task"); // Clone thread for task ownership
let thread = thread.clone();
// Process the stream in a separate task
tokio::spawn(async move { tokio::spawn(async move {
let mut current_message = Message::assistant(None, None); async fn process_stream_recursive(
llm_client: &LiteLLMClient,
model: &str,
tools_ref: &Arc<HashMap<String, Box<dyn ToolExecutor<Output = Value>>>>,
thread: &AgentThread,
tx: &mpsc::Sender<Result<Message>>,
) -> Result<()> {
// Collect all registered tools and their schemas
let tools: Vec<Tool> = tools_ref
.iter()
.map(|(name, tool)| Tool {
tool_type: "function".to_string(),
function: tool.get_schema(),
})
.collect();
while let Some(chunk_result) = stream.recv().await { println!("DEBUG: Starting recursive stream with {} tools", tools.len());
println!("DEBUG: Received new stream chunk");
match chunk_result {
Ok(chunk) => {
let delta = &chunk.choices[0].delta;
println!("DEBUG: Processing delta: {:?}", delta);
// Handle role changes // Create the request
if let Some(role) = &delta.role { let request = ChatCompletionRequest {
println!("DEBUG: Role change detected: {}", role); model: model.to_string(),
match role.as_str() { messages: thread.messages.clone(),
"assistant" => { tools: Some(tools),
current_message = Message::assistant(None, None); stream: Some(true),
println!("DEBUG: Reset current_message for assistant"); ..Default::default()
} };
"tool" => {
println!("DEBUG: Tool role detected, waiting for content");
continue;
}
_ => continue,
}
}
// Handle tool calls (tool execution start) // Get streaming response
if let Some(tool_calls) = &delta.tool_calls { let mut stream = llm_client.stream_chat_completion(request).await?;
println!("DEBUG: Processing {} tool calls", tool_calls.len()); let mut pending_tool_calls = HashMap::new();
for tool_call in tool_calls { let mut current_message = Message::assistant(None, None);
println!("DEBUG: Tool call detected - ID: {}, Name: {}", let mut has_tool_calls = false;
tool_call.id,
tool_call.function.name);
// Store or update the tool call // Process stream chunks
pending_tool_calls.insert(tool_call.id.clone(), tool_call.clone()); while let Some(chunk_result) = stream.recv().await {
match chunk_result {
Ok(chunk) => {
let delta = &chunk.choices[0].delta;
// Check if this tool call is complete and ready for execution // Handle role changes
if let Some(complete_tool_call) = pending_tool_calls.get(&tool_call.id) { if let Some(role) = &delta.role {
if let Some(tool) = tools_ref.get(&complete_tool_call.function.name) { match role.as_str() {
println!("DEBUG: Executing tool: {}", complete_tool_call.function.name); "assistant" => {
current_message = Message::assistant(None, None);
// Execute the tool let _ = tx.send(Ok(current_message.clone())).await;
match tool.execute(complete_tool_call).await {
Ok(result) => {
let result_str = serde_json::to_string(&result)
.unwrap_or_else(|e| format!("Error serializing result: {}", e));
println!("DEBUG: Tool execution successful: {}", result_str);
// Send tool result message
let tool_result_msg = Message::tool(
result_str,
complete_tool_call.id.clone(),
);
let _ = tx.send(Ok(tool_result_msg)).await;
}
Err(e) => {
println!("DEBUG: Tool execution failed: {:?}", e);
let error_msg = format!("Tool execution failed: {:?}", e);
let tool_error_msg = Message::tool(
error_msg,
complete_tool_call.id.clone(),
);
let _ = tx.send(Ok(tool_error_msg)).await;
}
}
// Remove the executed tool call
pending_tool_calls.remove(&tool_call.id);
} }
_ => continue,
} }
}
// Send the tool start message // Handle tool calls
let tool_start_msg = Message::assistant(None, Some(vec![tool_call.clone()])); if let Some(tool_calls) = &delta.tool_calls {
let _ = tx.send(Ok(tool_start_msg)).await; has_tool_calls = true;
for tool_call in tool_calls {
println!("DEBUG: Tool call detected - ID: {}", tool_call.id);
pending_tool_calls.insert(tool_call.id.clone(), tool_call.clone());
// Update current message with tool calls
if let Message::Assistant { tool_calls: ref mut msg_tool_calls, .. } = current_message {
*msg_tool_calls = Some(pending_tool_calls.values().cloned().collect());
}
let _ = tx.send(Ok(current_message.clone())).await;
}
}
// Handle content updates
if let Some(content) = &delta.content {
match &mut current_message {
Message::Assistant { content: msg_content, .. } => {
*msg_content = Some(
if let Some(existing) = msg_content {
format!("{}{}", existing, content)
} else {
content.clone()
}
);
}
_ => {}
}
let _ = tx.send(Ok(current_message.clone())).await;
} }
} }
Err(e) => {
// Handle content updates let _ = tx.send(Err(anyhow::Error::from(e))).await;
if let Some(content) = &delta.content { return Ok(());
println!("DEBUG: Content update received: {}", content);
match &mut current_message {
Message::Assistant { content: msg_content, tool_calls, .. } => {
*msg_content = Some(if let Some(existing) = msg_content {
let combined = format!("{}{}", existing, content);
println!("DEBUG: Updated assistant content: {}", combined);
combined
} else {
println!("DEBUG: New assistant content: {}", content);
content.clone()
});
}
Message::Tool { content: msg_content, .. } => {
println!("DEBUG: Updating tool content: {}", content);
*msg_content = content.clone();
}
_ => {}
}
let _ = tx.send(Ok(current_message.clone())).await;
} }
} }
Err(e) => {
println!("DEBUG: Error processing stream chunk: {:?}", e);
let _ = tx.send(Err(anyhow::Error::from(e))).await;
}
} }
// If we had tool calls, execute them and recurse
if has_tool_calls {
println!("DEBUG: Processing {} tool calls recursively", pending_tool_calls.len());
let mut tool_results = Vec::new();
// Execute all tools
for tool_call in pending_tool_calls.values() {
if let Some(tool) = tools_ref.get(&tool_call.function.name) {
match tool.execute(tool_call).await {
Ok(result) => {
let result_str = serde_json::to_string(&result)?;
let tool_result = Message::tool(result_str, tool_call.id.clone());
tool_results.push(tool_result.clone());
let _ = tx.send(Ok(tool_result)).await;
}
Err(e) => {
let error_msg = format!("Tool execution failed: {:?}", e);
let tool_error = Message::tool(error_msg, tool_call.id.clone());
tool_results.push(tool_error.clone());
let _ = tx.send(Ok(tool_error)).await;
}
}
}
}
// Create new thread with tool results
let mut new_thread = thread.clone();
new_thread.messages.push(current_message);
new_thread.messages.extend(tool_results);
// Recurse with new thread
Box::pin(process_stream_recursive(llm_client, model, tools_ref, &new_thread, tx)).await?;
}
Ok(())
}
// Start recursive processing
if let Err(e) = process_stream_recursive(&llm_client, &model, &tools_ref, &thread, &tx).await {
let _ = tx.send(Err(e)).await;
} }
println!("DEBUG: Stream processing completed");
}); });
println!("DEBUG: Returning stream receiver");
Ok(rx) Ok(rx)
} }
} }

View File

@ -6,6 +6,7 @@ use tokio::sync::mpsc;
use super::types::*; use super::types::*;
#[derive(Clone)]
pub struct LiteLLMClient { pub struct LiteLLMClient {
client: Client, client: Client,
pub(crate) api_key: String, pub(crate) api_key: String,