mirror of https://github.com/buster-so/buster.git
Staging (#330)
* Create a better handler for clicking favorites * chore(versions): bump api to v0.1.9; bump web to v0.1.9; bump cli to v0.1.9 [skip ci] * chore: update tag_info.json with potential release versions [skip ci] * Create a better handler for clicking favorites * update chat favorites * chore(versions): bump api to v0.1.10; bump web to v0.1.10; bump cli to v0.1.10 [skip ci] * chore: update tag_info.json with potential release versions [skip ci] * Update tests to be ran with multiple workers * create chat records update * Create createChatRecord.test.ts * chore(versions): bump api to v0.1.11; bump web to v0.1.11; bump cli to v0.1.11 [skip ci] * chore: update tag_info.json with potential release versions [skip ci] * fix yesterday bucket * add fast embed rerank for local deployment (#329) * add fast embed rerank for local * chore(versions): bump api to v0.1.12; bump web to v0.1.12; bump cli to v0.1.12 [skip ci] * chore: update tag_info.json with potential release versions [skip ci] --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: Nate Kelley <nate@buster.so> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Nate Kelley <133379588+nate-kelley-buster@users.noreply.github.com>
This commit is contained in:
parent
d6ac872afe
commit
fac8e8b673
|
@ -9,6 +9,8 @@
|
|||
crash.log
|
||||
crash.*.log
|
||||
|
||||
.fastembed_cache/
|
||||
|
||||
# Exclude all .tfvars files, which are likely to contain sensitive data, such as
|
||||
# password, private keys, and other secrets. These should not be part of version
|
||||
# control as they are data points which are potentially sensitive and subject
|
||||
|
|
|
@ -110,6 +110,7 @@ diesel_migrations = "2.0.0"
|
|||
html-escape = "0.2.13"
|
||||
tokio-cron-scheduler = "0.13.0"
|
||||
tokio-retry = "0.3.0"
|
||||
fastembed = "4.8.0"
|
||||
|
||||
[profile.release]
|
||||
debug = false
|
||||
|
|
|
@ -7,6 +7,8 @@ edition = "2021"
|
|||
reqwest = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
dotenv = { workspace = true }
|
||||
fastembed = "4.8.0"
|
||||
|
||||
[dev-dependencies]
|
||||
dotenv = { workspace = true }
|
||||
|
|
|
@ -1,25 +1,40 @@
|
|||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
use dotenv::dotenv;
|
||||
use std::env;
|
||||
|
||||
pub struct Reranker {
|
||||
api_key: String,
|
||||
base_url: String,
|
||||
model: String,
|
||||
client: Client,
|
||||
environment: String,
|
||||
}
|
||||
|
||||
impl Reranker {
|
||||
pub fn new() -> Result<Self, Box<dyn Error>> {
|
||||
let api_key = std::env::var("RERANK_API_KEY")?;
|
||||
let model = std::env::var("RERANK_MODEL")?;
|
||||
let base_url = std::env::var("RERANK_BASE_URL")?;
|
||||
dotenv().ok();
|
||||
let environment = env::var("ENVIRONMENT").unwrap_or_else(|_| "production".to_string());
|
||||
|
||||
// If local environment, we don't need these values
|
||||
let (api_key, model, base_url) = if environment == "local" {
|
||||
(String::new(), String::new(), String::new())
|
||||
} else {
|
||||
(
|
||||
env::var("RERANK_API_KEY")?,
|
||||
env::var("RERANK_MODEL")?,
|
||||
env::var("RERANK_BASE_URL")?,
|
||||
)
|
||||
};
|
||||
|
||||
let client = Client::new();
|
||||
Ok(Self {
|
||||
api_key,
|
||||
base_url,
|
||||
model,
|
||||
client,
|
||||
environment,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -29,6 +44,12 @@ impl Reranker {
|
|||
documents: &[&str],
|
||||
top_n: usize,
|
||||
) -> Result<Vec<RerankResult>, Box<dyn Error>> {
|
||||
// Use local fastembed reranking if ENVIRONMENT is set to local
|
||||
if self.environment == "local" {
|
||||
return self.local_rerank(query, documents, top_n).await;
|
||||
}
|
||||
|
||||
// Otherwise use the remote API
|
||||
let request_body = RerankRequest {
|
||||
query: query.to_string(),
|
||||
documents: documents.iter().map(|s| s.to_string()).collect(),
|
||||
|
@ -45,6 +66,37 @@ impl Reranker {
|
|||
let response_body: RerankResponse = response.json().await?;
|
||||
Ok(response_body.results)
|
||||
}
|
||||
|
||||
async fn local_rerank(
|
||||
&self,
|
||||
query: &str,
|
||||
documents: &[&str],
|
||||
top_n: usize,
|
||||
) -> Result<Vec<RerankResult>, Box<dyn Error>> {
|
||||
use fastembed::{TextRerank, RerankInitOptions, RerankerModel};
|
||||
|
||||
// Initialize the reranker model
|
||||
let model = TextRerank::try_new(
|
||||
RerankInitOptions::new(RerankerModel::JINARerankerV1TurboEn).with_show_download_progress(true),
|
||||
)?;
|
||||
|
||||
// Limit top_n to the number of documents
|
||||
let actual_top_n = std::cmp::min(top_n, documents.len());
|
||||
|
||||
// Perform reranking
|
||||
let fastembed_results = model.rerank(query, documents.to_vec(),false, Some(actual_top_n))?;
|
||||
|
||||
// Convert fastembed results to our RerankResult format
|
||||
let results = fastembed_results
|
||||
.iter()
|
||||
.map(|result| RerankResult {
|
||||
index: result.index,
|
||||
relevance_score: result.score,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
|
@ -60,7 +112,7 @@ struct RerankResponse {
|
|||
results: Vec<RerankResult>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[derive(Deserialize, Clone, Debug)]
|
||||
pub struct RerankResult {
|
||||
pub index: usize,
|
||||
pub relevance_score: f32,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "buster_server"
|
||||
version = "0.1.11"
|
||||
version = "0.1.12"
|
||||
edition = "2021"
|
||||
default-run = "buster_server"
|
||||
|
||||
|
@ -39,6 +39,7 @@ tower-http = { workspace = true }
|
|||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
fastembed = { workspace = true }
|
||||
|
||||
# Local dependencies
|
||||
handlers = { path = "../libs/handlers" }
|
||||
|
|
|
@ -5,12 +5,15 @@ use std::env;
|
|||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use axum::{Extension, Router, extract::Request};
|
||||
use middleware::{cors::cors, error::{init_sentry, sentry_layer, init_tracing_subscriber}};
|
||||
use axum::{extract::Request, Extension, Router};
|
||||
use database::{self, pool::init_pools};
|
||||
use diesel::{Connection, PgConnection};
|
||||
use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness};
|
||||
use dotenv::dotenv;
|
||||
use middleware::{
|
||||
cors::cors,
|
||||
error::{init_sentry, init_tracing_subscriber, sentry_layer},
|
||||
};
|
||||
use rustls::crypto::ring;
|
||||
use stored_values::jobs::trigger_stale_sync_jobs;
|
||||
use tokio::sync::broadcast;
|
||||
|
@ -19,6 +22,7 @@ use tower::ServiceBuilder;
|
|||
use tower_http::{compression::CompressionLayer, trace::TraceLayer};
|
||||
use tracing::{error, info, warn};
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
|
||||
use fastembed::{InitOptions, RerankInitOptions, RerankerModel, TextRerank};
|
||||
|
||||
pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!();
|
||||
|
||||
|
@ -30,6 +34,13 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
let environment = env::var("ENVIRONMENT").unwrap_or_else(|_| "development".to_string());
|
||||
let is_development = environment == "development";
|
||||
|
||||
if environment == "local" {
|
||||
let options =
|
||||
RerankInitOptions::new(RerankerModel::JINARerankerV1TurboEn).with_show_download_progress(true);
|
||||
let model = TextRerank::try_new(options)?;
|
||||
println!("Model loaded and ready!");
|
||||
}
|
||||
|
||||
ring::default_provider()
|
||||
.install_default()
|
||||
.expect("Failed to install default crypto provider");
|
||||
|
@ -43,9 +54,9 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
let log_level = env::var("LOG_LEVEL")
|
||||
.unwrap_or_else(|_| "warn".to_string())
|
||||
.to_uppercase();
|
||||
|
||||
let env_filter = EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| EnvFilter::new(log_level));
|
||||
|
||||
let env_filter =
|
||||
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(log_level));
|
||||
|
||||
// Initialize the tracing subscriber with Sentry integration using our middleware helper
|
||||
init_tracing_subscriber(env_filter);
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "buster-cli"
|
||||
version = "0.1.11"
|
||||
version = "0.1.12"
|
||||
edition = "2021"
|
||||
build = "build.rs"
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"api_tag": "api/v0.1.11", "api_version": "0.1.11"
|
||||
"api_tag": "api/v0.1.12", "api_version": "0.1.12"
|
||||
,
|
||||
"web_tag": "web/v0.1.11", "web_version": "0.1.11"
|
||||
"web_tag": "web/v0.1.12", "web_version": "0.1.12"
|
||||
,
|
||||
"cli_tag": "cli/v0.1.11", "cli_version": "0.1.11"
|
||||
"cli_tag": "cli/v0.1.12", "cli_version": "0.1.12"
|
||||
}
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
{
|
||||
"name": "web",
|
||||
"version": "0.1.11",
|
||||
"version": "0.1.12",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "web",
|
||||
"version": "0.1.11",
|
||||
"version": "0.1.12",
|
||||
"dependencies": {
|
||||
"@dnd-kit/core": "^6.3.1",
|
||||
"@dnd-kit/modifiers": "^9.0.0",
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "web",
|
||||
"version": "0.1.11",
|
||||
"version": "0.1.12",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"dev": "next dev --turbo",
|
||||
|
|
Loading…
Reference in New Issue