diff --git a/.env.example b/.env.example index c9853eccc..dda5314d8 100644 --- a/.env.example +++ b/.env.example @@ -15,7 +15,7 @@ RERANK_API_KEY="your_rerank_api_key" RERANK_MODEL="rerank-v3.5" RERANK_BASE_URL="https://api.cohere.com/v2/rerank" LLM_API_KEY="your_llm_api_key" -LLM_BASE_URL="http://buster-litellm:4000" +LLM_BASE_URL="http://buster-litellm:4001" # WEB VARS NEXT_PUBLIC_API_URL="http://localhost:3001" # External URL for the API service (buster-api) diff --git a/cli/cli/src/commands/config_utils.rs b/cli/cli/src/commands/config_utils.rs index 24f7c84e5..1581b9a26 100644 --- a/cli/cli/src/commands/config_utils.rs +++ b/cli/cli/src/commands/config_utils.rs @@ -177,13 +177,49 @@ pub fn update_env_file( }) } +#[derive(Debug, Deserialize, Serialize, Clone, Default)] +pub struct ModelInfo { + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub mode: Option, // e.g., "embedding", "chat" + #[serde(skip_serializing_if = "Option::is_none")] + pub input_cost_per_token: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub output_cost_per_token: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub base_model: Option, // e.g., "gpt-3.5-turbo" + // For any other custom key-value pairs in model_info + #[serde(flatten, skip_serializing_if = "Option::is_none")] + pub extras: Option>, +} + #[derive(Debug, Deserialize, Serialize, Clone)] pub struct LiteLLMModelConfig { - pub model_name: String, - pub api_base: Option, - pub api_key: Option, + pub model_name: String, // Alias for the model, e.g., "my-gpt4" + + // Parameters for LiteLLM to connect to and use the model. + // This should be a YAML map including the actual model identifier, API key, base URL, etc. + // Example for OpenAI: + // litellm_params: + // model: "gpt-4-turbo" // or "openai/gpt-4-turbo" + // api_key: "sk-..." + // api_base: "https://api.openai.com/v1" + // Example for Ollama: + // litellm_params: + // model: "ollama/mistral" + // api_base: "http://localhost:11434" + pub litellm_params: serde_yaml::Value, + #[serde(skip_serializing_if = "Option::is_none")] - pub litellm_params: Option, + pub model_info: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub tpm: Option, // Tokens Per Minute + #[serde(skip_serializing_if = "Option::is_none")] + pub rpm: Option, // Requests Per Minute } #[derive(Debug, Deserialize, Serialize, Clone)] @@ -223,11 +259,30 @@ pub fn create_litellm_yaml( // Build model list let model_list: Vec = OPENAI_MODELS .iter() - .map(|model_name| LiteLLMModelConfig { - model_name: model_name.to_string(), - api_base: api_base.map(|s| s.to_string()), - api_key: Some(api_key.to_string()), - litellm_params: None, + .map(|model_name| { + let mut params_map = serde_yaml::Mapping::new(); + params_map.insert( + serde_yaml::Value::String("model".to_string()), + serde_yaml::Value::String(model_name.to_string()), + ); + params_map.insert( + serde_yaml::Value::String("api_key".to_string()), + serde_yaml::Value::String(api_key.to_string()), + ); + if let Some(base) = api_base { + params_map.insert( + serde_yaml::Value::String("api_base".to_string()), + serde_yaml::Value::String(base.to_string()), + ); + } + + LiteLLMModelConfig { + model_name: model_name.to_string(), + litellm_params: serde_yaml::Value::Mapping(params_map), + model_info: Some(ModelInfo::default()), // Or None if preferred + tpm: None, + rpm: None, + } }) .collect(); @@ -310,18 +365,67 @@ pub fn update_litellm_yaml( let mut found = false; for model_config in &mut config.model_list { if &model_config.model_name == model_name { - model_config.api_key = Some(api_key.to_string()); - model_config.api_base = api_base.map(|s| s.to_string()); + // Ensure litellm_params is a mutable mapping + if let serde_yaml::Value::Mapping(params_map) = &mut model_config.litellm_params { + params_map.insert( + serde_yaml::Value::String("api_key".to_string()), + serde_yaml::Value::String(api_key.to_string()), + ); + if let Some(base) = api_base { + params_map.insert( + serde_yaml::Value::String("api_base".to_string()), + serde_yaml::Value::String(base.to_string()), + ); + } else { + params_map.remove(&serde_yaml::Value::String("api_base".to_string())); + } + } else { + // This case should ideally not happen if params are always created as Mappings + // For robustness, one might recreate it: + let mut params_map = serde_yaml::Mapping::new(); + params_map.insert( + serde_yaml::Value::String("model".to_string()), + serde_yaml::Value::String(model_name.to_string()), + ); + params_map.insert( + serde_yaml::Value::String("api_key".to_string()), + serde_yaml::Value::String(api_key.to_string()), + ); + if let Some(base) = api_base { + params_map.insert( + serde_yaml::Value::String("api_base".to_string()), + serde_yaml::Value::String(base.to_string()), + ); + } + model_config.litellm_params = serde_yaml::Value::Mapping(params_map); + } found = true; break; } } if !found { + let mut params_map = serde_yaml::Mapping::new(); + params_map.insert( + serde_yaml::Value::String("model".to_string()), + serde_yaml::Value::String(model_name.to_string()), + ); + params_map.insert( + serde_yaml::Value::String("api_key".to_string()), + serde_yaml::Value::String(api_key.to_string()), + ); + if let Some(base) = api_base { + params_map.insert( + serde_yaml::Value::String("api_base".to_string()), + serde_yaml::Value::String(base.to_string()), + ); + } + config.model_list.push(LiteLLMModelConfig { model_name: model_name.to_string(), - api_base: api_base.map(|s| s.to_string()), - api_key: Some(api_key.to_string()), - litellm_params: None, + litellm_params: serde_yaml::Value::Mapping(params_map), + model_info: Some(ModelInfo::default()), // Or None + tpm: None, + rpm: None, }); } } diff --git a/docker-compose.yml b/docker-compose.yml index cda590f3a..17a7f10b7 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -71,8 +71,8 @@ services: image: ghcr.io/berriai/litellm:main-latest container_name: buster-litellm volumes: - - ./litellm_vertex_config.yaml:/litellm_vertex_config.yaml - command: ["--config", "/litellm_vertex_config.yaml", "--port", "4001"] + - ./litellm_config/config.yaml:/config.yaml + command: ["--config", "/config.yaml", "--port", "4001"] ports: - "4001:4001" healthcheck: