diff --git a/api/Cargo.toml b/api/Cargo.toml index 543417abb..1d553b370 100644 --- a/api/Cargo.toml +++ b/api/Cargo.toml @@ -94,7 +94,8 @@ html-escape = "0.2.13" async-trait = "0.1.85" [dev-dependencies] -mockito = "1.2" +mockito = "1.2.0" +async-trait = "0.1.77" tokio = { version = "1.0", features = ["full", "test-util"] } [profile.release] diff --git a/api/src/utils/agent/agent.rs b/api/src/utils/agent/agent.rs index afa2ee682..69b664aa5 100644 --- a/api/src/utils/agent/agent.rs +++ b/api/src/utils/agent/agent.rs @@ -1,9 +1,13 @@ -use crate::utils::{clients::ai::litellm::{ChatCompletionRequest, LiteLLMClient, Message, Tool}, tools::ToolExecutor}; +use crate::utils::{ + clients::ai::litellm::{ChatCompletionRequest, LiteLLMClient, Message, Tool}, + tools::ToolExecutor, +}; use anyhow::Result; -use std::collections::HashMap; +use std::{collections::HashMap, env}; use tokio::sync::mpsc; +use async_trait::async_trait; -use super::types::{AgentThread}; +use super::types::AgentThread; /// The Agent struct is responsible for managing conversations with the LLM /// and coordinating tool executions. It maintains a registry of available tools @@ -19,11 +23,12 @@ pub struct Agent { impl Agent { /// Create a new Agent instance with a specific LLM client and model - pub fn new( - llm_client: LiteLLMClient, - model: String, - tools: HashMap>, - ) -> Self { + pub fn new(model: String, tools: HashMap>) -> Self { + let llm_api_key = env::var("LLM_API_KEY").expect("LLM_API_KEY must be set"); + let llm_base_url = env::var("LLM_BASE_URL").expect("LLM_API_BASE must be set"); + + let llm_client = LiteLLMClient::new(Some(llm_api_key), Some(llm_base_url)); + Self { llm_client, tools, @@ -31,13 +36,23 @@ impl Agent { } } - /// Register a new tool with the agent + /// Add a new tool with the agent /// /// # Arguments /// * `name` - The name of the tool, used to identify it in tool calls /// * `tool` - The tool implementation that will be executed - pub fn register_tool(&mut self, name: String, tool: Box) { - self.tools.insert(name, tool); + pub fn add_tool(&mut self, name: String, tool: T) { + self.tools.insert(name, Box::new(tool)); + } + + /// Add multiple tools to the agent at once + /// + /// # Arguments + /// * `tools` - HashMap of tool names and their implementations + pub fn add_tools(&mut self, tools: HashMap) { + for (name, tool) in tools { + self.tools.insert(name, Box::new(tool)); + } } /// Process a thread of conversation, potentially executing tools and continuing @@ -63,16 +78,20 @@ impl Agent { let request = ChatCompletionRequest { model: self.model.clone(), messages: thread.messages.clone(), - tools: Some(tools), + tools: if tools.is_empty() { None } else { Some(tools) }, ..Default::default() }; // Get the response from the LLM - let response = self.llm_client.chat_completion(request).await?; + let response = match self.llm_client.chat_completion(request).await { + Ok(response) => response, + Err(e) => return Err(anyhow::anyhow!("Error processing thread: {:?}", e)), + }; + let llm_message = &response.choices[0].message; // Create the initial assistant message - let mut message = match llm_message { + let message = match llm_message { Message::Assistant { content, tool_calls, @@ -178,4 +197,124 @@ impl Agent { Ok(rx) } -} \ No newline at end of file +} + +#[cfg(test)] +mod tests { + use crate::utils::clients::ai::litellm::ToolCall; + + use super::*; + use dotenv::dotenv; + use serde_json::{json, Value}; + + fn setup() { + dotenv().ok(); + } + + struct WeatherTool; + + #[async_trait] + impl ToolExecutor for WeatherTool { + async fn execute(&self, tool_call: &ToolCall) -> Result { + Ok(json!({ + "temperature": 20, + "unit": "fahrenheit" + })) + } + + fn get_schema(&self) -> Value { + json!({ + "name": "get_weather", + "description": "Get current weather information for a specific location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g., San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use" + } + }, + "required": ["location"] + } + }) + } + + fn get_name(&self) -> String { + "get_weather".to_string() + } + } + + #[tokio::test] + async fn test_agent_convo_no_tools() { + setup(); + + // Create LLM client and agent + let agent = Agent::new("o1".to_string(), HashMap::new()); + + let thread = AgentThread::new(None, vec![Message::user("Hello, world!".to_string())]); + + let response = match agent.process_thread(&thread).await { + Ok(response) => response, + Err(e) => panic!("Error processing thread: {:?}", e), + }; + + println!("Response: {:?}", response); + } + + #[tokio::test] + async fn test_agent_convo_with_tools() { + setup(); + + // Create LLM client and agent + let mut agent = Agent::new("o1".to_string(), HashMap::new()); + + let weather_tool = WeatherTool; + + agent.add_tool(weather_tool.get_name(), weather_tool); + + let thread = AgentThread::new( + None, + vec![Message::user( + "What is the weather in vineyard ut?".to_string(), + )], + ); + + let response = match agent.process_thread(&thread).await { + Ok(response) => response, + Err(e) => panic!("Error processing thread: {:?}", e), + }; + + println!("Response: {:?}", response); + } + + #[tokio::test] + async fn test_agent_with_multiple_steps() { + setup(); + + // Create LLM client and agent + let mut agent = Agent::new("o1".to_string(), HashMap::new()); + + let weather_tool = WeatherTool; + + agent.add_tool(weather_tool.get_name(), weather_tool); + + let thread = AgentThread::new( + None, + vec![Message::user( + "What is the weather in vineyard ut and san francisco?".to_string(), + )], + ); + + let response = match agent.process_thread(&thread).await { + Ok(response) => response, + Err(e) => panic!("Error processing thread: {:?}", e), + }; + + println!("Response: {:?}", response); + } +} diff --git a/api/src/utils/agent/types.rs b/api/src/utils/agent/types.rs index b52cfcc2a..39583c23d 100644 --- a/api/src/utils/agent/types.rs +++ b/api/src/utils/agent/types.rs @@ -11,3 +11,12 @@ pub struct AgentThread { /// Ordered sequence of messages in the conversation pub messages: Vec, } + +impl AgentThread { + pub fn new(id: Option, messages: Vec) -> Self { + Self { + id: id.unwrap_or(uuid::Uuid::new_v4().to_string()), + messages, + } + } +} diff --git a/api/src/utils/clients/ai/litellm/client.rs b/api/src/utils/clients/ai/litellm/client.rs index 14ed8de3d..91436659c 100644 --- a/api/src/utils/clients/ai/litellm/client.rs +++ b/api/src/utils/clients/ai/litellm/client.rs @@ -61,6 +61,9 @@ impl LiteLLMClient { .await? .json::() .await?; + + println!("Response: {:?}", response); + Ok(response) } @@ -134,6 +137,20 @@ mod tests { use std::time::Duration; use tokio::time::timeout; + use dotenv::dotenv; + + // Helper function to initialize environment before tests + async fn setup() -> (String, String) { + // Load environment variables from .env file + dotenv().ok(); + + // Get API key and base URL from environment + let api_key = env::var("LLM_API_KEY").expect("LLM_API_KEY must be set"); + let base_url = env::var("LLM_BASE_URL").expect("LLM_API_BASE must be set"); + + (api_key, base_url) + } + fn create_test_message() -> Message { Message::user("Hello".to_string()) } @@ -329,12 +346,20 @@ mod tests { let response = client.chat_completion(request).await.unwrap(); assert_eq!(response.id, "test-id"); - if let Message::Assistant { content, tool_calls, .. } = &response.choices[0].message { + if let Message::Assistant { + content, + tool_calls, + .. + } = &response.choices[0].message + { assert!(content.is_none()); let tool_calls = tool_calls.as_ref().unwrap(); assert_eq!(tool_calls[0].id, "call_123"); assert_eq!(tool_calls[0].function.name, "get_current_weather"); - assert_eq!(tool_calls[0].function.arguments, "{\"location\":\"Boston, MA\"}"); + assert_eq!( + tool_calls[0].function.arguments, + "{\"location\":\"Boston, MA\"}" + ); } else { panic!("Expected assistant message"); } @@ -368,4 +393,23 @@ mod tests { env::remove_var("LLM_API_KEY"); env::remove_var("LLM_BASE_URL"); } + + #[tokio::test] + async fn test_single_message_completion() { + let (api_key, base_url) = setup().await; + let client = LiteLLMClient::new(Some(api_key), Some(base_url)); + + let request = ChatCompletionRequest { + model: "o1".to_string(), + messages: vec![Message::user("Hello, world!".to_string())], + ..Default::default() + }; + + let response = match client.chat_completion(request).await { + Ok(response) => response, + Err(e) => panic!("Error processing thread: {:?}", e), + }; + + assert!(response.choices.len() > 0); + } } diff --git a/api/src/utils/tools/mod.rs b/api/src/utils/tools/mod.rs index 2c1bff28f..ec43f5e85 100644 --- a/api/src/utils/tools/mod.rs +++ b/api/src/utils/tools/mod.rs @@ -4,19 +4,19 @@ use serde_json::Value; use crate::utils::clients::ai::litellm::ToolCall; -mod search_files; -mod create_files; -mod bulk_modify_files; -mod search_data_catalog; -mod open_files; -mod send_to_user; +// mod bulk_modify_files; +// mod create_files; +// mod open_files; +// mod search_data_catalog; +// mod search_files; +// mod send_to_user; -pub use search_files::SearchFilesTool; -pub use create_files::CreateFilesTool; -pub use bulk_modify_files::BulkModifyFilesTool; -pub use search_data_catalog::SearchDataCatalogTool; -pub use open_files::OpenFilesTool; -pub use send_to_user::SendToUserTool; +// pub use bulk_modify_files::BulkModifyFilesTool; +// pub use create_files::CreateFilesTool; +// pub use open_files::OpenFilesTool; +// pub use search_data_catalog::SearchDataCatalogTool; +// pub use search_files::SearchFilesTool; +// pub use send_to_user::SendToUserTool; /// A trait that defines how tools should be implemented. /// Any struct that wants to be used as a tool must implement this trait. @@ -27,4 +27,17 @@ pub trait ToolExecutor: Send + Sync { /// Return the JSON schema that describes this tool's interface fn get_schema(&self) -> serde_json::Value; + + /// Return the name of the tool + fn get_name(&self) -> String; +} + +trait IntoBoxedTool { + fn boxed(self) -> Box; +} + +impl IntoBoxedTool for T { + fn boxed(self) -> Box { + Box::new(self) + } }