mirror of https://github.com/buster-so/buster.git
rip out fastembed due to build errors
This commit is contained in:
parent
5e7c8fc6b1
commit
993929720e
|
@ -110,7 +110,6 @@ diesel_migrations = "2.0.0"
|
|||
html-escape = "0.2.13"
|
||||
tokio-cron-scheduler = "0.13.0"
|
||||
tokio-retry = "0.3.0"
|
||||
fastembed = "4.9.0"
|
||||
|
||||
[profile.release]
|
||||
debug = false
|
||||
|
|
|
@ -8,7 +8,6 @@ reqwest = { workspace = true }
|
|||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
dotenv = { workspace = true }
|
||||
fastembed = "4.8.0"
|
||||
|
||||
[dev-dependencies]
|
||||
dotenv = { workspace = true }
|
||||
|
|
|
@ -9,24 +9,15 @@ pub struct Reranker {
|
|||
base_url: String,
|
||||
model: String,
|
||||
client: Client,
|
||||
environment: String,
|
||||
}
|
||||
|
||||
impl Reranker {
|
||||
pub fn new() -> Result<Self, Box<dyn Error>> {
|
||||
dotenv().ok();
|
||||
let environment = env::var("ENVIRONMENT").unwrap_or_else(|_| "production".to_string());
|
||||
|
||||
// If local environment, we don't need these values
|
||||
let (api_key, model, base_url) = if environment == "local" {
|
||||
(String::new(), String::new(), String::new())
|
||||
} else {
|
||||
(
|
||||
env::var("RERANK_API_KEY")?,
|
||||
env::var("RERANK_MODEL")?,
|
||||
env::var("RERANK_BASE_URL")?,
|
||||
)
|
||||
};
|
||||
let api_key = env::var("RERANK_API_KEY")?;
|
||||
let model = env::var("RERANK_MODEL")?;
|
||||
let base_url = env::var("RERANK_BASE_URL")?;
|
||||
|
||||
let client = Client::new();
|
||||
Ok(Self {
|
||||
|
@ -34,7 +25,6 @@ impl Reranker {
|
|||
base_url,
|
||||
model,
|
||||
client,
|
||||
environment,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -44,18 +34,13 @@ impl Reranker {
|
|||
documents: &[&str],
|
||||
top_n: usize,
|
||||
) -> Result<Vec<RerankResult>, Box<dyn Error>> {
|
||||
// Use local fastembed reranking if ENVIRONMENT is set to local
|
||||
if self.environment == "local" {
|
||||
return self.local_rerank(query, documents, top_n).await;
|
||||
}
|
||||
|
||||
// Otherwise use the remote API
|
||||
let request_body = RerankRequest {
|
||||
query: query.to_string(),
|
||||
documents: documents.iter().map(|s| s.to_string()).collect(),
|
||||
top_n,
|
||||
model: self.model.clone(),
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&self.base_url)
|
||||
|
@ -63,40 +48,10 @@ impl Reranker {
|
|||
.json(&request_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let response_body: RerankResponse = response.json().await?;
|
||||
Ok(response_body.results)
|
||||
}
|
||||
|
||||
async fn local_rerank(
|
||||
&self,
|
||||
query: &str,
|
||||
documents: &[&str],
|
||||
top_n: usize,
|
||||
) -> Result<Vec<RerankResult>, Box<dyn Error>> {
|
||||
use fastembed::{TextRerank, RerankInitOptions, RerankerModel};
|
||||
|
||||
// Initialize the reranker model
|
||||
let model = TextRerank::try_new(
|
||||
RerankInitOptions::new(RerankerModel::JINARerankerV1TurboEn).with_show_download_progress(true),
|
||||
)?;
|
||||
|
||||
// Limit top_n to the number of documents
|
||||
let actual_top_n = std::cmp::min(top_n, documents.len());
|
||||
|
||||
// Perform reranking
|
||||
let fastembed_results = model.rerank(query, documents.to_vec(),false, Some(actual_top_n))?;
|
||||
|
||||
// Convert fastembed results to our RerankResult format
|
||||
let results = fastembed_results
|
||||
.iter()
|
||||
.map(|result| RerankResult {
|
||||
index: result.index,
|
||||
relevance_score: result.score,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
|
|
|
@ -39,7 +39,6 @@ tower-http = { workspace = true }
|
|||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
fastembed = { workspace = true }
|
||||
|
||||
# Local dependencies
|
||||
handlers = { path = "../libs/handlers" }
|
||||
|
|
|
@ -22,7 +22,6 @@ use tower::ServiceBuilder;
|
|||
use tower_http::{compression::CompressionLayer, trace::TraceLayer};
|
||||
use tracing::{error, info, warn};
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
|
||||
use fastembed::{InitOptions, RerankInitOptions, RerankerModel, TextRerank};
|
||||
|
||||
pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!();
|
||||
|
||||
|
@ -34,13 +33,6 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
let environment = env::var("ENVIRONMENT").unwrap_or_else(|_| "development".to_string());
|
||||
let is_development = environment == "development";
|
||||
|
||||
if environment == "local" {
|
||||
let options =
|
||||
RerankInitOptions::new(RerankerModel::JINARerankerV1TurboEn).with_show_download_progress(true);
|
||||
let model = TextRerank::try_new(options)?;
|
||||
println!("Model loaded and ready!");
|
||||
}
|
||||
|
||||
ring::default_provider()
|
||||
.install_default()
|
||||
.expect("Failed to install default crypto provider");
|
||||
|
|
Loading…
Reference in New Issue