This commit is contained in:
dal 2025-05-05 09:54:40 -06:00
parent 678152ba2b
commit e8cbbf088c
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
2 changed files with 36 additions and 612 deletions

View File

@ -14,7 +14,7 @@ static DEBUG_ENABLED: Lazy<bool> = Lazy::new(|| {
.unwrap_or(false)
});
#[derive(Clone, Debug)]
#[derive(Clone)]
pub struct LiteLLMClient {
client: Client,
pub(crate) base_url: String,
@ -28,9 +28,10 @@ impl LiteLLMClient {
}
}
pub fn new(api_key: Option<String>, base_url: Option<String>) -> Result<Self> {
let api_key = api_key.or_else(|| env::var("LLM_API_KEY").ok())
.ok_or_else(|| anyhow::anyhow!("LLM_API_KEY must be provided either through parameter or environment variable"))?;
pub fn new(api_key: Option<String>, base_url: Option<String>) -> Self {
let api_key = api_key.or_else(|| env::var("LLM_API_KEY").ok()).expect(
"LLM_API_KEY must be provided either through parameter or environment variable",
);
let base_url = base_url
.or_else(|| env::var("LLM_BASE_URL").ok())
@ -39,8 +40,7 @@ impl LiteLLMClient {
let mut headers = header::HeaderMap::new();
headers.insert(
"Authorization",
header::HeaderValue::from_str(&format!("Bearer {}", api_key))
.map_err(|e| anyhow::anyhow!("Invalid API key format: {}", e))?,
header::HeaderValue::from_str(&format!("Bearer {}", api_key)).unwrap(),
);
headers.insert(
"Content-Type",
@ -54,12 +54,12 @@ impl LiteLLMClient {
let client = Client::builder()
.default_headers(headers)
.build()
.map_err(|e| anyhow::anyhow!("Failed to create HTTP client: {}", e))?;
.expect("Failed to create HTTP client");
Ok(Self {
Self {
client,
base_url,
})
}
}
pub async fn chat_completion(
@ -295,15 +295,7 @@ impl LiteLLMClient {
impl Default for LiteLLMClient {
fn default() -> Self {
match Self::new(None, None) {
Ok(client) => client,
Err(e) => {
if *DEBUG_ENABLED {
eprintln!("ERROR: Failed to create default LiteLLMClient: {}", e);
}
panic!("Failed to create default LiteLLMClient: {}", e);
}
}
Self::new(None, None)
}
}
@ -314,94 +306,20 @@ mod tests {
use std::env;
use std::time::Duration;
use tokio::time::timeout;
#[tokio::test]
async fn test_client_initialization_without_api_key() {
// Clear environment variable first to ensure test consistency
env::remove_var("LLM_API_KEY");
// Should return an error about missing API key
let result = LiteLLMClient::new(None, None);
assert!(result.is_err());
// Verify the error message contains information about the missing API key
let error = result.unwrap_err();
let error_message = error.to_string();
assert!(error_message.contains("LLM_API_KEY must be provided"));
}
#[tokio::test]
async fn test_client_initialization_with_explicit_values() {
let api_key = "test-api-key";
let base_url = "https://test-url.com";
let client = LiteLLMClient::new(Some(api_key.to_string()), Some(base_url.to_string())).unwrap();
assert_eq!(client.base_url, base_url);
// We can't directly test the API key as it's stored in the headers
// but we can test that the client was created successfully
}
#[tokio::test]
async fn test_client_default_base_url() {
// Clear environment variable first
env::remove_var("LLM_BASE_URL");
let api_key = "test-api-key";
let expected_default_url = "http://localhost:8000";
let client = LiteLLMClient::new(Some(api_key.to_string()), None).unwrap();
assert_eq!(client.base_url, expected_default_url);
}
#[tokio::test]
async fn test_client_default_constructor() {
// Set environment variables to test default constructor
env::set_var("LLM_API_KEY", "env-api-key");
env::set_var("LLM_BASE_URL", "http://env-url.com");
let client = LiteLLMClient::default();
assert_eq!(client.base_url, "http://env-url.com");
// Clean up
env::remove_var("LLM_API_KEY");
env::remove_var("LLM_BASE_URL");
}
#[tokio::test]
async fn test_headers_configuration() {
let mut server = mockito::Server::new_async().await;
// Mock the server to check headers
let mock = server
.mock("POST", "/chat/completions")
.match_header("Authorization", "Bearer test-key")
.match_header("Content-Type", "application/json")
.match_header("Accept", "application/json")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(r#"{"id":"test-id","object":"chat.completion","created":1,"model":"test","choices":[],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}"#)
.create();
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
// A minimal request to trigger the API call
let request = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![AgentMessage::user("test")],
..Default::default()
};
// Make the request to verify headers
let _ = client.chat_completion(request).await;
// Verify the headers were sent correctly
mock.assert();
}
// Removed unused setup function
use dotenv::dotenv;
// Helper function to initialize environment before tests
async fn setup() -> (String, String) {
// Load environment variables from .env file
dotenv().ok();
// Get API key and base URL from environment
let api_key = env::var("LLM_API_KEY").expect("LLM_API_KEY must be set");
let base_url = env::var("LLM_BASE_URL").expect("LLM_API_BASE must be set");
(api_key, base_url)
}
fn create_test_message() -> AgentMessage {
AgentMessage::user("Hello".to_string())
@ -459,7 +377,7 @@ mod tests {
)
.create();
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url()));
let response = client.chat_completion(request).await.unwrap();
assert_eq!(response.id, "test-id");
@ -489,7 +407,7 @@ mod tests {
.with_body(r#"{"error": "Invalid request"}"#)
.create();
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url()));
let result = client.chat_completion(request).await;
assert!(result.is_err());
@ -518,7 +436,7 @@ mod tests {
)
.create();
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url()));
let mut stream = client.stream_chat_completion(request).await.unwrap();
@ -594,7 +512,7 @@ mod tests {
)
.create();
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url()));
let response = client.chat_completion(request).await.unwrap();
assert_eq!(response.id, "test-id");
@ -628,10 +546,7 @@ mod tests {
env::set_var("LLM_BASE_URL", test_base_url);
// Test with no parameters (should use env vars)
// This would fail without an API key in env, so we set a temporary one
// We use it to test the constructor fallback to environment variables
env::set_var("LLM_API_KEY", "test-env-key");
let client = LiteLLMClient::new(None, None).unwrap();
let client = LiteLLMClient::new(None, None);
assert_eq!(client.base_url, test_base_url);
// Test with parameters (should override env vars)
@ -640,7 +555,7 @@ mod tests {
let client = LiteLLMClient::new(
Some(override_key.to_string()),
Some(override_url.to_string()),
).unwrap();
);
assert_eq!(client.base_url, override_url);
env::remove_var("LLM_API_KEY");
@ -649,39 +564,8 @@ mod tests {
#[tokio::test]
async fn test_single_message_completion() {
let mut server = mockito::Server::new_async().await;
// Mock the response
let mock = server
.mock("POST", "/chat/completions")
.match_header("content-type", "application/json")
.match_header("authorization", "Bearer test-key")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(r#"{
"id": "test-id",
"object": "chat.completion",
"created": 1234567890,
"model": "o1",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello, world! How can I assist you today?"
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 10,
"total_tokens": 20
}
}"#)
.create();
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
let (api_key, base_url) = setup().await;
let client = LiteLLMClient::new(Some(api_key), Some(base_url));
let request = ChatCompletionRequest {
model: "o1".to_string(),
@ -695,469 +579,5 @@ mod tests {
};
assert!(response.choices.len() > 0);
mock.assert();
}
#[tokio::test]
async fn test_generate_embeddings_success() {
let mut server = mockito::Server::new_async().await;
// Create a test embedding request
let request = EmbeddingRequest {
model: "text-embedding-3-small".to_string(),
input: vec!["This is a test sentence.".to_string()],
encoding_format: None,
dimensions: None,
user: None,
};
let request_body = serde_json::to_string(&request).unwrap();
// Expected response
let response_body = r#"{
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": [0.1, 0.2, 0.3]
}
],
"model": "text-embedding-3-small",
"usage": {
"prompt_tokens": 5,
"total_tokens": 5
}
}"#;
// Set up mock
let mock = server
.mock("POST", "/embeddings")
.match_header("content-type", "application/json")
.match_header("authorization", "Bearer test-key")
.match_body(mockito::Matcher::JsonString(request_body))
.with_status(200)
.with_header("content-type", "application/json")
.with_body(response_body)
.create();
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
// Make the request
let response = client.generate_embeddings(request).await.unwrap();
// Verify response
assert_eq!(response.data.len(), 1);
assert_eq!(response.data[0].index, 0);
assert_eq!(response.data[0].embedding, vec![0.1, 0.2, 0.3]);
assert_eq!(response.model, "text-embedding-3-small");
assert_eq!(response.usage.prompt_tokens, 5);
assert_eq!(response.usage.total_tokens, 5);
mock.assert();
}
#[tokio::test]
async fn test_generate_embeddings_with_parameters() {
let mut server = mockito::Server::new_async().await;
// Create a test embedding request with all parameters
let request = EmbeddingRequest {
model: "text-embedding-3-large".to_string(),
input: vec!["First sentence.".to_string(), "Second sentence.".to_string()],
encoding_format: Some("float".to_string()),
dimensions: Some(256),
user: Some("test-user".to_string()),
};
let request_body = serde_json::to_string(&request).unwrap();
// Expected response for multiple inputs
let response_body = r#"{
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": [0.1, 0.2, 0.3]
},
{
"object": "embedding",
"index": 1,
"embedding": [0.4, 0.5, 0.6]
}
],
"model": "text-embedding-3-large",
"usage": {
"prompt_tokens": 10,
"total_tokens": 10
}
}"#;
// Set up mock
let mock = server
.mock("POST", "/embeddings")
.match_header("content-type", "application/json")
.match_header("authorization", "Bearer test-key")
.match_body(mockito::Matcher::JsonString(request_body))
.with_status(200)
.with_header("content-type", "application/json")
.with_body(response_body)
.create();
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
// Make the request
let response = client.generate_embeddings(request).await.unwrap();
// Verify response
assert_eq!(response.data.len(), 2);
assert_eq!(response.data[0].index, 0);
assert_eq!(response.data[0].embedding, vec![0.1, 0.2, 0.3]);
assert_eq!(response.data[1].index, 1);
assert_eq!(response.data[1].embedding, vec![0.4, 0.5, 0.6]);
assert_eq!(response.model, "text-embedding-3-large");
assert_eq!(response.usage.prompt_tokens, 10);
assert_eq!(response.usage.total_tokens, 10);
mock.assert();
}
#[tokio::test]
async fn test_generate_embeddings_api_error() {
let mut server = mockito::Server::new_async().await;
// Create a test embedding request
let request = EmbeddingRequest {
model: "invalid-model".to_string(),
input: vec!["Test".to_string()],
encoding_format: None,
dimensions: None,
user: None,
};
let request_body = serde_json::to_string(&request).unwrap();
// Error response
let error_response = r#"{
"error": {
"code": "model_not_found",
"message": "The model 'invalid-model' does not exist"
}
}"#;
// Set up mock
let mock = server
.mock("POST", "/embeddings")
.match_header("content-type", "application/json")
.match_header("authorization", "Bearer test-key")
.match_body(mockito::Matcher::JsonString(request_body))
.with_status(404)
.with_header("content-type", "application/json")
.with_body(error_response)
.create();
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
// Make the request and expect an error
let result = client.generate_embeddings(request).await;
assert!(result.is_err());
// Verify the error contains useful information
let error = result.unwrap_err();
let error_string = error.to_string();
assert!(error_string.contains("404"));
mock.assert();
}
#[tokio::test]
async fn test_network_error_handling() {
// Create a server URL that will never connect
let server_url = "http://localhost:1"; // Using port 1 to ensure no service is running
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server_url.to_string())).unwrap();
let request = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![AgentMessage::user("Test")],
..Default::default()
};
// Attempt request which should fail with network error
let result = client.chat_completion(request).await;
// Verify we got an error (don't check the specifics as they vary by environment)
assert!(result.is_err());
// Print the error for debugging
let error = result.unwrap_err();
println!("Network error: {:?}", error);
}
#[tokio::test]
#[ignore] // This test actually makes a real HTTP request to an invalid port, so it's slow
async fn test_timeout_simulation() {
// Use a blackhole address that will cause a connection timeout
let server_url = "http://198.51.100.1:8000"; // TEST-NET-2 reserved for documentation
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server_url.to_string())).unwrap();
let request = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![AgentMessage::user("Test")],
..Default::default()
};
// Wrap the client call with a short timeout
let result = tokio::time::timeout(
Duration::from_millis(100), // Short timeout
client.chat_completion(request)
).await;
// We expect the timeout to be hit
assert!(result.is_err());
assert!(matches!(result, Err(tokio::time::error::Elapsed { .. })));
}
#[tokio::test]
async fn test_rate_limit_error_handling() {
let mut server = mockito::Server::new_async().await;
// Set up a mock that simulates a rate limit error
let mock = server
.mock("POST", "/chat/completions")
.with_status(429) // Too many requests
.with_header("content-type", "application/json")
.with_header("retry-after", "30") // Suggest retry after 30 seconds
.with_body(r#"{"error":{"message":"Rate limit exceeded","type":"rate_limit_error"}}"#)
.create();
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
let request = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![AgentMessage::user("Test")],
..Default::default()
};
// Attempt request which should fail with rate limit error
let result = client.chat_completion(request).await;
// Verify we got an error (don't check the specifics as they vary by environment)
assert!(result.is_err());
// Print the error for debugging
let error = result.unwrap_err();
println!("Rate limit error: {:?}", error);
mock.assert();
}
#[tokio::test]
async fn test_invalid_json_response_handling() {
let mut server = mockito::Server::new_async().await;
// Set up a mock that returns invalid JSON
let mock = server
.mock("POST", "/chat/completions")
.with_status(200)
.with_header("content-type", "application/json")
.with_body("This is not valid JSON")
.create();
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
let request = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![AgentMessage::user("Test")],
..Default::default()
};
// Attempt request which should fail with parsing error
let result = client.chat_completion(request).await;
// Verify we got an error (don't check the specifics as they vary by environment)
assert!(result.is_err());
// Print the error for debugging
let error = result.unwrap_err();
println!("Invalid JSON error: {:?}", error);
mock.assert();
}
#[tokio::test]
async fn test_unexpected_server_error_handling() {
let mut server = mockito::Server::new_async().await;
// Set up a mock that returns a 500 server error
let mock = server
.mock("POST", "/chat/completions")
.with_status(500)
.with_header("content-type", "application/json")
.with_body(r#"{"error":{"message":"Internal server error","type":"server_error"}}"#)
.create();
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
let request = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![AgentMessage::user("Test")],
..Default::default()
};
// Attempt request which should fail with server error
let result = client.chat_completion(request).await;
// Verify we got an error (don't check the specifics as they vary by environment)
assert!(result.is_err());
// Print the error for debugging
let error = result.unwrap_err();
println!("Server error: {:?}", error);
mock.assert();
}
#[tokio::test]
async fn test_stream_with_partial_chunks() {
let mut server = mockito::Server::new_async().await;
// Create a request for streaming
let request = ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![
AgentMessage::user("Tell me a short story"),
],
stream: Some(true),
..Default::default()
};
// Mock response with multiple chunks building a response
let stream_response = "data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"system_fingerprint\":\"fp_1\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\"},\"finish_reason\":null}]}\n\n\
data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Once\"},\"finish_reason\":null}]}\n\n\
data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" upon\"},\"finish_reason\":null}]}\n\n\
data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" a\"},\"finish_reason\":null}]}\n\n\
data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" time\"},\"finish_reason\":null}]}\n\n\
data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}]}\n\n\
data: [DONE]\n\n";
// Set up mock
let mock = server
.mock("POST", "/chat/completions")
.match_header("content-type", "application/json")
.match_header("authorization", "Bearer test-key")
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body(stream_response)
.create();
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
// Make the request
let mut stream = client.stream_chat_completion(request).await.unwrap();
// Collect all chunks
let mut chunks = Vec::new();
while let Some(chunk_result) = stream.recv().await {
match chunk_result {
Ok(chunk) => chunks.push(chunk),
Err(e) => panic!("Error in stream: {:?}", e),
}
}
// Verify chunks
assert_eq!(chunks.len(), 6); // 5 content chunks + 1 finish chunk
// First chunk should have role
assert_eq!(chunks[0].choices[0].delta.role, Some("assistant".to_string()));
// Content chunks should build the story
assert_eq!(chunks[1].choices[0].delta.content, Some("Once".to_string()));
assert_eq!(chunks[2].choices[0].delta.content, Some(" upon".to_string()));
assert_eq!(chunks[3].choices[0].delta.content, Some(" a".to_string()));
assert_eq!(chunks[4].choices[0].delta.content, Some(" time".to_string()));
// Last chunk should have finish reason
assert_eq!(chunks[5].choices[0].finish_reason, Some("stop".to_string()));
assert!(chunks[5].choices[0].delta.content.is_none());
mock.assert();
}
#[tokio::test]
async fn test_stream_with_tool_calls() {
let mut server = mockito::Server::new_async().await;
// Create a request for streaming with tool calls
let request = ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![
AgentMessage::user("What's the weather in New York?"),
],
tools: Some(vec![Tool {
tool_type: "function".to_string(),
function: serde_json::json!({
"name": "get_weather",
"description": "Get weather information",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string"}
},
"required": ["location"]
}
}),
}]),
stream: Some(true),
..Default::default()
};
// Mock response with tool calls in chunks
let stream_response = "data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\"},\"finish_reason\":null}]}\n\n\
data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_123\",\"type\":\"function\"}]},\"finish_reason\":null}]}\n\n\
data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"name\":\"get_weather\"}}]},\"finish_reason\":null}]}\n\n\
data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"{\\\"\"}}]},\"finish_reason\":null}]}\n\n\
data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"location\\\":\\\"\"}}]},\"finish_reason\":null}]}\n\n\
data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"New York\\\"\"}}]},\"finish_reason\":null}]}\n\n\
data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"}\"}}]},\"finish_reason\":null}]}\n\n\
data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"tool_calls\"}]}\n\n\
data: [DONE]\n\n";
// Set up mock
let mock = server
.mock("POST", "/chat/completions")
.match_header("content-type", "application/json")
.match_header("authorization", "Bearer test-key")
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body(stream_response)
.create();
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
// Make the request
let mut stream = client.stream_chat_completion(request).await.unwrap();
// Collect all chunks
let mut chunks = Vec::new();
while let Some(chunk_result) = stream.recv().await {
match chunk_result {
Ok(chunk) => chunks.push(chunk),
Err(e) => panic!("Error in stream: {:?}", e),
}
}
// Basic validation (number of chunks)
assert!(chunks.len() > 0);
// First chunk should have role
assert_eq!(chunks[0].choices[0].delta.role, Some("assistant".to_string()));
// Last chunk should have finish reason
assert_eq!(chunks.last().unwrap().choices[0].finish_reason, Some("tool_calls".to_string()));
mock.assert();
}
}

View File

@ -172,6 +172,7 @@ impl AgentMessage {
}
pub fn assistant(
id: Option<String>,
content: Option<String>,
tool_calls: Option<Vec<ToolCall>>,
progress: MessageProgress,
@ -181,7 +182,7 @@ impl AgentMessage {
let initial = initial.unwrap_or(false);
Self::Assistant {
id: None,
id,
content,
name,
tool_calls,
@ -557,6 +558,7 @@ mod tests {
message: AgentMessage::assistant(
Some("\n\nHello there, how may I assist you today?".to_string()),
None,
None,
MessageProgress::Complete,
None,
None,
@ -706,6 +708,7 @@ mod tests {
message: AgentMessage::assistant(
Some("".to_string()),
None,
None,
MessageProgress::Complete,
None,
None,
@ -988,6 +991,7 @@ mod tests {
choices: vec![Choice {
index: 0,
message: AgentMessage::assistant(
None,
None,
Some(vec![ToolCall {
id: "call_abc123".to_string(),