mirror of https://github.com/buster-so/buster.git
Merge branch 'evals' of https://github.com/buster-so/buster into evals
This commit is contained in:
commit
c8ae510019
|
@ -1,6 +1,6 @@
|
|||
use crate::tools::{IntoToolCallExecutor, ToolExecutor};
|
||||
use anyhow::Result;
|
||||
use braintrust::BraintrustClient;
|
||||
use braintrust::{BraintrustClient, TraceBuilder};
|
||||
use litellm::{
|
||||
AgentMessage, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient,
|
||||
MessageProgress, Metadata, Tool, ToolCall, ToolChoice,
|
||||
|
@ -389,15 +389,6 @@ impl Agent {
|
|||
*current = Some(thread.clone());
|
||||
}
|
||||
|
||||
// Initialize Braintrust client
|
||||
let client = BraintrustClient::new(
|
||||
None,
|
||||
"c7b996a6-1c7c-482d-b23f-3d39de16f433"
|
||||
)?;
|
||||
|
||||
// Create a root span for this thread
|
||||
let root_span_id = thread.id.to_string();
|
||||
|
||||
if recursion_depth >= 30 {
|
||||
let message = AgentMessage::assistant(
|
||||
Some("max_recursion_depth_message".to_string()),
|
||||
|
@ -415,6 +406,11 @@ impl Agent {
|
|||
// Collect all registered tools and their schemas
|
||||
let tools = self.get_enabled_tools().await;
|
||||
|
||||
// Get the most recent user message for logging
|
||||
let user_message = thread.messages.last()
|
||||
.filter(|msg| matches!(msg, AgentMessage::User { .. }))
|
||||
.cloned();
|
||||
|
||||
// Create the tool-enabled request
|
||||
let request = ChatCompletionRequest {
|
||||
model: self.model.clone(),
|
||||
|
@ -430,22 +426,13 @@ impl Agent {
|
|||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Create a span for the LLM call
|
||||
let llm_span = client.create_span(
|
||||
"llm_call",
|
||||
"llm",
|
||||
Some(&root_span_id),
|
||||
None
|
||||
).with_input(serde_json::to_value(&request)?);
|
||||
|
||||
|
||||
// Get the streaming response from the LLM
|
||||
let mut stream_rx = match self.llm_client.stream_chat_completion(request).await {
|
||||
Ok(rx) => rx,
|
||||
Err(e) => {
|
||||
// Log error in span
|
||||
let error_message = format!("Error starting stream: {:?}", e);
|
||||
client.log_span(llm_span.with_output(serde_json::json!({"error": error_message}))).await?;
|
||||
return Err(anyhow::anyhow!(error_message));
|
||||
},
|
||||
};
|
||||
|
@ -503,7 +490,6 @@ impl Agent {
|
|||
Err(e) => {
|
||||
// Log error in span
|
||||
let error_message = format!("Error in stream: {:?}", e);
|
||||
client.log_span(llm_span.with_output(serde_json::json!({"error": error_message}))).await?;
|
||||
return Err(anyhow::anyhow!(error_message));
|
||||
},
|
||||
}
|
||||
|
@ -530,24 +516,6 @@ impl Agent {
|
|||
Some(self.name.clone()),
|
||||
);
|
||||
|
||||
// Log the LLM response in Braintrust
|
||||
let llm_output = if let Some(content) = &final_message.get_content() {
|
||||
serde_json::json!({
|
||||
"content": content,
|
||||
"tool_calls": final_tool_calls
|
||||
})
|
||||
} else {
|
||||
serde_json::json!({
|
||||
"tool_calls": final_tool_calls
|
||||
})
|
||||
};
|
||||
|
||||
// Clone the span_id before moving llm_span
|
||||
let llm_span_id = llm_span.clone().span_id().to_string();
|
||||
|
||||
// Now we can safely move llm_span
|
||||
client.log_span(llm_span.with_output(llm_output)).await?;
|
||||
|
||||
// Broadcast the final assistant message
|
||||
self.get_stream_sender()
|
||||
.await
|
||||
|
@ -572,17 +540,15 @@ impl Agent {
|
|||
// Execute each requested tool
|
||||
for tool_call in tool_calls {
|
||||
if let Some(tool) = self.tools.read().await.get(&tool_call.function.name) {
|
||||
// Create a span for the tool call
|
||||
let tool_span = client.create_span(
|
||||
&tool_call.function.name,
|
||||
"tool",
|
||||
Some(&root_span_id),
|
||||
Some(&llm_span_id)
|
||||
);
|
||||
|
||||
// Parse the parameters and log them
|
||||
// Parse the parameters - log only the tool call as input
|
||||
let params: Value = serde_json::from_str(&tool_call.function.arguments)?;
|
||||
let tool_span = tool_span.with_input(params.clone());
|
||||
let tool_input = serde_json::json!({
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": params
|
||||
},
|
||||
"id": tool_call.id
|
||||
});
|
||||
|
||||
// Execute the tool
|
||||
let result = match tool.execute(params, tool_call.id.clone()).await {
|
||||
|
@ -590,14 +556,10 @@ impl Agent {
|
|||
Err(e) => {
|
||||
// Log error in tool span
|
||||
let error_message = format!("Tool execution error: {:?}", e);
|
||||
client.log_span(tool_span.with_output(serde_json::json!({"error": error_message}))).await?;
|
||||
return Err(anyhow::anyhow!(error_message));
|
||||
}
|
||||
};
|
||||
|
||||
// Log the tool result
|
||||
client.log_span(tool_span.with_output(result.clone())).await?;
|
||||
|
||||
let result_str = serde_json::to_string(&result)?;
|
||||
let tool_message = AgentMessage::tool(
|
||||
None,
|
||||
|
@ -623,6 +585,8 @@ impl Agent {
|
|||
new_thread.messages.push(final_message);
|
||||
new_thread.messages.extend(results);
|
||||
|
||||
// For recursive calls, we'll continue with the same trace
|
||||
// We don't finish the trace here to keep all interactions in one trace
|
||||
Box::pin(self.process_thread_with_depth(&new_thread, recursion_depth + 1)).await
|
||||
} else {
|
||||
// Send Done message and return
|
||||
|
|
|
@ -22,6 +22,7 @@ tokio-test = { workspace = true }
|
|||
mockito = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
dotenv = { workspace = true }
|
||||
|
||||
# Feature flags
|
||||
[features]
|
||||
|
|
|
@ -113,9 +113,9 @@ impl Span {
|
|||
self
|
||||
}
|
||||
|
||||
/// Alias for add_metadata
|
||||
pub fn with_metadata(self, key: &str, value: &str) -> Self {
|
||||
self.add_metadata(key, value)
|
||||
/// Alias for add_metadata that converts any displayable value to a string
|
||||
pub fn with_metadata<T: std::fmt::Display>(self, key: &str, value: T) -> Self {
|
||||
self.add_metadata(key, &value.to_string())
|
||||
}
|
||||
|
||||
/// Get the span ID
|
||||
|
|
|
@ -0,0 +1,238 @@
|
|||
use anyhow::Result;
|
||||
use braintrust::{BraintrustClient, TraceBuilder};
|
||||
use dotenv::dotenv;
|
||||
use serde_json::json;
|
||||
use std::env;
|
||||
use std::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
|
||||
// Helper function to initialize environment from .env file
|
||||
fn init_env() -> Result<()> {
|
||||
// Load environment variables from .env file
|
||||
dotenv().ok();
|
||||
|
||||
// Verify that the API key is set
|
||||
if env::var("BRAINTRUST_API_KEY").is_err() {
|
||||
println!("Warning: BRAINTRUST_API_KEY not found in environment or .env file");
|
||||
println!("Some tests may fail if they require a valid API key");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_real_client_initialization() -> Result<()> {
|
||||
// Initialize environment
|
||||
init_env()?;
|
||||
|
||||
// Create client with environment API key (None means use env var)
|
||||
let client = BraintrustClient::new(None, "c7b996a6-1c7c-482d-b23f-3d39de16f433")?;
|
||||
|
||||
// Simple verification that client was created
|
||||
assert!(client.project_id() == "c7b996a6-1c7c-482d-b23f-3d39de16f433");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_real_span_logging() -> Result<()> {
|
||||
// Initialize environment
|
||||
init_env()?;
|
||||
|
||||
// Skip test if no API key is available
|
||||
if env::var("BRAINTRUST_API_KEY").is_err() {
|
||||
println!("Skipping test_real_span_logging: No API key available");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Create client (None means use env var)
|
||||
let client = BraintrustClient::new(None, "c7b996a6-1c7c-482d-b23f-3d39de16f433")?;
|
||||
|
||||
// Create a span
|
||||
let span = client.create_span("Integration Test Span", "test", None, None);
|
||||
|
||||
// Add data to the span
|
||||
let span = span
|
||||
.with_input(json!({
|
||||
"test_input": "This is a test input for integration testing",
|
||||
"timestamp": chrono::Utc::now().to_rfc3339()
|
||||
}))
|
||||
.with_output(json!({
|
||||
"test_output": "This is a test output for integration testing",
|
||||
"timestamp": chrono::Utc::now().to_rfc3339()
|
||||
}))
|
||||
.with_metadata("test_source", "integration_test")
|
||||
.with_metadata("test_id", uuid::Uuid::new_v4().to_string());
|
||||
|
||||
// Log the span
|
||||
client.log_span(span).await?;
|
||||
|
||||
// Allow some time for async processing
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_real_trace_with_spans() -> Result<()> {
|
||||
// Initialize environment
|
||||
init_env()?;
|
||||
|
||||
// Skip test if no API key is available
|
||||
if env::var("BRAINTRUST_API_KEY").is_err() {
|
||||
println!("Skipping test_real_trace_with_spans: No API key available");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Create client (None means use env var)
|
||||
let client = BraintrustClient::new(None, "c7b996a6-1c7c-482d-b23f-3d39de16f433")?;
|
||||
|
||||
// Create a trace
|
||||
let trace_id = uuid::Uuid::new_v4().to_string();
|
||||
let trace = TraceBuilder::new(
|
||||
client.clone(),
|
||||
&format!("Integration Test Trace {}", trace_id)
|
||||
);
|
||||
|
||||
// Add a root span
|
||||
let root_span = trace.add_span("Root Operation", "function").await?;
|
||||
let mut root_span = root_span
|
||||
.with_input(json!({
|
||||
"operation": "root",
|
||||
"parameters": {
|
||||
"test": true,
|
||||
"timestamp": chrono::Utc::now().to_rfc3339()
|
||||
}
|
||||
}))
|
||||
.with_metadata("test_id", trace_id.clone());
|
||||
|
||||
// Log the root span
|
||||
client.log_span(root_span.clone()).await?;
|
||||
|
||||
sleep(Duration::from_secs(10)).await;
|
||||
|
||||
root_span = root_span
|
||||
.with_output(json!({
|
||||
"result": "success",
|
||||
"parameters": {
|
||||
"test": true,
|
||||
"timestamp": chrono::Utc::now().to_rfc3339()
|
||||
}
|
||||
}));
|
||||
|
||||
client.log_span(root_span).await?;
|
||||
|
||||
// Add an LLM span
|
||||
let llm_span = trace.add_span("LLM Call", "llm").await?;
|
||||
let mut llm_span = llm_span
|
||||
.with_input(json!({
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, this is a test message for integration testing"
|
||||
}
|
||||
]
|
||||
}))
|
||||
.with_metadata("model", "test-model");
|
||||
|
||||
// Log the LLM span
|
||||
client.log_span(llm_span.clone()).await?;
|
||||
|
||||
sleep(Duration::from_secs(15)).await;
|
||||
|
||||
llm_span = llm_span
|
||||
.with_output(json!({
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello! I'm responding to your integration test message."
|
||||
}
|
||||
}
|
||||
]
|
||||
}));
|
||||
|
||||
client.log_span(llm_span).await?;
|
||||
|
||||
// Add a tool span
|
||||
let tool_span = trace.add_span("Tool Execution", "tool").await?;
|
||||
let tool_span = tool_span
|
||||
.with_input(json!({
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": {
|
||||
"param1": "value1",
|
||||
"param2": 42
|
||||
}
|
||||
},
|
||||
"id": uuid::Uuid::new_v4().to_string()
|
||||
}))
|
||||
.with_output(json!({
|
||||
"result": "Tool execution successful",
|
||||
"data": {
|
||||
"value": "test result",
|
||||
"timestamp": chrono::Utc::now().to_rfc3339()
|
||||
}
|
||||
}));
|
||||
|
||||
// Log the tool span
|
||||
client.log_span(tool_span).await?;
|
||||
|
||||
// Finish the trace
|
||||
trace.finish().await?;
|
||||
|
||||
// Allow some time for async processing
|
||||
sleep(Duration::from_secs(30)).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_real_error_handling() -> Result<()> {
|
||||
// Initialize environment
|
||||
init_env()?;
|
||||
|
||||
// Skip test if no API key is available
|
||||
if env::var("BRAINTRUST_API_KEY").is_err() {
|
||||
println!("Skipping test_real_error_handling: No API key available");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Create client (None means use env var)
|
||||
let client = BraintrustClient::new(None, "c7b996a6-1c7c-482d-b23f-3d39de16f433")?;
|
||||
|
||||
// Create a trace for error testing
|
||||
let trace = TraceBuilder::new(
|
||||
client.clone(),
|
||||
"Integration Test Error Handling"
|
||||
);
|
||||
|
||||
// Add a span that will contain an error
|
||||
let error_span = trace.add_span("Error Operation", "function").await?;
|
||||
|
||||
// Simulate an operation that results in an error
|
||||
let error_message = "This is a simulated error for testing";
|
||||
let error_span = error_span
|
||||
.with_input(json!({
|
||||
"operation": "error_test",
|
||||
"should_fail": true
|
||||
}))
|
||||
.with_output(json!({
|
||||
"error": error_message,
|
||||
"stack_trace": "simulated stack trace for testing",
|
||||
"timestamp": chrono::Utc::now().to_rfc3339()
|
||||
}))
|
||||
.with_metadata("error", true)
|
||||
.with_metadata("error_type", "SimulatedError");
|
||||
|
||||
// Log the error span
|
||||
client.log_span(error_span).await?;
|
||||
|
||||
// Finish the trace
|
||||
trace.finish().await?;
|
||||
|
||||
// Allow some time for async processing
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -12,7 +12,13 @@ pub struct GetRawLlmMessagesRequest {
|
|||
pub chat_id: Uuid,
|
||||
}
|
||||
|
||||
pub async fn get_raw_llm_messages_handler(chat_id: Uuid) -> Result<Value> {
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct GetRawLlmMessagesResponse {
|
||||
pub chat_id: Uuid,
|
||||
pub raw_llm_messages: Value,
|
||||
}
|
||||
|
||||
pub async fn get_raw_llm_messages_handler(chat_id: Uuid) -> Result<GetRawLlmMessagesResponse> {
|
||||
let pool = get_pg_pool();
|
||||
let mut conn = pool.get().await?;
|
||||
|
||||
|
@ -26,5 +32,8 @@ pub async fn get_raw_llm_messages_handler(chat_id: Uuid) -> Result<Value> {
|
|||
.first::<Value>(&mut conn)
|
||||
.await?;
|
||||
|
||||
Ok(raw_llm_messages)
|
||||
Ok(GetRawLlmMessagesResponse {
|
||||
chat_id,
|
||||
raw_llm_messages,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -71,7 +71,7 @@ pub async fn generate_conversation_title(
|
|||
|
||||
// Create the request
|
||||
let request = ChatCompletionRequest {
|
||||
model: "gemini-2".to_string(),
|
||||
model: "gpt-4o-mini".to_string(),
|
||||
messages: vec![LiteLLMAgentMessage::User {
|
||||
id: None,
|
||||
content: prompt,
|
||||
|
|
|
@ -1764,7 +1764,7 @@ pub async fn generate_conversation_title(
|
|||
|
||||
// Create the request
|
||||
let request = ChatCompletionRequest {
|
||||
model: "gemini-2".to_string(),
|
||||
model: "gpt-4o-mini".to_string(),
|
||||
messages: vec![LiteLLMAgentMessage::User {
|
||||
id: None,
|
||||
content: prompt,
|
||||
|
|
|
@ -1,21 +1,20 @@
|
|||
use axum::{
|
||||
extract::{Path, State},
|
||||
Extension, Json,
|
||||
http::StatusCode,
|
||||
response::IntoResponse,
|
||||
};
|
||||
use handlers::chats::get_raw_llm_messages_handler;
|
||||
use middleware::AuthenticatedUser;
|
||||
use serde_json::Value;
|
||||
use uuid::Uuid;
|
||||
use crate::routes::rest::ApiResponse;
|
||||
use axum::{extract::Path, http::StatusCode, Extension};
|
||||
use handlers::chats::get_raw_llm_messages_handler::{
|
||||
get_raw_llm_messages_handler, GetRawLlmMessagesResponse,
|
||||
};
|
||||
use middleware::AuthenticatedUser;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub async fn get_chat_raw_llm_messages(
|
||||
Extension(user): Extension<AuthenticatedUser>,
|
||||
Extension(user): Extension<AuthenticatedUser>,
|
||||
Path(chat_id): Path<Uuid>,
|
||||
) -> Result<ApiResponse<Value>, (StatusCode, &'static str)> {
|
||||
) -> Result<ApiResponse<GetRawLlmMessagesResponse>, (StatusCode, &'static str)> {
|
||||
match get_raw_llm_messages_handler(chat_id).await {
|
||||
Ok(response) => Ok(ApiResponse::JsonData(response)),
|
||||
Err(e) => Err((StatusCode::INTERNAL_SERVER_ERROR, "Failed to get raw LLM messages")),
|
||||
Err(e) => Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Failed to get raw LLM messages",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue