rerank test and tweak

This commit is contained in:
dal 2025-05-07 18:31:37 -06:00
parent fb0077c583
commit 46cb2c3b3b
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
3 changed files with 61 additions and 21 deletions

View File

@ -7,3 +7,7 @@ edition = "2021"
reqwest = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
[dev-dependencies]
dotenv = { workspace = true }
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }

View File

@ -2,15 +2,7 @@ use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::error::Error;
#[derive(Debug)]
pub enum RerankerType {
Cohere,
Mxbai,
Jina,
}
pub struct Reranker {
reranker_type: RerankerType,
api_key: String,
base_url: String,
model: String,
@ -19,23 +11,11 @@ pub struct Reranker {
impl Reranker {
pub fn new() -> Result<Self, Box<dyn Error>> {
let provider = std::env::var("RERANK_PROVIDER")?;
let reranker_type = match provider.to_lowercase().as_str() {
"cohere" => RerankerType::Cohere,
"mxbai" => RerankerType::Mxbai,
"jina" => RerankerType::Jina,
_ => return Err("Invalid provider specified".into()),
};
let api_key = std::env::var("RERANK_API_KEY")?;
let model = std::env::var("RERANK_MODEL")?;
let base_url = match reranker_type {
RerankerType::Cohere => "https://api.cohere.com/v2/rerank",
RerankerType::Mxbai => "https://api.mixedbread.ai/v1/rerank",
RerankerType::Jina => "https://api.jina.ai/v1/rerank",
}.to_string();
let base_url = std::env::var("RERANK_BASE_URL")?;
let client = Client::new();
Ok(Self {
reranker_type,
api_key,
base_url,
model,

View File

@ -0,0 +1,56 @@
use rerank::{Reranker, RerankResult};
use std::error::Error;
#[tokio::test]
async fn test_reranker_integration() -> Result<(), Box<dyn Error>> {
// Load environment variables from .env file
dotenv::dotenv().ok();
// Initialize the reranker
let reranker = Reranker::new()?;
// Define a sample query and documents
let query = "What is the capital of France?";
let documents = vec![
"Paris is a major European city and a global center for art, fashion, gastronomy and culture.",
"London is the capital and largest city of England and the United Kingdom.",
"The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, France.",
"Berlin is the capital and largest city of Germany by both area and population.",
];
let top_n = 2;
// Perform reranking
let results: Vec<RerankResult> = reranker.rerank(query, &documents, top_n).await?;
// Assertions
assert_eq!(results.len(), top_n, "Should return top_n results");
// Check that indices are within the bounds of the original documents
for result in &results {
assert!(result.index < documents.len(), "Result index should be valid");
}
// Optional: Print results for manual verification (can be removed later)
println!("Query: {}", query);
for result in &results {
println!(
"Document Index: {}, Score: {:.4}, Document: {}",
result.index,
result.relevance_score,
documents[result.index]
);
}
// Example assertion: if we expect Paris-related documents to be ranked higher.
// This is a very basic check and might need adjustment based on actual model behavior.
if !results.is_empty() {
let first_result_doc = documents[results[0].index];
assert!(
first_result_doc.to_lowercase().contains("paris"),
"The top result for 'capital of France' should ideally mention Paris. Model output: {}",
first_result_doc
);
}
Ok(())
}