mirror of https://github.com/buster-so/buster.git
chore: Update API project configuration and add new modules
- Added dev dependencies for Mockito and Tokio async testing - Updated .cursorrules with async test requirement - Expanded utils modules with new serde_helpers and tools - Added new AI client module for LiteLLM
This commit is contained in:
parent
39dfc053fc
commit
087867032d
|
@ -1 +1,2 @@
|
|||
this is an axum web server
|
||||
- this is an axum web server
|
||||
- all tests need to be tokio async tests
|
|
@ -92,5 +92,9 @@ diesel_migrations = "2.0.0"
|
|||
serde_yaml = "0.9.34"
|
||||
html-escape = "0.2.13"
|
||||
|
||||
[dev-dependencies]
|
||||
mockito = "1.2"
|
||||
tokio = { version = "1.0", features = ["full", "test-util"] }
|
||||
|
||||
[profile.release]
|
||||
debug = false
|
||||
|
|
|
@ -0,0 +1,248 @@
|
|||
use reqwest::{Client, header};
|
||||
use futures_util::StreamExt;
|
||||
use tokio::sync::mpsc;
|
||||
use anyhow::Result;
|
||||
|
||||
use super::types::*;
|
||||
|
||||
pub struct LiteLLMClient {
|
||||
client: Client,
|
||||
pub(crate) api_key: String,
|
||||
pub(crate) base_url: String,
|
||||
}
|
||||
|
||||
impl LiteLLMClient {
|
||||
pub fn new(api_key: String, base_url: Option<String>) -> Self {
|
||||
let mut headers = header::HeaderMap::new();
|
||||
headers.insert(
|
||||
"Authorization",
|
||||
header::HeaderValue::from_str(&format!("Bearer {}", api_key)).unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"Content-Type",
|
||||
header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
headers.insert(
|
||||
"Accept",
|
||||
header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
let client = Client::builder()
|
||||
.default_headers(headers)
|
||||
.build()
|
||||
.expect("Failed to create HTTP client");
|
||||
|
||||
Self {
|
||||
client,
|
||||
api_key,
|
||||
base_url: base_url.unwrap_or_else(|| "http://localhost:8000".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn chat_completion(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse> {
|
||||
let url = format!("{}/chat/completions", self.base_url);
|
||||
let response = self.client
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?
|
||||
.json::<ChatCompletionResponse>()
|
||||
.await?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub async fn stream_chat_completion(
|
||||
&self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<mpsc::Receiver<Result<ChatCompletionChunk>>> {
|
||||
let url = format!("{}/chat/completions", self.base_url);
|
||||
let mut stream = self.client
|
||||
.post(&url)
|
||||
.json(&ChatCompletionRequest {
|
||||
stream: Some(true),
|
||||
..request
|
||||
})
|
||||
.send()
|
||||
.await?
|
||||
.bytes_stream();
|
||||
|
||||
let (tx, rx) = mpsc::channel(100);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut buffer = String::new();
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
let chunk_str = String::from_utf8_lossy(&chunk);
|
||||
buffer.push_str(&chunk_str);
|
||||
|
||||
while let Some(pos) = buffer.find("\n\n") {
|
||||
let line = buffer[..pos].trim().to_string();
|
||||
buffer = buffer[pos + 2..].to_string();
|
||||
|
||||
if line.starts_with("data: ") {
|
||||
let data = &line["data: ".len()..];
|
||||
if data == "[DONE]" {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Ok(response) = serde_json::from_str::<ChatCompletionChunk>(data) {
|
||||
let _ = tx.send(Ok(response)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(anyhow::Error::from(e))).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(rx)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use mockito;
|
||||
use tokio::time::timeout;
|
||||
use std::time::Duration;
|
||||
|
||||
fn create_test_message() -> Message {
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "Hello".to_string(),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_test_request() -> ChatCompletionRequest {
|
||||
ChatCompletionRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![create_test_message()],
|
||||
temperature: Some(0.7),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_chat_completion_success() {
|
||||
let mut server = mockito::Server::new();
|
||||
|
||||
let mock = server.mock("POST", "/chat/completions")
|
||||
.match_header("content-type", "application/json")
|
||||
.match_header("Authorization", "Bearer test-key")
|
||||
.with_status(200)
|
||||
.with_header("content-type", "application/json")
|
||||
.with_body(r#"{
|
||||
"id": "test-id",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "gpt-4",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello there!"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 30
|
||||
}
|
||||
}"#)
|
||||
.create();
|
||||
|
||||
let client = LiteLLMClient::new(
|
||||
"test-key".to_string(),
|
||||
Some(server.url()),
|
||||
);
|
||||
|
||||
let response = client.chat_completion(create_test_request()).await.unwrap();
|
||||
assert_eq!(response.id, "test-id");
|
||||
assert_eq!(response.choices[0].message.content, "Hello there!");
|
||||
|
||||
mock.assert();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_chat_completion_error() {
|
||||
let mut server = mockito::Server::new();
|
||||
|
||||
let mock = server.mock("POST", "/chat/completions")
|
||||
.match_header("content-type", "application/json")
|
||||
.with_status(400)
|
||||
.with_body(r#"{"error": "Invalid request"}"#)
|
||||
.create();
|
||||
|
||||
let client = LiteLLMClient::new(
|
||||
"test-key".to_string(),
|
||||
Some(server.url()),
|
||||
);
|
||||
|
||||
let result = client.chat_completion(create_test_request()).await;
|
||||
assert!(result.is_err());
|
||||
|
||||
mock.assert();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_stream_chat_completion() {
|
||||
let mut server = mockito::Server::new();
|
||||
|
||||
let mock = server.mock("POST", "/chat/completions")
|
||||
.match_header("content-type", "application/json")
|
||||
.match_body(mockito::Matcher::JsonString(r#"{"stream":true}"#.to_string()))
|
||||
.with_status(200)
|
||||
.with_header("content-type", "text/event-stream")
|
||||
.with_body(
|
||||
"data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n\
|
||||
data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" world\"},\"finish_reason\":null}]}\n\n\
|
||||
data: [DONE]\n\n"
|
||||
)
|
||||
.create();
|
||||
|
||||
let client = LiteLLMClient::new(
|
||||
"test-key".to_string(),
|
||||
Some(server.url()),
|
||||
);
|
||||
|
||||
let mut request = create_test_request();
|
||||
request.stream = Some(true);
|
||||
|
||||
let mut stream = client.stream_chat_completion(request).await.unwrap();
|
||||
|
||||
let mut chunks = Vec::new();
|
||||
while let Ok(Some(chunk)) = timeout(Duration::from_secs(1), stream.recv()).await {
|
||||
if let Ok(chunk) = chunk {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(chunks.len(), 2);
|
||||
assert_eq!(chunks[0].choices[0].delta.content, Some("Hello".to_string()));
|
||||
assert_eq!(chunks[1].choices[0].delta.content, Some(" world".to_string()));
|
||||
|
||||
mock.assert();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_client_initialization() {
|
||||
let api_key = "test-key".to_string();
|
||||
let base_url = "http://custom.url".to_string();
|
||||
|
||||
let client = LiteLLMClient::new(api_key.clone(), Some(base_url.clone()));
|
||||
assert_eq!(client.api_key, api_key);
|
||||
assert_eq!(client.base_url, base_url);
|
||||
|
||||
let client = LiteLLMClient::new(api_key.clone(), None);
|
||||
assert_eq!(client.base_url, "http://localhost:8000");
|
||||
}
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
mod types;
|
||||
mod client;
|
||||
|
||||
pub use types::*;
|
||||
pub use client::LiteLLMClient;
|
|
@ -0,0 +1,250 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<Message>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub frequency_penalty: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub logit_bias: Option<HashMap<String, i32>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub n: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub presence_penalty: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub response_format: Option<ResponseFormat>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub seed: Option<i64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stream: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<Tool>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<ToolChoice>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub user: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for ChatCompletionRequest {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
model: String::new(),
|
||||
messages: Vec::new(),
|
||||
frequency_penalty: None,
|
||||
logit_bias: None,
|
||||
max_tokens: None,
|
||||
n: None,
|
||||
presence_penalty: None,
|
||||
response_format: None,
|
||||
seed: None,
|
||||
stop: None,
|
||||
stream: None,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
user: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ResponseFormat {
|
||||
#[serde(rename = "type")]
|
||||
pub format_type: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Tool {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String,
|
||||
pub function: Function,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Function {
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ToolChoice {
|
||||
None(String),
|
||||
Auto(String),
|
||||
Function(FunctionCall),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct FunctionCall {
|
||||
#[serde(rename = "type")]
|
||||
pub call_type: String,
|
||||
pub function: Function,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub call_type: String,
|
||||
pub function: FunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: i64,
|
||||
pub model: String,
|
||||
pub choices: Vec<Choice>,
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Choice {
|
||||
pub index: i32,
|
||||
pub message: Message,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: i32,
|
||||
pub completion_tokens: i32,
|
||||
pub total_tokens: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionChunk {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: i64,
|
||||
pub model: String,
|
||||
pub choices: Vec<StreamChoice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct StreamChoice {
|
||||
pub index: i32,
|
||||
pub delta: DeltaMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct DeltaMessage {
|
||||
pub role: Option<String>,
|
||||
pub content: Option<String>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_chat_completion_request_serialization() {
|
||||
let request = ChatCompletionRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: "Hello".to_string(),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}],
|
||||
temperature: Some(0.7),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&request).unwrap();
|
||||
assert!(json.contains("\"model\":\"gpt-4\""));
|
||||
assert!(json.contains("\"temperature\":0.7"));
|
||||
assert!(!json.contains("frequency_penalty")); // Optional fields should be omitted
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_completion_request_deserialization() {
|
||||
let json = r#"{
|
||||
"model": "gpt-4",
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
}],
|
||||
"temperature": 0.7
|
||||
}"#;
|
||||
|
||||
let request: ChatCompletionRequest = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(request.model, "gpt-4");
|
||||
assert_eq!(request.messages[0].role, "user");
|
||||
assert_eq!(request.temperature, Some(0.7));
|
||||
assert_eq!(request.frequency_penalty, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_choice_serialization() {
|
||||
let none_choice = ToolChoice::None("none".to_string());
|
||||
let json = serde_json::to_string(&none_choice).unwrap();
|
||||
assert_eq!(json, "\"none\"");
|
||||
|
||||
let function_choice = ToolChoice::Function(FunctionCall {
|
||||
call_type: "function".to_string(),
|
||||
function: Function {
|
||||
name: "test".to_string(),
|
||||
description: Some("test desc".to_string()),
|
||||
parameters: serde_json::json!({}),
|
||||
},
|
||||
});
|
||||
let json = serde_json::to_string(&function_choice).unwrap();
|
||||
assert!(json.contains("\"type\":\"function\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_completion_response_deserialization() {
|
||||
let json = r#"{
|
||||
"id": "test-id",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "gpt-4",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello there!"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 30
|
||||
}
|
||||
}"#;
|
||||
|
||||
let response: ChatCompletionResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(response.id, "test-id");
|
||||
assert_eq!(response.choices[0].message.content, "Hello there!");
|
||||
assert_eq!(response.usage.total_tokens, 30);
|
||||
}
|
||||
}
|
|
@ -2,6 +2,7 @@ mod anthropic;
|
|||
pub mod embedding_router;
|
||||
mod hugging_face;
|
||||
pub mod langfuse;
|
||||
pub mod litellm;
|
||||
pub mod llm_router;
|
||||
pub mod ollama;
|
||||
pub mod openai;
|
||||
|
|
|
@ -6,6 +6,7 @@ pub mod prompts;
|
|||
pub mod query_engine;
|
||||
pub mod search_engine;
|
||||
pub mod security;
|
||||
pub mod serde_helpers;
|
||||
pub mod sharing;
|
||||
pub mod tools;
|
||||
pub mod user;
|
||||
pub mod serde_helpers;
|
Loading…
Reference in New Issue