diff --git a/api/libs/rerank/Cargo.toml b/api/libs/rerank/Cargo.toml index 7476f5f3c..44adfc306 100644 --- a/api/libs/rerank/Cargo.toml +++ b/api/libs/rerank/Cargo.toml @@ -7,3 +7,7 @@ edition = "2021" reqwest = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } + +[dev-dependencies] +dotenv = { workspace = true } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/api/libs/rerank/src/lib.rs b/api/libs/rerank/src/lib.rs index 4df2ba84f..67177db37 100644 --- a/api/libs/rerank/src/lib.rs +++ b/api/libs/rerank/src/lib.rs @@ -2,15 +2,7 @@ use reqwest::Client; use serde::{Deserialize, Serialize}; use std::error::Error; -#[derive(Debug)] -pub enum RerankerType { - Cohere, - Mxbai, - Jina, -} - pub struct Reranker { - reranker_type: RerankerType, api_key: String, base_url: String, model: String, @@ -19,23 +11,11 @@ pub struct Reranker { impl Reranker { pub fn new() -> Result> { - let provider = std::env::var("RERANK_PROVIDER")?; - let reranker_type = match provider.to_lowercase().as_str() { - "cohere" => RerankerType::Cohere, - "mxbai" => RerankerType::Mxbai, - "jina" => RerankerType::Jina, - _ => return Err("Invalid provider specified".into()), - }; let api_key = std::env::var("RERANK_API_KEY")?; let model = std::env::var("RERANK_MODEL")?; - let base_url = match reranker_type { - RerankerType::Cohere => "https://api.cohere.com/v2/rerank", - RerankerType::Mxbai => "https://api.mixedbread.ai/v1/rerank", - RerankerType::Jina => "https://api.jina.ai/v1/rerank", - }.to_string(); + let base_url = std::env::var("RERANK_BASE_URL")?; let client = Client::new(); Ok(Self { - reranker_type, api_key, base_url, model, diff --git a/api/libs/rerank/tests/integration_test.rs b/api/libs/rerank/tests/integration_test.rs new file mode 100644 index 000000000..d208044f3 --- /dev/null +++ b/api/libs/rerank/tests/integration_test.rs @@ -0,0 +1,56 @@ +use rerank::{Reranker, RerankResult}; +use std::error::Error; + +#[tokio::test] +async fn test_reranker_integration() -> Result<(), Box> { + // Load environment variables from .env file + dotenv::dotenv().ok(); + + // Initialize the reranker + let reranker = Reranker::new()?; + + // Define a sample query and documents + let query = "What is the capital of France?"; + let documents = vec![ + "Paris is a major European city and a global center for art, fashion, gastronomy and culture.", + "London is the capital and largest city of England and the United Kingdom.", + "The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, France.", + "Berlin is the capital and largest city of Germany by both area and population.", + ]; + let top_n = 2; + + // Perform reranking + let results: Vec = reranker.rerank(query, &documents, top_n).await?; + + // Assertions + assert_eq!(results.len(), top_n, "Should return top_n results"); + + // Check that indices are within the bounds of the original documents + for result in &results { + assert!(result.index < documents.len(), "Result index should be valid"); + } + + // Optional: Print results for manual verification (can be removed later) + println!("Query: {}", query); + for result in &results { + println!( + "Document Index: {}, Score: {:.4}, Document: {}", + result.index, + result.relevance_score, + documents[result.index] + ); + } + + // Example assertion: if we expect Paris-related documents to be ranked higher. + // This is a very basic check and might need adjustment based on actual model behavior. + if !results.is_empty() { + let first_result_doc = documents[results[0].index]; + assert!( + first_result_doc.to_lowercase().contains("paris"), + "The top result for 'capital of France' should ideally mention Paris. Model output: {}", + first_result_doc + ); + } + + Ok(()) +} \ No newline at end of file