Merge branch 'evals' of https://github.com/buster-so/buster into evals

This commit is contained in:
Nate Kelley 2025-03-14 13:51:16 -06:00
commit c8ae510019
No known key found for this signature in database
GPG Key ID: FD90372AB8D98B4F
8 changed files with 284 additions and 73 deletions

View File

@ -1,6 +1,6 @@
use crate::tools::{IntoToolCallExecutor, ToolExecutor};
use anyhow::Result;
use braintrust::BraintrustClient;
use braintrust::{BraintrustClient, TraceBuilder};
use litellm::{
AgentMessage, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient,
MessageProgress, Metadata, Tool, ToolCall, ToolChoice,
@ -389,15 +389,6 @@ impl Agent {
*current = Some(thread.clone());
}
// Initialize Braintrust client
let client = BraintrustClient::new(
None,
"c7b996a6-1c7c-482d-b23f-3d39de16f433"
)?;
// Create a root span for this thread
let root_span_id = thread.id.to_string();
if recursion_depth >= 30 {
let message = AgentMessage::assistant(
Some("max_recursion_depth_message".to_string()),
@ -415,6 +406,11 @@ impl Agent {
// Collect all registered tools and their schemas
let tools = self.get_enabled_tools().await;
// Get the most recent user message for logging
let user_message = thread.messages.last()
.filter(|msg| matches!(msg, AgentMessage::User { .. }))
.cloned();
// Create the tool-enabled request
let request = ChatCompletionRequest {
model: self.model.clone(),
@ -430,22 +426,13 @@ impl Agent {
}),
..Default::default()
};
// Create a span for the LLM call
let llm_span = client.create_span(
"llm_call",
"llm",
Some(&root_span_id),
None
).with_input(serde_json::to_value(&request)?);
// Get the streaming response from the LLM
let mut stream_rx = match self.llm_client.stream_chat_completion(request).await {
Ok(rx) => rx,
Err(e) => {
// Log error in span
let error_message = format!("Error starting stream: {:?}", e);
client.log_span(llm_span.with_output(serde_json::json!({"error": error_message}))).await?;
return Err(anyhow::anyhow!(error_message));
},
};
@ -503,7 +490,6 @@ impl Agent {
Err(e) => {
// Log error in span
let error_message = format!("Error in stream: {:?}", e);
client.log_span(llm_span.with_output(serde_json::json!({"error": error_message}))).await?;
return Err(anyhow::anyhow!(error_message));
},
}
@ -530,24 +516,6 @@ impl Agent {
Some(self.name.clone()),
);
// Log the LLM response in Braintrust
let llm_output = if let Some(content) = &final_message.get_content() {
serde_json::json!({
"content": content,
"tool_calls": final_tool_calls
})
} else {
serde_json::json!({
"tool_calls": final_tool_calls
})
};
// Clone the span_id before moving llm_span
let llm_span_id = llm_span.clone().span_id().to_string();
// Now we can safely move llm_span
client.log_span(llm_span.with_output(llm_output)).await?;
// Broadcast the final assistant message
self.get_stream_sender()
.await
@ -572,17 +540,15 @@ impl Agent {
// Execute each requested tool
for tool_call in tool_calls {
if let Some(tool) = self.tools.read().await.get(&tool_call.function.name) {
// Create a span for the tool call
let tool_span = client.create_span(
&tool_call.function.name,
"tool",
Some(&root_span_id),
Some(&llm_span_id)
);
// Parse the parameters and log them
// Parse the parameters - log only the tool call as input
let params: Value = serde_json::from_str(&tool_call.function.arguments)?;
let tool_span = tool_span.with_input(params.clone());
let tool_input = serde_json::json!({
"function": {
"name": tool_call.function.name,
"arguments": params
},
"id": tool_call.id
});
// Execute the tool
let result = match tool.execute(params, tool_call.id.clone()).await {
@ -590,14 +556,10 @@ impl Agent {
Err(e) => {
// Log error in tool span
let error_message = format!("Tool execution error: {:?}", e);
client.log_span(tool_span.with_output(serde_json::json!({"error": error_message}))).await?;
return Err(anyhow::anyhow!(error_message));
}
};
// Log the tool result
client.log_span(tool_span.with_output(result.clone())).await?;
let result_str = serde_json::to_string(&result)?;
let tool_message = AgentMessage::tool(
None,
@ -623,6 +585,8 @@ impl Agent {
new_thread.messages.push(final_message);
new_thread.messages.extend(results);
// For recursive calls, we'll continue with the same trace
// We don't finish the trace here to keep all interactions in one trace
Box::pin(self.process_thread_with_depth(&new_thread, recursion_depth + 1)).await
} else {
// Send Done message and return

View File

@ -22,6 +22,7 @@ tokio-test = { workspace = true }
mockito = { workspace = true }
async-trait = { workspace = true }
futures = { workspace = true }
dotenv = { workspace = true }
# Feature flags
[features]

View File

@ -113,9 +113,9 @@ impl Span {
self
}
/// Alias for add_metadata
pub fn with_metadata(self, key: &str, value: &str) -> Self {
self.add_metadata(key, value)
/// Alias for add_metadata that converts any displayable value to a string
pub fn with_metadata<T: std::fmt::Display>(self, key: &str, value: T) -> Self {
self.add_metadata(key, &value.to_string())
}
/// Get the span ID

View File

@ -0,0 +1,238 @@
use anyhow::Result;
use braintrust::{BraintrustClient, TraceBuilder};
use dotenv::dotenv;
use serde_json::json;
use std::env;
use std::time::Duration;
use tokio::time::sleep;
// Helper function to initialize environment from .env file
fn init_env() -> Result<()> {
// Load environment variables from .env file
dotenv().ok();
// Verify that the API key is set
if env::var("BRAINTRUST_API_KEY").is_err() {
println!("Warning: BRAINTRUST_API_KEY not found in environment or .env file");
println!("Some tests may fail if they require a valid API key");
}
Ok(())
}
#[tokio::test]
async fn test_real_client_initialization() -> Result<()> {
// Initialize environment
init_env()?;
// Create client with environment API key (None means use env var)
let client = BraintrustClient::new(None, "c7b996a6-1c7c-482d-b23f-3d39de16f433")?;
// Simple verification that client was created
assert!(client.project_id() == "c7b996a6-1c7c-482d-b23f-3d39de16f433");
Ok(())
}
#[tokio::test]
async fn test_real_span_logging() -> Result<()> {
// Initialize environment
init_env()?;
// Skip test if no API key is available
if env::var("BRAINTRUST_API_KEY").is_err() {
println!("Skipping test_real_span_logging: No API key available");
return Ok(());
}
// Create client (None means use env var)
let client = BraintrustClient::new(None, "c7b996a6-1c7c-482d-b23f-3d39de16f433")?;
// Create a span
let span = client.create_span("Integration Test Span", "test", None, None);
// Add data to the span
let span = span
.with_input(json!({
"test_input": "This is a test input for integration testing",
"timestamp": chrono::Utc::now().to_rfc3339()
}))
.with_output(json!({
"test_output": "This is a test output for integration testing",
"timestamp": chrono::Utc::now().to_rfc3339()
}))
.with_metadata("test_source", "integration_test")
.with_metadata("test_id", uuid::Uuid::new_v4().to_string());
// Log the span
client.log_span(span).await?;
// Allow some time for async processing
sleep(Duration::from_millis(100)).await;
Ok(())
}
#[tokio::test]
async fn test_real_trace_with_spans() -> Result<()> {
// Initialize environment
init_env()?;
// Skip test if no API key is available
if env::var("BRAINTRUST_API_KEY").is_err() {
println!("Skipping test_real_trace_with_spans: No API key available");
return Ok(());
}
// Create client (None means use env var)
let client = BraintrustClient::new(None, "c7b996a6-1c7c-482d-b23f-3d39de16f433")?;
// Create a trace
let trace_id = uuid::Uuid::new_v4().to_string();
let trace = TraceBuilder::new(
client.clone(),
&format!("Integration Test Trace {}", trace_id)
);
// Add a root span
let root_span = trace.add_span("Root Operation", "function").await?;
let mut root_span = root_span
.with_input(json!({
"operation": "root",
"parameters": {
"test": true,
"timestamp": chrono::Utc::now().to_rfc3339()
}
}))
.with_metadata("test_id", trace_id.clone());
// Log the root span
client.log_span(root_span.clone()).await?;
sleep(Duration::from_secs(10)).await;
root_span = root_span
.with_output(json!({
"result": "success",
"parameters": {
"test": true,
"timestamp": chrono::Utc::now().to_rfc3339()
}
}));
client.log_span(root_span).await?;
// Add an LLM span
let llm_span = trace.add_span("LLM Call", "llm").await?;
let mut llm_span = llm_span
.with_input(json!({
"messages": [
{
"role": "user",
"content": "Hello, this is a test message for integration testing"
}
]
}))
.with_metadata("model", "test-model");
// Log the LLM span
client.log_span(llm_span.clone()).await?;
sleep(Duration::from_secs(15)).await;
llm_span = llm_span
.with_output(json!({
"choices": [
{
"message": {
"role": "assistant",
"content": "Hello! I'm responding to your integration test message."
}
}
]
}));
client.log_span(llm_span).await?;
// Add a tool span
let tool_span = trace.add_span("Tool Execution", "tool").await?;
let tool_span = tool_span
.with_input(json!({
"function": {
"name": "test_tool",
"arguments": {
"param1": "value1",
"param2": 42
}
},
"id": uuid::Uuid::new_v4().to_string()
}))
.with_output(json!({
"result": "Tool execution successful",
"data": {
"value": "test result",
"timestamp": chrono::Utc::now().to_rfc3339()
}
}));
// Log the tool span
client.log_span(tool_span).await?;
// Finish the trace
trace.finish().await?;
// Allow some time for async processing
sleep(Duration::from_secs(30)).await;
Ok(())
}
#[tokio::test]
async fn test_real_error_handling() -> Result<()> {
// Initialize environment
init_env()?;
// Skip test if no API key is available
if env::var("BRAINTRUST_API_KEY").is_err() {
println!("Skipping test_real_error_handling: No API key available");
return Ok(());
}
// Create client (None means use env var)
let client = BraintrustClient::new(None, "c7b996a6-1c7c-482d-b23f-3d39de16f433")?;
// Create a trace for error testing
let trace = TraceBuilder::new(
client.clone(),
"Integration Test Error Handling"
);
// Add a span that will contain an error
let error_span = trace.add_span("Error Operation", "function").await?;
// Simulate an operation that results in an error
let error_message = "This is a simulated error for testing";
let error_span = error_span
.with_input(json!({
"operation": "error_test",
"should_fail": true
}))
.with_output(json!({
"error": error_message,
"stack_trace": "simulated stack trace for testing",
"timestamp": chrono::Utc::now().to_rfc3339()
}))
.with_metadata("error", true)
.with_metadata("error_type", "SimulatedError");
// Log the error span
client.log_span(error_span).await?;
// Finish the trace
trace.finish().await?;
// Allow some time for async processing
sleep(Duration::from_millis(100)).await;
Ok(())
}

View File

@ -12,7 +12,13 @@ pub struct GetRawLlmMessagesRequest {
pub chat_id: Uuid,
}
pub async fn get_raw_llm_messages_handler(chat_id: Uuid) -> Result<Value> {
#[derive(Debug, Serialize, Deserialize)]
pub struct GetRawLlmMessagesResponse {
pub chat_id: Uuid,
pub raw_llm_messages: Value,
}
pub async fn get_raw_llm_messages_handler(chat_id: Uuid) -> Result<GetRawLlmMessagesResponse> {
let pool = get_pg_pool();
let mut conn = pool.get().await?;
@ -26,5 +32,8 @@ pub async fn get_raw_llm_messages_handler(chat_id: Uuid) -> Result<Value> {
.first::<Value>(&mut conn)
.await?;
Ok(raw_llm_messages)
Ok(GetRawLlmMessagesResponse {
chat_id,
raw_llm_messages,
})
}

View File

@ -71,7 +71,7 @@ pub async fn generate_conversation_title(
// Create the request
let request = ChatCompletionRequest {
model: "gemini-2".to_string(),
model: "gpt-4o-mini".to_string(),
messages: vec![LiteLLMAgentMessage::User {
id: None,
content: prompt,

View File

@ -1764,7 +1764,7 @@ pub async fn generate_conversation_title(
// Create the request
let request = ChatCompletionRequest {
model: "gemini-2".to_string(),
model: "gpt-4o-mini".to_string(),
messages: vec![LiteLLMAgentMessage::User {
id: None,
content: prompt,

View File

@ -1,21 +1,20 @@
use axum::{
extract::{Path, State},
Extension, Json,
http::StatusCode,
response::IntoResponse,
};
use handlers::chats::get_raw_llm_messages_handler;
use middleware::AuthenticatedUser;
use serde_json::Value;
use uuid::Uuid;
use crate::routes::rest::ApiResponse;
use axum::{extract::Path, http::StatusCode, Extension};
use handlers::chats::get_raw_llm_messages_handler::{
get_raw_llm_messages_handler, GetRawLlmMessagesResponse,
};
use middleware::AuthenticatedUser;
use uuid::Uuid;
pub async fn get_chat_raw_llm_messages(
Extension(user): Extension<AuthenticatedUser>,
Extension(user): Extension<AuthenticatedUser>,
Path(chat_id): Path<Uuid>,
) -> Result<ApiResponse<Value>, (StatusCode, &'static str)> {
) -> Result<ApiResponse<GetRawLlmMessagesResponse>, (StatusCode, &'static str)> {
match get_raw_llm_messages_handler(chat_id).await {
Ok(response) => Ok(ApiResponse::JsonData(response)),
Err(e) => Err((StatusCode::INTERNAL_SERVER_ERROR, "Failed to get raw LLM messages")),
Err(e) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to get raw LLM messages",
)),
}
}