diff --git a/api/src/utils/agent/agent.rs b/api/src/utils/agent/agent.rs index 1099f1bf9..9ec1c2cae 100644 --- a/api/src/utils/agent/agent.rs +++ b/api/src/utils/agent/agent.rs @@ -10,6 +10,7 @@ use serde::Serialize; use super::types::AgentThread; +#[derive(Clone)] /// The Agent struct is responsible for managing conversations with the LLM /// and coordinating tool executions. It maintains a registry of available tools /// and handles the recursive nature of tool calls. @@ -151,152 +152,148 @@ impl Agent { &self, thread: &AgentThread, ) -> Result>> { - // Collect all registered tools and their schemas - let tools: Vec = self - .tools - .iter() - .map(|(name, tool)| Tool { - tool_type: "function".to_string(), - function: tool.get_schema(), - }) - .collect(); - - println!("DEBUG: Starting stream_process_thread with {} tools", tools.len()); - println!("DEBUG: Tools registered: {:?}", tools); - - // Create the request to send to the LLM - let request = ChatCompletionRequest { - model: self.model.clone(), - messages: thread.messages.clone(), - tools: Some(tools), - stream: Some(true), - ..Default::default() - }; - - println!("DEBUG: Created chat completion request with model: {}", self.model); - - // Get the streaming response from the LLM - let mut stream = self.llm_client.stream_chat_completion(request).await?; let (tx, rx) = mpsc::channel(100); - let mut pending_tool_calls = HashMap::new(); - - // Clone the Arc for use in the spawned task let tools_ref = self.tools.clone(); + let model = self.model.clone(); + let llm_client = self.llm_client.clone(); - println!("DEBUG: Stream initialized, starting processing task"); + // Clone thread for task ownership + let thread = thread.clone(); - // Process the stream in a separate task tokio::spawn(async move { - let mut current_message = Message::assistant(None, None); + async fn process_stream_recursive( + llm_client: &LiteLLMClient, + model: &str, + tools_ref: &Arc>>>, + thread: &AgentThread, + tx: &mpsc::Sender>, + ) -> Result<()> { + // Collect all registered tools and their schemas + let tools: Vec = tools_ref + .iter() + .map(|(name, tool)| Tool { + tool_type: "function".to_string(), + function: tool.get_schema(), + }) + .collect(); - while let Some(chunk_result) = stream.recv().await { - println!("DEBUG: Received new stream chunk"); - match chunk_result { - Ok(chunk) => { - let delta = &chunk.choices[0].delta; - println!("DEBUG: Processing delta: {:?}", delta); + println!("DEBUG: Starting recursive stream with {} tools", tools.len()); - // Handle role changes - if let Some(role) = &delta.role { - println!("DEBUG: Role change detected: {}", role); - match role.as_str() { - "assistant" => { - current_message = Message::assistant(None, None); - println!("DEBUG: Reset current_message for assistant"); - } - "tool" => { - println!("DEBUG: Tool role detected, waiting for content"); - continue; - } - _ => continue, - } - } + // Create the request + let request = ChatCompletionRequest { + model: model.to_string(), + messages: thread.messages.clone(), + tools: Some(tools), + stream: Some(true), + ..Default::default() + }; - // Handle tool calls (tool execution start) - if let Some(tool_calls) = &delta.tool_calls { - println!("DEBUG: Processing {} tool calls", tool_calls.len()); - for tool_call in tool_calls { - println!("DEBUG: Tool call detected - ID: {}, Name: {}", - tool_call.id, - tool_call.function.name); - - // Store or update the tool call - pending_tool_calls.insert(tool_call.id.clone(), tool_call.clone()); - - // Check if this tool call is complete and ready for execution - if let Some(complete_tool_call) = pending_tool_calls.get(&tool_call.id) { - if let Some(tool) = tools_ref.get(&complete_tool_call.function.name) { - println!("DEBUG: Executing tool: {}", complete_tool_call.function.name); - - // Execute the tool - match tool.execute(complete_tool_call).await { - Ok(result) => { - let result_str = serde_json::to_string(&result) - .unwrap_or_else(|e| format!("Error serializing result: {}", e)); - println!("DEBUG: Tool execution successful: {}", result_str); - - // Send tool result message - let tool_result_msg = Message::tool( - result_str, - complete_tool_call.id.clone(), - ); - let _ = tx.send(Ok(tool_result_msg)).await; - } - Err(e) => { - println!("DEBUG: Tool execution failed: {:?}", e); - let error_msg = format!("Tool execution failed: {:?}", e); - let tool_error_msg = Message::tool( - error_msg, - complete_tool_call.id.clone(), - ); - let _ = tx.send(Ok(tool_error_msg)).await; - } - } - - // Remove the executed tool call - pending_tool_calls.remove(&tool_call.id); + // Get streaming response + let mut stream = llm_client.stream_chat_completion(request).await?; + let mut pending_tool_calls = HashMap::new(); + let mut current_message = Message::assistant(None, None); + let mut has_tool_calls = false; + + // Process stream chunks + while let Some(chunk_result) = stream.recv().await { + match chunk_result { + Ok(chunk) => { + let delta = &chunk.choices[0].delta; + + // Handle role changes + if let Some(role) = &delta.role { + match role.as_str() { + "assistant" => { + current_message = Message::assistant(None, None); + let _ = tx.send(Ok(current_message.clone())).await; } + _ => continue, } + } - // Send the tool start message - let tool_start_msg = Message::assistant(None, Some(vec![tool_call.clone()])); - let _ = tx.send(Ok(tool_start_msg)).await; + // Handle tool calls + if let Some(tool_calls) = &delta.tool_calls { + has_tool_calls = true; + for tool_call in tool_calls { + println!("DEBUG: Tool call detected - ID: {}", tool_call.id); + pending_tool_calls.insert(tool_call.id.clone(), tool_call.clone()); + + // Update current message with tool calls + if let Message::Assistant { tool_calls: ref mut msg_tool_calls, .. } = current_message { + *msg_tool_calls = Some(pending_tool_calls.values().cloned().collect()); + } + + let _ = tx.send(Ok(current_message.clone())).await; + } + } + + // Handle content updates + if let Some(content) = &delta.content { + match &mut current_message { + Message::Assistant { content: msg_content, .. } => { + *msg_content = Some( + if let Some(existing) = msg_content { + format!("{}{}", existing, content) + } else { + content.clone() + } + ); + } + _ => {} + } + let _ = tx.send(Ok(current_message.clone())).await; } } - - // Handle content updates - if let Some(content) = &delta.content { - println!("DEBUG: Content update received: {}", content); - match &mut current_message { - Message::Assistant { content: msg_content, tool_calls, .. } => { - *msg_content = Some(if let Some(existing) = msg_content { - let combined = format!("{}{}", existing, content); - println!("DEBUG: Updated assistant content: {}", combined); - combined - } else { - println!("DEBUG: New assistant content: {}", content); - content.clone() - }); - } - Message::Tool { content: msg_content, .. } => { - println!("DEBUG: Updating tool content: {}", content); - *msg_content = content.clone(); - } - _ => {} - } - let _ = tx.send(Ok(current_message.clone())).await; + Err(e) => { + let _ = tx.send(Err(anyhow::Error::from(e))).await; + return Ok(()); } } - Err(e) => { - println!("DEBUG: Error processing stream chunk: {:?}", e); - let _ = tx.send(Err(anyhow::Error::from(e))).await; - } } + + // If we had tool calls, execute them and recurse + if has_tool_calls { + println!("DEBUG: Processing {} tool calls recursively", pending_tool_calls.len()); + let mut tool_results = Vec::new(); + + // Execute all tools + for tool_call in pending_tool_calls.values() { + if let Some(tool) = tools_ref.get(&tool_call.function.name) { + match tool.execute(tool_call).await { + Ok(result) => { + let result_str = serde_json::to_string(&result)?; + let tool_result = Message::tool(result_str, tool_call.id.clone()); + tool_results.push(tool_result.clone()); + let _ = tx.send(Ok(tool_result)).await; + } + Err(e) => { + let error_msg = format!("Tool execution failed: {:?}", e); + let tool_error = Message::tool(error_msg, tool_call.id.clone()); + tool_results.push(tool_error.clone()); + let _ = tx.send(Ok(tool_error)).await; + } + } + } + } + + // Create new thread with tool results + let mut new_thread = thread.clone(); + new_thread.messages.push(current_message); + new_thread.messages.extend(tool_results); + + // Recurse with new thread + Box::pin(process_stream_recursive(llm_client, model, tools_ref, &new_thread, tx)).await?; + } + + Ok(()) + } + + // Start recursive processing + if let Err(e) = process_stream_recursive(&llm_client, &model, &tools_ref, &thread, &tx).await { + let _ = tx.send(Err(e)).await; } - println!("DEBUG: Stream processing completed"); }); - println!("DEBUG: Returning stream receiver"); Ok(rx) } } diff --git a/api/src/utils/clients/ai/litellm/client.rs b/api/src/utils/clients/ai/litellm/client.rs index 5aecbeb06..fe99a85fe 100644 --- a/api/src/utils/clients/ai/litellm/client.rs +++ b/api/src/utils/clients/ai/litellm/client.rs @@ -6,6 +6,7 @@ use tokio::sync::mpsc; use super::types::*; +#[derive(Clone)] pub struct LiteLLMClient { client: Client, pub(crate) api_key: String,