From 4b743fe5ecda92a41c7e18a4540e52d20bd180a0 Mon Sep 17 00:00:00 2001 From: dal Date: Fri, 7 Feb 2025 11:35:13 -0700 Subject: [PATCH] feat(agent): Add recursion depth limit to prevent infinite processing - Implement a maximum recursion depth of 30 for agent thread processing - Add recursion depth tracking to prevent potential infinite loops - Provide user-friendly message when maximum recursion depth is reached - Update debug logging to include current recursion depth - Modify both synchronous and streaming thread processing methods --- api/src/utils/agent/agent.rs | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/api/src/utils/agent/agent.rs b/api/src/utils/agent/agent.rs index 9ec1c2cae..1943b6c24 100644 --- a/api/src/utils/agent/agent.rs +++ b/api/src/utils/agent/agent.rs @@ -77,6 +77,17 @@ impl Agent { /// # Returns /// * A Result containing the final Message from the assistant pub async fn process_thread(&self, thread: &AgentThread) -> Result { + self.process_thread_with_depth(thread, 0).await + } + + async fn process_thread_with_depth(&self, thread: &AgentThread, recursion_depth: u32) -> Result { + if recursion_depth >= 30 { + return Ok(Message::assistant( + Some("I apologize, but I've reached the maximum number of actions (30). Please try breaking your request into smaller parts.".to_string()), + None + )); + } + // Collect all registered tools and their schemas let tools: Vec = self .tools @@ -135,7 +146,7 @@ impl Agent { new_thread.messages.push(message); new_thread.messages.extend(results); - Box::pin(self.process_thread(&new_thread)).await + Box::pin(self.process_thread_with_depth(&new_thread, recursion_depth + 1)).await } else { Ok(message) } @@ -167,7 +178,17 @@ impl Agent { tools_ref: &Arc>>>, thread: &AgentThread, tx: &mpsc::Sender>, + recursion_depth: u32, ) -> Result<()> { + if recursion_depth >= 30 { + let limit_message = Message::assistant( + Some("I apologize, but I've reached the maximum number of actions (30). Please try breaking your request into smaller parts.".to_string()), + None + ); + let _ = tx.send(Ok(limit_message)).await; + return Ok(()); + } + // Collect all registered tools and their schemas let tools: Vec = tools_ref .iter() @@ -177,7 +198,7 @@ impl Agent { }) .collect(); - println!("DEBUG: Starting recursive stream with {} tools", tools.len()); + println!("DEBUG: Starting recursive stream with {} tools at depth {}", tools.len(), recursion_depth); // Create the request let request = ChatCompletionRequest { @@ -253,7 +274,7 @@ impl Agent { // If we had tool calls, execute them and recurse if has_tool_calls { - println!("DEBUG: Processing {} tool calls recursively", pending_tool_calls.len()); + println!("DEBUG: Processing {} tool calls recursively at depth {}", pending_tool_calls.len(), recursion_depth); let mut tool_results = Vec::new(); // Execute all tools @@ -282,14 +303,14 @@ impl Agent { new_thread.messages.extend(tool_results); // Recurse with new thread - Box::pin(process_stream_recursive(llm_client, model, tools_ref, &new_thread, tx)).await?; + Box::pin(process_stream_recursive(llm_client, model, tools_ref, &new_thread, tx, recursion_depth + 1)).await?; } Ok(()) } // Start recursive processing - if let Err(e) = process_stream_recursive(&llm_client, &model, &tools_ref, &thread, &tx).await { + if let Err(e) = process_stream_recursive(&llm_client, &model, &tools_ref, &thread, &tx, 0).await { let _ = tx.send(Err(e)).await; } });