feat(agent): Add recursion depth limit to prevent infinite processing

- Implement a maximum recursion depth of 30 for agent thread processing
- Add recursion depth tracking to prevent potential infinite loops
- Provide user-friendly message when maximum recursion depth is reached
- Update debug logging to include current recursion depth
- Modify both synchronous and streaming thread processing methods
This commit is contained in:
dal 2025-02-07 11:35:13 -07:00
parent 8b51618afd
commit 4b743fe5ec
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
1 changed files with 26 additions and 5 deletions

View File

@ -77,6 +77,17 @@ impl Agent {
/// # Returns
/// * A Result containing the final Message from the assistant
pub async fn process_thread(&self, thread: &AgentThread) -> Result<Message> {
self.process_thread_with_depth(thread, 0).await
}
async fn process_thread_with_depth(&self, thread: &AgentThread, recursion_depth: u32) -> Result<Message> {
if recursion_depth >= 30 {
return Ok(Message::assistant(
Some("I apologize, but I've reached the maximum number of actions (30). Please try breaking your request into smaller parts.".to_string()),
None
));
}
// Collect all registered tools and their schemas
let tools: Vec<Tool> = self
.tools
@ -135,7 +146,7 @@ impl Agent {
new_thread.messages.push(message);
new_thread.messages.extend(results);
Box::pin(self.process_thread(&new_thread)).await
Box::pin(self.process_thread_with_depth(&new_thread, recursion_depth + 1)).await
} else {
Ok(message)
}
@ -167,7 +178,17 @@ impl Agent {
tools_ref: &Arc<HashMap<String, Box<dyn ToolExecutor<Output = Value>>>>,
thread: &AgentThread,
tx: &mpsc::Sender<Result<Message>>,
recursion_depth: u32,
) -> Result<()> {
if recursion_depth >= 30 {
let limit_message = Message::assistant(
Some("I apologize, but I've reached the maximum number of actions (30). Please try breaking your request into smaller parts.".to_string()),
None
);
let _ = tx.send(Ok(limit_message)).await;
return Ok(());
}
// Collect all registered tools and their schemas
let tools: Vec<Tool> = tools_ref
.iter()
@ -177,7 +198,7 @@ impl Agent {
})
.collect();
println!("DEBUG: Starting recursive stream with {} tools", tools.len());
println!("DEBUG: Starting recursive stream with {} tools at depth {}", tools.len(), recursion_depth);
// Create the request
let request = ChatCompletionRequest {
@ -253,7 +274,7 @@ impl Agent {
// If we had tool calls, execute them and recurse
if has_tool_calls {
println!("DEBUG: Processing {} tool calls recursively", pending_tool_calls.len());
println!("DEBUG: Processing {} tool calls recursively at depth {}", pending_tool_calls.len(), recursion_depth);
let mut tool_results = Vec::new();
// Execute all tools
@ -282,14 +303,14 @@ impl Agent {
new_thread.messages.extend(tool_results);
// Recurse with new thread
Box::pin(process_stream_recursive(llm_client, model, tools_ref, &new_thread, tx)).await?;
Box::pin(process_stream_recursive(llm_client, model, tools_ref, &new_thread, tx, recursion_depth + 1)).await?;
}
Ok(())
}
// Start recursive processing
if let Err(e) = process_stream_recursive(&llm_client, &model, &tools_ref, &thread, &tx).await {
if let Err(e) = process_stream_recursive(&llm_client, &model, &tools_ref, &thread, &tx, 0).await {
let _ = tx.send(Err(e)).await;
}
});