diff --git a/api/.cursorrules b/api/.cursorrules index 2549ff4a3..63b2126e7 100644 --- a/api/.cursorrules +++ b/api/.cursorrules @@ -1,2 +1,8 @@ - this is an axum web server -- all tests need to be tokio async tests \ No newline at end of file + +## Testing Guidelines +- all tests need to be tokio async tests +- all unit tests should be inline with the code they are testing +- all integration tests should be in the tests folder +- all tests should comment the test case and the expected output +- Makes sure to use mockito::Server::new_async() instead of mockito::Server::new() \ No newline at end of file diff --git a/api/src/utils/clients/ai/litellm/client.rs b/api/src/utils/clients/ai/litellm/client.rs index 683ee7dba..b7ffbc066 100644 --- a/api/src/utils/clients/ai/litellm/client.rs +++ b/api/src/utils/clients/ai/litellm/client.rs @@ -132,11 +132,16 @@ mod tests { #[tokio::test] async fn test_chat_completion_success() { - let mut server = mockito::Server::new(); + let mut server = mockito::Server::new_async().await; + + // Create expected request body + let request = create_test_request(); + let request_body = serde_json::to_string(&request).unwrap(); let mock = server.mock("POST", "/chat/completions") .match_header("content-type", "application/json") - .match_header("Authorization", "Bearer test-key") + .match_header("authorization", "Bearer test-key") + .match_body(mockito::Matcher::JsonString(request_body)) .with_status(200) .with_header("content-type", "application/json") .with_body(r#"{ @@ -165,7 +170,7 @@ mod tests { Some(server.url()), ); - let response = client.chat_completion(create_test_request()).await.unwrap(); + let response = client.chat_completion(request).await.unwrap(); assert_eq!(response.id, "test-id"); assert_eq!(response.choices[0].message.content, "Hello there!"); @@ -174,11 +179,17 @@ mod tests { #[tokio::test] async fn test_chat_completion_error() { - let mut server = mockito::Server::new(); + let mut server = mockito::Server::new_async().await; + + let request = create_test_request(); + let request_body = serde_json::to_string(&request).unwrap(); let mock = server.mock("POST", "/chat/completions") .match_header("content-type", "application/json") + .match_header("authorization", "Bearer test-key") + .match_body(mockito::Matcher::JsonString(request_body)) .with_status(400) + .with_header("content-type", "application/json") .with_body(r#"{"error": "Invalid request"}"#) .create(); @@ -187,7 +198,7 @@ mod tests { Some(server.url()), ); - let result = client.chat_completion(create_test_request()).await; + let result = client.chat_completion(request).await; assert!(result.is_err()); mock.assert(); @@ -195,11 +206,16 @@ mod tests { #[tokio::test] async fn test_stream_chat_completion() { - let mut server = mockito::Server::new(); + let mut server = mockito::Server::new_async().await; + + let mut request = create_test_request(); + request.stream = Some(true); + let request_body = serde_json::to_string(&request).unwrap(); let mock = server.mock("POST", "/chat/completions") .match_header("content-type", "application/json") - .match_body(mockito::Matcher::JsonString(r#"{"stream":true}"#.to_string())) + .match_header("authorization", "Bearer test-key") + .match_body(mockito::Matcher::JsonString(request_body)) .with_status(200) .with_header("content-type", "text/event-stream") .with_body( @@ -214,9 +230,6 @@ mod tests { Some(server.url()), ); - let mut request = create_test_request(); - request.stream = Some(true); - let mut stream = client.stream_chat_completion(request).await.unwrap(); let mut chunks = Vec::new();