mirror of https://github.com/buster-so/buster.git
refactor: Update Agent and Thread types for improved message handling
- Renamed `Thread` to `AgentThread` for clarity - Modified `ToolExecutor` to return `serde_json::Value` instead of `String` - Updated message processing to handle new message structures - Improved content handling in streaming and tool call scenarios - Simplified message content extraction and serialization
This commit is contained in:
parent
aeb1a02ba1
commit
692f8f7a1d
|
@ -3,7 +3,7 @@ use anyhow::Result;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
use super::types::{Thread, ToolExecutor};
|
use super::types::{AgentThread, ToolExecutor};
|
||||||
|
|
||||||
/// The Agent struct is responsible for managing conversations with the LLM
|
/// The Agent struct is responsible for managing conversations with the LLM
|
||||||
/// and coordinating tool executions. It maintains a registry of available tools
|
/// and coordinating tool executions. It maintains a registry of available tools
|
||||||
|
@ -48,7 +48,7 @@ impl Agent {
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// * A Result containing the final Message from the assistant
|
/// * A Result containing the final Message from the assistant
|
||||||
pub async fn process_thread(&self, thread: &Thread) -> Result<Message> {
|
pub async fn process_thread(&self, thread: &AgentThread) -> Result<Message> {
|
||||||
// Collect all registered tools and their schemas
|
// Collect all registered tools and their schemas
|
||||||
let tools: Vec<Tool> = self
|
let tools: Vec<Tool> = self
|
||||||
.tools
|
.tools
|
||||||
|
@ -72,10 +72,21 @@ impl Agent {
|
||||||
let llm_message = &response.choices[0].message;
|
let llm_message = &response.choices[0].message;
|
||||||
|
|
||||||
// Create the initial assistant message
|
// Create the initial assistant message
|
||||||
let mut message = Message::assistant(None, llm_message.tool_calls.clone());
|
let mut message = match llm_message {
|
||||||
|
Message::Assistant {
|
||||||
|
content,
|
||||||
|
tool_calls,
|
||||||
|
..
|
||||||
|
} => Message::assistant(content.clone(), tool_calls.clone()),
|
||||||
|
_ => return Err(anyhow::anyhow!("Expected assistant message from LLM")),
|
||||||
|
};
|
||||||
|
|
||||||
// If the LLM wants to use tools, execute them
|
// If the LLM wants to use tools, execute them
|
||||||
if let Some(tool_calls) = &llm_message.tool_calls {
|
if let Message::Assistant {
|
||||||
|
tool_calls: Some(tool_calls),
|
||||||
|
..
|
||||||
|
} = &llm_message
|
||||||
|
{
|
||||||
let mut results = Vec::new();
|
let mut results = Vec::new();
|
||||||
|
|
||||||
// Execute each requested tool
|
// Execute each requested tool
|
||||||
|
@ -83,7 +94,10 @@ impl Agent {
|
||||||
if let Some(tool) = self.tools.get(&tool_call.function.name) {
|
if let Some(tool) = self.tools.get(&tool_call.function.name) {
|
||||||
let result = tool.execute(tool_call).await?;
|
let result = tool.execute(tool_call).await?;
|
||||||
// Create a message for the tool's response
|
// Create a message for the tool's response
|
||||||
results.push(Message::tool(result, tool_call.id.clone()));
|
results.push(Message::tool(
|
||||||
|
serde_json::to_string(&result).unwrap(),
|
||||||
|
tool_call.id.clone(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,12 +108,6 @@ impl Agent {
|
||||||
|
|
||||||
Box::pin(self.process_thread(&new_thread)).await
|
Box::pin(self.process_thread(&new_thread)).await
|
||||||
} else {
|
} else {
|
||||||
// If no tools were called, return the final response
|
|
||||||
message.content = if let Some(content) = &llm_message.content {
|
|
||||||
Some(content)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(message)
|
Ok(message)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -113,7 +121,7 @@ impl Agent {
|
||||||
/// * A Result containing a receiver for streamed messages
|
/// * A Result containing a receiver for streamed messages
|
||||||
pub async fn stream_process_thread(
|
pub async fn stream_process_thread(
|
||||||
&self,
|
&self,
|
||||||
thread: &Thread,
|
thread: &AgentThread,
|
||||||
) -> Result<mpsc::Receiver<Result<Message>>> {
|
) -> Result<mpsc::Receiver<Result<Message>>> {
|
||||||
// Collect all registered tools and their schemas
|
// Collect all registered tools and their schemas
|
||||||
let tools: Vec<Tool> = self
|
let tools: Vec<Tool> = self
|
||||||
|
@ -148,8 +156,14 @@ impl Agent {
|
||||||
let delta = &chunk.choices[0].delta;
|
let delta = &chunk.choices[0].delta;
|
||||||
|
|
||||||
// Update the message with the new content
|
// Update the message with the new content
|
||||||
if let Some(content) = &mut current_message.content {
|
if let Message::Assistant { content, .. } = &mut current_message {
|
||||||
content[0].text.push_str(&delta.content);
|
if let Some(new_content) = &delta.content {
|
||||||
|
*content = Some(if let Some(existing) = content {
|
||||||
|
format!("{}{}", existing, new_content)
|
||||||
|
} else {
|
||||||
|
new_content.clone()
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
let _ = tx.send(Ok(current_message.clone())).await;
|
let _ = tx.send(Ok(current_message.clone())).await;
|
||||||
|
|
||||||
|
@ -165,533 +179,3 @@ impl Agent {
|
||||||
Ok(rx)
|
Ok(rx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use crate::utils::clients::ai::litellm::ToolCall;
|
|
||||||
|
|
||||||
use super::*;
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use mockito;
|
|
||||||
use serde_json::json;
|
|
||||||
|
|
||||||
/// Mock weather tools for testing
|
|
||||||
struct CurrentWeatherTool;
|
|
||||||
struct ForecastTool;
|
|
||||||
struct WeatherAlertsTool;
|
|
||||||
struct AirQualityTool;
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl ToolExecutor for CurrentWeatherTool {
|
|
||||||
async fn execute(&self, tool_call: &ToolCall) -> Result<String> {
|
|
||||||
let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)?;
|
|
||||||
let city = args["city"].as_str().unwrap_or("unknown");
|
|
||||||
|
|
||||||
match city.to_lowercase().as_str() {
|
|
||||||
"salt lake city" => {
|
|
||||||
Ok("Current weather in Salt Lake City: 75°F, Sunny, Humidity: 45%".to_string())
|
|
||||||
}
|
|
||||||
"new york" => {
|
|
||||||
Ok("Current weather in New York: 68°F, Cloudy, Humidity: 72%".to_string())
|
|
||||||
}
|
|
||||||
_ => Ok(format!("Weather data not available for {}", city)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_schema(&self) -> serde_json::Value {
|
|
||||||
json!({
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"city": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The city to get current weather for"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["city"]
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl ToolExecutor for ForecastTool {
|
|
||||||
async fn execute(&self, tool_call: &ToolCall) -> Result<String> {
|
|
||||||
let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)?;
|
|
||||||
let city = args["city"].as_str().unwrap_or("unknown");
|
|
||||||
let days = args["days"].as_i64().unwrap_or(7);
|
|
||||||
|
|
||||||
if city.to_lowercase() == "salt lake city" {
|
|
||||||
Ok(format!("{}-day forecast for Salt Lake City:\nDay 1: 75°F, Sunny\nDay 2: 78°F, Partly Cloudy\nDay 3: 72°F, Sunny", days))
|
|
||||||
} else {
|
|
||||||
Ok(format!("{}-day forecast not available for {}", days, city))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_schema(&self) -> serde_json::Value {
|
|
||||||
json!({
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"city": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The city to get forecast for"
|
|
||||||
},
|
|
||||||
"days": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "Number of days to forecast (1-7)",
|
|
||||||
"minimum": 1,
|
|
||||||
"maximum": 7
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["city"]
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl ToolExecutor for WeatherAlertsTool {
|
|
||||||
async fn execute(&self, tool_call: &ToolCall) -> Result<String> {
|
|
||||||
let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)?;
|
|
||||||
let city = args["city"].as_str().unwrap_or("unknown");
|
|
||||||
|
|
||||||
match city.to_lowercase().as_str() {
|
|
||||||
"salt lake city" => Ok("No current weather alerts for Salt Lake City".to_string()),
|
|
||||||
"miami" => Ok("ALERT: Tropical Storm Warning for Miami area".to_string()),
|
|
||||||
_ => Ok(format!("Weather alerts not available for {}", city)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_schema(&self) -> serde_json::Value {
|
|
||||||
json!({
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"city": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The city to get weather alerts for"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["city"]
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl ToolExecutor for AirQualityTool {
|
|
||||||
async fn execute(&self, tool_call: &ToolCall) -> Result<String> {
|
|
||||||
let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)?;
|
|
||||||
let city = args["city"].as_str().unwrap_or("unknown");
|
|
||||||
|
|
||||||
match city.to_lowercase().as_str() {
|
|
||||||
"salt lake city" => Ok("Air Quality in Salt Lake City: Good (AQI: 42)".to_string()),
|
|
||||||
"beijing" => Ok("Air Quality in Beijing: Unhealthy (AQI: 152)".to_string()),
|
|
||||||
_ => Ok(format!("Air quality data not available for {}", city)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_schema(&self) -> serde_json::Value {
|
|
||||||
json!({
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"city": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The city to get air quality for"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["city"]
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Helper function to create a mock server with a specific response
|
|
||||||
async fn setup_mock_server(
|
|
||||||
response_body: &str,
|
|
||||||
content_type: Option<&str>,
|
|
||||||
) -> (mockito::ServerGuard, String) {
|
|
||||||
let mut server = mockito::Server::new_async().await;
|
|
||||||
let mock = server
|
|
||||||
.mock("POST", "/v1/chat/completions")
|
|
||||||
.with_status(200)
|
|
||||||
.with_header("content-type", content_type.unwrap_or("application/json"))
|
|
||||||
.with_body(response_body)
|
|
||||||
.create_async()
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let url = server.url();
|
|
||||||
|
|
||||||
(server, url)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Helper function to create an agent with weather tools
|
|
||||||
fn create_weather_agent(client: LiteLLMClient) -> Agent {
|
|
||||||
let mut tools: HashMap<String, Box<dyn ToolExecutor>> = HashMap::new();
|
|
||||||
tools.insert(
|
|
||||||
"get_current_weather".to_string(),
|
|
||||||
Box::new(CurrentWeatherTool),
|
|
||||||
);
|
|
||||||
tools.insert("get_forecast".to_string(), Box::new(ForecastTool));
|
|
||||||
tools.insert(
|
|
||||||
"get_weather_alerts".to_string(),
|
|
||||||
Box::new(WeatherAlertsTool),
|
|
||||||
);
|
|
||||||
tools.insert("get_air_quality".to_string(), Box::new(AirQualityTool));
|
|
||||||
|
|
||||||
Agent::new(client, "gpt-4".to_string(), tools)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Test 1: Direct text response without tool calls
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_direct_response() {
|
|
||||||
let response = r#"{
|
|
||||||
"id": "test-id",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": 1234567890,
|
|
||||||
"model": "gpt-4",
|
|
||||||
"system_fingerprint": "fp_44709d6fcb",
|
|
||||||
"choices": [{
|
|
||||||
"index": 0,
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [{
|
|
||||||
"text": "The weather service is currently available for Salt Lake City, New York, Miami, and Beijing.",
|
|
||||||
"type": "text"
|
|
||||||
}]
|
|
||||||
},
|
|
||||||
"logprobs": null,
|
|
||||||
"finish_reason": "stop"
|
|
||||||
}],
|
|
||||||
"service_tier": "default",
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 10,
|
|
||||||
"completion_tokens": 20,
|
|
||||||
"total_tokens": 30,
|
|
||||||
"completion_tokens_details": {
|
|
||||||
"reasoning_tokens": 0,
|
|
||||||
"accepted_prediction_tokens": 0,
|
|
||||||
"rejected_prediction_tokens": 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}"#;
|
|
||||||
|
|
||||||
let (server, url) = setup_mock_server(response, None).await;
|
|
||||||
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(url));
|
|
||||||
let agent = create_weather_agent(client);
|
|
||||||
|
|
||||||
let thread = Thread {
|
|
||||||
id: "test-thread".to_string(),
|
|
||||||
messages: vec![Message::user("What cities can I get weather for?".to_string())],
|
|
||||||
};
|
|
||||||
|
|
||||||
let result = match agent.process_thread(&thread).await {
|
|
||||||
Ok(result) => result,
|
|
||||||
Err(e) => {
|
|
||||||
println!("Error processing thread: {:?}", e);
|
|
||||||
println!("Error chain:");
|
|
||||||
let mut source = e.source();
|
|
||||||
while let Some(e) = source {
|
|
||||||
println!(" Caused by: {}", e);
|
|
||||||
source = e.source();
|
|
||||||
}
|
|
||||||
panic!("Test failed due to error");
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let content = result.content.unwrap();
|
|
||||||
assert!(content[0].text.contains("Salt Lake City"));
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Test 2: Single tool call then response
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_single_tool_call() {
|
|
||||||
let first_response = r#"{
|
|
||||||
"id": "test-id",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": 1234567890,
|
|
||||||
"model": "gpt-4",
|
|
||||||
"choices": [{
|
|
||||||
"index": 0,
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [],
|
|
||||||
"tool_calls": [{
|
|
||||||
"id": "call_123",
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "get_current_weather",
|
|
||||||
"arguments": "{\"city\":\"Salt Lake City\"}",
|
|
||||||
"call_type": "function"
|
|
||||||
}
|
|
||||||
}]
|
|
||||||
},
|
|
||||||
"finish_reason": "tool_calls"
|
|
||||||
}],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 10,
|
|
||||||
"completion_tokens": 20,
|
|
||||||
"total_tokens": 30
|
|
||||||
}
|
|
||||||
}"#;
|
|
||||||
|
|
||||||
let second_response = r#"{
|
|
||||||
"id": "test-id",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": 1234567890,
|
|
||||||
"model": "gpt-4",
|
|
||||||
"choices": [{
|
|
||||||
"index": 0,
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [{
|
|
||||||
"text": "The current weather in Salt Lake City is 75°F, Sunny, with 45% humidity.",
|
|
||||||
"type": "text"
|
|
||||||
}]
|
|
||||||
},
|
|
||||||
"finish_reason": "stop"
|
|
||||||
}],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 15,
|
|
||||||
"completion_tokens": 25,
|
|
||||||
"total_tokens": 40
|
|
||||||
}
|
|
||||||
}"#;
|
|
||||||
|
|
||||||
let mut server = mockito::Server::new_async().await;
|
|
||||||
let mock1 = server
|
|
||||||
.mock("POST", "/v1/chat/completions")
|
|
||||||
.with_status(200)
|
|
||||||
.with_header("content-type", "application/json")
|
|
||||||
.with_body(first_response)
|
|
||||||
.expect(1)
|
|
||||||
.create_async()
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let mock2 = server
|
|
||||||
.mock("POST", "/v1/chat/completions")
|
|
||||||
.with_status(200)
|
|
||||||
.with_header("content-type", "application/json")
|
|
||||||
.with_body(second_response)
|
|
||||||
.expect(1)
|
|
||||||
.create_async()
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url()));
|
|
||||||
let agent = create_weather_agent(client);
|
|
||||||
|
|
||||||
let thread = Thread {
|
|
||||||
id: "test-thread".to_string(),
|
|
||||||
messages: vec![Message::user(
|
|
||||||
"What's the current weather in Salt Lake City?".to_string(),
|
|
||||||
)],
|
|
||||||
};
|
|
||||||
|
|
||||||
let result = agent.process_thread(&thread).await.unwrap();
|
|
||||||
let content = result.content.unwrap();
|
|
||||||
assert!(content[0].text.contains("75°F"));
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Test 3: Multiple tool calls then response
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_multiple_tool_calls() {
|
|
||||||
let first_response = r#"{
|
|
||||||
"id": "test-id",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": 1234567890,
|
|
||||||
"model": "gpt-4",
|
|
||||||
"choices": [{
|
|
||||||
"index": 0,
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [],
|
|
||||||
"tool_calls": [
|
|
||||||
{
|
|
||||||
"id": "call_123",
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "get_current_weather",
|
|
||||||
"arguments": "{\"city\":\"Salt Lake City\"}",
|
|
||||||
"call_type": "function"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "call_124",
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "get_air_quality",
|
|
||||||
"arguments": "{\"city\":\"Salt Lake City\"}",
|
|
||||||
"call_type": "function"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"finish_reason": "tool_calls"
|
|
||||||
}],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 12,
|
|
||||||
"completion_tokens": 22,
|
|
||||||
"total_tokens": 34
|
|
||||||
}
|
|
||||||
}"#;
|
|
||||||
|
|
||||||
let second_response = r#"{
|
|
||||||
"id": "test-id",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": 1234567890,
|
|
||||||
"model": "gpt-4",
|
|
||||||
"choices": [{
|
|
||||||
"index": 0,
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [{
|
|
||||||
"text": "In Salt Lake City, it's currently 75°F and sunny with good air quality (AQI: 42).",
|
|
||||||
"type": "text"
|
|
||||||
}]
|
|
||||||
},
|
|
||||||
"finish_reason": "stop"
|
|
||||||
}],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 18,
|
|
||||||
"completion_tokens": 28,
|
|
||||||
"total_tokens": 46
|
|
||||||
}
|
|
||||||
}"#;
|
|
||||||
|
|
||||||
let mut server = mockito::Server::new_async().await;
|
|
||||||
let mock1 = server
|
|
||||||
.mock("POST", "/v1/chat/completions")
|
|
||||||
.with_status(200)
|
|
||||||
.with_header("content-type", "application/json")
|
|
||||||
.with_body(first_response)
|
|
||||||
.expect(1)
|
|
||||||
.create_async()
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let mock2 = server
|
|
||||||
.mock("POST", "/v1/chat/completions")
|
|
||||||
.with_status(200)
|
|
||||||
.with_header("content-type", "application/json")
|
|
||||||
.with_body(second_response)
|
|
||||||
.expect(1)
|
|
||||||
.create_async()
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url()));
|
|
||||||
let agent = create_weather_agent(client);
|
|
||||||
|
|
||||||
let thread = Thread {
|
|
||||||
id: "test-thread".to_string(),
|
|
||||||
messages: vec![Message::user(
|
|
||||||
"What's the weather and air quality in Salt Lake City?".to_string(),
|
|
||||||
)],
|
|
||||||
};
|
|
||||||
|
|
||||||
let result = agent.process_thread(&thread).await.unwrap();
|
|
||||||
let content = result.content.unwrap();
|
|
||||||
assert!(content[0].text.contains("75°F") && content[0].text.contains("AQI"));
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Test 4: Streaming response
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_streaming_response() {
|
|
||||||
let stream_responses = vec![
|
|
||||||
serde_json::json!({
|
|
||||||
"id": "test-id",
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": 1234567890,
|
|
||||||
"model": "gpt-4",
|
|
||||||
"choices": [{
|
|
||||||
"index": 0,
|
|
||||||
"delta": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [{
|
|
||||||
"text": "The current",
|
|
||||||
"type": "text"
|
|
||||||
}]
|
|
||||||
},
|
|
||||||
"finish_reason": null
|
|
||||||
}]
|
|
||||||
}),
|
|
||||||
serde_json::json!({
|
|
||||||
"id": "test-id",
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": 1234567890,
|
|
||||||
"model": "gpt-4",
|
|
||||||
"choices": [{
|
|
||||||
"index": 0,
|
|
||||||
"delta": {
|
|
||||||
"content": [{
|
|
||||||
"text": " weather",
|
|
||||||
"type": "text"
|
|
||||||
}]
|
|
||||||
},
|
|
||||||
"finish_reason": null
|
|
||||||
}]
|
|
||||||
}),
|
|
||||||
serde_json::json!({
|
|
||||||
"id": "test-id",
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": 1234567890,
|
|
||||||
"model": "gpt-4",
|
|
||||||
"choices": [{
|
|
||||||
"index": 0,
|
|
||||||
"delta": {
|
|
||||||
"content": [{
|
|
||||||
"text": " in Salt Lake City",
|
|
||||||
"type": "text"
|
|
||||||
}]
|
|
||||||
},
|
|
||||||
"finish_reason": null
|
|
||||||
}]
|
|
||||||
}),
|
|
||||||
serde_json::json!({
|
|
||||||
"id": "test-id",
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": 1234567890,
|
|
||||||
"model": "gpt-4",
|
|
||||||
"choices": [{
|
|
||||||
"index": 0,
|
|
||||||
"delta": {
|
|
||||||
"content": [{
|
|
||||||
"text": " is sunny",
|
|
||||||
"type": "text"
|
|
||||||
}]
|
|
||||||
},
|
|
||||||
"finish_reason": "stop"
|
|
||||||
}]
|
|
||||||
}),
|
|
||||||
];
|
|
||||||
|
|
||||||
let stream_body = stream_responses
|
|
||||||
.iter()
|
|
||||||
.map(|r| format!("data: {}\n\n", r.to_string()))
|
|
||||||
.collect::<String>()
|
|
||||||
+ "data: [DONE]\n\n";
|
|
||||||
|
|
||||||
let mut server = mockito::Server::new_async().await;
|
|
||||||
let mock = server
|
|
||||||
.mock("POST", "/v1/chat/completions")
|
|
||||||
.with_status(200)
|
|
||||||
.with_header("content-type", "text/event-stream")
|
|
||||||
.with_body(stream_body)
|
|
||||||
.create_async()
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url()));
|
|
||||||
let agent = create_weather_agent(client);
|
|
||||||
|
|
||||||
let thread = Thread {
|
|
||||||
id: "test-thread".to_string(),
|
|
||||||
messages: vec![Message::user(
|
|
||||||
"What's the current weather in Salt Lake City?".to_string(),
|
|
||||||
)],
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut stream = agent.stream_process_thread(&thread).await.unwrap();
|
|
||||||
let mut received_content = String::new();
|
|
||||||
|
|
||||||
while let Some(message_result) = stream.recv().await {
|
|
||||||
let message = message_result.unwrap();
|
|
||||||
received_content = message.content.unwrap()[0].text.clone();
|
|
||||||
}
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
received_content,
|
|
||||||
"The current weather in Salt Lake City is sunny"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::utils::clients::ai::litellm::{Message, ToolCall};
|
use crate::utils::clients::ai::litellm::{Message, ToolCall};
|
||||||
|
|
||||||
/// A Thread represents a conversation between a user and the AI agent.
|
/// A Thread represents a conversation between a user and the AI agent.
|
||||||
/// It contains a sequence of messages in chronological order.
|
/// It contains a sequence of messages in chronological order.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct Thread {
|
pub struct AgentThread {
|
||||||
/// Unique identifier for the thread
|
/// Unique identifier for the thread
|
||||||
pub id: String,
|
pub id: String,
|
||||||
/// Ordered sequence of messages in the conversation
|
/// Ordered sequence of messages in the conversation
|
||||||
|
@ -19,7 +20,7 @@ pub struct Thread {
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait ToolExecutor: Send + Sync {
|
pub trait ToolExecutor: Send + Sync {
|
||||||
/// Execute the tool with given arguments and return a result
|
/// Execute the tool with given arguments and return a result
|
||||||
async fn execute(&self, tool_call: &ToolCall) -> Result<String>;
|
async fn execute(&self, tool_call: &ToolCall) -> Result<Value>;
|
||||||
|
|
||||||
/// Return the JSON schema that describes this tool's interface
|
/// Return the JSON schema that describes this tool's interface
|
||||||
fn get_schema(&self) -> serde_json::Value;
|
fn get_schema(&self) -> serde_json::Value;
|
||||||
|
|
Loading…
Reference in New Issue