refactor: Update Agent and Thread types for improved message handling

- Renamed `Thread` to `AgentThread` for clarity
- Modified `ToolExecutor` to return `serde_json::Value` instead of `String`
- Updated message processing to handle new message structures
- Improved content handling in streaming and tool call scenarios
- Simplified message content extraction and serialization
This commit is contained in:
dal 2025-01-26 08:45:49 -07:00
parent aeb1a02ba1
commit 692f8f7a1d
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
2 changed files with 31 additions and 546 deletions

View File

@ -3,7 +3,7 @@ use anyhow::Result;
use std::collections::HashMap; use std::collections::HashMap;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use super::types::{Thread, ToolExecutor}; use super::types::{AgentThread, ToolExecutor};
/// The Agent struct is responsible for managing conversations with the LLM /// The Agent struct is responsible for managing conversations with the LLM
/// and coordinating tool executions. It maintains a registry of available tools /// and coordinating tool executions. It maintains a registry of available tools
@ -48,7 +48,7 @@ impl Agent {
/// ///
/// # Returns /// # Returns
/// * A Result containing the final Message from the assistant /// * A Result containing the final Message from the assistant
pub async fn process_thread(&self, thread: &Thread) -> Result<Message> { pub async fn process_thread(&self, thread: &AgentThread) -> Result<Message> {
// Collect all registered tools and their schemas // Collect all registered tools and their schemas
let tools: Vec<Tool> = self let tools: Vec<Tool> = self
.tools .tools
@ -72,10 +72,21 @@ impl Agent {
let llm_message = &response.choices[0].message; let llm_message = &response.choices[0].message;
// Create the initial assistant message // Create the initial assistant message
let mut message = Message::assistant(None, llm_message.tool_calls.clone()); let mut message = match llm_message {
Message::Assistant {
content,
tool_calls,
..
} => Message::assistant(content.clone(), tool_calls.clone()),
_ => return Err(anyhow::anyhow!("Expected assistant message from LLM")),
};
// If the LLM wants to use tools, execute them // If the LLM wants to use tools, execute them
if let Some(tool_calls) = &llm_message.tool_calls { if let Message::Assistant {
tool_calls: Some(tool_calls),
..
} = &llm_message
{
let mut results = Vec::new(); let mut results = Vec::new();
// Execute each requested tool // Execute each requested tool
@ -83,7 +94,10 @@ impl Agent {
if let Some(tool) = self.tools.get(&tool_call.function.name) { if let Some(tool) = self.tools.get(&tool_call.function.name) {
let result = tool.execute(tool_call).await?; let result = tool.execute(tool_call).await?;
// Create a message for the tool's response // Create a message for the tool's response
results.push(Message::tool(result, tool_call.id.clone())); results.push(Message::tool(
serde_json::to_string(&result).unwrap(),
tool_call.id.clone(),
));
} }
} }
@ -94,12 +108,6 @@ impl Agent {
Box::pin(self.process_thread(&new_thread)).await Box::pin(self.process_thread(&new_thread)).await
} else { } else {
// If no tools were called, return the final response
message.content = if let Some(content) = &llm_message.content {
Some(content)
} else {
None
};
Ok(message) Ok(message)
} }
} }
@ -113,7 +121,7 @@ impl Agent {
/// * A Result containing a receiver for streamed messages /// * A Result containing a receiver for streamed messages
pub async fn stream_process_thread( pub async fn stream_process_thread(
&self, &self,
thread: &Thread, thread: &AgentThread,
) -> Result<mpsc::Receiver<Result<Message>>> { ) -> Result<mpsc::Receiver<Result<Message>>> {
// Collect all registered tools and their schemas // Collect all registered tools and their schemas
let tools: Vec<Tool> = self let tools: Vec<Tool> = self
@ -148,8 +156,14 @@ impl Agent {
let delta = &chunk.choices[0].delta; let delta = &chunk.choices[0].delta;
// Update the message with the new content // Update the message with the new content
if let Some(content) = &mut current_message.content { if let Message::Assistant { content, .. } = &mut current_message {
content[0].text.push_str(&delta.content); if let Some(new_content) = &delta.content {
*content = Some(if let Some(existing) = content {
format!("{}{}", existing, new_content)
} else {
new_content.clone()
});
}
} }
let _ = tx.send(Ok(current_message.clone())).await; let _ = tx.send(Ok(current_message.clone())).await;
@ -165,533 +179,3 @@ impl Agent {
Ok(rx) Ok(rx)
} }
} }
#[cfg(test)]
mod tests {
use crate::utils::clients::ai::litellm::ToolCall;
use super::*;
use async_trait::async_trait;
use mockito;
use serde_json::json;
/// Mock weather tools for testing
struct CurrentWeatherTool;
struct ForecastTool;
struct WeatherAlertsTool;
struct AirQualityTool;
#[async_trait]
impl ToolExecutor for CurrentWeatherTool {
async fn execute(&self, tool_call: &ToolCall) -> Result<String> {
let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)?;
let city = args["city"].as_str().unwrap_or("unknown");
match city.to_lowercase().as_str() {
"salt lake city" => {
Ok("Current weather in Salt Lake City: 75°F, Sunny, Humidity: 45%".to_string())
}
"new york" => {
Ok("Current weather in New York: 68°F, Cloudy, Humidity: 72%".to_string())
}
_ => Ok(format!("Weather data not available for {}", city)),
}
}
fn get_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to get current weather for"
}
},
"required": ["city"]
})
}
}
#[async_trait]
impl ToolExecutor for ForecastTool {
async fn execute(&self, tool_call: &ToolCall) -> Result<String> {
let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)?;
let city = args["city"].as_str().unwrap_or("unknown");
let days = args["days"].as_i64().unwrap_or(7);
if city.to_lowercase() == "salt lake city" {
Ok(format!("{}-day forecast for Salt Lake City:\nDay 1: 75°F, Sunny\nDay 2: 78°F, Partly Cloudy\nDay 3: 72°F, Sunny", days))
} else {
Ok(format!("{}-day forecast not available for {}", days, city))
}
}
fn get_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to get forecast for"
},
"days": {
"type": "integer",
"description": "Number of days to forecast (1-7)",
"minimum": 1,
"maximum": 7
}
},
"required": ["city"]
})
}
}
#[async_trait]
impl ToolExecutor for WeatherAlertsTool {
async fn execute(&self, tool_call: &ToolCall) -> Result<String> {
let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)?;
let city = args["city"].as_str().unwrap_or("unknown");
match city.to_lowercase().as_str() {
"salt lake city" => Ok("No current weather alerts for Salt Lake City".to_string()),
"miami" => Ok("ALERT: Tropical Storm Warning for Miami area".to_string()),
_ => Ok(format!("Weather alerts not available for {}", city)),
}
}
fn get_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to get weather alerts for"
}
},
"required": ["city"]
})
}
}
#[async_trait]
impl ToolExecutor for AirQualityTool {
async fn execute(&self, tool_call: &ToolCall) -> Result<String> {
let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)?;
let city = args["city"].as_str().unwrap_or("unknown");
match city.to_lowercase().as_str() {
"salt lake city" => Ok("Air Quality in Salt Lake City: Good (AQI: 42)".to_string()),
"beijing" => Ok("Air Quality in Beijing: Unhealthy (AQI: 152)".to_string()),
_ => Ok(format!("Air quality data not available for {}", city)),
}
}
fn get_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to get air quality for"
}
},
"required": ["city"]
})
}
}
/// Helper function to create a mock server with a specific response
async fn setup_mock_server(
response_body: &str,
content_type: Option<&str>,
) -> (mockito::ServerGuard, String) {
let mut server = mockito::Server::new_async().await;
let mock = server
.mock("POST", "/v1/chat/completions")
.with_status(200)
.with_header("content-type", content_type.unwrap_or("application/json"))
.with_body(response_body)
.create_async()
.await;
let url = server.url();
(server, url)
}
/// Helper function to create an agent with weather tools
fn create_weather_agent(client: LiteLLMClient) -> Agent {
let mut tools: HashMap<String, Box<dyn ToolExecutor>> = HashMap::new();
tools.insert(
"get_current_weather".to_string(),
Box::new(CurrentWeatherTool),
);
tools.insert("get_forecast".to_string(), Box::new(ForecastTool));
tools.insert(
"get_weather_alerts".to_string(),
Box::new(WeatherAlertsTool),
);
tools.insert("get_air_quality".to_string(), Box::new(AirQualityTool));
Agent::new(client, "gpt-4".to_string(), tools)
}
/// Test 1: Direct text response without tool calls
#[tokio::test]
async fn test_direct_response() {
let response = r#"{
"id": "test-id",
"object": "chat.completion",
"created": 1234567890,
"model": "gpt-4",
"system_fingerprint": "fp_44709d6fcb",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": [{
"text": "The weather service is currently available for Salt Lake City, New York, Miami, and Beijing.",
"type": "text"
}]
},
"logprobs": null,
"finish_reason": "stop"
}],
"service_tier": "default",
"usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30,
"completion_tokens_details": {
"reasoning_tokens": 0,
"accepted_prediction_tokens": 0,
"rejected_prediction_tokens": 0
}
}
}"#;
let (server, url) = setup_mock_server(response, None).await;
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(url));
let agent = create_weather_agent(client);
let thread = Thread {
id: "test-thread".to_string(),
messages: vec![Message::user("What cities can I get weather for?".to_string())],
};
let result = match agent.process_thread(&thread).await {
Ok(result) => result,
Err(e) => {
println!("Error processing thread: {:?}", e);
println!("Error chain:");
let mut source = e.source();
while let Some(e) = source {
println!(" Caused by: {}", e);
source = e.source();
}
panic!("Test failed due to error");
}
};
let content = result.content.unwrap();
assert!(content[0].text.contains("Salt Lake City"));
}
/// Test 2: Single tool call then response
#[tokio::test]
async fn test_single_tool_call() {
let first_response = r#"{
"id": "test-id",
"object": "chat.completion",
"created": 1234567890,
"model": "gpt-4",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": [],
"tool_calls": [{
"id": "call_123",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": "{\"city\":\"Salt Lake City\"}",
"call_type": "function"
}
}]
},
"finish_reason": "tool_calls"
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30
}
}"#;
let second_response = r#"{
"id": "test-id",
"object": "chat.completion",
"created": 1234567890,
"model": "gpt-4",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": [{
"text": "The current weather in Salt Lake City is 75°F, Sunny, with 45% humidity.",
"type": "text"
}]
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 15,
"completion_tokens": 25,
"total_tokens": 40
}
}"#;
let mut server = mockito::Server::new_async().await;
let mock1 = server
.mock("POST", "/v1/chat/completions")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(first_response)
.expect(1)
.create_async()
.await;
let mock2 = server
.mock("POST", "/v1/chat/completions")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(second_response)
.expect(1)
.create_async()
.await;
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url()));
let agent = create_weather_agent(client);
let thread = Thread {
id: "test-thread".to_string(),
messages: vec![Message::user(
"What's the current weather in Salt Lake City?".to_string(),
)],
};
let result = agent.process_thread(&thread).await.unwrap();
let content = result.content.unwrap();
assert!(content[0].text.contains("75°F"));
}
/// Test 3: Multiple tool calls then response
#[tokio::test]
async fn test_multiple_tool_calls() {
let first_response = r#"{
"id": "test-id",
"object": "chat.completion",
"created": 1234567890,
"model": "gpt-4",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": [],
"tool_calls": [
{
"id": "call_123",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": "{\"city\":\"Salt Lake City\"}",
"call_type": "function"
}
},
{
"id": "call_124",
"type": "function",
"function": {
"name": "get_air_quality",
"arguments": "{\"city\":\"Salt Lake City\"}",
"call_type": "function"
}
}
]
},
"finish_reason": "tool_calls"
}],
"usage": {
"prompt_tokens": 12,
"completion_tokens": 22,
"total_tokens": 34
}
}"#;
let second_response = r#"{
"id": "test-id",
"object": "chat.completion",
"created": 1234567890,
"model": "gpt-4",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": [{
"text": "In Salt Lake City, it's currently 75°F and sunny with good air quality (AQI: 42).",
"type": "text"
}]
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 18,
"completion_tokens": 28,
"total_tokens": 46
}
}"#;
let mut server = mockito::Server::new_async().await;
let mock1 = server
.mock("POST", "/v1/chat/completions")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(first_response)
.expect(1)
.create_async()
.await;
let mock2 = server
.mock("POST", "/v1/chat/completions")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(second_response)
.expect(1)
.create_async()
.await;
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url()));
let agent = create_weather_agent(client);
let thread = Thread {
id: "test-thread".to_string(),
messages: vec![Message::user(
"What's the weather and air quality in Salt Lake City?".to_string(),
)],
};
let result = agent.process_thread(&thread).await.unwrap();
let content = result.content.unwrap();
assert!(content[0].text.contains("75°F") && content[0].text.contains("AQI"));
}
/// Test 4: Streaming response
#[tokio::test]
async fn test_streaming_response() {
let stream_responses = vec![
serde_json::json!({
"id": "test-id",
"object": "chat.completion.chunk",
"created": 1234567890,
"model": "gpt-4",
"choices": [{
"index": 0,
"delta": {
"role": "assistant",
"content": [{
"text": "The current",
"type": "text"
}]
},
"finish_reason": null
}]
}),
serde_json::json!({
"id": "test-id",
"object": "chat.completion.chunk",
"created": 1234567890,
"model": "gpt-4",
"choices": [{
"index": 0,
"delta": {
"content": [{
"text": " weather",
"type": "text"
}]
},
"finish_reason": null
}]
}),
serde_json::json!({
"id": "test-id",
"object": "chat.completion.chunk",
"created": 1234567890,
"model": "gpt-4",
"choices": [{
"index": 0,
"delta": {
"content": [{
"text": " in Salt Lake City",
"type": "text"
}]
},
"finish_reason": null
}]
}),
serde_json::json!({
"id": "test-id",
"object": "chat.completion.chunk",
"created": 1234567890,
"model": "gpt-4",
"choices": [{
"index": 0,
"delta": {
"content": [{
"text": " is sunny",
"type": "text"
}]
},
"finish_reason": "stop"
}]
}),
];
let stream_body = stream_responses
.iter()
.map(|r| format!("data: {}\n\n", r.to_string()))
.collect::<String>()
+ "data: [DONE]\n\n";
let mut server = mockito::Server::new_async().await;
let mock = server
.mock("POST", "/v1/chat/completions")
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body(stream_body)
.create_async()
.await;
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url()));
let agent = create_weather_agent(client);
let thread = Thread {
id: "test-thread".to_string(),
messages: vec![Message::user(
"What's the current weather in Salt Lake City?".to_string(),
)],
};
let mut stream = agent.stream_process_thread(&thread).await.unwrap();
let mut received_content = String::new();
while let Some(message_result) = stream.recv().await {
let message = message_result.unwrap();
received_content = message.content.unwrap()[0].text.clone();
}
assert_eq!(
received_content,
"The current weather in Salt Lake City is sunny"
);
}
}

View File

@ -1,13 +1,14 @@
use anyhow::Result; use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::utils::clients::ai::litellm::{Message, ToolCall}; use crate::utils::clients::ai::litellm::{Message, ToolCall};
/// A Thread represents a conversation between a user and the AI agent. /// A Thread represents a conversation between a user and the AI agent.
/// It contains a sequence of messages in chronological order. /// It contains a sequence of messages in chronological order.
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Thread { pub struct AgentThread {
/// Unique identifier for the thread /// Unique identifier for the thread
pub id: String, pub id: String,
/// Ordered sequence of messages in the conversation /// Ordered sequence of messages in the conversation
@ -19,7 +20,7 @@ pub struct Thread {
#[async_trait] #[async_trait]
pub trait ToolExecutor: Send + Sync { pub trait ToolExecutor: Send + Sync {
/// Execute the tool with given arguments and return a result /// Execute the tool with given arguments and return a result
async fn execute(&self, tool_call: &ToolCall) -> Result<String>; async fn execute(&self, tool_call: &ToolCall) -> Result<Value>;
/// Return the JSON schema that describes this tool's interface /// Return the JSON schema that describes this tool's interface
fn get_schema(&self) -> serde_json::Value; fn get_schema(&self) -> serde_json::Value;