diff --git a/.env.example b/.env.example index 3b96fc7bf..c97ee3d28 100644 --- a/.env.example +++ b/.env.example @@ -15,6 +15,8 @@ BUSTER_WH_TOKEN="buster-wh-token" EMBEDDING_PROVIDER="ollama" EMBEDDING_MODEL="mxbai-embed-large" COHERE_API_KEY="" +LLM_API_KEY="test-key" +LLM_BASE_URL="http://localhost:8000" diff --git a/api/src/utils/clients/ai/litellm/client.rs b/api/src/utils/clients/ai/litellm/client.rs index b7ffbc066..b4dd36ed5 100644 --- a/api/src/utils/clients/ai/litellm/client.rs +++ b/api/src/utils/clients/ai/litellm/client.rs @@ -2,6 +2,7 @@ use reqwest::{Client, header}; use futures_util::StreamExt; use tokio::sync::mpsc; use anyhow::Result; +use std::env; use super::types::*; @@ -12,7 +13,15 @@ pub struct LiteLLMClient { } impl LiteLLMClient { - pub fn new(api_key: String, base_url: Option) -> Self { + pub fn new(api_key: Option, base_url: Option) -> Self { + let api_key = api_key + .or_else(|| env::var("LLM_API_KEY").ok()) + .expect("LLM_API_KEY must be provided either through parameter or environment variable"); + + let base_url = base_url + .or_else(|| env::var("LLM_BASE_URL").ok()) + .unwrap_or_else(|| "http://localhost:8000".to_string()); + let mut headers = header::HeaderMap::new(); headers.insert( "Authorization", @@ -35,7 +44,7 @@ impl LiteLLMClient { Self { client, api_key, - base_url: base_url.unwrap_or_else(|| "http://localhost:8000".to_string()), + base_url, } } @@ -110,6 +119,7 @@ mod tests { use mockito; use tokio::time::timeout; use std::time::Duration; + use std::env; fn create_test_message() -> Message { Message { @@ -166,7 +176,7 @@ mod tests { .create(); let client = LiteLLMClient::new( - "test-key".to_string(), + Some("test-key".to_string()), Some(server.url()), ); @@ -194,7 +204,7 @@ mod tests { .create(); let client = LiteLLMClient::new( - "test-key".to_string(), + Some("test-key".to_string()), Some(server.url()), ); @@ -226,7 +236,7 @@ mod tests { .create(); let client = LiteLLMClient::new( - "test-key".to_string(), + Some("test-key".to_string()), Some(server.url()), ); @@ -247,15 +257,26 @@ mod tests { } #[test] - fn test_client_initialization() { - let api_key = "test-key".to_string(); - let base_url = "http://custom.url".to_string(); + fn test_client_initialization_with_env_vars() { + let test_api_key = "test-env-key"; + let test_base_url = "http://test-env-url"; - let client = LiteLLMClient::new(api_key.clone(), Some(base_url.clone())); - assert_eq!(client.api_key, api_key); - assert_eq!(client.base_url, base_url); - - let client = LiteLLMClient::new(api_key.clone(), None); - assert_eq!(client.base_url, "http://localhost:8000"); + env::set_var("LLM_API_KEY", test_api_key); + env::set_var("LLM_BASE_URL", test_base_url); + + // Test with no parameters (should use env vars) + let client = LiteLLMClient::new(None, None); + assert_eq!(client.api_key, test_api_key); + assert_eq!(client.base_url, test_base_url); + + // Test with parameters (should override env vars) + let override_key = "override-key"; + let override_url = "http://override-url"; + let client = LiteLLMClient::new(Some(override_key.to_string()), Some(override_url.to_string())); + assert_eq!(client.api_key, override_key); + assert_eq!(client.base_url, override_url); + + env::remove_var("LLM_API_KEY"); + env::remove_var("LLM_BASE_URL"); } } \ No newline at end of file