refactor: Enhance Agent and Tool management with new methods and tests

- Added environment variable-based LLM client initialization in `Agent::new()`
- Introduced `add_tool()` and `add_tools()` methods for more flexible tool registration
- Implemented new `get_name()` method for `ToolExecutor` trait
- Added comprehensive test cases for Agent with and without tools
- Updated `AgentThread` with a convenient constructor method
- Temporarily commented out unused tool modules
- Added debug print in LiteLLM client for response logging
This commit is contained in:
dal 2025-01-30 14:12:59 -07:00
parent 6a73b59aa1
commit ec04a5e98e
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
5 changed files with 236 additions and 30 deletions

View File

@ -94,7 +94,8 @@ html-escape = "0.2.13"
async-trait = "0.1.85"
[dev-dependencies]
mockito = "1.2"
mockito = "1.2.0"
async-trait = "0.1.77"
tokio = { version = "1.0", features = ["full", "test-util"] }
[profile.release]

View File

@ -1,9 +1,13 @@
use crate::utils::{clients::ai::litellm::{ChatCompletionRequest, LiteLLMClient, Message, Tool}, tools::ToolExecutor};
use crate::utils::{
clients::ai::litellm::{ChatCompletionRequest, LiteLLMClient, Message, Tool},
tools::ToolExecutor,
};
use anyhow::Result;
use std::collections::HashMap;
use std::{collections::HashMap, env};
use tokio::sync::mpsc;
use async_trait::async_trait;
use super::types::{AgentThread};
use super::types::AgentThread;
/// The Agent struct is responsible for managing conversations with the LLM
/// and coordinating tool executions. It maintains a registry of available tools
@ -19,11 +23,12 @@ pub struct Agent {
impl Agent {
/// Create a new Agent instance with a specific LLM client and model
pub fn new(
llm_client: LiteLLMClient,
model: String,
tools: HashMap<String, Box<dyn ToolExecutor>>,
) -> Self {
pub fn new(model: String, tools: HashMap<String, Box<dyn ToolExecutor>>) -> Self {
let llm_api_key = env::var("LLM_API_KEY").expect("LLM_API_KEY must be set");
let llm_base_url = env::var("LLM_BASE_URL").expect("LLM_API_BASE must be set");
let llm_client = LiteLLMClient::new(Some(llm_api_key), Some(llm_base_url));
Self {
llm_client,
tools,
@ -31,13 +36,23 @@ impl Agent {
}
}
/// Register a new tool with the agent
/// Add a new tool with the agent
///
/// # Arguments
/// * `name` - The name of the tool, used to identify it in tool calls
/// * `tool` - The tool implementation that will be executed
pub fn register_tool(&mut self, name: String, tool: Box<dyn ToolExecutor>) {
self.tools.insert(name, tool);
pub fn add_tool<T: ToolExecutor + 'static>(&mut self, name: String, tool: T) {
self.tools.insert(name, Box::new(tool));
}
/// Add multiple tools to the agent at once
///
/// # Arguments
/// * `tools` - HashMap of tool names and their implementations
pub fn add_tools<T: ToolExecutor + 'static>(&mut self, tools: HashMap<String, T>) {
for (name, tool) in tools {
self.tools.insert(name, Box::new(tool));
}
}
/// Process a thread of conversation, potentially executing tools and continuing
@ -63,16 +78,20 @@ impl Agent {
let request = ChatCompletionRequest {
model: self.model.clone(),
messages: thread.messages.clone(),
tools: Some(tools),
tools: if tools.is_empty() { None } else { Some(tools) },
..Default::default()
};
// Get the response from the LLM
let response = self.llm_client.chat_completion(request).await?;
let response = match self.llm_client.chat_completion(request).await {
Ok(response) => response,
Err(e) => return Err(anyhow::anyhow!("Error processing thread: {:?}", e)),
};
let llm_message = &response.choices[0].message;
// Create the initial assistant message
let mut message = match llm_message {
let message = match llm_message {
Message::Assistant {
content,
tool_calls,
@ -179,3 +198,123 @@ impl Agent {
Ok(rx)
}
}
#[cfg(test)]
mod tests {
use crate::utils::clients::ai::litellm::ToolCall;
use super::*;
use dotenv::dotenv;
use serde_json::{json, Value};
fn setup() {
dotenv().ok();
}
struct WeatherTool;
#[async_trait]
impl ToolExecutor for WeatherTool {
async fn execute(&self, tool_call: &ToolCall) -> Result<Value> {
Ok(json!({
"temperature": 20,
"unit": "fahrenheit"
}))
}
fn get_schema(&self) -> Value {
json!({
"name": "get_weather",
"description": "Get current weather information for a specific location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g., San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use"
}
},
"required": ["location"]
}
})
}
fn get_name(&self) -> String {
"get_weather".to_string()
}
}
#[tokio::test]
async fn test_agent_convo_no_tools() {
setup();
// Create LLM client and agent
let agent = Agent::new("o1".to_string(), HashMap::new());
let thread = AgentThread::new(None, vec![Message::user("Hello, world!".to_string())]);
let response = match agent.process_thread(&thread).await {
Ok(response) => response,
Err(e) => panic!("Error processing thread: {:?}", e),
};
println!("Response: {:?}", response);
}
#[tokio::test]
async fn test_agent_convo_with_tools() {
setup();
// Create LLM client and agent
let mut agent = Agent::new("o1".to_string(), HashMap::new());
let weather_tool = WeatherTool;
agent.add_tool(weather_tool.get_name(), weather_tool);
let thread = AgentThread::new(
None,
vec![Message::user(
"What is the weather in vineyard ut?".to_string(),
)],
);
let response = match agent.process_thread(&thread).await {
Ok(response) => response,
Err(e) => panic!("Error processing thread: {:?}", e),
};
println!("Response: {:?}", response);
}
#[tokio::test]
async fn test_agent_with_multiple_steps() {
setup();
// Create LLM client and agent
let mut agent = Agent::new("o1".to_string(), HashMap::new());
let weather_tool = WeatherTool;
agent.add_tool(weather_tool.get_name(), weather_tool);
let thread = AgentThread::new(
None,
vec![Message::user(
"What is the weather in vineyard ut and san francisco?".to_string(),
)],
);
let response = match agent.process_thread(&thread).await {
Ok(response) => response,
Err(e) => panic!("Error processing thread: {:?}", e),
};
println!("Response: {:?}", response);
}
}

View File

@ -11,3 +11,12 @@ pub struct AgentThread {
/// Ordered sequence of messages in the conversation
pub messages: Vec<Message>,
}
impl AgentThread {
pub fn new(id: Option<String>, messages: Vec<Message>) -> Self {
Self {
id: id.unwrap_or(uuid::Uuid::new_v4().to_string()),
messages,
}
}
}

View File

@ -61,6 +61,9 @@ impl LiteLLMClient {
.await?
.json::<ChatCompletionResponse>()
.await?;
println!("Response: {:?}", response);
Ok(response)
}
@ -134,6 +137,20 @@ mod tests {
use std::time::Duration;
use tokio::time::timeout;
use dotenv::dotenv;
// Helper function to initialize environment before tests
async fn setup() -> (String, String) {
// Load environment variables from .env file
dotenv().ok();
// Get API key and base URL from environment
let api_key = env::var("LLM_API_KEY").expect("LLM_API_KEY must be set");
let base_url = env::var("LLM_BASE_URL").expect("LLM_API_BASE must be set");
(api_key, base_url)
}
fn create_test_message() -> Message {
Message::user("Hello".to_string())
}
@ -329,12 +346,20 @@ mod tests {
let response = client.chat_completion(request).await.unwrap();
assert_eq!(response.id, "test-id");
if let Message::Assistant { content, tool_calls, .. } = &response.choices[0].message {
if let Message::Assistant {
content,
tool_calls,
..
} = &response.choices[0].message
{
assert!(content.is_none());
let tool_calls = tool_calls.as_ref().unwrap();
assert_eq!(tool_calls[0].id, "call_123");
assert_eq!(tool_calls[0].function.name, "get_current_weather");
assert_eq!(tool_calls[0].function.arguments, "{\"location\":\"Boston, MA\"}");
assert_eq!(
tool_calls[0].function.arguments,
"{\"location\":\"Boston, MA\"}"
);
} else {
panic!("Expected assistant message");
}
@ -368,4 +393,23 @@ mod tests {
env::remove_var("LLM_API_KEY");
env::remove_var("LLM_BASE_URL");
}
#[tokio::test]
async fn test_single_message_completion() {
let (api_key, base_url) = setup().await;
let client = LiteLLMClient::new(Some(api_key), Some(base_url));
let request = ChatCompletionRequest {
model: "o1".to_string(),
messages: vec![Message::user("Hello, world!".to_string())],
..Default::default()
};
let response = match client.chat_completion(request).await {
Ok(response) => response,
Err(e) => panic!("Error processing thread: {:?}", e),
};
assert!(response.choices.len() > 0);
}
}

View File

@ -4,19 +4,19 @@ use serde_json::Value;
use crate::utils::clients::ai::litellm::ToolCall;
mod search_files;
mod create_files;
mod bulk_modify_files;
mod search_data_catalog;
mod open_files;
mod send_to_user;
// mod bulk_modify_files;
// mod create_files;
// mod open_files;
// mod search_data_catalog;
// mod search_files;
// mod send_to_user;
pub use search_files::SearchFilesTool;
pub use create_files::CreateFilesTool;
pub use bulk_modify_files::BulkModifyFilesTool;
pub use search_data_catalog::SearchDataCatalogTool;
pub use open_files::OpenFilesTool;
pub use send_to_user::SendToUserTool;
// pub use bulk_modify_files::BulkModifyFilesTool;
// pub use create_files::CreateFilesTool;
// pub use open_files::OpenFilesTool;
// pub use search_data_catalog::SearchDataCatalogTool;
// pub use search_files::SearchFilesTool;
// pub use send_to_user::SendToUserTool;
/// A trait that defines how tools should be implemented.
/// Any struct that wants to be used as a tool must implement this trait.
@ -27,4 +27,17 @@ pub trait ToolExecutor: Send + Sync {
/// Return the JSON schema that describes this tool's interface
fn get_schema(&self) -> serde_json::Value;
/// Return the name of the tool
fn get_name(&self) -> String;
}
trait IntoBoxedTool {
fn boxed(self) -> Box<dyn ToolExecutor>;
}
impl<T: ToolExecutor + 'static> IntoBoxedTool for T {
fn boxed(self) -> Box<dyn ToolExecutor> {
Box::new(self)
}
}