mirror of https://github.com/buster-so/buster.git
rerank test and tweak
This commit is contained in:
parent
fb0077c583
commit
46cb2c3b3b
|
@ -7,3 +7,7 @@ edition = "2021"
|
|||
reqwest = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
dotenv = { workspace = true }
|
||||
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
|
||||
|
|
|
@ -2,15 +2,7 @@ use reqwest::Client;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum RerankerType {
|
||||
Cohere,
|
||||
Mxbai,
|
||||
Jina,
|
||||
}
|
||||
|
||||
pub struct Reranker {
|
||||
reranker_type: RerankerType,
|
||||
api_key: String,
|
||||
base_url: String,
|
||||
model: String,
|
||||
|
@ -19,23 +11,11 @@ pub struct Reranker {
|
|||
|
||||
impl Reranker {
|
||||
pub fn new() -> Result<Self, Box<dyn Error>> {
|
||||
let provider = std::env::var("RERANK_PROVIDER")?;
|
||||
let reranker_type = match provider.to_lowercase().as_str() {
|
||||
"cohere" => RerankerType::Cohere,
|
||||
"mxbai" => RerankerType::Mxbai,
|
||||
"jina" => RerankerType::Jina,
|
||||
_ => return Err("Invalid provider specified".into()),
|
||||
};
|
||||
let api_key = std::env::var("RERANK_API_KEY")?;
|
||||
let model = std::env::var("RERANK_MODEL")?;
|
||||
let base_url = match reranker_type {
|
||||
RerankerType::Cohere => "https://api.cohere.com/v2/rerank",
|
||||
RerankerType::Mxbai => "https://api.mixedbread.ai/v1/rerank",
|
||||
RerankerType::Jina => "https://api.jina.ai/v1/rerank",
|
||||
}.to_string();
|
||||
let base_url = std::env::var("RERANK_BASE_URL")?;
|
||||
let client = Client::new();
|
||||
Ok(Self {
|
||||
reranker_type,
|
||||
api_key,
|
||||
base_url,
|
||||
model,
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
use rerank::{Reranker, RerankResult};
|
||||
use std::error::Error;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reranker_integration() -> Result<(), Box<dyn Error>> {
|
||||
// Load environment variables from .env file
|
||||
dotenv::dotenv().ok();
|
||||
|
||||
// Initialize the reranker
|
||||
let reranker = Reranker::new()?;
|
||||
|
||||
// Define a sample query and documents
|
||||
let query = "What is the capital of France?";
|
||||
let documents = vec![
|
||||
"Paris is a major European city and a global center for art, fashion, gastronomy and culture.",
|
||||
"London is the capital and largest city of England and the United Kingdom.",
|
||||
"The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, France.",
|
||||
"Berlin is the capital and largest city of Germany by both area and population.",
|
||||
];
|
||||
let top_n = 2;
|
||||
|
||||
// Perform reranking
|
||||
let results: Vec<RerankResult> = reranker.rerank(query, &documents, top_n).await?;
|
||||
|
||||
// Assertions
|
||||
assert_eq!(results.len(), top_n, "Should return top_n results");
|
||||
|
||||
// Check that indices are within the bounds of the original documents
|
||||
for result in &results {
|
||||
assert!(result.index < documents.len(), "Result index should be valid");
|
||||
}
|
||||
|
||||
// Optional: Print results for manual verification (can be removed later)
|
||||
println!("Query: {}", query);
|
||||
for result in &results {
|
||||
println!(
|
||||
"Document Index: {}, Score: {:.4}, Document: {}",
|
||||
result.index,
|
||||
result.relevance_score,
|
||||
documents[result.index]
|
||||
);
|
||||
}
|
||||
|
||||
// Example assertion: if we expect Paris-related documents to be ranked higher.
|
||||
// This is a very basic check and might need adjustment based on actual model behavior.
|
||||
if !results.is_empty() {
|
||||
let first_result_doc = documents[results[0].index];
|
||||
assert!(
|
||||
first_result_doc.to_lowercase().contains("paris"),
|
||||
"The top result for 'capital of France' should ideally mention Paris. Model output: {}",
|
||||
first_result_doc
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
Loading…
Reference in New Issue