feat: Add LLM configuration support with environment variable fallback

- Updated `.env.example` with new LLM configuration variables
- Modified `LiteLLMClient` to support API key and base URL from environment variables
- Enhanced client initialization to use env vars with optional parameter overrides
- Added comprehensive test case for environment variable configuration
This commit is contained in:
dal 2025-01-25 15:32:02 -07:00
parent 085bc17a4d
commit 5235aa977d
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
2 changed files with 37 additions and 14 deletions

View File

@ -15,6 +15,8 @@ BUSTER_WH_TOKEN="buster-wh-token"
EMBEDDING_PROVIDER="ollama" EMBEDDING_PROVIDER="ollama"
EMBEDDING_MODEL="mxbai-embed-large" EMBEDDING_MODEL="mxbai-embed-large"
COHERE_API_KEY="" COHERE_API_KEY=""
LLM_API_KEY="test-key"
LLM_BASE_URL="http://localhost:8000"

View File

@ -2,6 +2,7 @@ use reqwest::{Client, header};
use futures_util::StreamExt; use futures_util::StreamExt;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use anyhow::Result; use anyhow::Result;
use std::env;
use super::types::*; use super::types::*;
@ -12,7 +13,15 @@ pub struct LiteLLMClient {
} }
impl LiteLLMClient { impl LiteLLMClient {
pub fn new(api_key: String, base_url: Option<String>) -> Self { pub fn new(api_key: Option<String>, base_url: Option<String>) -> Self {
let api_key = api_key
.or_else(|| env::var("LLM_API_KEY").ok())
.expect("LLM_API_KEY must be provided either through parameter or environment variable");
let base_url = base_url
.or_else(|| env::var("LLM_BASE_URL").ok())
.unwrap_or_else(|| "http://localhost:8000".to_string());
let mut headers = header::HeaderMap::new(); let mut headers = header::HeaderMap::new();
headers.insert( headers.insert(
"Authorization", "Authorization",
@ -35,7 +44,7 @@ impl LiteLLMClient {
Self { Self {
client, client,
api_key, api_key,
base_url: base_url.unwrap_or_else(|| "http://localhost:8000".to_string()), base_url,
} }
} }
@ -110,6 +119,7 @@ mod tests {
use mockito; use mockito;
use tokio::time::timeout; use tokio::time::timeout;
use std::time::Duration; use std::time::Duration;
use std::env;
fn create_test_message() -> Message { fn create_test_message() -> Message {
Message { Message {
@ -166,7 +176,7 @@ mod tests {
.create(); .create();
let client = LiteLLMClient::new( let client = LiteLLMClient::new(
"test-key".to_string(), Some("test-key".to_string()),
Some(server.url()), Some(server.url()),
); );
@ -194,7 +204,7 @@ mod tests {
.create(); .create();
let client = LiteLLMClient::new( let client = LiteLLMClient::new(
"test-key".to_string(), Some("test-key".to_string()),
Some(server.url()), Some(server.url()),
); );
@ -226,7 +236,7 @@ mod tests {
.create(); .create();
let client = LiteLLMClient::new( let client = LiteLLMClient::new(
"test-key".to_string(), Some("test-key".to_string()),
Some(server.url()), Some(server.url()),
); );
@ -247,15 +257,26 @@ mod tests {
} }
#[test] #[test]
fn test_client_initialization() { fn test_client_initialization_with_env_vars() {
let api_key = "test-key".to_string(); let test_api_key = "test-env-key";
let base_url = "http://custom.url".to_string(); let test_base_url = "http://test-env-url";
let client = LiteLLMClient::new(api_key.clone(), Some(base_url.clone())); env::set_var("LLM_API_KEY", test_api_key);
assert_eq!(client.api_key, api_key); env::set_var("LLM_BASE_URL", test_base_url);
assert_eq!(client.base_url, base_url);
let client = LiteLLMClient::new(api_key.clone(), None); // Test with no parameters (should use env vars)
assert_eq!(client.base_url, "http://localhost:8000"); let client = LiteLLMClient::new(None, None);
assert_eq!(client.api_key, test_api_key);
assert_eq!(client.base_url, test_base_url);
// Test with parameters (should override env vars)
let override_key = "override-key";
let override_url = "http://override-url";
let client = LiteLLMClient::new(Some(override_key.to_string()), Some(override_url.to_string()));
assert_eq!(client.api_key, override_key);
assert_eq!(client.base_url, override_url);
env::remove_var("LLM_API_KEY");
env::remove_var("LLM_BASE_URL");
} }
} }