mirror of https://github.com/buster-so/buster.git
refactor(agent): Implement recursive stream processing with improved tool execution
- Refactor agent stream processing to use a recursive approach for handling tool calls - Enhance tool execution with more robust error handling and result tracking - Improve stream chunk processing with detailed state management - Add support for recursive thread generation based on tool call results - Implement cloning for LiteLLMClient to support stream processing tasks
This commit is contained in:
parent
bb4e4ca9d8
commit
8b51618afd
|
@ -10,6 +10,7 @@ use serde::Serialize;
|
|||
|
||||
use super::types::AgentThread;
|
||||
|
||||
#[derive(Clone)]
|
||||
/// The Agent struct is responsible for managing conversations with the LLM
|
||||
/// and coordinating tool executions. It maintains a registry of available tools
|
||||
/// and handles the recursive nature of tool calls.
|
||||
|
@ -151,152 +152,148 @@ impl Agent {
|
|||
&self,
|
||||
thread: &AgentThread,
|
||||
) -> Result<mpsc::Receiver<Result<Message>>> {
|
||||
// Collect all registered tools and their schemas
|
||||
let tools: Vec<Tool> = self
|
||||
.tools
|
||||
.iter()
|
||||
.map(|(name, tool)| Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: tool.get_schema(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
println!("DEBUG: Starting stream_process_thread with {} tools", tools.len());
|
||||
println!("DEBUG: Tools registered: {:?}", tools);
|
||||
|
||||
// Create the request to send to the LLM
|
||||
let request = ChatCompletionRequest {
|
||||
model: self.model.clone(),
|
||||
messages: thread.messages.clone(),
|
||||
tools: Some(tools),
|
||||
stream: Some(true),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
println!("DEBUG: Created chat completion request with model: {}", self.model);
|
||||
|
||||
// Get the streaming response from the LLM
|
||||
let mut stream = self.llm_client.stream_chat_completion(request).await?;
|
||||
let (tx, rx) = mpsc::channel(100);
|
||||
let mut pending_tool_calls = HashMap::new();
|
||||
|
||||
// Clone the Arc for use in the spawned task
|
||||
let tools_ref = self.tools.clone();
|
||||
let model = self.model.clone();
|
||||
let llm_client = self.llm_client.clone();
|
||||
|
||||
println!("DEBUG: Stream initialized, starting processing task");
|
||||
// Clone thread for task ownership
|
||||
let thread = thread.clone();
|
||||
|
||||
// Process the stream in a separate task
|
||||
tokio::spawn(async move {
|
||||
let mut current_message = Message::assistant(None, None);
|
||||
async fn process_stream_recursive(
|
||||
llm_client: &LiteLLMClient,
|
||||
model: &str,
|
||||
tools_ref: &Arc<HashMap<String, Box<dyn ToolExecutor<Output = Value>>>>,
|
||||
thread: &AgentThread,
|
||||
tx: &mpsc::Sender<Result<Message>>,
|
||||
) -> Result<()> {
|
||||
// Collect all registered tools and their schemas
|
||||
let tools: Vec<Tool> = tools_ref
|
||||
.iter()
|
||||
.map(|(name, tool)| Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: tool.get_schema(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
while let Some(chunk_result) = stream.recv().await {
|
||||
println!("DEBUG: Received new stream chunk");
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
let delta = &chunk.choices[0].delta;
|
||||
println!("DEBUG: Processing delta: {:?}", delta);
|
||||
println!("DEBUG: Starting recursive stream with {} tools", tools.len());
|
||||
|
||||
// Handle role changes
|
||||
if let Some(role) = &delta.role {
|
||||
println!("DEBUG: Role change detected: {}", role);
|
||||
match role.as_str() {
|
||||
"assistant" => {
|
||||
current_message = Message::assistant(None, None);
|
||||
println!("DEBUG: Reset current_message for assistant");
|
||||
}
|
||||
"tool" => {
|
||||
println!("DEBUG: Tool role detected, waiting for content");
|
||||
continue;
|
||||
}
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
// Create the request
|
||||
let request = ChatCompletionRequest {
|
||||
model: model.to_string(),
|
||||
messages: thread.messages.clone(),
|
||||
tools: Some(tools),
|
||||
stream: Some(true),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Handle tool calls (tool execution start)
|
||||
if let Some(tool_calls) = &delta.tool_calls {
|
||||
println!("DEBUG: Processing {} tool calls", tool_calls.len());
|
||||
for tool_call in tool_calls {
|
||||
println!("DEBUG: Tool call detected - ID: {}, Name: {}",
|
||||
tool_call.id,
|
||||
tool_call.function.name);
|
||||
|
||||
// Store or update the tool call
|
||||
pending_tool_calls.insert(tool_call.id.clone(), tool_call.clone());
|
||||
|
||||
// Check if this tool call is complete and ready for execution
|
||||
if let Some(complete_tool_call) = pending_tool_calls.get(&tool_call.id) {
|
||||
if let Some(tool) = tools_ref.get(&complete_tool_call.function.name) {
|
||||
println!("DEBUG: Executing tool: {}", complete_tool_call.function.name);
|
||||
|
||||
// Execute the tool
|
||||
match tool.execute(complete_tool_call).await {
|
||||
Ok(result) => {
|
||||
let result_str = serde_json::to_string(&result)
|
||||
.unwrap_or_else(|e| format!("Error serializing result: {}", e));
|
||||
println!("DEBUG: Tool execution successful: {}", result_str);
|
||||
|
||||
// Send tool result message
|
||||
let tool_result_msg = Message::tool(
|
||||
result_str,
|
||||
complete_tool_call.id.clone(),
|
||||
);
|
||||
let _ = tx.send(Ok(tool_result_msg)).await;
|
||||
}
|
||||
Err(e) => {
|
||||
println!("DEBUG: Tool execution failed: {:?}", e);
|
||||
let error_msg = format!("Tool execution failed: {:?}", e);
|
||||
let tool_error_msg = Message::tool(
|
||||
error_msg,
|
||||
complete_tool_call.id.clone(),
|
||||
);
|
||||
let _ = tx.send(Ok(tool_error_msg)).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the executed tool call
|
||||
pending_tool_calls.remove(&tool_call.id);
|
||||
// Get streaming response
|
||||
let mut stream = llm_client.stream_chat_completion(request).await?;
|
||||
let mut pending_tool_calls = HashMap::new();
|
||||
let mut current_message = Message::assistant(None, None);
|
||||
let mut has_tool_calls = false;
|
||||
|
||||
// Process stream chunks
|
||||
while let Some(chunk_result) = stream.recv().await {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
let delta = &chunk.choices[0].delta;
|
||||
|
||||
// Handle role changes
|
||||
if let Some(role) = &delta.role {
|
||||
match role.as_str() {
|
||||
"assistant" => {
|
||||
current_message = Message::assistant(None, None);
|
||||
let _ = tx.send(Ok(current_message.clone())).await;
|
||||
}
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
|
||||
// Send the tool start message
|
||||
let tool_start_msg = Message::assistant(None, Some(vec![tool_call.clone()]));
|
||||
let _ = tx.send(Ok(tool_start_msg)).await;
|
||||
// Handle tool calls
|
||||
if let Some(tool_calls) = &delta.tool_calls {
|
||||
has_tool_calls = true;
|
||||
for tool_call in tool_calls {
|
||||
println!("DEBUG: Tool call detected - ID: {}", tool_call.id);
|
||||
pending_tool_calls.insert(tool_call.id.clone(), tool_call.clone());
|
||||
|
||||
// Update current message with tool calls
|
||||
if let Message::Assistant { tool_calls: ref mut msg_tool_calls, .. } = current_message {
|
||||
*msg_tool_calls = Some(pending_tool_calls.values().cloned().collect());
|
||||
}
|
||||
|
||||
let _ = tx.send(Ok(current_message.clone())).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle content updates
|
||||
if let Some(content) = &delta.content {
|
||||
match &mut current_message {
|
||||
Message::Assistant { content: msg_content, .. } => {
|
||||
*msg_content = Some(
|
||||
if let Some(existing) = msg_content {
|
||||
format!("{}{}", existing, content)
|
||||
} else {
|
||||
content.clone()
|
||||
}
|
||||
);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
let _ = tx.send(Ok(current_message.clone())).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle content updates
|
||||
if let Some(content) = &delta.content {
|
||||
println!("DEBUG: Content update received: {}", content);
|
||||
match &mut current_message {
|
||||
Message::Assistant { content: msg_content, tool_calls, .. } => {
|
||||
*msg_content = Some(if let Some(existing) = msg_content {
|
||||
let combined = format!("{}{}", existing, content);
|
||||
println!("DEBUG: Updated assistant content: {}", combined);
|
||||
combined
|
||||
} else {
|
||||
println!("DEBUG: New assistant content: {}", content);
|
||||
content.clone()
|
||||
});
|
||||
}
|
||||
Message::Tool { content: msg_content, .. } => {
|
||||
println!("DEBUG: Updating tool content: {}", content);
|
||||
*msg_content = content.clone();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
let _ = tx.send(Ok(current_message.clone())).await;
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(anyhow::Error::from(e))).await;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
println!("DEBUG: Error processing stream chunk: {:?}", e);
|
||||
let _ = tx.send(Err(anyhow::Error::from(e))).await;
|
||||
}
|
||||
}
|
||||
|
||||
// If we had tool calls, execute them and recurse
|
||||
if has_tool_calls {
|
||||
println!("DEBUG: Processing {} tool calls recursively", pending_tool_calls.len());
|
||||
let mut tool_results = Vec::new();
|
||||
|
||||
// Execute all tools
|
||||
for tool_call in pending_tool_calls.values() {
|
||||
if let Some(tool) = tools_ref.get(&tool_call.function.name) {
|
||||
match tool.execute(tool_call).await {
|
||||
Ok(result) => {
|
||||
let result_str = serde_json::to_string(&result)?;
|
||||
let tool_result = Message::tool(result_str, tool_call.id.clone());
|
||||
tool_results.push(tool_result.clone());
|
||||
let _ = tx.send(Ok(tool_result)).await;
|
||||
}
|
||||
Err(e) => {
|
||||
let error_msg = format!("Tool execution failed: {:?}", e);
|
||||
let tool_error = Message::tool(error_msg, tool_call.id.clone());
|
||||
tool_results.push(tool_error.clone());
|
||||
let _ = tx.send(Ok(tool_error)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create new thread with tool results
|
||||
let mut new_thread = thread.clone();
|
||||
new_thread.messages.push(current_message);
|
||||
new_thread.messages.extend(tool_results);
|
||||
|
||||
// Recurse with new thread
|
||||
Box::pin(process_stream_recursive(llm_client, model, tools_ref, &new_thread, tx)).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Start recursive processing
|
||||
if let Err(e) = process_stream_recursive(&llm_client, &model, &tools_ref, &thread, &tx).await {
|
||||
let _ = tx.send(Err(e)).await;
|
||||
}
|
||||
println!("DEBUG: Stream processing completed");
|
||||
});
|
||||
|
||||
println!("DEBUG: Returning stream receiver");
|
||||
Ok(rx)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ use tokio::sync::mpsc;
|
|||
|
||||
use super::types::*;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct LiteLLMClient {
|
||||
client: Client,
|
||||
pub(crate) api_key: String,
|
||||
|
|
Loading…
Reference in New Issue