mirror of https://github.com/buster-so/buster.git
rerank test and tweak
This commit is contained in:
parent
fb0077c583
commit
46cb2c3b3b
|
@ -7,3 +7,7 @@ edition = "2021"
|
||||||
reqwest = { workspace = true }
|
reqwest = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
dotenv = { workspace = true }
|
||||||
|
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
|
||||||
|
|
|
@ -2,15 +2,7 @@ use reqwest::Client;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum RerankerType {
|
|
||||||
Cohere,
|
|
||||||
Mxbai,
|
|
||||||
Jina,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct Reranker {
|
pub struct Reranker {
|
||||||
reranker_type: RerankerType,
|
|
||||||
api_key: String,
|
api_key: String,
|
||||||
base_url: String,
|
base_url: String,
|
||||||
model: String,
|
model: String,
|
||||||
|
@ -19,23 +11,11 @@ pub struct Reranker {
|
||||||
|
|
||||||
impl Reranker {
|
impl Reranker {
|
||||||
pub fn new() -> Result<Self, Box<dyn Error>> {
|
pub fn new() -> Result<Self, Box<dyn Error>> {
|
||||||
let provider = std::env::var("RERANK_PROVIDER")?;
|
|
||||||
let reranker_type = match provider.to_lowercase().as_str() {
|
|
||||||
"cohere" => RerankerType::Cohere,
|
|
||||||
"mxbai" => RerankerType::Mxbai,
|
|
||||||
"jina" => RerankerType::Jina,
|
|
||||||
_ => return Err("Invalid provider specified".into()),
|
|
||||||
};
|
|
||||||
let api_key = std::env::var("RERANK_API_KEY")?;
|
let api_key = std::env::var("RERANK_API_KEY")?;
|
||||||
let model = std::env::var("RERANK_MODEL")?;
|
let model = std::env::var("RERANK_MODEL")?;
|
||||||
let base_url = match reranker_type {
|
let base_url = std::env::var("RERANK_BASE_URL")?;
|
||||||
RerankerType::Cohere => "https://api.cohere.com/v2/rerank",
|
|
||||||
RerankerType::Mxbai => "https://api.mixedbread.ai/v1/rerank",
|
|
||||||
RerankerType::Jina => "https://api.jina.ai/v1/rerank",
|
|
||||||
}.to_string();
|
|
||||||
let client = Client::new();
|
let client = Client::new();
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
reranker_type,
|
|
||||||
api_key,
|
api_key,
|
||||||
base_url,
|
base_url,
|
||||||
model,
|
model,
|
||||||
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
use rerank::{Reranker, RerankResult};
|
||||||
|
use std::error::Error;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_reranker_integration() -> Result<(), Box<dyn Error>> {
|
||||||
|
// Load environment variables from .env file
|
||||||
|
dotenv::dotenv().ok();
|
||||||
|
|
||||||
|
// Initialize the reranker
|
||||||
|
let reranker = Reranker::new()?;
|
||||||
|
|
||||||
|
// Define a sample query and documents
|
||||||
|
let query = "What is the capital of France?";
|
||||||
|
let documents = vec![
|
||||||
|
"Paris is a major European city and a global center for art, fashion, gastronomy and culture.",
|
||||||
|
"London is the capital and largest city of England and the United Kingdom.",
|
||||||
|
"The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, France.",
|
||||||
|
"Berlin is the capital and largest city of Germany by both area and population.",
|
||||||
|
];
|
||||||
|
let top_n = 2;
|
||||||
|
|
||||||
|
// Perform reranking
|
||||||
|
let results: Vec<RerankResult> = reranker.rerank(query, &documents, top_n).await?;
|
||||||
|
|
||||||
|
// Assertions
|
||||||
|
assert_eq!(results.len(), top_n, "Should return top_n results");
|
||||||
|
|
||||||
|
// Check that indices are within the bounds of the original documents
|
||||||
|
for result in &results {
|
||||||
|
assert!(result.index < documents.len(), "Result index should be valid");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optional: Print results for manual verification (can be removed later)
|
||||||
|
println!("Query: {}", query);
|
||||||
|
for result in &results {
|
||||||
|
println!(
|
||||||
|
"Document Index: {}, Score: {:.4}, Document: {}",
|
||||||
|
result.index,
|
||||||
|
result.relevance_score,
|
||||||
|
documents[result.index]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example assertion: if we expect Paris-related documents to be ranked higher.
|
||||||
|
// This is a very basic check and might need adjustment based on actual model behavior.
|
||||||
|
if !results.is_empty() {
|
||||||
|
let first_result_doc = documents[results[0].index];
|
||||||
|
assert!(
|
||||||
|
first_result_doc.to_lowercase().contains("paris"),
|
||||||
|
"The top result for 'capital of France' should ideally mention Paris. Model output: {}",
|
||||||
|
first_result_doc
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
Loading…
Reference in New Issue