diff --git a/api/Cargo.toml b/api/Cargo.toml index 289d7ba32..07532fa64 100644 --- a/api/Cargo.toml +++ b/api/Cargo.toml @@ -110,7 +110,6 @@ diesel_migrations = "2.0.0" html-escape = "0.2.13" tokio-cron-scheduler = "0.13.0" tokio-retry = "0.3.0" -fastembed = "4.9.0" [profile.release] debug = false diff --git a/api/libs/rerank/Cargo.toml b/api/libs/rerank/Cargo.toml index 52ad7369c..2c0b57863 100644 --- a/api/libs/rerank/Cargo.toml +++ b/api/libs/rerank/Cargo.toml @@ -8,7 +8,6 @@ reqwest = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } dotenv = { workspace = true } -fastembed = "4.8.0" [dev-dependencies] dotenv = { workspace = true } diff --git a/api/libs/rerank/src/lib.rs b/api/libs/rerank/src/lib.rs index 95fb716b9..5411bb394 100644 --- a/api/libs/rerank/src/lib.rs +++ b/api/libs/rerank/src/lib.rs @@ -9,24 +9,15 @@ pub struct Reranker { base_url: String, model: String, client: Client, - environment: String, } impl Reranker { pub fn new() -> Result> { dotenv().ok(); - let environment = env::var("ENVIRONMENT").unwrap_or_else(|_| "production".to_string()); - // If local environment, we don't need these values - let (api_key, model, base_url) = if environment == "local" { - (String::new(), String::new(), String::new()) - } else { - ( - env::var("RERANK_API_KEY")?, - env::var("RERANK_MODEL")?, - env::var("RERANK_BASE_URL")?, - ) - }; + let api_key = env::var("RERANK_API_KEY")?; + let model = env::var("RERANK_MODEL")?; + let base_url = env::var("RERANK_BASE_URL")?; let client = Client::new(); Ok(Self { @@ -34,7 +25,6 @@ impl Reranker { base_url, model, client, - environment, }) } @@ -44,18 +34,13 @@ impl Reranker { documents: &[&str], top_n: usize, ) -> Result, Box> { - // Use local fastembed reranking if ENVIRONMENT is set to local - if self.environment == "local" { - return self.local_rerank(query, documents, top_n).await; - } - - // Otherwise use the remote API let request_body = RerankRequest { query: query.to_string(), documents: documents.iter().map(|s| s.to_string()).collect(), top_n, model: self.model.clone(), }; + let response = self .client .post(&self.base_url) @@ -63,40 +48,10 @@ impl Reranker { .json(&request_body) .send() .await?; + let response_body: RerankResponse = response.json().await?; Ok(response_body.results) } - - async fn local_rerank( - &self, - query: &str, - documents: &[&str], - top_n: usize, - ) -> Result, Box> { - use fastembed::{TextRerank, RerankInitOptions, RerankerModel}; - - // Initialize the reranker model - let model = TextRerank::try_new( - RerankInitOptions::new(RerankerModel::JINARerankerV1TurboEn).with_show_download_progress(true), - )?; - - // Limit top_n to the number of documents - let actual_top_n = std::cmp::min(top_n, documents.len()); - - // Perform reranking - let fastembed_results = model.rerank(query, documents.to_vec(),false, Some(actual_top_n))?; - - // Convert fastembed results to our RerankResult format - let results = fastembed_results - .iter() - .map(|result| RerankResult { - index: result.index, - relevance_score: result.score, - }) - .collect(); - - Ok(results) - } } #[derive(Serialize)] diff --git a/api/server/Cargo.toml b/api/server/Cargo.toml index 9eadbf44b..8d54a687d 100644 --- a/api/server/Cargo.toml +++ b/api/server/Cargo.toml @@ -39,7 +39,6 @@ tower-http = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } uuid = { workspace = true } -fastembed = { workspace = true } # Local dependencies handlers = { path = "../libs/handlers" } diff --git a/api/server/src/main.rs b/api/server/src/main.rs index 7b799fe1d..674d82524 100644 --- a/api/server/src/main.rs +++ b/api/server/src/main.rs @@ -22,7 +22,6 @@ use tower::ServiceBuilder; use tower_http::{compression::CompressionLayer, trace::TraceLayer}; use tracing::{error, info, warn}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; -use fastembed::{InitOptions, RerankInitOptions, RerankerModel, TextRerank}; pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!(); @@ -34,13 +33,6 @@ async fn main() -> Result<(), anyhow::Error> { let environment = env::var("ENVIRONMENT").unwrap_or_else(|_| "development".to_string()); let is_development = environment == "development"; - if environment == "local" { - let options = - RerankInitOptions::new(RerankerModel::JINARerankerV1TurboEn).with_show_download_progress(true); - let model = TextRerank::try_new(options)?; - println!("Model loaded and ready!"); - } - ring::default_provider() .install_default() .expect("Failed to install default crypto provider");