chore: Update API project configuration and add new modules

- Added dev dependencies for Mockito and Tokio async testing
- Updated .cursorrules with async test requirement
- Expanded utils modules with new serde_helpers and tools
- Added new AI client module for LiteLLM
This commit is contained in:
dal 2025-01-25 15:17:21 -07:00
parent 39dfc053fc
commit 087867032d
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
8 changed files with 512 additions and 2 deletions

View File

@ -1 +1,2 @@
this is an axum web server
- this is an axum web server
- all tests need to be tokio async tests

View File

@ -92,5 +92,9 @@ diesel_migrations = "2.0.0"
serde_yaml = "0.9.34"
html-escape = "0.2.13"
[dev-dependencies]
mockito = "1.2"
tokio = { version = "1.0", features = ["full", "test-util"] }
[profile.release]
debug = false

View File

@ -0,0 +1,248 @@
use reqwest::{Client, header};
use futures_util::StreamExt;
use tokio::sync::mpsc;
use anyhow::Result;
use super::types::*;
pub struct LiteLLMClient {
client: Client,
pub(crate) api_key: String,
pub(crate) base_url: String,
}
impl LiteLLMClient {
pub fn new(api_key: String, base_url: Option<String>) -> Self {
let mut headers = header::HeaderMap::new();
headers.insert(
"Authorization",
header::HeaderValue::from_str(&format!("Bearer {}", api_key)).unwrap(),
);
headers.insert(
"Content-Type",
header::HeaderValue::from_static("application/json"),
);
headers.insert(
"Accept",
header::HeaderValue::from_static("application/json"),
);
let client = Client::builder()
.default_headers(headers)
.build()
.expect("Failed to create HTTP client");
Self {
client,
api_key,
base_url: base_url.unwrap_or_else(|| "http://localhost:8000".to_string()),
}
}
pub async fn chat_completion(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse> {
let url = format!("{}/chat/completions", self.base_url);
let response = self.client
.post(&url)
.json(&request)
.send()
.await?
.json::<ChatCompletionResponse>()
.await?;
Ok(response)
}
pub async fn stream_chat_completion(
&self,
request: ChatCompletionRequest,
) -> Result<mpsc::Receiver<Result<ChatCompletionChunk>>> {
let url = format!("{}/chat/completions", self.base_url);
let mut stream = self.client
.post(&url)
.json(&ChatCompletionRequest {
stream: Some(true),
..request
})
.send()
.await?
.bytes_stream();
let (tx, rx) = mpsc::channel(100);
tokio::spawn(async move {
let mut buffer = String::new();
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
let chunk_str = String::from_utf8_lossy(&chunk);
buffer.push_str(&chunk_str);
while let Some(pos) = buffer.find("\n\n") {
let line = buffer[..pos].trim().to_string();
buffer = buffer[pos + 2..].to_string();
if line.starts_with("data: ") {
let data = &line["data: ".len()..];
if data == "[DONE]" {
break;
}
if let Ok(response) = serde_json::from_str::<ChatCompletionChunk>(data) {
let _ = tx.send(Ok(response)).await;
}
}
}
}
Err(e) => {
let _ = tx.send(Err(anyhow::Error::from(e))).await;
}
}
}
});
Ok(rx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use mockito;
use tokio::time::timeout;
use std::time::Duration;
fn create_test_message() -> Message {
Message {
role: "user".to_string(),
content: "Hello".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}
}
fn create_test_request() -> ChatCompletionRequest {
ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![create_test_message()],
temperature: Some(0.7),
..Default::default()
}
}
#[tokio::test]
async fn test_chat_completion_success() {
let mut server = mockito::Server::new();
let mock = server.mock("POST", "/chat/completions")
.match_header("content-type", "application/json")
.match_header("Authorization", "Bearer test-key")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(r#"{
"id": "test-id",
"object": "chat.completion",
"created": 1234567890,
"model": "gpt-4",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello there!"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30
}
}"#)
.create();
let client = LiteLLMClient::new(
"test-key".to_string(),
Some(server.url()),
);
let response = client.chat_completion(create_test_request()).await.unwrap();
assert_eq!(response.id, "test-id");
assert_eq!(response.choices[0].message.content, "Hello there!");
mock.assert();
}
#[tokio::test]
async fn test_chat_completion_error() {
let mut server = mockito::Server::new();
let mock = server.mock("POST", "/chat/completions")
.match_header("content-type", "application/json")
.with_status(400)
.with_body(r#"{"error": "Invalid request"}"#)
.create();
let client = LiteLLMClient::new(
"test-key".to_string(),
Some(server.url()),
);
let result = client.chat_completion(create_test_request()).await;
assert!(result.is_err());
mock.assert();
}
#[tokio::test]
async fn test_stream_chat_completion() {
let mut server = mockito::Server::new();
let mock = server.mock("POST", "/chat/completions")
.match_header("content-type", "application/json")
.match_body(mockito::Matcher::JsonString(r#"{"stream":true}"#.to_string()))
.with_status(200)
.with_header("content-type", "text/event-stream")
.with_body(
"data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n\
data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" world\"},\"finish_reason\":null}]}\n\n\
data: [DONE]\n\n"
)
.create();
let client = LiteLLMClient::new(
"test-key".to_string(),
Some(server.url()),
);
let mut request = create_test_request();
request.stream = Some(true);
let mut stream = client.stream_chat_completion(request).await.unwrap();
let mut chunks = Vec::new();
while let Ok(Some(chunk)) = timeout(Duration::from_secs(1), stream.recv()).await {
if let Ok(chunk) = chunk {
chunks.push(chunk);
}
}
assert_eq!(chunks.len(), 2);
assert_eq!(chunks[0].choices[0].delta.content, Some("Hello".to_string()));
assert_eq!(chunks[1].choices[0].delta.content, Some(" world".to_string()));
mock.assert();
}
#[test]
fn test_client_initialization() {
let api_key = "test-key".to_string();
let base_url = "http://custom.url".to_string();
let client = LiteLLMClient::new(api_key.clone(), Some(base_url.clone()));
assert_eq!(client.api_key, api_key);
assert_eq!(client.base_url, base_url);
let client = LiteLLMClient::new(api_key.clone(), None);
assert_eq!(client.base_url, "http://localhost:8000");
}
}

View File

@ -0,0 +1,5 @@
mod types;
mod client;
pub use types::*;
pub use client::LiteLLMClient;

View File

@ -0,0 +1,250 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<HashMap<String, i32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}
impl Default for ChatCompletionRequest {
fn default() -> Self {
Self {
model: String::new(),
messages: Vec::new(),
frequency_penalty: None,
logit_bias: None,
max_tokens: None,
n: None,
presence_penalty: None,
response_format: None,
seed: None,
stop: None,
stream: None,
temperature: None,
top_p: None,
tools: None,
tool_choice: None,
user: None,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ResponseFormat {
#[serde(rename = "type")]
pub format_type: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: Function,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Function {
pub name: String,
pub description: Option<String>,
pub parameters: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolChoice {
None(String),
Auto(String),
Function(FunctionCall),
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FunctionCall {
#[serde(rename = "type")]
pub call_type: String,
pub function: Function,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: FunctionCall,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: i64,
pub model: String,
pub choices: Vec<Choice>,
pub usage: Usage,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Choice {
pub index: i32,
pub message: Message,
pub finish_reason: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: i32,
pub completion_tokens: i32,
pub total_tokens: i32,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatCompletionChunk {
pub id: String,
pub object: String,
pub created: i64,
pub model: String,
pub choices: Vec<StreamChoice>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct StreamChoice {
pub index: i32,
pub delta: DeltaMessage,
pub finish_reason: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct DeltaMessage {
pub role: Option<String>,
pub content: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chat_completion_request_serialization() {
let request = ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![Message {
role: "user".to_string(),
content: "Hello".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}],
temperature: Some(0.7),
..Default::default()
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("\"model\":\"gpt-4\""));
assert!(json.contains("\"temperature\":0.7"));
assert!(!json.contains("frequency_penalty")); // Optional fields should be omitted
}
#[test]
fn test_chat_completion_request_deserialization() {
let json = r#"{
"model": "gpt-4",
"messages": [{
"role": "user",
"content": "Hello"
}],
"temperature": 0.7
}"#;
let request: ChatCompletionRequest = serde_json::from_str(json).unwrap();
assert_eq!(request.model, "gpt-4");
assert_eq!(request.messages[0].role, "user");
assert_eq!(request.temperature, Some(0.7));
assert_eq!(request.frequency_penalty, None);
}
#[test]
fn test_tool_choice_serialization() {
let none_choice = ToolChoice::None("none".to_string());
let json = serde_json::to_string(&none_choice).unwrap();
assert_eq!(json, "\"none\"");
let function_choice = ToolChoice::Function(FunctionCall {
call_type: "function".to_string(),
function: Function {
name: "test".to_string(),
description: Some("test desc".to_string()),
parameters: serde_json::json!({}),
},
});
let json = serde_json::to_string(&function_choice).unwrap();
assert!(json.contains("\"type\":\"function\""));
}
#[test]
fn test_chat_completion_response_deserialization() {
let json = r#"{
"id": "test-id",
"object": "chat.completion",
"created": 1234567890,
"model": "gpt-4",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello there!"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30
}
}"#;
let response: ChatCompletionResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.id, "test-id");
assert_eq!(response.choices[0].message.content, "Hello there!");
assert_eq!(response.usage.total_tokens, 30);
}
}

View File

@ -2,6 +2,7 @@ mod anthropic;
pub mod embedding_router;
mod hugging_face;
pub mod langfuse;
pub mod litellm;
pub mod llm_router;
pub mod ollama;
pub mod openai;

View File

@ -6,6 +6,7 @@ pub mod prompts;
pub mod query_engine;
pub mod search_engine;
pub mod security;
pub mod serde_helpers;
pub mod sharing;
pub mod tools;
pub mod user;
pub mod serde_helpers;

View File