mirror of https://github.com/buster-so/buster.git
refactor: Enhance Agent and Tool management with new methods and tests
- Added environment variable-based LLM client initialization in `Agent::new()` - Introduced `add_tool()` and `add_tools()` methods for more flexible tool registration - Implemented new `get_name()` method for `ToolExecutor` trait - Added comprehensive test cases for Agent with and without tools - Updated `AgentThread` with a convenient constructor method - Temporarily commented out unused tool modules - Added debug print in LiteLLM client for response logging
This commit is contained in:
parent
6a73b59aa1
commit
ec04a5e98e
|
@ -94,7 +94,8 @@ html-escape = "0.2.13"
|
||||||
async-trait = "0.1.85"
|
async-trait = "0.1.85"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
mockito = "1.2"
|
mockito = "1.2.0"
|
||||||
|
async-trait = "0.1.77"
|
||||||
tokio = { version = "1.0", features = ["full", "test-util"] }
|
tokio = { version = "1.0", features = ["full", "test-util"] }
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
|
|
|
@ -1,9 +1,13 @@
|
||||||
use crate::utils::{clients::ai::litellm::{ChatCompletionRequest, LiteLLMClient, Message, Tool}, tools::ToolExecutor};
|
use crate::utils::{
|
||||||
|
clients::ai::litellm::{ChatCompletionRequest, LiteLLMClient, Message, Tool},
|
||||||
|
tools::ToolExecutor,
|
||||||
|
};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use std::collections::HashMap;
|
use std::{collections::HashMap, env};
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
use super::types::{AgentThread};
|
use super::types::AgentThread;
|
||||||
|
|
||||||
/// The Agent struct is responsible for managing conversations with the LLM
|
/// The Agent struct is responsible for managing conversations with the LLM
|
||||||
/// and coordinating tool executions. It maintains a registry of available tools
|
/// and coordinating tool executions. It maintains a registry of available tools
|
||||||
|
@ -19,11 +23,12 @@ pub struct Agent {
|
||||||
|
|
||||||
impl Agent {
|
impl Agent {
|
||||||
/// Create a new Agent instance with a specific LLM client and model
|
/// Create a new Agent instance with a specific LLM client and model
|
||||||
pub fn new(
|
pub fn new(model: String, tools: HashMap<String, Box<dyn ToolExecutor>>) -> Self {
|
||||||
llm_client: LiteLLMClient,
|
let llm_api_key = env::var("LLM_API_KEY").expect("LLM_API_KEY must be set");
|
||||||
model: String,
|
let llm_base_url = env::var("LLM_BASE_URL").expect("LLM_API_BASE must be set");
|
||||||
tools: HashMap<String, Box<dyn ToolExecutor>>,
|
|
||||||
) -> Self {
|
let llm_client = LiteLLMClient::new(Some(llm_api_key), Some(llm_base_url));
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
llm_client,
|
llm_client,
|
||||||
tools,
|
tools,
|
||||||
|
@ -31,13 +36,23 @@ impl Agent {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Register a new tool with the agent
|
/// Add a new tool with the agent
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * `name` - The name of the tool, used to identify it in tool calls
|
/// * `name` - The name of the tool, used to identify it in tool calls
|
||||||
/// * `tool` - The tool implementation that will be executed
|
/// * `tool` - The tool implementation that will be executed
|
||||||
pub fn register_tool(&mut self, name: String, tool: Box<dyn ToolExecutor>) {
|
pub fn add_tool<T: ToolExecutor + 'static>(&mut self, name: String, tool: T) {
|
||||||
self.tools.insert(name, tool);
|
self.tools.insert(name, Box::new(tool));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add multiple tools to the agent at once
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `tools` - HashMap of tool names and their implementations
|
||||||
|
pub fn add_tools<T: ToolExecutor + 'static>(&mut self, tools: HashMap<String, T>) {
|
||||||
|
for (name, tool) in tools {
|
||||||
|
self.tools.insert(name, Box::new(tool));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Process a thread of conversation, potentially executing tools and continuing
|
/// Process a thread of conversation, potentially executing tools and continuing
|
||||||
|
@ -63,16 +78,20 @@ impl Agent {
|
||||||
let request = ChatCompletionRequest {
|
let request = ChatCompletionRequest {
|
||||||
model: self.model.clone(),
|
model: self.model.clone(),
|
||||||
messages: thread.messages.clone(),
|
messages: thread.messages.clone(),
|
||||||
tools: Some(tools),
|
tools: if tools.is_empty() { None } else { Some(tools) },
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
// Get the response from the LLM
|
// Get the response from the LLM
|
||||||
let response = self.llm_client.chat_completion(request).await?;
|
let response = match self.llm_client.chat_completion(request).await {
|
||||||
|
Ok(response) => response,
|
||||||
|
Err(e) => return Err(anyhow::anyhow!("Error processing thread: {:?}", e)),
|
||||||
|
};
|
||||||
|
|
||||||
let llm_message = &response.choices[0].message;
|
let llm_message = &response.choices[0].message;
|
||||||
|
|
||||||
// Create the initial assistant message
|
// Create the initial assistant message
|
||||||
let mut message = match llm_message {
|
let message = match llm_message {
|
||||||
Message::Assistant {
|
Message::Assistant {
|
||||||
content,
|
content,
|
||||||
tool_calls,
|
tool_calls,
|
||||||
|
@ -178,4 +197,124 @@ impl Agent {
|
||||||
|
|
||||||
Ok(rx)
|
Ok(rx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use crate::utils::clients::ai::litellm::ToolCall;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use dotenv::dotenv;
|
||||||
|
use serde_json::{json, Value};
|
||||||
|
|
||||||
|
fn setup() {
|
||||||
|
dotenv().ok();
|
||||||
|
}
|
||||||
|
|
||||||
|
struct WeatherTool;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ToolExecutor for WeatherTool {
|
||||||
|
async fn execute(&self, tool_call: &ToolCall) -> Result<Value> {
|
||||||
|
Ok(json!({
|
||||||
|
"temperature": 20,
|
||||||
|
"unit": "fahrenheit"
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_schema(&self) -> Value {
|
||||||
|
json!({
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get current weather information for a specific location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g., San Francisco, CA"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location"]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_name(&self) -> String {
|
||||||
|
"get_weather".to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_agent_convo_no_tools() {
|
||||||
|
setup();
|
||||||
|
|
||||||
|
// Create LLM client and agent
|
||||||
|
let agent = Agent::new("o1".to_string(), HashMap::new());
|
||||||
|
|
||||||
|
let thread = AgentThread::new(None, vec![Message::user("Hello, world!".to_string())]);
|
||||||
|
|
||||||
|
let response = match agent.process_thread(&thread).await {
|
||||||
|
Ok(response) => response,
|
||||||
|
Err(e) => panic!("Error processing thread: {:?}", e),
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("Response: {:?}", response);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_agent_convo_with_tools() {
|
||||||
|
setup();
|
||||||
|
|
||||||
|
// Create LLM client and agent
|
||||||
|
let mut agent = Agent::new("o1".to_string(), HashMap::new());
|
||||||
|
|
||||||
|
let weather_tool = WeatherTool;
|
||||||
|
|
||||||
|
agent.add_tool(weather_tool.get_name(), weather_tool);
|
||||||
|
|
||||||
|
let thread = AgentThread::new(
|
||||||
|
None,
|
||||||
|
vec![Message::user(
|
||||||
|
"What is the weather in vineyard ut?".to_string(),
|
||||||
|
)],
|
||||||
|
);
|
||||||
|
|
||||||
|
let response = match agent.process_thread(&thread).await {
|
||||||
|
Ok(response) => response,
|
||||||
|
Err(e) => panic!("Error processing thread: {:?}", e),
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("Response: {:?}", response);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_agent_with_multiple_steps() {
|
||||||
|
setup();
|
||||||
|
|
||||||
|
// Create LLM client and agent
|
||||||
|
let mut agent = Agent::new("o1".to_string(), HashMap::new());
|
||||||
|
|
||||||
|
let weather_tool = WeatherTool;
|
||||||
|
|
||||||
|
agent.add_tool(weather_tool.get_name(), weather_tool);
|
||||||
|
|
||||||
|
let thread = AgentThread::new(
|
||||||
|
None,
|
||||||
|
vec![Message::user(
|
||||||
|
"What is the weather in vineyard ut and san francisco?".to_string(),
|
||||||
|
)],
|
||||||
|
);
|
||||||
|
|
||||||
|
let response = match agent.process_thread(&thread).await {
|
||||||
|
Ok(response) => response,
|
||||||
|
Err(e) => panic!("Error processing thread: {:?}", e),
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("Response: {:?}", response);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -11,3 +11,12 @@ pub struct AgentThread {
|
||||||
/// Ordered sequence of messages in the conversation
|
/// Ordered sequence of messages in the conversation
|
||||||
pub messages: Vec<Message>,
|
pub messages: Vec<Message>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl AgentThread {
|
||||||
|
pub fn new(id: Option<String>, messages: Vec<Message>) -> Self {
|
||||||
|
Self {
|
||||||
|
id: id.unwrap_or(uuid::Uuid::new_v4().to_string()),
|
||||||
|
messages,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -61,6 +61,9 @@ impl LiteLLMClient {
|
||||||
.await?
|
.await?
|
||||||
.json::<ChatCompletionResponse>()
|
.json::<ChatCompletionResponse>()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
println!("Response: {:?}", response);
|
||||||
|
|
||||||
Ok(response)
|
Ok(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -134,6 +137,20 @@ mod tests {
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::time::timeout;
|
use tokio::time::timeout;
|
||||||
|
|
||||||
|
use dotenv::dotenv;
|
||||||
|
|
||||||
|
// Helper function to initialize environment before tests
|
||||||
|
async fn setup() -> (String, String) {
|
||||||
|
// Load environment variables from .env file
|
||||||
|
dotenv().ok();
|
||||||
|
|
||||||
|
// Get API key and base URL from environment
|
||||||
|
let api_key = env::var("LLM_API_KEY").expect("LLM_API_KEY must be set");
|
||||||
|
let base_url = env::var("LLM_BASE_URL").expect("LLM_API_BASE must be set");
|
||||||
|
|
||||||
|
(api_key, base_url)
|
||||||
|
}
|
||||||
|
|
||||||
fn create_test_message() -> Message {
|
fn create_test_message() -> Message {
|
||||||
Message::user("Hello".to_string())
|
Message::user("Hello".to_string())
|
||||||
}
|
}
|
||||||
|
@ -329,12 +346,20 @@ mod tests {
|
||||||
|
|
||||||
let response = client.chat_completion(request).await.unwrap();
|
let response = client.chat_completion(request).await.unwrap();
|
||||||
assert_eq!(response.id, "test-id");
|
assert_eq!(response.id, "test-id");
|
||||||
if let Message::Assistant { content, tool_calls, .. } = &response.choices[0].message {
|
if let Message::Assistant {
|
||||||
|
content,
|
||||||
|
tool_calls,
|
||||||
|
..
|
||||||
|
} = &response.choices[0].message
|
||||||
|
{
|
||||||
assert!(content.is_none());
|
assert!(content.is_none());
|
||||||
let tool_calls = tool_calls.as_ref().unwrap();
|
let tool_calls = tool_calls.as_ref().unwrap();
|
||||||
assert_eq!(tool_calls[0].id, "call_123");
|
assert_eq!(tool_calls[0].id, "call_123");
|
||||||
assert_eq!(tool_calls[0].function.name, "get_current_weather");
|
assert_eq!(tool_calls[0].function.name, "get_current_weather");
|
||||||
assert_eq!(tool_calls[0].function.arguments, "{\"location\":\"Boston, MA\"}");
|
assert_eq!(
|
||||||
|
tool_calls[0].function.arguments,
|
||||||
|
"{\"location\":\"Boston, MA\"}"
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
panic!("Expected assistant message");
|
panic!("Expected assistant message");
|
||||||
}
|
}
|
||||||
|
@ -368,4 +393,23 @@ mod tests {
|
||||||
env::remove_var("LLM_API_KEY");
|
env::remove_var("LLM_API_KEY");
|
||||||
env::remove_var("LLM_BASE_URL");
|
env::remove_var("LLM_BASE_URL");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_single_message_completion() {
|
||||||
|
let (api_key, base_url) = setup().await;
|
||||||
|
let client = LiteLLMClient::new(Some(api_key), Some(base_url));
|
||||||
|
|
||||||
|
let request = ChatCompletionRequest {
|
||||||
|
model: "o1".to_string(),
|
||||||
|
messages: vec![Message::user("Hello, world!".to_string())],
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = match client.chat_completion(request).await {
|
||||||
|
Ok(response) => response,
|
||||||
|
Err(e) => panic!("Error processing thread: {:?}", e),
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!(response.choices.len() > 0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,19 +4,19 @@ use serde_json::Value;
|
||||||
|
|
||||||
use crate::utils::clients::ai::litellm::ToolCall;
|
use crate::utils::clients::ai::litellm::ToolCall;
|
||||||
|
|
||||||
mod search_files;
|
// mod bulk_modify_files;
|
||||||
mod create_files;
|
// mod create_files;
|
||||||
mod bulk_modify_files;
|
// mod open_files;
|
||||||
mod search_data_catalog;
|
// mod search_data_catalog;
|
||||||
mod open_files;
|
// mod search_files;
|
||||||
mod send_to_user;
|
// mod send_to_user;
|
||||||
|
|
||||||
pub use search_files::SearchFilesTool;
|
// pub use bulk_modify_files::BulkModifyFilesTool;
|
||||||
pub use create_files::CreateFilesTool;
|
// pub use create_files::CreateFilesTool;
|
||||||
pub use bulk_modify_files::BulkModifyFilesTool;
|
// pub use open_files::OpenFilesTool;
|
||||||
pub use search_data_catalog::SearchDataCatalogTool;
|
// pub use search_data_catalog::SearchDataCatalogTool;
|
||||||
pub use open_files::OpenFilesTool;
|
// pub use search_files::SearchFilesTool;
|
||||||
pub use send_to_user::SendToUserTool;
|
// pub use send_to_user::SendToUserTool;
|
||||||
|
|
||||||
/// A trait that defines how tools should be implemented.
|
/// A trait that defines how tools should be implemented.
|
||||||
/// Any struct that wants to be used as a tool must implement this trait.
|
/// Any struct that wants to be used as a tool must implement this trait.
|
||||||
|
@ -27,4 +27,17 @@ pub trait ToolExecutor: Send + Sync {
|
||||||
|
|
||||||
/// Return the JSON schema that describes this tool's interface
|
/// Return the JSON schema that describes this tool's interface
|
||||||
fn get_schema(&self) -> serde_json::Value;
|
fn get_schema(&self) -> serde_json::Value;
|
||||||
|
|
||||||
|
/// Return the name of the tool
|
||||||
|
fn get_name(&self) -> String;
|
||||||
|
}
|
||||||
|
|
||||||
|
trait IntoBoxedTool {
|
||||||
|
fn boxed(self) -> Box<dyn ToolExecutor>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: ToolExecutor + 'static> IntoBoxedTool for T {
|
||||||
|
fn boxed(self) -> Box<dyn ToolExecutor> {
|
||||||
|
Box::new(self)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue