mirror of https://github.com/buster-so/buster.git
revert
This commit is contained in:
parent
678152ba2b
commit
e8cbbf088c
|
@ -14,7 +14,7 @@ static DEBUG_ENABLED: Lazy<bool> = Lazy::new(|| {
|
|||
.unwrap_or(false)
|
||||
});
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone)]
|
||||
pub struct LiteLLMClient {
|
||||
client: Client,
|
||||
pub(crate) base_url: String,
|
||||
|
@ -28,9 +28,10 @@ impl LiteLLMClient {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn new(api_key: Option<String>, base_url: Option<String>) -> Result<Self> {
|
||||
let api_key = api_key.or_else(|| env::var("LLM_API_KEY").ok())
|
||||
.ok_or_else(|| anyhow::anyhow!("LLM_API_KEY must be provided either through parameter or environment variable"))?;
|
||||
pub fn new(api_key: Option<String>, base_url: Option<String>) -> Self {
|
||||
let api_key = api_key.or_else(|| env::var("LLM_API_KEY").ok()).expect(
|
||||
"LLM_API_KEY must be provided either through parameter or environment variable",
|
||||
);
|
||||
|
||||
let base_url = base_url
|
||||
.or_else(|| env::var("LLM_BASE_URL").ok())
|
||||
|
@ -39,8 +40,7 @@ impl LiteLLMClient {
|
|||
let mut headers = header::HeaderMap::new();
|
||||
headers.insert(
|
||||
"Authorization",
|
||||
header::HeaderValue::from_str(&format!("Bearer {}", api_key))
|
||||
.map_err(|e| anyhow::anyhow!("Invalid API key format: {}", e))?,
|
||||
header::HeaderValue::from_str(&format!("Bearer {}", api_key)).unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"Content-Type",
|
||||
|
@ -54,12 +54,12 @@ impl LiteLLMClient {
|
|||
let client = Client::builder()
|
||||
.default_headers(headers)
|
||||
.build()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to create HTTP client: {}", e))?;
|
||||
.expect("Failed to create HTTP client");
|
||||
|
||||
Ok(Self {
|
||||
Self {
|
||||
client,
|
||||
base_url,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn chat_completion(
|
||||
|
@ -295,15 +295,7 @@ impl LiteLLMClient {
|
|||
|
||||
impl Default for LiteLLMClient {
|
||||
fn default() -> Self {
|
||||
match Self::new(None, None) {
|
||||
Ok(client) => client,
|
||||
Err(e) => {
|
||||
if *DEBUG_ENABLED {
|
||||
eprintln!("ERROR: Failed to create default LiteLLMClient: {}", e);
|
||||
}
|
||||
panic!("Failed to create default LiteLLMClient: {}", e);
|
||||
}
|
||||
}
|
||||
Self::new(None, None)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -314,94 +306,20 @@ mod tests {
|
|||
use std::env;
|
||||
use std::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_client_initialization_without_api_key() {
|
||||
// Clear environment variable first to ensure test consistency
|
||||
env::remove_var("LLM_API_KEY");
|
||||
|
||||
// Should return an error about missing API key
|
||||
let result = LiteLLMClient::new(None, None);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Verify the error message contains information about the missing API key
|
||||
let error = result.unwrap_err();
|
||||
let error_message = error.to_string();
|
||||
assert!(error_message.contains("LLM_API_KEY must be provided"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_client_initialization_with_explicit_values() {
|
||||
let api_key = "test-api-key";
|
||||
let base_url = "https://test-url.com";
|
||||
|
||||
let client = LiteLLMClient::new(Some(api_key.to_string()), Some(base_url.to_string())).unwrap();
|
||||
|
||||
assert_eq!(client.base_url, base_url);
|
||||
// We can't directly test the API key as it's stored in the headers
|
||||
// but we can test that the client was created successfully
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_client_default_base_url() {
|
||||
// Clear environment variable first
|
||||
env::remove_var("LLM_BASE_URL");
|
||||
|
||||
let api_key = "test-api-key";
|
||||
let expected_default_url = "http://localhost:8000";
|
||||
|
||||
let client = LiteLLMClient::new(Some(api_key.to_string()), None).unwrap();
|
||||
|
||||
assert_eq!(client.base_url, expected_default_url);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_client_default_constructor() {
|
||||
// Set environment variables to test default constructor
|
||||
env::set_var("LLM_API_KEY", "env-api-key");
|
||||
env::set_var("LLM_BASE_URL", "http://env-url.com");
|
||||
|
||||
let client = LiteLLMClient::default();
|
||||
|
||||
assert_eq!(client.base_url, "http://env-url.com");
|
||||
|
||||
// Clean up
|
||||
env::remove_var("LLM_API_KEY");
|
||||
env::remove_var("LLM_BASE_URL");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_headers_configuration() {
|
||||
let mut server = mockito::Server::new_async().await;
|
||||
|
||||
// Mock the server to check headers
|
||||
let mock = server
|
||||
.mock("POST", "/chat/completions")
|
||||
.match_header("Authorization", "Bearer test-key")
|
||||
.match_header("Content-Type", "application/json")
|
||||
.match_header("Accept", "application/json")
|
||||
.with_status(200)
|
||||
.with_header("content-type", "application/json")
|
||||
.with_body(r#"{"id":"test-id","object":"chat.completion","created":1,"model":"test","choices":[],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}"#)
|
||||
.create();
|
||||
|
||||
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
|
||||
|
||||
// A minimal request to trigger the API call
|
||||
let request = ChatCompletionRequest {
|
||||
model: "test-model".to_string(),
|
||||
messages: vec![AgentMessage::user("test")],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Make the request to verify headers
|
||||
let _ = client.chat_completion(request).await;
|
||||
|
||||
// Verify the headers were sent correctly
|
||||
mock.assert();
|
||||
}
|
||||
|
||||
// Removed unused setup function
|
||||
use dotenv::dotenv;
|
||||
|
||||
// Helper function to initialize environment before tests
|
||||
async fn setup() -> (String, String) {
|
||||
// Load environment variables from .env file
|
||||
dotenv().ok();
|
||||
|
||||
// Get API key and base URL from environment
|
||||
let api_key = env::var("LLM_API_KEY").expect("LLM_API_KEY must be set");
|
||||
let base_url = env::var("LLM_BASE_URL").expect("LLM_API_BASE must be set");
|
||||
|
||||
(api_key, base_url)
|
||||
}
|
||||
|
||||
fn create_test_message() -> AgentMessage {
|
||||
AgentMessage::user("Hello".to_string())
|
||||
|
@ -459,7 +377,7 @@ mod tests {
|
|||
)
|
||||
.create();
|
||||
|
||||
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
|
||||
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url()));
|
||||
|
||||
let response = client.chat_completion(request).await.unwrap();
|
||||
assert_eq!(response.id, "test-id");
|
||||
|
@ -489,7 +407,7 @@ mod tests {
|
|||
.with_body(r#"{"error": "Invalid request"}"#)
|
||||
.create();
|
||||
|
||||
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
|
||||
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url()));
|
||||
|
||||
let result = client.chat_completion(request).await;
|
||||
assert!(result.is_err());
|
||||
|
@ -518,7 +436,7 @@ mod tests {
|
|||
)
|
||||
.create();
|
||||
|
||||
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
|
||||
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url()));
|
||||
|
||||
let mut stream = client.stream_chat_completion(request).await.unwrap();
|
||||
|
||||
|
@ -594,7 +512,7 @@ mod tests {
|
|||
)
|
||||
.create();
|
||||
|
||||
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
|
||||
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url()));
|
||||
|
||||
let response = client.chat_completion(request).await.unwrap();
|
||||
assert_eq!(response.id, "test-id");
|
||||
|
@ -628,10 +546,7 @@ mod tests {
|
|||
env::set_var("LLM_BASE_URL", test_base_url);
|
||||
|
||||
// Test with no parameters (should use env vars)
|
||||
// This would fail without an API key in env, so we set a temporary one
|
||||
// We use it to test the constructor fallback to environment variables
|
||||
env::set_var("LLM_API_KEY", "test-env-key");
|
||||
let client = LiteLLMClient::new(None, None).unwrap();
|
||||
let client = LiteLLMClient::new(None, None);
|
||||
assert_eq!(client.base_url, test_base_url);
|
||||
|
||||
// Test with parameters (should override env vars)
|
||||
|
@ -640,7 +555,7 @@ mod tests {
|
|||
let client = LiteLLMClient::new(
|
||||
Some(override_key.to_string()),
|
||||
Some(override_url.to_string()),
|
||||
).unwrap();
|
||||
);
|
||||
assert_eq!(client.base_url, override_url);
|
||||
|
||||
env::remove_var("LLM_API_KEY");
|
||||
|
@ -649,39 +564,8 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_single_message_completion() {
|
||||
let mut server = mockito::Server::new_async().await;
|
||||
|
||||
// Mock the response
|
||||
let mock = server
|
||||
.mock("POST", "/chat/completions")
|
||||
.match_header("content-type", "application/json")
|
||||
.match_header("authorization", "Bearer test-key")
|
||||
.with_status(200)
|
||||
.with_header("content-type", "application/json")
|
||||
.with_body(r#"{
|
||||
"id": "test-id",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "o1",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello, world! How can I assist you today?"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 10,
|
||||
"total_tokens": 20
|
||||
}
|
||||
}"#)
|
||||
.create();
|
||||
|
||||
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
|
||||
let (api_key, base_url) = setup().await;
|
||||
let client = LiteLLMClient::new(Some(api_key), Some(base_url));
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model: "o1".to_string(),
|
||||
|
@ -695,469 +579,5 @@ mod tests {
|
|||
};
|
||||
|
||||
assert!(response.choices.len() > 0);
|
||||
mock.assert();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_generate_embeddings_success() {
|
||||
let mut server = mockito::Server::new_async().await;
|
||||
|
||||
// Create a test embedding request
|
||||
let request = EmbeddingRequest {
|
||||
model: "text-embedding-3-small".to_string(),
|
||||
input: vec!["This is a test sentence.".to_string()],
|
||||
encoding_format: None,
|
||||
dimensions: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
let request_body = serde_json::to_string(&request).unwrap();
|
||||
|
||||
// Expected response
|
||||
let response_body = r#"{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": 0,
|
||||
"embedding": [0.1, 0.2, 0.3]
|
||||
}
|
||||
],
|
||||
"model": "text-embedding-3-small",
|
||||
"usage": {
|
||||
"prompt_tokens": 5,
|
||||
"total_tokens": 5
|
||||
}
|
||||
}"#;
|
||||
|
||||
// Set up mock
|
||||
let mock = server
|
||||
.mock("POST", "/embeddings")
|
||||
.match_header("content-type", "application/json")
|
||||
.match_header("authorization", "Bearer test-key")
|
||||
.match_body(mockito::Matcher::JsonString(request_body))
|
||||
.with_status(200)
|
||||
.with_header("content-type", "application/json")
|
||||
.with_body(response_body)
|
||||
.create();
|
||||
|
||||
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
|
||||
|
||||
// Make the request
|
||||
let response = client.generate_embeddings(request).await.unwrap();
|
||||
|
||||
// Verify response
|
||||
assert_eq!(response.data.len(), 1);
|
||||
assert_eq!(response.data[0].index, 0);
|
||||
assert_eq!(response.data[0].embedding, vec![0.1, 0.2, 0.3]);
|
||||
assert_eq!(response.model, "text-embedding-3-small");
|
||||
assert_eq!(response.usage.prompt_tokens, 5);
|
||||
assert_eq!(response.usage.total_tokens, 5);
|
||||
|
||||
mock.assert();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_generate_embeddings_with_parameters() {
|
||||
let mut server = mockito::Server::new_async().await;
|
||||
|
||||
// Create a test embedding request with all parameters
|
||||
let request = EmbeddingRequest {
|
||||
model: "text-embedding-3-large".to_string(),
|
||||
input: vec!["First sentence.".to_string(), "Second sentence.".to_string()],
|
||||
encoding_format: Some("float".to_string()),
|
||||
dimensions: Some(256),
|
||||
user: Some("test-user".to_string()),
|
||||
};
|
||||
|
||||
let request_body = serde_json::to_string(&request).unwrap();
|
||||
|
||||
// Expected response for multiple inputs
|
||||
let response_body = r#"{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": 0,
|
||||
"embedding": [0.1, 0.2, 0.3]
|
||||
},
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": 1,
|
||||
"embedding": [0.4, 0.5, 0.6]
|
||||
}
|
||||
],
|
||||
"model": "text-embedding-3-large",
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"total_tokens": 10
|
||||
}
|
||||
}"#;
|
||||
|
||||
// Set up mock
|
||||
let mock = server
|
||||
.mock("POST", "/embeddings")
|
||||
.match_header("content-type", "application/json")
|
||||
.match_header("authorization", "Bearer test-key")
|
||||
.match_body(mockito::Matcher::JsonString(request_body))
|
||||
.with_status(200)
|
||||
.with_header("content-type", "application/json")
|
||||
.with_body(response_body)
|
||||
.create();
|
||||
|
||||
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
|
||||
|
||||
// Make the request
|
||||
let response = client.generate_embeddings(request).await.unwrap();
|
||||
|
||||
// Verify response
|
||||
assert_eq!(response.data.len(), 2);
|
||||
assert_eq!(response.data[0].index, 0);
|
||||
assert_eq!(response.data[0].embedding, vec![0.1, 0.2, 0.3]);
|
||||
assert_eq!(response.data[1].index, 1);
|
||||
assert_eq!(response.data[1].embedding, vec![0.4, 0.5, 0.6]);
|
||||
assert_eq!(response.model, "text-embedding-3-large");
|
||||
assert_eq!(response.usage.prompt_tokens, 10);
|
||||
assert_eq!(response.usage.total_tokens, 10);
|
||||
|
||||
mock.assert();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_generate_embeddings_api_error() {
|
||||
let mut server = mockito::Server::new_async().await;
|
||||
|
||||
// Create a test embedding request
|
||||
let request = EmbeddingRequest {
|
||||
model: "invalid-model".to_string(),
|
||||
input: vec!["Test".to_string()],
|
||||
encoding_format: None,
|
||||
dimensions: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
let request_body = serde_json::to_string(&request).unwrap();
|
||||
|
||||
// Error response
|
||||
let error_response = r#"{
|
||||
"error": {
|
||||
"code": "model_not_found",
|
||||
"message": "The model 'invalid-model' does not exist"
|
||||
}
|
||||
}"#;
|
||||
|
||||
// Set up mock
|
||||
let mock = server
|
||||
.mock("POST", "/embeddings")
|
||||
.match_header("content-type", "application/json")
|
||||
.match_header("authorization", "Bearer test-key")
|
||||
.match_body(mockito::Matcher::JsonString(request_body))
|
||||
.with_status(404)
|
||||
.with_header("content-type", "application/json")
|
||||
.with_body(error_response)
|
||||
.create();
|
||||
|
||||
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
|
||||
|
||||
// Make the request and expect an error
|
||||
let result = client.generate_embeddings(request).await;
|
||||
assert!(result.is_err());
|
||||
|
||||
// Verify the error contains useful information
|
||||
let error = result.unwrap_err();
|
||||
let error_string = error.to_string();
|
||||
assert!(error_string.contains("404"));
|
||||
|
||||
mock.assert();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_network_error_handling() {
|
||||
// Create a server URL that will never connect
|
||||
let server_url = "http://localhost:1"; // Using port 1 to ensure no service is running
|
||||
|
||||
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server_url.to_string())).unwrap();
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model: "test-model".to_string(),
|
||||
messages: vec![AgentMessage::user("Test")],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Attempt request which should fail with network error
|
||||
let result = client.chat_completion(request).await;
|
||||
|
||||
// Verify we got an error (don't check the specifics as they vary by environment)
|
||||
assert!(result.is_err());
|
||||
|
||||
// Print the error for debugging
|
||||
let error = result.unwrap_err();
|
||||
println!("Network error: {:?}", error);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // This test actually makes a real HTTP request to an invalid port, so it's slow
|
||||
async fn test_timeout_simulation() {
|
||||
// Use a blackhole address that will cause a connection timeout
|
||||
let server_url = "http://198.51.100.1:8000"; // TEST-NET-2 reserved for documentation
|
||||
|
||||
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server_url.to_string())).unwrap();
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model: "test-model".to_string(),
|
||||
messages: vec![AgentMessage::user("Test")],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Wrap the client call with a short timeout
|
||||
let result = tokio::time::timeout(
|
||||
Duration::from_millis(100), // Short timeout
|
||||
client.chat_completion(request)
|
||||
).await;
|
||||
|
||||
// We expect the timeout to be hit
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result, Err(tokio::time::error::Elapsed { .. })));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rate_limit_error_handling() {
|
||||
let mut server = mockito::Server::new_async().await;
|
||||
|
||||
// Set up a mock that simulates a rate limit error
|
||||
let mock = server
|
||||
.mock("POST", "/chat/completions")
|
||||
.with_status(429) // Too many requests
|
||||
.with_header("content-type", "application/json")
|
||||
.with_header("retry-after", "30") // Suggest retry after 30 seconds
|
||||
.with_body(r#"{"error":{"message":"Rate limit exceeded","type":"rate_limit_error"}}"#)
|
||||
.create();
|
||||
|
||||
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model: "test-model".to_string(),
|
||||
messages: vec![AgentMessage::user("Test")],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Attempt request which should fail with rate limit error
|
||||
let result = client.chat_completion(request).await;
|
||||
|
||||
// Verify we got an error (don't check the specifics as they vary by environment)
|
||||
assert!(result.is_err());
|
||||
|
||||
// Print the error for debugging
|
||||
let error = result.unwrap_err();
|
||||
println!("Rate limit error: {:?}", error);
|
||||
|
||||
mock.assert();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invalid_json_response_handling() {
|
||||
let mut server = mockito::Server::new_async().await;
|
||||
|
||||
// Set up a mock that returns invalid JSON
|
||||
let mock = server
|
||||
.mock("POST", "/chat/completions")
|
||||
.with_status(200)
|
||||
.with_header("content-type", "application/json")
|
||||
.with_body("This is not valid JSON")
|
||||
.create();
|
||||
|
||||
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model: "test-model".to_string(),
|
||||
messages: vec![AgentMessage::user("Test")],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Attempt request which should fail with parsing error
|
||||
let result = client.chat_completion(request).await;
|
||||
|
||||
// Verify we got an error (don't check the specifics as they vary by environment)
|
||||
assert!(result.is_err());
|
||||
|
||||
// Print the error for debugging
|
||||
let error = result.unwrap_err();
|
||||
println!("Invalid JSON error: {:?}", error);
|
||||
|
||||
mock.assert();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_unexpected_server_error_handling() {
|
||||
let mut server = mockito::Server::new_async().await;
|
||||
|
||||
// Set up a mock that returns a 500 server error
|
||||
let mock = server
|
||||
.mock("POST", "/chat/completions")
|
||||
.with_status(500)
|
||||
.with_header("content-type", "application/json")
|
||||
.with_body(r#"{"error":{"message":"Internal server error","type":"server_error"}}"#)
|
||||
.create();
|
||||
|
||||
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model: "test-model".to_string(),
|
||||
messages: vec![AgentMessage::user("Test")],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Attempt request which should fail with server error
|
||||
let result = client.chat_completion(request).await;
|
||||
|
||||
// Verify we got an error (don't check the specifics as they vary by environment)
|
||||
assert!(result.is_err());
|
||||
|
||||
// Print the error for debugging
|
||||
let error = result.unwrap_err();
|
||||
println!("Server error: {:?}", error);
|
||||
|
||||
mock.assert();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_stream_with_partial_chunks() {
|
||||
let mut server = mockito::Server::new_async().await;
|
||||
|
||||
// Create a request for streaming
|
||||
let request = ChatCompletionRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![
|
||||
AgentMessage::user("Tell me a short story"),
|
||||
],
|
||||
stream: Some(true),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Mock response with multiple chunks building a response
|
||||
let stream_response = "data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"system_fingerprint\":\"fp_1\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\"},\"finish_reason\":null}]}\n\n\
|
||||
data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Once\"},\"finish_reason\":null}]}\n\n\
|
||||
data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" upon\"},\"finish_reason\":null}]}\n\n\
|
||||
data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" a\"},\"finish_reason\":null}]}\n\n\
|
||||
data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" time\"},\"finish_reason\":null}]}\n\n\
|
||||
data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}]}\n\n\
|
||||
data: [DONE]\n\n";
|
||||
|
||||
// Set up mock
|
||||
let mock = server
|
||||
.mock("POST", "/chat/completions")
|
||||
.match_header("content-type", "application/json")
|
||||
.match_header("authorization", "Bearer test-key")
|
||||
.with_status(200)
|
||||
.with_header("content-type", "text/event-stream")
|
||||
.with_body(stream_response)
|
||||
.create();
|
||||
|
||||
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
|
||||
|
||||
// Make the request
|
||||
let mut stream = client.stream_chat_completion(request).await.unwrap();
|
||||
|
||||
// Collect all chunks
|
||||
let mut chunks = Vec::new();
|
||||
while let Some(chunk_result) = stream.recv().await {
|
||||
match chunk_result {
|
||||
Ok(chunk) => chunks.push(chunk),
|
||||
Err(e) => panic!("Error in stream: {:?}", e),
|
||||
}
|
||||
}
|
||||
|
||||
// Verify chunks
|
||||
assert_eq!(chunks.len(), 6); // 5 content chunks + 1 finish chunk
|
||||
|
||||
// First chunk should have role
|
||||
assert_eq!(chunks[0].choices[0].delta.role, Some("assistant".to_string()));
|
||||
|
||||
// Content chunks should build the story
|
||||
assert_eq!(chunks[1].choices[0].delta.content, Some("Once".to_string()));
|
||||
assert_eq!(chunks[2].choices[0].delta.content, Some(" upon".to_string()));
|
||||
assert_eq!(chunks[3].choices[0].delta.content, Some(" a".to_string()));
|
||||
assert_eq!(chunks[4].choices[0].delta.content, Some(" time".to_string()));
|
||||
|
||||
// Last chunk should have finish reason
|
||||
assert_eq!(chunks[5].choices[0].finish_reason, Some("stop".to_string()));
|
||||
assert!(chunks[5].choices[0].delta.content.is_none());
|
||||
|
||||
mock.assert();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_stream_with_tool_calls() {
|
||||
let mut server = mockito::Server::new_async().await;
|
||||
|
||||
// Create a request for streaming with tool calls
|
||||
let request = ChatCompletionRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![
|
||||
AgentMessage::user("What's the weather in New York?"),
|
||||
],
|
||||
tools: Some(vec![Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: serde_json::json!({
|
||||
"name": "get_weather",
|
||||
"description": "Get weather information",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string"}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}),
|
||||
}]),
|
||||
stream: Some(true),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Mock response with tool calls in chunks
|
||||
let stream_response = "data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\"},\"finish_reason\":null}]}\n\n\
|
||||
data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_123\",\"type\":\"function\"}]},\"finish_reason\":null}]}\n\n\
|
||||
data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"name\":\"get_weather\"}}]},\"finish_reason\":null}]}\n\n\
|
||||
data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"{\\\"\"}}]},\"finish_reason\":null}]}\n\n\
|
||||
data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"location\\\":\\\"\"}}]},\"finish_reason\":null}]}\n\n\
|
||||
data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"New York\\\"\"}}]},\"finish_reason\":null}]}\n\n\
|
||||
data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"}\"}}]},\"finish_reason\":null}]}\n\n\
|
||||
data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"tool_calls\"}]}\n\n\
|
||||
data: [DONE]\n\n";
|
||||
|
||||
// Set up mock
|
||||
let mock = server
|
||||
.mock("POST", "/chat/completions")
|
||||
.match_header("content-type", "application/json")
|
||||
.match_header("authorization", "Bearer test-key")
|
||||
.with_status(200)
|
||||
.with_header("content-type", "text/event-stream")
|
||||
.with_body(stream_response)
|
||||
.create();
|
||||
|
||||
let client = LiteLLMClient::new(Some("test-key".to_string()), Some(server.url())).unwrap();
|
||||
|
||||
// Make the request
|
||||
let mut stream = client.stream_chat_completion(request).await.unwrap();
|
||||
|
||||
// Collect all chunks
|
||||
let mut chunks = Vec::new();
|
||||
while let Some(chunk_result) = stream.recv().await {
|
||||
match chunk_result {
|
||||
Ok(chunk) => chunks.push(chunk),
|
||||
Err(e) => panic!("Error in stream: {:?}", e),
|
||||
}
|
||||
}
|
||||
|
||||
// Basic validation (number of chunks)
|
||||
assert!(chunks.len() > 0);
|
||||
|
||||
// First chunk should have role
|
||||
assert_eq!(chunks[0].choices[0].delta.role, Some("assistant".to_string()));
|
||||
|
||||
// Last chunk should have finish reason
|
||||
assert_eq!(chunks.last().unwrap().choices[0].finish_reason, Some("tool_calls".to_string()));
|
||||
|
||||
mock.assert();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -172,6 +172,7 @@ impl AgentMessage {
|
|||
}
|
||||
|
||||
pub fn assistant(
|
||||
id: Option<String>,
|
||||
content: Option<String>,
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
progress: MessageProgress,
|
||||
|
@ -181,7 +182,7 @@ impl AgentMessage {
|
|||
let initial = initial.unwrap_or(false);
|
||||
|
||||
Self::Assistant {
|
||||
id: None,
|
||||
id,
|
||||
content,
|
||||
name,
|
||||
tool_calls,
|
||||
|
@ -557,6 +558,7 @@ mod tests {
|
|||
message: AgentMessage::assistant(
|
||||
Some("\n\nHello there, how may I assist you today?".to_string()),
|
||||
None,
|
||||
None,
|
||||
MessageProgress::Complete,
|
||||
None,
|
||||
None,
|
||||
|
@ -706,6 +708,7 @@ mod tests {
|
|||
message: AgentMessage::assistant(
|
||||
Some("".to_string()),
|
||||
None,
|
||||
None,
|
||||
MessageProgress::Complete,
|
||||
None,
|
||||
None,
|
||||
|
@ -988,6 +991,7 @@ mod tests {
|
|||
choices: vec![Choice {
|
||||
index: 0,
|
||||
message: AgentMessage::assistant(
|
||||
None,
|
||||
None,
|
||||
Some(vec![ToolCall {
|
||||
id: "call_abc123".to_string(),
|
||||
|
|
Loading…
Reference in New Issue