mirror of https://github.com/buster-so/buster.git
refactor: Enhance Agent and Tool management with new methods and tests
- Added environment variable-based LLM client initialization in `Agent::new()` - Introduced `add_tool()` and `add_tools()` methods for more flexible tool registration - Implemented new `get_name()` method for `ToolExecutor` trait - Added comprehensive test cases for Agent with and without tools - Updated `AgentThread` with a convenient constructor method - Temporarily commented out unused tool modules - Added debug print in LiteLLM client for response logging
This commit is contained in:
parent
6a73b59aa1
commit
ec04a5e98e
|
@ -94,7 +94,8 @@ html-escape = "0.2.13"
|
|||
async-trait = "0.1.85"
|
||||
|
||||
[dev-dependencies]
|
||||
mockito = "1.2"
|
||||
mockito = "1.2.0"
|
||||
async-trait = "0.1.77"
|
||||
tokio = { version = "1.0", features = ["full", "test-util"] }
|
||||
|
||||
[profile.release]
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
use crate::utils::{clients::ai::litellm::{ChatCompletionRequest, LiteLLMClient, Message, Tool}, tools::ToolExecutor};
|
||||
use crate::utils::{
|
||||
clients::ai::litellm::{ChatCompletionRequest, LiteLLMClient, Message, Tool},
|
||||
tools::ToolExecutor,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use std::collections::HashMap;
|
||||
use std::{collections::HashMap, env};
|
||||
use tokio::sync::mpsc;
|
||||
use async_trait::async_trait;
|
||||
|
||||
use super::types::{AgentThread};
|
||||
use super::types::AgentThread;
|
||||
|
||||
/// The Agent struct is responsible for managing conversations with the LLM
|
||||
/// and coordinating tool executions. It maintains a registry of available tools
|
||||
|
@ -19,11 +23,12 @@ pub struct Agent {
|
|||
|
||||
impl Agent {
|
||||
/// Create a new Agent instance with a specific LLM client and model
|
||||
pub fn new(
|
||||
llm_client: LiteLLMClient,
|
||||
model: String,
|
||||
tools: HashMap<String, Box<dyn ToolExecutor>>,
|
||||
) -> Self {
|
||||
pub fn new(model: String, tools: HashMap<String, Box<dyn ToolExecutor>>) -> Self {
|
||||
let llm_api_key = env::var("LLM_API_KEY").expect("LLM_API_KEY must be set");
|
||||
let llm_base_url = env::var("LLM_BASE_URL").expect("LLM_API_BASE must be set");
|
||||
|
||||
let llm_client = LiteLLMClient::new(Some(llm_api_key), Some(llm_base_url));
|
||||
|
||||
Self {
|
||||
llm_client,
|
||||
tools,
|
||||
|
@ -31,13 +36,23 @@ impl Agent {
|
|||
}
|
||||
}
|
||||
|
||||
/// Register a new tool with the agent
|
||||
/// Add a new tool with the agent
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `name` - The name of the tool, used to identify it in tool calls
|
||||
/// * `tool` - The tool implementation that will be executed
|
||||
pub fn register_tool(&mut self, name: String, tool: Box<dyn ToolExecutor>) {
|
||||
self.tools.insert(name, tool);
|
||||
pub fn add_tool<T: ToolExecutor + 'static>(&mut self, name: String, tool: T) {
|
||||
self.tools.insert(name, Box::new(tool));
|
||||
}
|
||||
|
||||
/// Add multiple tools to the agent at once
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `tools` - HashMap of tool names and their implementations
|
||||
pub fn add_tools<T: ToolExecutor + 'static>(&mut self, tools: HashMap<String, T>) {
|
||||
for (name, tool) in tools {
|
||||
self.tools.insert(name, Box::new(tool));
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a thread of conversation, potentially executing tools and continuing
|
||||
|
@ -63,16 +78,20 @@ impl Agent {
|
|||
let request = ChatCompletionRequest {
|
||||
model: self.model.clone(),
|
||||
messages: thread.messages.clone(),
|
||||
tools: Some(tools),
|
||||
tools: if tools.is_empty() { None } else { Some(tools) },
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Get the response from the LLM
|
||||
let response = self.llm_client.chat_completion(request).await?;
|
||||
let response = match self.llm_client.chat_completion(request).await {
|
||||
Ok(response) => response,
|
||||
Err(e) => return Err(anyhow::anyhow!("Error processing thread: {:?}", e)),
|
||||
};
|
||||
|
||||
let llm_message = &response.choices[0].message;
|
||||
|
||||
// Create the initial assistant message
|
||||
let mut message = match llm_message {
|
||||
let message = match llm_message {
|
||||
Message::Assistant {
|
||||
content,
|
||||
tool_calls,
|
||||
|
@ -179,3 +198,123 @@ impl Agent {
|
|||
Ok(rx)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::utils::clients::ai::litellm::ToolCall;
|
||||
|
||||
use super::*;
|
||||
use dotenv::dotenv;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
fn setup() {
|
||||
dotenv().ok();
|
||||
}
|
||||
|
||||
struct WeatherTool;
|
||||
|
||||
#[async_trait]
|
||||
impl ToolExecutor for WeatherTool {
|
||||
async fn execute(&self, tool_call: &ToolCall) -> Result<Value> {
|
||||
Ok(json!({
|
||||
"temperature": 20,
|
||||
"unit": "fahrenheit"
|
||||
}))
|
||||
}
|
||||
|
||||
fn get_schema(&self) -> Value {
|
||||
json!({
|
||||
"name": "get_weather",
|
||||
"description": "Get current weather information for a specific location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g., San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn get_name(&self) -> String {
|
||||
"get_weather".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_agent_convo_no_tools() {
|
||||
setup();
|
||||
|
||||
// Create LLM client and agent
|
||||
let agent = Agent::new("o1".to_string(), HashMap::new());
|
||||
|
||||
let thread = AgentThread::new(None, vec![Message::user("Hello, world!".to_string())]);
|
||||
|
||||
let response = match agent.process_thread(&thread).await {
|
||||
Ok(response) => response,
|
||||
Err(e) => panic!("Error processing thread: {:?}", e),
|
||||
};
|
||||
|
||||
println!("Response: {:?}", response);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_agent_convo_with_tools() {
|
||||
setup();
|
||||
|
||||
// Create LLM client and agent
|
||||
let mut agent = Agent::new("o1".to_string(), HashMap::new());
|
||||
|
||||
let weather_tool = WeatherTool;
|
||||
|
||||
agent.add_tool(weather_tool.get_name(), weather_tool);
|
||||
|
||||
let thread = AgentThread::new(
|
||||
None,
|
||||
vec![Message::user(
|
||||
"What is the weather in vineyard ut?".to_string(),
|
||||
)],
|
||||
);
|
||||
|
||||
let response = match agent.process_thread(&thread).await {
|
||||
Ok(response) => response,
|
||||
Err(e) => panic!("Error processing thread: {:?}", e),
|
||||
};
|
||||
|
||||
println!("Response: {:?}", response);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_agent_with_multiple_steps() {
|
||||
setup();
|
||||
|
||||
// Create LLM client and agent
|
||||
let mut agent = Agent::new("o1".to_string(), HashMap::new());
|
||||
|
||||
let weather_tool = WeatherTool;
|
||||
|
||||
agent.add_tool(weather_tool.get_name(), weather_tool);
|
||||
|
||||
let thread = AgentThread::new(
|
||||
None,
|
||||
vec![Message::user(
|
||||
"What is the weather in vineyard ut and san francisco?".to_string(),
|
||||
)],
|
||||
);
|
||||
|
||||
let response = match agent.process_thread(&thread).await {
|
||||
Ok(response) => response,
|
||||
Err(e) => panic!("Error processing thread: {:?}", e),
|
||||
};
|
||||
|
||||
println!("Response: {:?}", response);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,3 +11,12 @@ pub struct AgentThread {
|
|||
/// Ordered sequence of messages in the conversation
|
||||
pub messages: Vec<Message>,
|
||||
}
|
||||
|
||||
impl AgentThread {
|
||||
pub fn new(id: Option<String>, messages: Vec<Message>) -> Self {
|
||||
Self {
|
||||
id: id.unwrap_or(uuid::Uuid::new_v4().to_string()),
|
||||
messages,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -61,6 +61,9 @@ impl LiteLLMClient {
|
|||
.await?
|
||||
.json::<ChatCompletionResponse>()
|
||||
.await?;
|
||||
|
||||
println!("Response: {:?}", response);
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
|
@ -134,6 +137,20 @@ mod tests {
|
|||
use std::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use dotenv::dotenv;
|
||||
|
||||
// Helper function to initialize environment before tests
|
||||
async fn setup() -> (String, String) {
|
||||
// Load environment variables from .env file
|
||||
dotenv().ok();
|
||||
|
||||
// Get API key and base URL from environment
|
||||
let api_key = env::var("LLM_API_KEY").expect("LLM_API_KEY must be set");
|
||||
let base_url = env::var("LLM_BASE_URL").expect("LLM_API_BASE must be set");
|
||||
|
||||
(api_key, base_url)
|
||||
}
|
||||
|
||||
fn create_test_message() -> Message {
|
||||
Message::user("Hello".to_string())
|
||||
}
|
||||
|
@ -329,12 +346,20 @@ mod tests {
|
|||
|
||||
let response = client.chat_completion(request).await.unwrap();
|
||||
assert_eq!(response.id, "test-id");
|
||||
if let Message::Assistant { content, tool_calls, .. } = &response.choices[0].message {
|
||||
if let Message::Assistant {
|
||||
content,
|
||||
tool_calls,
|
||||
..
|
||||
} = &response.choices[0].message
|
||||
{
|
||||
assert!(content.is_none());
|
||||
let tool_calls = tool_calls.as_ref().unwrap();
|
||||
assert_eq!(tool_calls[0].id, "call_123");
|
||||
assert_eq!(tool_calls[0].function.name, "get_current_weather");
|
||||
assert_eq!(tool_calls[0].function.arguments, "{\"location\":\"Boston, MA\"}");
|
||||
assert_eq!(
|
||||
tool_calls[0].function.arguments,
|
||||
"{\"location\":\"Boston, MA\"}"
|
||||
);
|
||||
} else {
|
||||
panic!("Expected assistant message");
|
||||
}
|
||||
|
@ -368,4 +393,23 @@ mod tests {
|
|||
env::remove_var("LLM_API_KEY");
|
||||
env::remove_var("LLM_BASE_URL");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_single_message_completion() {
|
||||
let (api_key, base_url) = setup().await;
|
||||
let client = LiteLLMClient::new(Some(api_key), Some(base_url));
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model: "o1".to_string(),
|
||||
messages: vec![Message::user("Hello, world!".to_string())],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let response = match client.chat_completion(request).await {
|
||||
Ok(response) => response,
|
||||
Err(e) => panic!("Error processing thread: {:?}", e),
|
||||
};
|
||||
|
||||
assert!(response.choices.len() > 0);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,19 +4,19 @@ use serde_json::Value;
|
|||
|
||||
use crate::utils::clients::ai::litellm::ToolCall;
|
||||
|
||||
mod search_files;
|
||||
mod create_files;
|
||||
mod bulk_modify_files;
|
||||
mod search_data_catalog;
|
||||
mod open_files;
|
||||
mod send_to_user;
|
||||
// mod bulk_modify_files;
|
||||
// mod create_files;
|
||||
// mod open_files;
|
||||
// mod search_data_catalog;
|
||||
// mod search_files;
|
||||
// mod send_to_user;
|
||||
|
||||
pub use search_files::SearchFilesTool;
|
||||
pub use create_files::CreateFilesTool;
|
||||
pub use bulk_modify_files::BulkModifyFilesTool;
|
||||
pub use search_data_catalog::SearchDataCatalogTool;
|
||||
pub use open_files::OpenFilesTool;
|
||||
pub use send_to_user::SendToUserTool;
|
||||
// pub use bulk_modify_files::BulkModifyFilesTool;
|
||||
// pub use create_files::CreateFilesTool;
|
||||
// pub use open_files::OpenFilesTool;
|
||||
// pub use search_data_catalog::SearchDataCatalogTool;
|
||||
// pub use search_files::SearchFilesTool;
|
||||
// pub use send_to_user::SendToUserTool;
|
||||
|
||||
/// A trait that defines how tools should be implemented.
|
||||
/// Any struct that wants to be used as a tool must implement this trait.
|
||||
|
@ -27,4 +27,17 @@ pub trait ToolExecutor: Send + Sync {
|
|||
|
||||
/// Return the JSON schema that describes this tool's interface
|
||||
fn get_schema(&self) -> serde_json::Value;
|
||||
|
||||
/// Return the name of the tool
|
||||
fn get_name(&self) -> String;
|
||||
}
|
||||
|
||||
trait IntoBoxedTool {
|
||||
fn boxed(self) -> Box<dyn ToolExecutor>;
|
||||
}
|
||||
|
||||
impl<T: ToolExecutor + 'static> IntoBoxedTool for T {
|
||||
fn boxed(self) -> Box<dyn ToolExecutor> {
|
||||
Box::new(self)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue