buster/api/src/utils/agent/agent.rs

698 lines
27 KiB
Rust

use crate::utils::{
clients::ai::litellm::{
ChatCompletionRequest, DeltaFunctionCall, DeltaToolCall, FunctionCall, LiteLLMClient,
Message, MessageProgress, Tool, ToolCall, ToolChoice,
},
tools::ToolExecutor,
};
use anyhow::Result;
use serde::Serialize;
use serde_json::Value;
use std::{collections::HashMap, env, sync::Arc};
use tokio::sync::mpsc;
use super::types::AgentThread;
#[derive(Clone)]
/// The Agent struct is responsible for managing conversations with the LLM
/// and coordinating tool executions. It maintains a registry of available tools
/// and handles the recursive nature of tool calls.
pub struct Agent {
/// Client for communicating with the LLM provider
llm_client: LiteLLMClient,
/// Registry of available tools, mapped by their names
tools: Arc<HashMap<String, Box<dyn ToolExecutor<Output = Value>>>>,
/// The model identifier to use (e.g., "gpt-4")
model: String,
}
impl Agent {
/// Create a new Agent instance with a specific LLM client and model
pub fn new(
model: String,
tools: HashMap<String, Box<dyn ToolExecutor<Output = Value>>>,
) -> Self {
let llm_api_key = env::var("LLM_API_KEY").expect("LLM_API_KEY must be set");
let llm_base_url = env::var("LLM_BASE_URL").expect("LLM_API_BASE must be set");
let llm_client = LiteLLMClient::new(Some(llm_api_key), Some(llm_base_url));
Self {
llm_client,
tools: Arc::new(tools),
model,
}
}
/// Add a new tool with the agent
///
/// # Arguments
/// * `name` - The name of the tool, used to identify it in tool calls
/// * `tool` - The tool implementation that will be executed
pub fn add_tool(&mut self, name: String, tool: impl ToolExecutor<Output = Value> + 'static) {
// Get a mutable reference to the HashMap inside the Arc
Arc::get_mut(&mut self.tools)
.expect("Failed to get mutable reference to tools")
.insert(name, Box::new(tool));
}
/// Add multiple tools to the agent at once
///
/// # Arguments
/// * `tools` - HashMap of tool names and their implementations
pub fn add_tools<E: ToolExecutor<Output = Value> + 'static>(
&mut self,
tools: HashMap<String, E>,
) {
let tools_map =
Arc::get_mut(&mut self.tools).expect("Failed to get mutable reference to tools");
for (name, tool) in tools {
tools_map.insert(name, Box::new(tool));
}
}
/// Process a thread of conversation, potentially executing tools and continuing
/// the conversation recursively until a final response is reached.
///
/// # Arguments
/// * `thread` - The conversation thread to process
///
/// # Returns
/// * A Result containing the final Message from the assistant
pub async fn process_thread(&self, thread: &AgentThread) -> Result<Message> {
self.process_thread_with_depth(thread, 0).await
}
async fn process_thread_with_depth(
&self,
thread: &AgentThread,
recursion_depth: u32,
) -> Result<Message> {
if recursion_depth >= 30 {
return Ok(Message::assistant(
Some("I apologize, but I've reached the maximum number of actions (30). Please try breaking your request into smaller parts.".to_string()),
None,
None,
));
}
// Collect all registered tools and their schemas
let tools: Vec<Tool> = self
.tools
.iter()
.map(|(name, tool)| Tool {
tool_type: "function".to_string(),
function: tool.get_schema(),
})
.collect();
// First, make request with tool_choice set to none
let initial_request = ChatCompletionRequest {
model: self.model.clone(),
messages: thread.messages.clone(),
tools: if tools.is_empty() {
None
} else {
Some(tools.clone())
},
tool_choice: Some(ToolChoice::None("none".to_string())),
..Default::default()
};
// Get initial response
let initial_response = self.llm_client.chat_completion(initial_request).await?;
let initial_message = &initial_response.choices[0].message;
// Ensure we have content from the initial message
let initial_content = match initial_message {
Message::Assistant { content, .. } => content.clone().unwrap_or_default(),
_ => return Err(anyhow::anyhow!("Expected assistant message from LLM")),
};
// Create a new thread with the initial response (ensuring content is present)
let mut tool_thread = thread.clone();
tool_thread
.messages
.push(Message::assistant(Some(initial_content), None, None));
// Create the tool-enabled request
let request = ChatCompletionRequest {
model: self.model.clone(),
messages: tool_thread.messages.clone(),
tools: if tools.is_empty() { None } else { Some(tools) },
tool_choice: Some(ToolChoice::Auto("auto".to_string())),
..Default::default()
};
// Get the response from the LLM
let response = match self.llm_client.chat_completion(request).await {
Ok(response) => response,
Err(e) => return Err(anyhow::anyhow!("Error processing thread: {:?}", e)),
};
let llm_message = &response.choices[0].message;
// Create the assistant message
let message = match llm_message {
Message::Assistant {
content,
tool_calls,
..
} => Message::assistant(content.clone(), tool_calls.clone(), None),
_ => return Err(anyhow::anyhow!("Expected assistant message from LLM")),
};
// If this is an auto response without tool calls, it means we're done
if let Message::Assistant {
tool_calls: None, ..
} = &llm_message
{
return Ok(message);
}
// If the LLM wants to use tools, execute them and continue
if let Message::Assistant {
tool_calls: Some(tool_calls),
..
} = &llm_message
{
let mut results = Vec::new();
// Execute each requested tool
for tool_call in tool_calls {
if let Some(tool) = self.tools.get(&tool_call.function.name) {
let result = tool.execute(tool_call).await?;
let result_str = serde_json::to_string(&result)?;
results.push(Message::tool(result_str, tool_call.id.clone(), None));
}
}
// Create a new thread with the tool results and continue recursively
let mut new_thread = thread.clone();
new_thread.messages.push(message);
new_thread.messages.extend(results);
Box::pin(self.process_thread_with_depth(&new_thread, recursion_depth + 1)).await
} else {
Ok(message)
}
}
/// Process a thread of conversation with streaming responses
///
/// # Arguments
/// * `thread` - The conversation thread to process
///
/// # Returns
/// * A Result containing a receiver for streamed messages
pub async fn stream_process_thread(
&self,
thread: &AgentThread,
) -> Result<mpsc::Receiver<Result<Message>>> {
let (tx, rx) = mpsc::channel(100);
let tools_ref = self.tools.clone();
let model = self.model.clone();
let llm_client = self.llm_client.clone();
// Clone thread for task ownership
let thread = thread.clone();
tokio::spawn(async move {
async fn process_stream_recursive(
llm_client: &LiteLLMClient,
model: &str,
tools_ref: &Arc<HashMap<String, Box<dyn ToolExecutor<Output = Value>>>>,
thread: &AgentThread,
tx: &mpsc::Sender<Result<Message>>,
recursion_depth: u32,
) -> Result<()> {
if recursion_depth >= 30 {
let limit_message = Message::assistant(
Some("I apologize, but I've reached the maximum number of actions (30). Please try breaking your request into smaller parts.".to_string()),
None,
None,
);
let _ = tx.send(Ok(limit_message)).await;
return Ok(());
}
// Collect all registered tools and their schemas
let tools: Vec<Tool> = tools_ref
.iter()
.map(|(name, tool)| Tool {
tool_type: "function".to_string(),
function: tool.get_schema(),
})
.collect();
// First, make request with tool_choice set to none
let initial_request = ChatCompletionRequest {
model: model.to_string(),
messages: thread.messages.clone(),
tools: if tools.is_empty() {
None
} else {
Some(tools.clone())
},
tool_choice: Some(ToolChoice::None("none".to_string())),
stream: Some(true),
..Default::default()
};
// Get streaming response for initial thoughts
let mut initial_stream = llm_client.stream_chat_completion(initial_request).await?;
let mut initial_message = Message::assistant(Some(String::new()), None, None);
let mut has_started = false;
// Process initial stream chunks
while let Some(chunk_result) = initial_stream.recv().await {
match chunk_result {
Ok(chunk) => {
let delta = &chunk.choices[0].delta;
// Handle content updates - send delta directly
if let Some(content) = &delta.content {
// Send the delta chunk immediately with InProgress
let _ = tx
.send(Ok(Message::assistant(
Some(content.clone()),
None,
Some(MessageProgress::InProgress),
)))
.await;
// Also accumulate for our thread history
if let Message::Assistant {
content: msg_content,
..
} = &mut initial_message
{
if let Some(existing) = msg_content {
existing.push_str(content);
}
}
}
}
Err(e) => {
let _ = tx.send(Err(anyhow::Error::from(e))).await;
return Ok(());
}
}
}
// Ensure we have content in the initial message
let initial_content = match &initial_message {
Message::Assistant { content, .. } => content.clone().unwrap_or_default(),
_ => String::new(),
};
// Create new thread with initial response (ensuring content is present)
let mut tool_thread = thread.clone();
tool_thread
.messages
.push(Message::assistant(Some(initial_content), None, None));
// Create the tool-enabled request
let request = ChatCompletionRequest {
model: model.to_string(),
messages: tool_thread.messages.clone(),
tools: if tools.is_empty() { None } else { Some(tools) },
tool_choice: Some(ToolChoice::Auto("auto".to_string())),
stream: Some(true),
..Default::default()
};
// Get streaming response
let mut stream = llm_client.stream_chat_completion(request).await?;
let mut current_message = Message::assistant(Some(String::new()), None, None);
let mut current_pending_tool: Option<PendingToolCall> = None;
let mut has_tool_calls = false;
let mut tool_results = Vec::new();
// Process stream chunks
while let Some(chunk_result) = stream.recv().await {
match chunk_result {
Ok(chunk) => {
let delta = &chunk.choices[0].delta;
// Check for tool call completion
if let Some(finish_reason) = &chunk.choices[0].finish_reason {
if finish_reason == "tool_calls" {
has_tool_calls = true;
// Tool call is complete - execute it
if let Some(pending) = current_pending_tool.take() {
let tool_call = pending.into_tool_call();
// Create and preserve the assistant message with the tool call
let assistant_tool_message = Message::assistant(
None,
Some(vec![tool_call.clone()]),
Some(MessageProgress::Complete),
);
let _ = tx.send(Ok(assistant_tool_message.clone())).await;
// Execute the tool
if let Some(tool) = tools_ref.get(&tool_call.function.name)
{
match tool.execute(&tool_call).await {
Ok(result) => {
let result_str =
serde_json::to_string(&result)?;
let tool_result = Message::tool(
result_str,
tool_call.id.clone(),
Some(MessageProgress::Complete),
);
let _ = tx.send(Ok(tool_result.clone())).await;
// Store both the assistant tool message and the tool result
tool_results.push(assistant_tool_message);
tool_results.push(tool_result);
}
Err(e) => {
let error_msg =
format!("Tool execution failed: {:?}", e);
let tool_error = Message::tool(
error_msg,
tool_call.id.clone(),
Some(MessageProgress::Complete),
);
let _ = tx.send(Ok(tool_error.clone())).await;
// Store both the assistant tool message and the error
tool_results.push(assistant_tool_message);
tool_results.push(tool_error);
}
}
}
}
continue;
}
}
// Handle content updates - only send if we have actual content
if let Some(content) = &delta.content {
if !content.trim().is_empty() {
if let Message::Assistant {
content: msg_content,
..
} = &mut current_message
{
if let Some(existing) = msg_content {
existing.push_str(content);
}
}
let _ = tx
.send(Ok(Message::assistant(
Some(content.clone()),
None,
None,
)))
.await;
}
}
// Handle tool calls - only send when we have meaningful tool call data
if let Some(tool_calls) = &delta.tool_calls {
has_tool_calls = true;
if current_pending_tool.is_none() {
current_pending_tool = Some(PendingToolCall::new());
}
if let Some(pending) = &mut current_pending_tool {
for tool_call in tool_calls {
pending.update_from_delta(tool_call);
// Send an update if we have a name, regardless of arguments
if let Some(name) = &pending.function_name {
let temp_tool_call = ToolCall {
id: pending.id.clone().unwrap_or_default(),
function: FunctionCall {
name: name.clone(),
arguments: pending.arguments.clone(),
},
call_type: pending.call_type.clone().unwrap_or_default(),
code_interpreter: None,
retrieval: None,
};
let _ = tx
.send(Ok(Message::assistant(
None,
Some(vec![temp_tool_call]),
Some(MessageProgress::InProgress),
)))
.await;
}
}
}
}
}
Err(e) => {
let _ = tx.send(Err(anyhow::Error::from(e))).await;
return Ok(());
}
}
}
// If we didn't get any tool calls in the auto response, we're done
if !has_tool_calls {
// Only include current_message in the thread if it has content
if let Message::Assistant {
content: Some(content),
..
} = &current_message
{
if !content.trim().is_empty() {
// Send the complete message
let complete_message = Message::assistant(
Some(content.clone()),
None,
Some(MessageProgress::Complete),
);
let _ = tx.send(Ok(complete_message.clone())).await;
let mut new_thread = thread.clone();
new_thread.messages.push(current_message);
return Ok(());
}
}
return Ok(());
}
// Create new thread with tool results and recurse
let mut new_thread = thread.clone();
// Only include current_message if it has content
if let Message::Assistant {
content: Some(content),
..
} = &current_message
{
if !content.trim().is_empty() {
new_thread.messages.push(current_message);
}
}
new_thread.messages.extend(tool_results);
// Recurse with new thread
Box::pin(process_stream_recursive(
llm_client,
model,
tools_ref,
&new_thread,
tx,
recursion_depth + 1,
))
.await?;
Ok(())
}
// Start recursive processing
if let Err(e) =
process_stream_recursive(&llm_client, &model, &tools_ref, &thread, &tx, 0).await
{
let _ = tx.send(Err(e)).await;
}
});
Ok(rx)
}
}
#[derive(Debug, Default)]
struct PendingToolCall {
id: Option<String>,
call_type: Option<String>,
function_name: Option<String>,
arguments: String,
code_interpreter: Option<Value>,
retrieval: Option<Value>,
}
impl PendingToolCall {
fn new() -> Self {
Self::default()
}
fn update_from_delta(&mut self, tool_call: &DeltaToolCall) {
if let Some(id) = &tool_call.id {
self.id = Some(id.clone());
}
if let Some(call_type) = &tool_call.call_type {
self.call_type = Some(call_type.clone());
}
if let Some(function) = &tool_call.function {
if let Some(name) = &function.name {
self.function_name = Some(name.clone());
}
if let Some(args) = &function.arguments {
self.arguments.push_str(args);
}
}
if let Some(code_interpreter) = &tool_call.code_interpreter {
self.code_interpreter = None;
}
if let Some(retrieval) = &tool_call.retrieval {
self.retrieval = None;
}
}
fn into_tool_call(self) -> ToolCall {
ToolCall {
id: self.id.unwrap_or_default(),
function: FunctionCall {
name: self.function_name.unwrap_or_default(),
arguments: self.arguments,
},
call_type: self.call_type.unwrap_or_default(),
code_interpreter: None,
retrieval: None,
}
}
}
#[cfg(test)]
mod tests {
use crate::utils::clients::ai::litellm::ToolCall;
use super::*;
use axum::async_trait;
use dotenv::dotenv;
use serde_json::{json, Value};
fn setup() {
dotenv().ok();
}
struct WeatherTool;
#[async_trait]
impl ToolExecutor for WeatherTool {
type Output = Value;
async fn execute(&self, tool_call: &ToolCall) -> Result<Self::Output> {
Ok(json!({
"temperature": 20,
"unit": "fahrenheit"
}))
}
fn get_schema(&self) -> Value {
json!({
"name": "get_weather",
"description": "Get current weather information for a specific location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g., San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use"
}
},
"required": ["location"]
}
})
}
fn get_name(&self) -> String {
"get_weather".to_string()
}
}
#[tokio::test]
async fn test_agent_convo_no_tools() {
setup();
// Create LLM client and agent
let agent = Agent::new("o1".to_string(), HashMap::new());
let thread = AgentThread::new(None, vec![Message::user("Hello, world!".to_string())]);
let response = match agent.process_thread(&thread).await {
Ok(response) => response,
Err(e) => panic!("Error processing thread: {:?}", e),
};
println!("Response: {:?}", response);
}
#[tokio::test]
async fn test_agent_convo_with_tools() {
setup();
// Create LLM client and agent
let mut agent = Agent::new("o1".to_string(), HashMap::new());
let weather_tool = WeatherTool;
agent.add_tool(weather_tool.get_name(), weather_tool);
let thread = AgentThread::new(
None,
vec![Message::user(
"What is the weather in vineyard ut?".to_string(),
)],
);
let response = match agent.process_thread(&thread).await {
Ok(response) => response,
Err(e) => panic!("Error processing thread: {:?}", e),
};
println!("Response: {:?}", response);
}
#[tokio::test]
async fn test_agent_with_multiple_steps() {
setup();
// Create LLM client and agent
let mut agent = Agent::new("o1".to_string(), HashMap::new());
let weather_tool = WeatherTool;
agent.add_tool(weather_tool.get_name(), weather_tool);
let thread = AgentThread::new(
None,
vec![Message::user(
"What is the weather in vineyard ut and san francisco?".to_string(),
)],
);
let response = match agent.process_thread(&thread).await {
Ok(response) => response,
Err(e) => panic!("Error processing thread: {:?}", e),
};
println!("Response: {:?}", response);
}
}