oai and litellm config

This commit is contained in:
dal 2025-05-09 07:46:09 -06:00
parent 7da963d9bb
commit 64d97b5e17
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
15 changed files with 638 additions and 112 deletions

View File

@ -15,7 +15,7 @@ RERANK_API_KEY="your_rerank_api_key"
RERANK_MODEL="rerank-v3.5"
RERANK_BASE_URL="https://api.cohere.com/v2/rerank"
LLM_API_KEY="your_llm_api_key"
LLM_BASE_URL="https://api.openai.com/v1"
LLM_BASE_URL="http://buster-litellm:4000"
# WEB VARS
NEXT_PUBLIC_API_URL="http://localhost:3001" # External URL for the API service (buster-api)
@ -23,4 +23,4 @@ NEXT_PUBLIC_URL="http://localhost:3000" # External URL for the Web service
NEXT_PUBLIC_SUPABASE_URL="http://kong:8000" # External URL for Supabase (Kong proxy)
NEXT_PUBLIC_WS_URL="ws://localhost:3001"
NEXT_PUBLIC_SUPABASE_ANON_KEY="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.ey AgCiAgICAicm9sZSI6ICJhbm9uIiwKICAgICJpc3MiOiAic3VwYWJhc2UtZGVtbyIsCiAgICAiaWF0IjogMTY0MTc2OTIwMCwKICAgICJleHAiOiAxNzk5NTM1NjAwCn0.dc_X5iR_VP_qT0zsiyj_I_OZ2T9FtRU2BBNWN8Bu4GE"
NEXT_PRIVATE_SUPABASE_SERVICE_ROLE_KEY="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.ey AgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q"
NEXT_PRIVATE_SUPABASE_SERVICE_ROLE_KEY="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.ey AgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q"

View File

@ -680,7 +680,9 @@ impl Agent {
event_id: None, // Raindrop assigns this
timestamp: Some(chrono::Utc::now()),
};
if let Err(e) = client.track_events(vec![event]).await {}
if let Err(e) = client.track_events(vec![event]).await {
tracing::error!(agent_name = %user_id, session_id = %session_id, "Error tracking llm_request with Raindrop: {:?}", e);
}
});
}
// --- End Track Request ---
@ -931,7 +933,9 @@ impl Agent {
event_id: None, // Raindrop assigns this
timestamp: Some(chrono::Utc::now()),
};
if let Err(e) = client.track_events(vec![event]).await {}
if let Err(e) = client.track_events(vec![event]).await {
tracing::error!(agent_name = %user_id, session_id = %session_id, "Error tracking llm_response with Raindrop: {:?}", e);
}
});
}
// --- End Track Response ---

View File

@ -29,7 +29,7 @@ pub fn get_configuration(
// 2. Define the model for this mode (From original MODEL const)
let model =
if env::var("ENVIRONMENT").unwrap_or_else(|_| "development".to_string()) == "local" {
"o4-mini".to_string()
"gpt-4.1-mini".to_string()
} else {
"gemini-2.0-flash-001".to_string()
};

View File

@ -981,7 +981,7 @@ async fn llm_filter_helper(
let llm_client = LiteLLMClient::new(None, None);
let model = if env::var("ENVIRONMENT").unwrap_or_else(|_| "development".to_string()) == "local" {
"gpt-4.1-mini".to_string()
"gpt-4.1-nano".to_string()
} else {
"gemini-2.0-flash-001".to_string()
};
@ -998,6 +998,7 @@ async fn llm_filter_helper(
type_: "json_object".to_string(),
json_schema: None,
}),
store: Some(true),
metadata: Some(Metadata {
generation_name: format!("filter_data_catalog_{}_agent", generation_name_suffix),
user_id: user_id.to_string(),

View File

@ -64,7 +64,7 @@ Example Output for the above plan: `["Create line chart visualization 'Daily Tra
);
let model = if env::var("ENVIRONMENT").unwrap_or_else(|_| "development".to_string()) == "local" {
"gpt-4.1-mini".to_string()
"gpt-4.1-nano".to_string()
} else {
"gemini-2.0-flash-001".to_string()
};
@ -74,6 +74,7 @@ Example Output for the above plan: `["Create line chart visualization 'Daily Tra
messages: vec![AgentMessage::User { id: None, content: prompt, name: None }],
stream: Some(false),
response_format: Some(ResponseFormat { type_: "json_object".to_string(), json_schema: None }),
store: Some(true),
metadata: Some(Metadata {
generation_name: "generate_todos_from_plan".to_string(),
user_id: user_id.to_string(),

View File

@ -70,7 +70,7 @@ pub async fn generate_conversation_title(
let llm_client = LiteLLMClient::new(None, None);
let model = if env::var("ENVIRONMENT").unwrap_or_else(|_| "development".to_string()) == "local" {
"gpt-4.1-mini".to_string()
"gpt-4.1-nano".to_string()
} else {
"gemini-2.0-flash-001".to_string()
};
@ -83,6 +83,7 @@ pub async fn generate_conversation_title(
content: prompt,
name: None,
}],
store: Some(true),
metadata: Some(Metadata {
generation_name: "conversation_title".to_string(),
user_id: user_id.to_string(),

View File

@ -2714,7 +2714,7 @@ pub async fn generate_conversation_title(
let llm_client = LiteLLMClient::new(None, None);
let model = if env::var("ENVIRONMENT").unwrap_or_else(|_| "development".to_string()) == "local" {
"gpt-4.1-mini".to_string()
"gpt-4.1-nano".to_string()
} else {
"gemini-2.0-flash-001".to_string()
};
@ -2727,6 +2727,7 @@ pub async fn generate_conversation_title(
content: prompt,
name: None,
}],
store: Some(true),
metadata: Some(Metadata {
generation_name: "conversation_title".to_string(),
user_id: user_id.to_string(),

View File

@ -7,12 +7,13 @@ edition = "2021"
# Use workspace dependencies
anyhow = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tokio = { workspace = true }
futures-util = { workspace = true }
serde_json = "1.0"
tokio = { version = "1", features = ["full"] }
futures-util = "0.3"
reqwest = { workspace = true }
dotenv = { workspace = true }
once_cell = { workspace = true }
once_cell = "1.19.0"
tracing = "0.1"
[dev-dependencies]
mockito = { workspace = true }

View File

@ -4,6 +4,7 @@ use reqwest::{header, Client};
use std::env;
use tokio::sync::mpsc;
use once_cell::sync::Lazy;
use tracing;
use super::types::*;
@ -29,9 +30,20 @@ impl LiteLLMClient {
}
pub fn new(api_key: Option<String>, base_url: Option<String>) -> Self {
let api_key = api_key.or_else(|| env::var("LLM_API_KEY").ok()).expect(
"LLM_API_KEY must be provided either through parameter or environment variable",
);
// Check for API key - when using LiteLLM with a config file, the API key is typically
// already in the config file, so we just need a dummy value here for the client
let api_key = api_key
.or_else(|| env::var("LLM_API_KEY").ok())
.unwrap_or_else(|| {
// If we have a LiteLLM config path, we can use a placeholder API key
// since auth will be handled by the LiteLLM server using the config
if env::var("LITELLM_CONFIG_PATH").is_ok() {
Self::debug_log("Using LiteLLM config from environment");
"dummy-key-not-used".to_string()
} else {
panic!("LLM_API_KEY must be provided either through parameter, environment variable, or LITELLM_CONFIG_PATH must be set");
}
});
let base_url = base_url
.or_else(|| env::var("LLM_BASE_URL").ok())
@ -81,16 +93,30 @@ impl LiteLLMClient {
.post(&url)
.json(&request)
.send()
.await?;
.await
.map_err(|e| {
tracing::error!("Failed to send chat completion request: {:?}", e);
anyhow::Error::from(e)
})?;
// Get the raw response text
let response_text = response.text().await?;
let response_text = response.text().await.map_err(|e| {
tracing::error!("Failed to read chat completion response text: {:?}", e);
anyhow::Error::from(e)
})?;
if *DEBUG_ENABLED {
Self::debug_log(&format!("Raw response payload: {}", response_text));
}
// Parse the response text into the expected type
let response: ChatCompletionResponse = serde_json::from_str(&response_text)?;
let response: ChatCompletionResponse = serde_json::from_str(&response_text).map_err(|e| {
tracing::error!(
"Failed to parse chat completion response. Text: {}, Error: {:?}",
response_text,
e
);
anyhow::Error::from(e)
})?;
// Log tool calls if present and debug is enabled
if *DEBUG_ENABLED {
@ -141,7 +167,11 @@ impl LiteLLMClient {
..request
})
.send()
.await?
.await
.map_err(|e| {
tracing::error!("Failed to send stream chat completion request: {:?}", e);
anyhow::Error::from(e)
})?
.bytes_stream();
let (tx, rx) = mpsc::channel(100);
@ -174,45 +204,55 @@ impl LiteLLMClient {
if debug_enabled {
Self::debug_log("Stream completed with [DONE] signal");
}
break;
return;
}
if let Ok(response) =
serde_json::from_str::<ChatCompletionChunk>(data)
{
// Log tool calls if present and debug is enabled
if debug_enabled {
if let Some(tool_calls) = &response.choices[0].delta.tool_calls
{
Self::debug_log("Tool calls in stream chunk:");
for tool_call in tool_calls {
if let (Some(id), Some(function)) =
(tool_call.id.clone(), tool_call.function.clone())
{
Self::debug_log(&format!("Tool Call ID: {}", id));
if let Some(name) = function.name {
Self::debug_log(&format!("Tool Name: {}", name));
}
if let Some(arguments) = function.arguments {
Self::debug_log(&format!(
"Tool Arguments: {}",
arguments
));
match serde_json::from_str::<ChatCompletionChunk>(data) {
Ok(response) => {
// Log tool calls if present and debug is enabled
if debug_enabled {
if let Some(tool_calls) = &response.choices[0].delta.tool_calls
{
Self::debug_log("Tool calls in stream chunk:");
for tool_call in tool_calls {
if let (Some(id), Some(function)) =
(tool_call.id.clone(), tool_call.function.clone())
{
Self::debug_log(&format!("Tool Call ID: {}", id));
if let Some(name) = function.name {
Self::debug_log(&format!("Tool Name: {}", name));
}
if let Some(arguments) = function.arguments {
Self::debug_log(&format!(
"Tool Arguments: {}",
arguments
));
}
}
}
}
Self::debug_log(&format!("Parsed stream chunk: {:?}", response));
}
// Use try_send instead of send to avoid blocking
if tx.try_send(Ok(response)).is_err() {
// If the channel is full, log it but continue processing
if debug_enabled {
Self::debug_log("Warning: Channel full, receiver not keeping up");
}
}
Self::debug_log(&format!("Parsed stream chunk: {:?}", response));
}
// Use try_send instead of send to avoid blocking
if tx.try_send(Ok(response)).is_err() {
// If the channel is full, log it but continue processing
Err(e) => {
if debug_enabled {
Self::debug_log("Warning: Channel full, receiver not keeping up");
Self::debug_log(&format!("Error in stream processing: {:?}", e));
}
tracing::error!("Error receiving chunk from stream: {:?}", e);
// Use try_send to avoid blocking
let _ = tx.try_send(Err(anyhow::Error::from(e)));
}
}
} else if !line.is_empty() {
tracing::warn!("Received unexpected line in stream: {}", line);
}
}
}
@ -220,6 +260,7 @@ impl LiteLLMClient {
if debug_enabled {
Self::debug_log(&format!("Error in stream processing: {:?}", e));
}
tracing::error!("Error receiving chunk from stream: {:?}", e);
// Use try_send to avoid blocking
let _ = tx.try_send(Err(anyhow::Error::from(e)));
}

View File

@ -1,8 +1,20 @@
dev:
docker stop buster-redis-make || true && \
docker rm buster-redis-make || true && \
cd .. && docker run -d --name buster-redis-make -p 6379:6379 -v buster_redis_data:/data redis && cd api && \
supabase start
@echo "Checking services for dev target..."
@if ! (docker ps -q -f name=buster-redis-make -f status=running > /dev/null && \
docker ps -q -f name=supabase_db -f status=running > /dev/null && \
docker ps -q -f name=supabase_kong -f status=running > /dev/null); then \
echo "One or more dev services (Redis, Supabase DB, Supabase Kong) not running. Restarting all..."; \
docker stop buster-redis-make || true; \
docker rm buster-redis-make || true; \
supabase stop || true; \
echo "Starting Redis..."; \
(cd .. && docker run -d --name buster-redis-make -p 6379:6379 -v buster_redis_data:/data redis && cd api); \
echo "Starting Supabase..."; \
supabase start; \
echo "Services restarted."; \
else \
echo "Dev services (Redis and Supabase) already running."; \
fi
supabase db reset
export DATABASE_URL=postgres://postgres:postgres@127.0.0.1:54322/postgres && \
diesel migration run && \
@ -28,9 +40,16 @@ stop:
supabase stop
fast:
docker stop buster-redis-make || true && \
docker rm buster-redis-make || true && \
cd .. && docker run -d --name buster-redis-make -p 6379:6379 -v buster_redis_data:/data redis && cd api && \
@echo "Checking Redis for fast target..."
@if ! docker ps -q -f name=buster-redis-make -f status=running > /dev/null; then \
echo "Redis container 'buster-redis-make' not running. Starting it..."; \
docker stop buster-redis-make || true; \
docker rm buster-redis-make || true; \
(cd .. && docker run -d --name buster-redis-make -p 6379:6379 -v buster_redis_data:/data redis && cd api); \
echo "Redis started."; \
else \
echo "Redis container 'buster-redis-make' already running."; \
fi
export RUST_LOG=debug && \
export CARGO_INCREMENTAL=1 && \
nice cargo watch -C server -x run

View File

@ -323,7 +323,7 @@ async fn filter_datasets_with_llm(
let model =
if env::var("ENVIRONMENT").unwrap_or_else(|_| "development".to_string()) == "local" {
"gpt-4.1-mini".to_string()
"gpt-4.1-nano".to_string()
} else {
"gemini-2.0-flash-001".to_string()
};
@ -341,6 +341,7 @@ async fn filter_datasets_with_llm(
type_: "json_object".to_string(),
json_schema: None,
}),
store: Some(true),
metadata: Some(Metadata {
generation_name: "filter_data_catalog".to_string(),
user_id: user.id.to_string(),

View File

@ -155,7 +155,8 @@ pub async fn manage_settings_interactive() -> Result<(), BusterError> {
final_rerank_api_key.as_deref(),
final_rerank_model.as_deref(),
final_rerank_base_url.as_deref(),
current_llm_base_url.as_deref().or(Some(&llm_base_url_default)) // Ensure LLM_BASE_URL is present
current_llm_base_url.as_deref().or(Some(&llm_base_url_default)), // Ensure LLM_BASE_URL is present
None,
)?;
println!("Configuration saved to {}.", target_dotenv_path.display());

View File

@ -1,23 +1,33 @@
use crate::error::BusterError;
use dirs;
use serde::{Deserialize, Serialize};
use serde_yaml;
use std::fs;
use std::io::{self, Write};
use std::path::{Path, PathBuf};
// Moved from run.rs
pub fn prompt_for_input(prompt_message: &str, default_value: Option<&str>, is_sensitive: bool) -> Result<String, BusterError> {
pub fn prompt_for_input(
prompt_message: &str,
default_value: Option<&str>,
is_sensitive: bool,
) -> Result<String, BusterError> {
if let Some(def_val) = default_value {
print!("{} (default: {}): ", prompt_message, def_val);
} else {
print!("{}: ", prompt_message);
}
io::stdout().flush().map_err(|e| BusterError::CommandError(format!("Failed to flush stdout: {}", e)))?;
io::stdout()
.flush()
.map_err(|e| BusterError::CommandError(format!("Failed to flush stdout: {}", e)))?;
let mut input = String::new();
// Simple masking for sensitive input is complex in raw terminal io without extra crates.
// For a real CLI, rpassword or similar would be used.
// Here, we just read the line.
io::stdin().read_line(&mut input).map_err(|e| BusterError::CommandError(format!("Failed to read line: {}", e)))?;
io::stdin()
.read_line(&mut input)
.map_err(|e| BusterError::CommandError(format!("Failed to read line: {}", e)))?;
let trimmed_input = input.trim().to_string();
if trimmed_input.is_empty() {
@ -38,22 +48,44 @@ pub fn get_app_base_dir() -> Result<PathBuf, BusterError> {
.ok_or_else(|| BusterError::CommandError("Failed to get home directory.".to_string()))
}
pub fn get_cached_value(app_base_dir: &Path, cache_file_name: &str) -> Result<Option<String>, BusterError> {
pub fn get_cached_value(
app_base_dir: &Path,
cache_file_name: &str,
) -> Result<Option<String>, BusterError> {
let cache_file_path = app_base_dir.join(cache_file_name);
if cache_file_path.exists() {
fs::read_to_string(cache_file_path)
.map(|val| Some(val.trim().to_string()))
.map_err(|e| BusterError::CommandError(format!("Failed to read cached file {}: {}", cache_file_name, e)))
.map_err(|e| {
BusterError::CommandError(format!(
"Failed to read cached file {}: {}",
cache_file_name, e
))
})
} else {
Ok(None)
}
}
pub fn cache_value(app_base_dir: &Path, cache_file_name: &str, value: &str) -> Result<(), BusterError> {
pub fn cache_value(
app_base_dir: &Path,
cache_file_name: &str,
value: &str,
) -> Result<(), BusterError> {
let cache_file_path = app_base_dir.join(cache_file_name);
fs::create_dir_all(app_base_dir).map_err(|e| BusterError::CommandError(format!("Failed to create app base dir {}: {}", app_base_dir.display(), e)))?;
fs::write(cache_file_path, value)
.map_err(|e| BusterError::CommandError(format!("Failed to cache value to {}: {}", cache_file_name, e)))
fs::create_dir_all(app_base_dir).map_err(|e| {
BusterError::CommandError(format!(
"Failed to create app base dir {}: {}",
app_base_dir.display(),
e
))
})?;
fs::write(cache_file_path, value).map_err(|e| {
BusterError::CommandError(format!(
"Failed to cache value to {}: {}",
cache_file_name, e
))
})
}
pub fn update_env_file(
@ -62,7 +94,8 @@ pub fn update_env_file(
rerank_api_key: Option<&str>,
rerank_model: Option<&str>,
rerank_base_url: Option<&str>,
llm_base_url: Option<&str> // Added for completeness, though not prompted by user yet
llm_base_url: Option<&str>, // Added for completeness, though not prompted by user yet
litellm_config_path: Option<&str>, // Added for litellm config path
) -> Result<(), BusterError> {
let mut new_env_lines: Vec<String> = Vec::new();
let mut llm_key_updated = false;
@ -70,10 +103,15 @@ pub fn update_env_file(
let mut rerank_model_updated = false;
let mut rerank_base_updated = false;
let mut llm_base_updated = false;
let mut litellm_config_updated = false;
if target_dotenv_path.exists() {
let env_content = fs::read_to_string(target_dotenv_path).map_err(|e| {
BusterError::CommandError(format!("Failed to read .env file at {}: {}", target_dotenv_path.display(), e))
BusterError::CommandError(format!(
"Failed to read .env file at {}: {}",
target_dotenv_path.display(),
e
))
})?;
for line in env_content.lines() {
@ -92,6 +130,12 @@ pub fn update_env_file(
} else if line.starts_with("LLM_BASE_URL=") && llm_base_url.is_some() {
new_env_lines.push(format!("LLM_BASE_URL=\"{}\"", llm_base_url.unwrap()));
llm_base_updated = true;
} else if line.starts_with("LITELLM_CONFIG_PATH=") && litellm_config_path.is_some() {
new_env_lines.push(format!(
"LITELLM_CONFIG_PATH=\"{}\"",
litellm_config_path.unwrap()
));
litellm_config_updated = true;
} else {
new_env_lines.push(line.to_string());
}
@ -117,29 +161,288 @@ pub fn update_env_file(
// Ensure default LLM_BASE_URL if .env is being created from scratch and no override provided
new_env_lines.push("LLM_BASE_URL=\"https://api.openai.com/v1\"".to_string());
}
if !litellm_config_updated && litellm_config_path.is_some() {
new_env_lines.push(format!(
"LITELLM_CONFIG_PATH=\"{}\"",
litellm_config_path.unwrap()
));
}
fs::write(target_dotenv_path, new_env_lines.join("\n")).map_err(|e| {
BusterError::CommandError(format!("Failed to write updated .env file to {}: {}", target_dotenv_path.display(), e))
BusterError::CommandError(format!(
"Failed to write updated .env file to {}: {}",
target_dotenv_path.display(),
e
))
})
}
pub fn prompt_and_manage_openai_api_key(app_base_dir: &Path, force_prompt: bool) -> Result<String, BusterError> {
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct LiteLLMModelConfig {
pub model_name: String,
pub api_base: Option<String>,
pub api_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub litellm_params: Option<serde_yaml::Value>,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct LiteLLMConfig {
pub model_list: Vec<LiteLLMModelConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub environment_variables: Option<serde_yaml::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub general_settings: Option<serde_yaml::Value>,
}
const OPENAI_MODELS: [&str; 5] = [
"gpt-4.1",
"gpt-4.1-mini",
"gpt-4.1-nano",
"o4-mini",
"o3",
];
const DEFAULT_OPENAI_MODEL: &str = "gpt-4.1";
pub fn create_litellm_yaml(
app_base_dir: &Path,
api_key: &str,
api_base: Option<&str>,
) -> Result<PathBuf, BusterError> {
let litellm_config_dir = app_base_dir.join("litellm_config");
fs::create_dir_all(&litellm_config_dir).map_err(|e| {
BusterError::CommandError(format!(
"Failed to create litellm config directory at {}: {}",
litellm_config_dir.display(),
e
))
})?;
let config_path = litellm_config_dir.join("config.yaml");
// Build model list
let model_list: Vec<LiteLLMModelConfig> = OPENAI_MODELS
.iter()
.map(|model_name| LiteLLMModelConfig {
model_name: model_name.to_string(),
api_base: api_base.map(|s| s.to_string()),
api_key: Some(api_key.to_string()),
litellm_params: None,
})
.collect();
// Env vars mapping
let mut env_vars_map = serde_yaml::Mapping::new();
env_vars_map.insert(
serde_yaml::Value::String("OPENAI_API_KEY".to_string()),
serde_yaml::Value::String(api_key.to_string()),
);
// General settings mapping (fallback_models etc.)
let mut general_settings_map = serde_yaml::Mapping::new();
general_settings_map.insert(
serde_yaml::Value::String("fallback_models".to_string()),
serde_yaml::Value::Sequence(
OPENAI_MODELS
.iter()
.map(|m| serde_yaml::Value::String((*m).to_string()))
.collect(),
),
);
let config = LiteLLMConfig {
model_list,
environment_variables: Some(serde_yaml::Value::Mapping(env_vars_map)),
general_settings: Some(serde_yaml::Value::Mapping(general_settings_map)),
};
let yaml_content = serde_yaml::to_string(&config).map_err(|e| {
BusterError::CommandError(format!("Failed to serialize LiteLLM config to YAML: {}", e))
})?;
fs::write(&config_path, yaml_content).map_err(|e| {
BusterError::CommandError(format!(
"Failed to write LiteLLM config file to {}: {}",
config_path.display(),
e
))
})?;
Ok(config_path)
}
pub fn update_litellm_yaml(
app_base_dir: &Path,
api_key: &str,
api_base: Option<&str>,
) -> Result<PathBuf, BusterError> {
let litellm_config_dir = app_base_dir.join("litellm_config");
let config_path = litellm_config_dir.join("config.yaml");
// Ensure directory exists
fs::create_dir_all(&litellm_config_dir).map_err(|e| {
BusterError::CommandError(format!(
"Failed to create litellm config directory at {}: {}",
litellm_config_dir.display(),
e
))
})?;
if !config_path.exists() {
return create_litellm_yaml(app_base_dir, api_key, api_base);
}
// Read existing config
let yaml_content = fs::read_to_string(&config_path).map_err(|e| {
BusterError::CommandError(format!(
"Failed to read LiteLLM config file at {}: {}",
config_path.display(),
e
))
})?;
let mut config: LiteLLMConfig = serde_yaml::from_str(&yaml_content).map_err(|e| {
BusterError::CommandError(format!("Failed to parse LiteLLM config YAML: {}", e))
})?;
// Ensure each model present and updated
for model_name in OPENAI_MODELS.iter() {
let mut found = false;
for model_config in &mut config.model_list {
if &model_config.model_name == model_name {
model_config.api_key = Some(api_key.to_string());
model_config.api_base = api_base.map(|s| s.to_string());
found = true;
break;
}
}
if !found {
config.model_list.push(LiteLLMModelConfig {
model_name: model_name.to_string(),
api_base: api_base.map(|s| s.to_string()),
api_key: Some(api_key.to_string()),
litellm_params: None,
});
}
}
// Update environment variables
match &mut config.environment_variables {
Some(serde_yaml::Value::Mapping(map)) => {
map.insert(
serde_yaml::Value::String("OPENAI_API_KEY".to_string()),
serde_yaml::Value::String(api_key.to_string()),
);
}
_ => {
let mut env_map = serde_yaml::Mapping::new();
env_map.insert(
serde_yaml::Value::String("OPENAI_API_KEY".to_string()),
serde_yaml::Value::String(api_key.to_string()),
);
config.environment_variables = Some(serde_yaml::Value::Mapping(env_map));
}
}
// Update general settings fallback_models to include all models
let fallback_seq: Vec<serde_yaml::Value> = OPENAI_MODELS
.iter()
.map(|m| serde_yaml::Value::String((*m).to_string()))
.collect();
match &mut config.general_settings {
Some(serde_yaml::Value::Mapping(settings)) => {
settings.insert(
serde_yaml::Value::String("fallback_models".to_string()),
serde_yaml::Value::Sequence(fallback_seq),
);
}
_ => {
let mut settings = serde_yaml::Mapping::new();
settings.insert(
serde_yaml::Value::String("fallback_models".to_string()),
serde_yaml::Value::Sequence(fallback_seq),
);
config.general_settings = Some(serde_yaml::Value::Mapping(settings));
}
}
// Serialize and write back
let updated_yaml = serde_yaml::to_string(&config).map_err(|e| {
BusterError::CommandError(format!(
"Failed to serialize updated LiteLLM config to YAML: {}",
e
))
})?;
fs::write(&config_path, updated_yaml).map_err(|e| {
BusterError::CommandError(format!(
"Failed to write updated LiteLLM config file to {}: {}",
config_path.display(),
e
))
})?;
Ok(config_path)
}
pub fn prompt_and_manage_openai_api_key(
app_base_dir: &Path,
force_prompt: bool,
) -> Result<String, BusterError> {
let cache_file = ".openai_api_key";
let mut current_key = get_cached_value(app_base_dir, cache_file)?;
if force_prompt || current_key.is_none() {
if current_key.is_some() {
let key_display = current_key.as_ref().map_or("", |k| if k.len() > 4 { &k[k.len()-4..] } else { "****" });
let update_choice = prompt_for_input(&format!("Current OpenAI API key ends with ...{}. Update? (y/n)", key_display), Some("n"), false)?.to_lowercase();
let key_display = current_key.as_ref().map_or("", |k| {
if k.len() > 4 {
&k[k.len() - 4..]
} else {
"****"
}
});
let update_choice = prompt_for_input(
&format!("Current OpenAI API key ends with ...{}. Update? (y/n)", key_display),
Some("n"),
false,
)?
.to_lowercase();
if update_choice != "y" {
return Ok(current_key.unwrap());
}
}
let new_key = prompt_for_input("Enter your OpenAI API Key:", None, true)?;
let api_base_choice = prompt_for_input(
"Use custom API base URL? (y/n):",
Some("n"),
false,
)?
.to_lowercase();
let api_base = if api_base_choice == "y" {
Some(
prompt_for_input(
"Enter the API base URL:",
Some("https://api.openai.com/v1"),
false,
)?,
)
} else {
Some("https://api.openai.com/v1".to_string())
};
// Update LiteLLM config first (borrows new_key)
update_litellm_yaml(app_base_dir, &new_key, api_base.as_deref())?;
// Cache the key after successful update
cache_value(app_base_dir, cache_file, &new_key)?;
current_key = Some(new_key);
println!("LiteLLM configuration file updated successfully.");
}
current_key.ok_or_else(|| BusterError::CommandError("OpenAI API Key setup failed.".to_string()))
current_key.ok_or_else(|| {
BusterError::CommandError("OpenAI API Key setup failed.".to_string())
})
}
pub struct RerankerConfig {
@ -149,7 +452,10 @@ pub struct RerankerConfig {
pub base_url: String,
}
pub fn prompt_and_manage_reranker_settings(app_base_dir: &Path, force_prompt: bool) -> Result<RerankerConfig, BusterError> {
pub fn prompt_and_manage_reranker_settings(
app_base_dir: &Path,
force_prompt: bool,
) -> Result<RerankerConfig, BusterError> {
let provider_cache = ".reranker_provider";
let key_cache = ".reranker_api_key";
let model_cache = ".reranker_model";
@ -161,14 +467,34 @@ pub fn prompt_and_manage_reranker_settings(app_base_dir: &Path, force_prompt: bo
let mut current_url = get_cached_value(app_base_dir, url_cache)?;
let mut needs_update = force_prompt;
if !needs_update && (current_provider.is_none() || current_key.is_none() || current_model.is_none() || current_url.is_none()) {
if !needs_update
&& (current_provider.is_none()
|| current_key.is_none()
|| current_model.is_none()
|| current_url.is_none())
{
needs_update = true; // If any part is missing, force update flow for initial setup
}
if needs_update {
if !force_prompt && current_provider.is_some() && current_model.is_some() { // Already prompted if force_prompt is true
let update_choice = prompt_for_input(&format!("Current Reranker: {} (Model: {}). Update settings? (y/n)", current_provider.as_ref().unwrap_or(&"N/A".to_string()), current_model.as_ref().unwrap_or(&"N/A".to_string())), Some("n"), false)?.to_lowercase();
if update_choice != "y" && current_provider.is_some() && current_key.is_some() && current_model.is_some() && current_url.is_some(){
if !force_prompt && current_provider.is_some() && current_model.is_some() {
// Already prompted if force_prompt is true
let update_choice = prompt_for_input(
&format!(
"Current Reranker: {} (Model: {}). Update settings? (y/n)",
current_provider.as_ref().unwrap_or(&"N/A".to_string()),
current_model.as_ref().unwrap_or(&"N/A".to_string())
),
Some("n"),
false,
)?
.to_lowercase();
if update_choice != "y"
&& current_provider.is_some()
&& current_key.is_some()
&& current_model.is_some()
&& current_url.is_some()
{
return Ok(RerankerConfig {
provider: current_provider.unwrap(),
api_key: current_key.unwrap(),
@ -177,9 +503,23 @@ pub fn prompt_and_manage_reranker_settings(app_base_dir: &Path, force_prompt: bo
});
}
} else if force_prompt && current_provider.is_some() && current_model.is_some() {
let update_choice = prompt_for_input(&format!("Current Reranker: {} (Model: {}). Update settings? (y/n)", current_provider.as_ref().unwrap_or(&"N/A".to_string()), current_model.as_ref().unwrap_or(&"N/A".to_string())), Some("n"), false)?.to_lowercase();
if update_choice != "y" && current_provider.is_some() && current_key.is_some() && current_model.is_some() && current_url.is_some(){
return Ok(RerankerConfig {
let update_choice = prompt_for_input(
&format!(
"Current Reranker: {} (Model: {}). Update settings? (y/n)",
current_provider.as_ref().unwrap_or(&"N/A".to_string()),
current_model.as_ref().unwrap_or(&"N/A".to_string())
),
Some("n"),
false,
)?
.to_lowercase();
if update_choice != "y"
&& current_provider.is_some()
&& current_key.is_some()
&& current_model.is_some()
&& current_url.is_some()
{
return Ok(RerankerConfig {
provider: current_provider.unwrap(),
api_key: current_key.unwrap(),
model: current_model.unwrap(),
@ -201,30 +541,60 @@ pub fn prompt_and_manage_reranker_settings(app_base_dir: &Path, force_prompt: bo
};
let (new_provider, default_model, default_url) = match provider_choice {
1 => ("Cohere", "rerank-english-v3.0", "https://api.cohere.com/v1/rerank"), // user asked for v3.5 but official docs say v3.0 for rerank model
2 => ("Mixedbread", "mixedbread-ai/mxbai-rerank-xsmall-v1", "https://api.mixedbread.ai/v1/reranking"),
3 => ("Jina", "jina-reranker-v1-base-en", "https://api.jina.ai/v1/rerank"),
1 => (
"Cohere",
"rerank-english-v3.0",
"https://api.cohere.com/v1/rerank",
), // user asked for v3.5 but official docs say v3.0 for rerank model
2 => (
"Mixedbread",
"mixedbread-ai/mxbai-rerank-xsmall-v1",
"https://api.mixedbread.ai/v1/reranking",
),
3 => (
"Jina",
"jina-reranker-v1-base-en",
"https://api.jina.ai/v1/rerank",
),
_ => unreachable!(),
};
let new_key_val = prompt_for_input(&format!("Enter your {} API Key:", new_provider), None, true)?;
let new_model_val = prompt_for_input(&format!("Enter {} model name:", new_provider), Some(default_model), false)?;
let new_url_val = prompt_for_input(&format!("Enter {} rerank base URL:", new_provider), Some(default_url), false)?;
let new_key_val =
prompt_for_input(&format!("Enter your {} API Key:", new_provider), None, true)?;
let new_model_val = prompt_for_input(
&format!("Enter {} model name:", new_provider),
Some(default_model),
false,
)?;
let new_url_val = prompt_for_input(
&format!("Enter {} rerank base URL:", new_provider),
Some(default_url),
false,
)?;
cache_value(app_base_dir, provider_cache, new_provider)?;
cache_value(app_base_dir, key_cache, &new_key_val)?;
cache_value(app_base_dir, model_cache, &new_model_val)?;
cache_value(app_base_dir, url_cache, &new_url_val)?;
current_provider = Some(new_provider.to_string());
current_key = Some(new_key_val);
current_model = Some(new_model_val);
current_url = Some(new_url_val);
}
if let (Some(prov), Some(key), Some(model), Some(url)) = (current_provider, current_key, current_model, current_url) {
Ok(RerankerConfig { provider: prov, api_key: key, model, base_url: url })
if let (Some(prov), Some(key), Some(model), Some(url)) =
(current_provider, current_key, current_model, current_url)
{
Ok(RerankerConfig {
provider: prov,
api_key: key,
model,
base_url: url,
})
} else {
Err(BusterError::CommandError("Reranker configuration setup failed. Some values are missing.".to_string()))
Err(BusterError::CommandError(
"Reranker configuration setup failed. Some values are missing.".to_string(),
))
}
}
}

View File

@ -117,6 +117,14 @@ async fn setup_persistent_app_environment() -> Result<PathBuf, BusterError> {
let llm_api_key = config_utils::prompt_and_manage_openai_api_key(&app_base_dir, false)?;
let reranker_config = config_utils::prompt_and_manage_reranker_settings(&app_base_dir, false)?;
// Create/update LiteLLM YAML config
let litellm_config_path = config_utils::update_litellm_yaml(
&app_base_dir,
&llm_api_key,
Some("https://api.openai.com/v1"),
)?;
let litellm_config_path_str = litellm_config_path.to_string_lossy();
// Update .env file (this is the root .env)
config_utils::update_env_file(
&main_dot_env_target_path, // Ensure this targets the root .env
@ -125,6 +133,7 @@ async fn setup_persistent_app_environment() -> Result<PathBuf, BusterError> {
Some(&reranker_config.model),
Some(&reranker_config.base_url),
None, // Not prompting for LLM_BASE_URL in this flow yet, example has it.
Some(&litellm_config_path_str), // Add LiteLLM config path to env
)
.map_err(|e| {
BusterError::CommandError(format!(
@ -146,6 +155,81 @@ async fn run_docker_compose_command(
) -> Result<(), BusterError> {
let persistent_app_dir = setup_persistent_app_environment().await?;
// Handle LiteLLM config if a start or reset operation is being performed
if operation_name == "Starting" || operation_name == "Resetting" {
// Check if litellm_config path is in environment
let litellm_config_path = if let Ok(path) = std::env::var("LITELLM_CONFIG_PATH") {
Some(path)
} else {
// Try to read from .env file
let env_path = persistent_app_dir.join(".env");
if env_path.exists() {
let content = fs::read_to_string(&env_path).map_err(|e| {
BusterError::CommandError(format!(
"Failed to read .env file at {}: {}",
env_path.display(),
e
))
})?;
content.lines()
.find(|line| line.starts_with("LITELLM_CONFIG_PATH="))
.and_then(|line| {
line.split('=').nth(1).map(|s| {
s.trim_matches(|c| c == '"' || c == '\'').to_string()
})
})
} else {
None
}
};
// If we have a litellm config path, modify docker-compose.yml to use it
if let Some(config_path) = litellm_config_path {
println!("Using custom LiteLLM configuration: {}", config_path);
// Read the docker-compose.yml file
let docker_compose_path = persistent_app_dir.join("docker-compose.yml");
let docker_compose_content = fs::read_to_string(&docker_compose_path).map_err(|e| {
BusterError::CommandError(format!(
"Failed to read docker-compose.yml: {}",
e
))
})?;
// Create a simple backup
fs::write(
persistent_app_dir.join("docker-compose.yml.bak"),
&docker_compose_content,
)
.map_err(|e| {
BusterError::CommandError(format!(
"Failed to create backup of docker-compose.yml: {}",
e
))
})?;
// Replace the litellm config path
let modified_content = docker_compose_content
.replace(
" - ./litellm_vertex_config.yaml:/litellm_vertex_config.yaml",
&format!(" - {}:/litellm_config.yaml", config_path)
)
.replace(
" command: [\"--config\", \"/litellm_vertex_config.yaml\", \"--port\", \"4001\"]",
" command: [\"--config\", \"/litellm_config.yaml\", \"--port\", \"4001\"]"
);
// Write the modified docker-compose.yml
fs::write(&docker_compose_path, modified_content).map_err(|e| {
BusterError::CommandError(format!(
"Failed to update docker-compose.yml with custom LiteLLM config: {}",
e
))
})?;
}
}
let data_db_path = persistent_app_dir.join("supabase/volumes/db/data");
fs::create_dir_all(&data_db_path).map_err(|e| {
BusterError::CommandError(format!(

View File

@ -33,12 +33,13 @@ services:
- EMBEDDING_MODEL=${EMBEDDING_MODEL}
- COHERE_API_KEY=${COHERE_API_KEY}
- ENVIRONMENT=${ENVIRONMENT}
- LOG_LEVEL=${LOG_LEVEL}
- LOG_LEVEL=DEBUG
- RERANK_API_KEY=${RERANK_API_KEY}
- RERANK_MODEL=${RERANK_MODEL}
- RERANK_BASE_URL=${RERANK_BASE_URL}
- LLM_API_KEY=${LLM_API_KEY}
- LLM_BASE_URL=${LLM_BASE_URL}
- RUST_LOG=debug
ports:
- "3001:3001"
- "3000:3000"
@ -66,22 +67,22 @@ services:
network_mode: "service:api"
# Pausing this for local deployments until we can build out better multi-model support.
# litellm:
# image: ghcr.io/berriai/litellm:main-latest
# container_name: buster-litellm
# volumes:
# - ./litellm_vertex_config.yaml:/litellm_vertex_config.yaml
# command: ["--config", "/litellm_vertex_config.yaml", "--port", "4001"]
# ports:
# - "4001:4001"
# healthcheck:
# test: ["CMD", "curl", "-f", "http://localhost:4001/health/readiness"]
# interval: 30s
# timeout: 10s
# retries: 3
# depends_on:
# api:
# condition: service_healthy
litellm:
image: ghcr.io/berriai/litellm:main-latest
container_name: buster-litellm
volumes:
- ./litellm_vertex_config.yaml:/litellm_vertex_config.yaml
command: ["--config", "/litellm_vertex_config.yaml", "--port", "4001"]
ports:
- "4001:4001"
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:4001/health/readiness"]
interval: 30s
timeout: 10s
retries: 3
depends_on:
api:
condition: service_healthy
volumes:
buster_redis_data:
buster_redis_data: