From 087867032dc068c21fd16ec0c4e5a25b51cd9b94 Mon Sep 17 00:00:00 2001 From: dal Date: Sat, 25 Jan 2025 15:17:21 -0700 Subject: [PATCH] chore: Update API project configuration and add new modules - Added dev dependencies for Mockito and Tokio async testing - Updated .cursorrules with async test requirement - Expanded utils modules with new serde_helpers and tools - Added new AI client module for LiteLLM --- api/.cursorrules | 3 +- api/Cargo.toml | 4 + api/src/utils/clients/ai/litellm/client.rs | 248 ++++++++++++++++++++ api/src/utils/clients/ai/litellm/mod.rs | 5 + api/src/utils/clients/ai/litellm/types.rs | 250 +++++++++++++++++++++ api/src/utils/clients/ai/mod.rs | 1 + api/src/utils/mod.rs | 3 +- api/src/utils/tools/mod.rs | 0 8 files changed, 512 insertions(+), 2 deletions(-) create mode 100644 api/src/utils/clients/ai/litellm/client.rs create mode 100644 api/src/utils/clients/ai/litellm/mod.rs create mode 100644 api/src/utils/clients/ai/litellm/types.rs create mode 100644 api/src/utils/tools/mod.rs diff --git a/api/.cursorrules b/api/.cursorrules index f28e11382..2549ff4a3 100644 --- a/api/.cursorrules +++ b/api/.cursorrules @@ -1 +1,2 @@ -this is an axum web server \ No newline at end of file +- this is an axum web server +- all tests need to be tokio async tests \ No newline at end of file diff --git a/api/Cargo.toml b/api/Cargo.toml index 2fc38d93d..0c8627b4f 100644 --- a/api/Cargo.toml +++ b/api/Cargo.toml @@ -92,5 +92,9 @@ diesel_migrations = "2.0.0" serde_yaml = "0.9.34" html-escape = "0.2.13" +[dev-dependencies] +mockito = "1.2" +tokio = { version = "1.0", features = ["full", "test-util"] } + [profile.release] debug = false diff --git a/api/src/utils/clients/ai/litellm/client.rs b/api/src/utils/clients/ai/litellm/client.rs new file mode 100644 index 000000000..683ee7dba --- /dev/null +++ b/api/src/utils/clients/ai/litellm/client.rs @@ -0,0 +1,248 @@ +use reqwest::{Client, header}; +use futures_util::StreamExt; +use tokio::sync::mpsc; +use anyhow::Result; + +use super::types::*; + +pub struct LiteLLMClient { + client: Client, + pub(crate) api_key: String, + pub(crate) base_url: String, +} + +impl LiteLLMClient { + pub fn new(api_key: String, base_url: Option) -> Self { + let mut headers = header::HeaderMap::new(); + headers.insert( + "Authorization", + header::HeaderValue::from_str(&format!("Bearer {}", api_key)).unwrap(), + ); + headers.insert( + "Content-Type", + header::HeaderValue::from_static("application/json"), + ); + headers.insert( + "Accept", + header::HeaderValue::from_static("application/json"), + ); + + let client = Client::builder() + .default_headers(headers) + .build() + .expect("Failed to create HTTP client"); + + Self { + client, + api_key, + base_url: base_url.unwrap_or_else(|| "http://localhost:8000".to_string()), + } + } + + pub async fn chat_completion(&self, request: ChatCompletionRequest) -> Result { + let url = format!("{}/chat/completions", self.base_url); + let response = self.client + .post(&url) + .json(&request) + .send() + .await? + .json::() + .await?; + Ok(response) + } + + pub async fn stream_chat_completion( + &self, + request: ChatCompletionRequest, + ) -> Result>> { + let url = format!("{}/chat/completions", self.base_url); + let mut stream = self.client + .post(&url) + .json(&ChatCompletionRequest { + stream: Some(true), + ..request + }) + .send() + .await? + .bytes_stream(); + + let (tx, rx) = mpsc::channel(100); + + tokio::spawn(async move { + let mut buffer = String::new(); + + while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(chunk) => { + let chunk_str = String::from_utf8_lossy(&chunk); + buffer.push_str(&chunk_str); + + while let Some(pos) = buffer.find("\n\n") { + let line = buffer[..pos].trim().to_string(); + buffer = buffer[pos + 2..].to_string(); + + if line.starts_with("data: ") { + let data = &line["data: ".len()..]; + if data == "[DONE]" { + break; + } + + if let Ok(response) = serde_json::from_str::(data) { + let _ = tx.send(Ok(response)).await; + } + } + } + } + Err(e) => { + let _ = tx.send(Err(anyhow::Error::from(e))).await; + } + } + } + }); + + Ok(rx) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use mockito; + use tokio::time::timeout; + use std::time::Duration; + + fn create_test_message() -> Message { + Message { + role: "user".to_string(), + content: "Hello".to_string(), + name: None, + tool_calls: None, + tool_call_id: None, + } + } + + fn create_test_request() -> ChatCompletionRequest { + ChatCompletionRequest { + model: "gpt-4".to_string(), + messages: vec![create_test_message()], + temperature: Some(0.7), + ..Default::default() + } + } + + #[tokio::test] + async fn test_chat_completion_success() { + let mut server = mockito::Server::new(); + + let mock = server.mock("POST", "/chat/completions") + .match_header("content-type", "application/json") + .match_header("Authorization", "Bearer test-key") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{ + "id": "test-id", + "object": "chat.completion", + "created": 1234567890, + "model": "gpt-4", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "Hello there!" + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 20, + "total_tokens": 30 + } + }"#) + .create(); + + let client = LiteLLMClient::new( + "test-key".to_string(), + Some(server.url()), + ); + + let response = client.chat_completion(create_test_request()).await.unwrap(); + assert_eq!(response.id, "test-id"); + assert_eq!(response.choices[0].message.content, "Hello there!"); + + mock.assert(); + } + + #[tokio::test] + async fn test_chat_completion_error() { + let mut server = mockito::Server::new(); + + let mock = server.mock("POST", "/chat/completions") + .match_header("content-type", "application/json") + .with_status(400) + .with_body(r#"{"error": "Invalid request"}"#) + .create(); + + let client = LiteLLMClient::new( + "test-key".to_string(), + Some(server.url()), + ); + + let result = client.chat_completion(create_test_request()).await; + assert!(result.is_err()); + + mock.assert(); + } + + #[tokio::test] + async fn test_stream_chat_completion() { + let mut server = mockito::Server::new(); + + let mock = server.mock("POST", "/chat/completions") + .match_header("content-type", "application/json") + .match_body(mockito::Matcher::JsonString(r#"{"stream":true}"#.to_string())) + .with_status(200) + .with_header("content-type", "text/event-stream") + .with_body( + "data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n\ + data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" world\"},\"finish_reason\":null}]}\n\n\ + data: [DONE]\n\n" + ) + .create(); + + let client = LiteLLMClient::new( + "test-key".to_string(), + Some(server.url()), + ); + + let mut request = create_test_request(); + request.stream = Some(true); + + let mut stream = client.stream_chat_completion(request).await.unwrap(); + + let mut chunks = Vec::new(); + while let Ok(Some(chunk)) = timeout(Duration::from_secs(1), stream.recv()).await { + if let Ok(chunk) = chunk { + chunks.push(chunk); + } + } + + assert_eq!(chunks.len(), 2); + assert_eq!(chunks[0].choices[0].delta.content, Some("Hello".to_string())); + assert_eq!(chunks[1].choices[0].delta.content, Some(" world".to_string())); + + mock.assert(); + } + + #[test] + fn test_client_initialization() { + let api_key = "test-key".to_string(); + let base_url = "http://custom.url".to_string(); + + let client = LiteLLMClient::new(api_key.clone(), Some(base_url.clone())); + assert_eq!(client.api_key, api_key); + assert_eq!(client.base_url, base_url); + + let client = LiteLLMClient::new(api_key.clone(), None); + assert_eq!(client.base_url, "http://localhost:8000"); + } +} \ No newline at end of file diff --git a/api/src/utils/clients/ai/litellm/mod.rs b/api/src/utils/clients/ai/litellm/mod.rs new file mode 100644 index 000000000..83e22b07b --- /dev/null +++ b/api/src/utils/clients/ai/litellm/mod.rs @@ -0,0 +1,5 @@ +mod types; +mod client; + +pub use types::*; +pub use client::LiteLLMClient; \ No newline at end of file diff --git a/api/src/utils/clients/ai/litellm/types.rs b/api/src/utils/clients/ai/litellm/types.rs new file mode 100644 index 000000000..ac28e1729 --- /dev/null +++ b/api/src/utils/clients/ai/litellm/types.rs @@ -0,0 +1,250 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +#[derive(Debug, Serialize, Deserialize)] +pub struct ChatCompletionRequest { + pub model: String, + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub logit_bias: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub seed: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, +} + +impl Default for ChatCompletionRequest { + fn default() -> Self { + Self { + model: String::new(), + messages: Vec::new(), + frequency_penalty: None, + logit_bias: None, + max_tokens: None, + n: None, + presence_penalty: None, + response_format: None, + seed: None, + stop: None, + stream: None, + temperature: None, + top_p: None, + tools: None, + tool_choice: None, + user: None, + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Message { + pub role: String, + pub content: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ResponseFormat { + #[serde(rename = "type")] + pub format_type: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Tool { + #[serde(rename = "type")] + pub tool_type: String, + pub function: Function, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Function { + pub name: String, + pub description: Option, + pub parameters: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ToolChoice { + None(String), + Auto(String), + Function(FunctionCall), +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct FunctionCall { + #[serde(rename = "type")] + pub call_type: String, + pub function: Function, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ToolCall { + pub id: String, + #[serde(rename = "type")] + pub call_type: String, + pub function: FunctionCall, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ChatCompletionResponse { + pub id: String, + pub object: String, + pub created: i64, + pub model: String, + pub choices: Vec, + pub usage: Usage, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Choice { + pub index: i32, + pub message: Message, + pub finish_reason: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Usage { + pub prompt_tokens: i32, + pub completion_tokens: i32, + pub total_tokens: i32, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ChatCompletionChunk { + pub id: String, + pub object: String, + pub created: i64, + pub model: String, + pub choices: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct StreamChoice { + pub index: i32, + pub delta: DeltaMessage, + pub finish_reason: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct DeltaMessage { + pub role: Option, + pub content: Option, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_chat_completion_request_serialization() { + let request = ChatCompletionRequest { + model: "gpt-4".to_string(), + messages: vec![Message { + role: "user".to_string(), + content: "Hello".to_string(), + name: None, + tool_calls: None, + tool_call_id: None, + }], + temperature: Some(0.7), + ..Default::default() + }; + + let json = serde_json::to_string(&request).unwrap(); + assert!(json.contains("\"model\":\"gpt-4\"")); + assert!(json.contains("\"temperature\":0.7")); + assert!(!json.contains("frequency_penalty")); // Optional fields should be omitted + } + + #[test] + fn test_chat_completion_request_deserialization() { + let json = r#"{ + "model": "gpt-4", + "messages": [{ + "role": "user", + "content": "Hello" + }], + "temperature": 0.7 + }"#; + + let request: ChatCompletionRequest = serde_json::from_str(json).unwrap(); + assert_eq!(request.model, "gpt-4"); + assert_eq!(request.messages[0].role, "user"); + assert_eq!(request.temperature, Some(0.7)); + assert_eq!(request.frequency_penalty, None); + } + + #[test] + fn test_tool_choice_serialization() { + let none_choice = ToolChoice::None("none".to_string()); + let json = serde_json::to_string(&none_choice).unwrap(); + assert_eq!(json, "\"none\""); + + let function_choice = ToolChoice::Function(FunctionCall { + call_type: "function".to_string(), + function: Function { + name: "test".to_string(), + description: Some("test desc".to_string()), + parameters: serde_json::json!({}), + }, + }); + let json = serde_json::to_string(&function_choice).unwrap(); + assert!(json.contains("\"type\":\"function\"")); + } + + #[test] + fn test_chat_completion_response_deserialization() { + let json = r#"{ + "id": "test-id", + "object": "chat.completion", + "created": 1234567890, + "model": "gpt-4", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "Hello there!" + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 20, + "total_tokens": 30 + } + }"#; + + let response: ChatCompletionResponse = serde_json::from_str(json).unwrap(); + assert_eq!(response.id, "test-id"); + assert_eq!(response.choices[0].message.content, "Hello there!"); + assert_eq!(response.usage.total_tokens, 30); + } +} \ No newline at end of file diff --git a/api/src/utils/clients/ai/mod.rs b/api/src/utils/clients/ai/mod.rs index 5b2393145..3c63e2349 100644 --- a/api/src/utils/clients/ai/mod.rs +++ b/api/src/utils/clients/ai/mod.rs @@ -2,6 +2,7 @@ mod anthropic; pub mod embedding_router; mod hugging_face; pub mod langfuse; +pub mod litellm; pub mod llm_router; pub mod ollama; pub mod openai; diff --git a/api/src/utils/mod.rs b/api/src/utils/mod.rs index 410ec4ff3..16ac1fc8c 100644 --- a/api/src/utils/mod.rs +++ b/api/src/utils/mod.rs @@ -6,6 +6,7 @@ pub mod prompts; pub mod query_engine; pub mod search_engine; pub mod security; +pub mod serde_helpers; pub mod sharing; +pub mod tools; pub mod user; -pub mod serde_helpers; \ No newline at end of file diff --git a/api/src/utils/tools/mod.rs b/api/src/utils/tools/mod.rs new file mode 100644 index 000000000..e69de29bb