From 0a9f17fa42516707b9ce08f2c5b0871f399587a5 Mon Sep 17 00:00:00 2001 From: dal Date: Fri, 14 Mar 2025 12:50:12 -0600 Subject: [PATCH 1/3] agent with gpt-4o-mini --- api/libs/agents/src/agent.rs | 70 ++---- api/libs/braintrust/Cargo.toml | 1 + .../braintrust/tests/integration_tests.rs | 220 ++++++++++++++++++ .../helpers/generate_conversation_title.rs | 2 +- .../handlers/src/chats/post_chat_handler.rs | 2 +- 5 files changed, 240 insertions(+), 55 deletions(-) create mode 100644 api/libs/braintrust/tests/integration_tests.rs diff --git a/api/libs/agents/src/agent.rs b/api/libs/agents/src/agent.rs index 06f90d503..4cb54d673 100644 --- a/api/libs/agents/src/agent.rs +++ b/api/libs/agents/src/agent.rs @@ -1,6 +1,6 @@ use crate::tools::{IntoToolCallExecutor, ToolExecutor}; use anyhow::Result; -use braintrust::BraintrustClient; +use braintrust::{BraintrustClient, TraceBuilder}; use litellm::{ AgentMessage, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient, MessageProgress, Metadata, Tool, ToolCall, ToolChoice, @@ -389,15 +389,6 @@ impl Agent { *current = Some(thread.clone()); } - // Initialize Braintrust client - let client = BraintrustClient::new( - None, - "c7b996a6-1c7c-482d-b23f-3d39de16f433" - )?; - - // Create a root span for this thread - let root_span_id = thread.id.to_string(); - if recursion_depth >= 30 { let message = AgentMessage::assistant( Some("max_recursion_depth_message".to_string()), @@ -415,6 +406,11 @@ impl Agent { // Collect all registered tools and their schemas let tools = self.get_enabled_tools().await; + // Get the most recent user message for logging + let user_message = thread.messages.last() + .filter(|msg| matches!(msg, AgentMessage::User { .. })) + .cloned(); + // Create the tool-enabled request let request = ChatCompletionRequest { model: self.model.clone(), @@ -430,22 +426,13 @@ impl Agent { }), ..Default::default() }; - - // Create a span for the LLM call - let llm_span = client.create_span( - "llm_call", - "llm", - Some(&root_span_id), - None - ).with_input(serde_json::to_value(&request)?); - + // Get the streaming response from the LLM let mut stream_rx = match self.llm_client.stream_chat_completion(request).await { Ok(rx) => rx, Err(e) => { // Log error in span let error_message = format!("Error starting stream: {:?}", e); - client.log_span(llm_span.with_output(serde_json::json!({"error": error_message}))).await?; return Err(anyhow::anyhow!(error_message)); }, }; @@ -503,7 +490,6 @@ impl Agent { Err(e) => { // Log error in span let error_message = format!("Error in stream: {:?}", e); - client.log_span(llm_span.with_output(serde_json::json!({"error": error_message}))).await?; return Err(anyhow::anyhow!(error_message)); }, } @@ -530,24 +516,6 @@ impl Agent { Some(self.name.clone()), ); - // Log the LLM response in Braintrust - let llm_output = if let Some(content) = &final_message.get_content() { - serde_json::json!({ - "content": content, - "tool_calls": final_tool_calls - }) - } else { - serde_json::json!({ - "tool_calls": final_tool_calls - }) - }; - - // Clone the span_id before moving llm_span - let llm_span_id = llm_span.clone().span_id().to_string(); - - // Now we can safely move llm_span - client.log_span(llm_span.with_output(llm_output)).await?; - // Broadcast the final assistant message self.get_stream_sender() .await @@ -572,17 +540,15 @@ impl Agent { // Execute each requested tool for tool_call in tool_calls { if let Some(tool) = self.tools.read().await.get(&tool_call.function.name) { - // Create a span for the tool call - let tool_span = client.create_span( - &tool_call.function.name, - "tool", - Some(&root_span_id), - Some(&llm_span_id) - ); - - // Parse the parameters and log them + // Parse the parameters - log only the tool call as input let params: Value = serde_json::from_str(&tool_call.function.arguments)?; - let tool_span = tool_span.with_input(params.clone()); + let tool_input = serde_json::json!({ + "function": { + "name": tool_call.function.name, + "arguments": params + }, + "id": tool_call.id + }); // Execute the tool let result = match tool.execute(params, tool_call.id.clone()).await { @@ -590,14 +556,10 @@ impl Agent { Err(e) => { // Log error in tool span let error_message = format!("Tool execution error: {:?}", e); - client.log_span(tool_span.with_output(serde_json::json!({"error": error_message}))).await?; return Err(anyhow::anyhow!(error_message)); } }; - // Log the tool result - client.log_span(tool_span.with_output(result.clone())).await?; - let result_str = serde_json::to_string(&result)?; let tool_message = AgentMessage::tool( None, @@ -623,6 +585,8 @@ impl Agent { new_thread.messages.push(final_message); new_thread.messages.extend(results); + // For recursive calls, we'll continue with the same trace + // We don't finish the trace here to keep all interactions in one trace Box::pin(self.process_thread_with_depth(&new_thread, recursion_depth + 1)).await } else { // Send Done message and return diff --git a/api/libs/braintrust/Cargo.toml b/api/libs/braintrust/Cargo.toml index df4ae38d8..b433dcc10 100644 --- a/api/libs/braintrust/Cargo.toml +++ b/api/libs/braintrust/Cargo.toml @@ -22,6 +22,7 @@ tokio-test = { workspace = true } mockito = { workspace = true } async-trait = { workspace = true } futures = { workspace = true } +dotenv = { workspace = true } # Feature flags [features] diff --git a/api/libs/braintrust/tests/integration_tests.rs b/api/libs/braintrust/tests/integration_tests.rs new file mode 100644 index 000000000..18cd13816 --- /dev/null +++ b/api/libs/braintrust/tests/integration_tests.rs @@ -0,0 +1,220 @@ +use anyhow::Result; +use braintrust::{BraintrustClient, TraceBuilder}; +use dotenv::dotenv; +use serde_json::json; +use std::env; +use std::time::Duration; +use tokio::time::sleep; + +// Helper function to initialize environment from .env file +fn init_env() -> Result<()> { + // Load environment variables from .env file + dotenv().ok(); + + // Verify that the API key is set + if env::var("BRAINTRUST_API_KEY").is_err() { + println!("Warning: BRAINTRUST_API_KEY not found in environment or .env file"); + println!("Some tests may fail if they require a valid API key"); + } + + Ok(()) +} + +#[tokio::test] +async fn test_real_client_initialization() -> Result<()> { + // Initialize environment + init_env()?; + + // Create client with environment API key (None means use env var) + let client = BraintrustClient::new(None, "c7b996a6-1c7c-482d-b23f-3d39de16f433")?; + + // Simple verification that client was created + assert!(client.project_id() == "c7b996a6-1c7c-482d-b23f-3d39de16f433"); + + Ok(()) +} + +#[tokio::test] +async fn test_real_span_logging() -> Result<()> { + // Initialize environment + init_env()?; + + // Skip test if no API key is available + if env::var("BRAINTRUST_API_KEY").is_err() { + println!("Skipping test_real_span_logging: No API key available"); + return Ok(()); + } + + // Create client (None means use env var) + let client = BraintrustClient::new(None, "c7b996a6-1c7c-482d-b23f-3d39de16f433")?; + + // Create a span + let span = client.create_span("Integration Test Span", "test", None, None); + + // Add data to the span + let span = span + .with_input(json!({ + "test_input": "This is a test input for integration testing", + "timestamp": chrono::Utc::now().to_rfc3339() + })) + .with_output(json!({ + "test_output": "This is a test output for integration testing", + "timestamp": chrono::Utc::now().to_rfc3339() + })) + .with_metadata("test_source", "integration_test") + .with_metadata("test_id", uuid::Uuid::new_v4().to_string()); + + // Log the span + client.log_span(span).await?; + + // Allow some time for async processing + sleep(Duration::from_millis(100)).await; + + Ok(()) +} + +#[tokio::test] +async fn test_real_trace_with_spans() -> Result<()> { + // Initialize environment + init_env()?; + + // Skip test if no API key is available + if env::var("BRAINTRUST_API_KEY").is_err() { + println!("Skipping test_real_trace_with_spans: No API key available"); + return Ok(()); + } + + // Create client (None means use env var) + let client = BraintrustClient::new(None, "c7b996a6-1c7c-482d-b23f-3d39de16f433")?; + + // Create a trace + let trace_id = uuid::Uuid::new_v4().to_string(); + let trace = TraceBuilder::new( + client.clone(), + &format!("Integration Test Trace {}", trace_id) + ); + + // Add a root span + let root_span = trace.add_span("Root Operation", "function").await?; + let root_span = root_span + .with_input(json!({ + "operation": "root", + "parameters": { + "test": true, + "timestamp": chrono::Utc::now().to_rfc3339() + } + })) + .with_metadata("test_id", trace_id.clone()); + + // Log the root span + client.log_span(root_span).await?; + + // Add an LLM span + let llm_span = trace.add_span("LLM Call", "llm").await?; + let llm_span = llm_span + .with_input(json!({ + "messages": [ + { + "role": "user", + "content": "Hello, this is a test message for integration testing" + } + ] + })) + .with_output(json!({ + "choices": [ + { + "message": { + "role": "assistant", + "content": "Hello! I'm responding to your integration test message." + } + } + ] + })) + .with_tokens(15, 12) + .with_metadata("model", "test-model"); + + // Log the LLM span + client.log_span(llm_span).await?; + + // Add a tool span + let tool_span = trace.add_span("Tool Execution", "tool").await?; + let tool_span = tool_span + .with_input(json!({ + "function": { + "name": "test_tool", + "arguments": { + "param1": "value1", + "param2": 42 + } + }, + "id": uuid::Uuid::new_v4().to_string() + })) + .with_output(json!({ + "result": "Tool execution successful", + "data": { + "value": "test result", + "timestamp": chrono::Utc::now().to_rfc3339() + } + })); + + // Log the tool span + client.log_span(tool_span).await?; + + // Finish the trace + trace.finish().await?; + + // Allow some time for async processing + sleep(Duration::from_millis(200)).await; + + Ok(()) +} + +#[tokio::test] +async fn test_real_error_handling() -> Result<()> { + // Initialize environment + init_env()?; + + // Skip test if no API key is available + if env::var("BRAINTRUST_API_KEY").is_err() { + println!("Skipping test_real_error_handling: No API key available"); + return Ok(()); + } + + // Create client (None means use env var) + let client = BraintrustClient::new(None, "c7b996a6-1c7c-482d-b23f-3d39de16f433")?; + + // Create a trace for error testing + let trace = TraceBuilder::new( + client.clone(), + "Integration Test Error Handling" + ); + + // Add a span that will contain an error + let error_span = trace.add_span("Error Operation", "function").await?; + + // Simulate an operation that results in an error + let error_message = "This is a simulated error for testing"; + let error_span = error_span + .with_input(json!({ + "operation": "error_test", + "should_fail": true + })) + .with_output(json!({ + "error": error_message, + "stack_trace": "simulated stack trace for testing", + "timestamp": chrono::Utc::now().to_rfc3339() + })) + .with_metadata("error", true) + .with_metadata("error_type", "SimulatedError"); + + // Log the error span + client.log_span(error_span).await?; + + // Finish the trace + trace.finish().await?; + + // Allow some time for async processing + sleep(Duration::from_millis(100)).await; + + Ok(()) +} diff --git a/api/libs/handlers/src/chats/helpers/generate_conversation_title.rs b/api/libs/handlers/src/chats/helpers/generate_conversation_title.rs index 779d5e74b..6bd08c671 100644 --- a/api/libs/handlers/src/chats/helpers/generate_conversation_title.rs +++ b/api/libs/handlers/src/chats/helpers/generate_conversation_title.rs @@ -71,7 +71,7 @@ pub async fn generate_conversation_title( // Create the request let request = ChatCompletionRequest { - model: "gemini-2".to_string(), + model: "gpt-4o-mini".to_string(), messages: vec![LiteLLMAgentMessage::User { id: None, content: prompt, diff --git a/api/libs/handlers/src/chats/post_chat_handler.rs b/api/libs/handlers/src/chats/post_chat_handler.rs index a043fb26d..4e11314ad 100644 --- a/api/libs/handlers/src/chats/post_chat_handler.rs +++ b/api/libs/handlers/src/chats/post_chat_handler.rs @@ -1764,7 +1764,7 @@ pub async fn generate_conversation_title( // Create the request let request = ChatCompletionRequest { - model: "gemini-2".to_string(), + model: "gpt-4o-mini".to_string(), messages: vec![LiteLLMAgentMessage::User { id: None, content: prompt, From 93243073022d138d11b27d93d7688a25cb313c94 Mon Sep 17 00:00:00 2001 From: dal Date: Fri, 14 Mar 2025 13:20:48 -0600 Subject: [PATCH 2/3] ok integration tests are working. About to implement with actual agent. --- api/libs/braintrust/src/types.rs | 6 +-- .../braintrust/tests/integration_tests.rs | 48 +++++++++++++------ 2 files changed, 36 insertions(+), 18 deletions(-) diff --git a/api/libs/braintrust/src/types.rs b/api/libs/braintrust/src/types.rs index 76df0ba97..172b47e79 100644 --- a/api/libs/braintrust/src/types.rs +++ b/api/libs/braintrust/src/types.rs @@ -113,9 +113,9 @@ impl Span { self } - /// Alias for add_metadata - pub fn with_metadata(self, key: &str, value: &str) -> Self { - self.add_metadata(key, value) + /// Alias for add_metadata that converts any displayable value to a string + pub fn with_metadata(self, key: &str, value: T) -> Self { + self.add_metadata(key, &value.to_string()) } /// Get the span ID diff --git a/api/libs/braintrust/tests/integration_tests.rs b/api/libs/braintrust/tests/integration_tests.rs index 18cd13816..bc003ba87 100644 --- a/api/libs/braintrust/tests/integration_tests.rs +++ b/api/libs/braintrust/tests/integration_tests.rs @@ -96,7 +96,7 @@ async fn test_real_trace_with_spans() -> Result<()> { // Add a root span let root_span = trace.add_span("Root Operation", "function").await?; - let root_span = root_span + let mut root_span = root_span .with_input(json!({ "operation": "root", "parameters": { @@ -107,11 +107,24 @@ async fn test_real_trace_with_spans() -> Result<()> { .with_metadata("test_id", trace_id.clone()); // Log the root span + client.log_span(root_span.clone()).await?; + + sleep(Duration::from_secs(10)).await; + + root_span = root_span + .with_output(json!({ + "result": "success", + "parameters": { + "test": true, + "timestamp": chrono::Utc::now().to_rfc3339() + } + })); + client.log_span(root_span).await?; // Add an LLM span let llm_span = trace.add_span("LLM Call", "llm").await?; - let llm_span = llm_span + let mut llm_span = llm_span .with_input(json!({ "messages": [ { @@ -120,22 +133,27 @@ async fn test_real_trace_with_spans() -> Result<()> { } ] })) - .with_output(json!({ - "choices": [ - { - "message": { - "role": "assistant", - "content": "Hello! I'm responding to your integration test message." - } - } - ] - })) - .with_tokens(15, 12) .with_metadata("model", "test-model"); // Log the LLM span - client.log_span(llm_span).await?; + client.log_span(llm_span.clone()).await?; + sleep(Duration::from_secs(15)).await; + + llm_span = llm_span + .with_output(json!({ + "choices": [ + { + "message": { + "role": "assistant", + "content": "Hello! I'm responding to your integration test message." + } + } + ] + })); + + client.log_span(llm_span).await?; + // Add a tool span let tool_span = trace.add_span("Tool Execution", "tool").await?; let tool_span = tool_span @@ -164,7 +182,7 @@ async fn test_real_trace_with_spans() -> Result<()> { trace.finish().await?; // Allow some time for async processing - sleep(Duration::from_millis(200)).await; + sleep(Duration::from_secs(30)).await; Ok(()) } From 02f2ff9bd1c7c33f1cf1e3f26e7756887938a88f Mon Sep 17 00:00:00 2001 From: dal Date: Fri, 14 Mar 2025 13:50:25 -0600 Subject: [PATCH 3/3] raw llm --- .../src/chats/get_raw_llm_messages_handler.rs | 13 ++++++++-- .../routes/chats/get_chat_raw_llm_messages.rs | 25 +++++++++---------- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/api/libs/handlers/src/chats/get_raw_llm_messages_handler.rs b/api/libs/handlers/src/chats/get_raw_llm_messages_handler.rs index 6d7f2f6e3..9d7a22dc5 100644 --- a/api/libs/handlers/src/chats/get_raw_llm_messages_handler.rs +++ b/api/libs/handlers/src/chats/get_raw_llm_messages_handler.rs @@ -12,7 +12,13 @@ pub struct GetRawLlmMessagesRequest { pub chat_id: Uuid, } -pub async fn get_raw_llm_messages_handler(chat_id: Uuid) -> Result { +#[derive(Debug, Serialize, Deserialize)] +pub struct GetRawLlmMessagesResponse { + pub chat_id: Uuid, + pub raw_llm_messages: Value, +} + +pub async fn get_raw_llm_messages_handler(chat_id: Uuid) -> Result { let pool = get_pg_pool(); let mut conn = pool.get().await?; @@ -26,5 +32,8 @@ pub async fn get_raw_llm_messages_handler(chat_id: Uuid) -> Result { .first::(&mut conn) .await?; - Ok(raw_llm_messages) + Ok(GetRawLlmMessagesResponse { + chat_id, + raw_llm_messages, + }) } diff --git a/api/src/routes/rest/routes/chats/get_chat_raw_llm_messages.rs b/api/src/routes/rest/routes/chats/get_chat_raw_llm_messages.rs index 512d39695..7b843632c 100644 --- a/api/src/routes/rest/routes/chats/get_chat_raw_llm_messages.rs +++ b/api/src/routes/rest/routes/chats/get_chat_raw_llm_messages.rs @@ -1,21 +1,20 @@ -use axum::{ - extract::{Path, State}, - Extension, Json, - http::StatusCode, - response::IntoResponse, -}; -use handlers::chats::get_raw_llm_messages_handler; -use middleware::AuthenticatedUser; -use serde_json::Value; -use uuid::Uuid; use crate::routes::rest::ApiResponse; +use axum::{extract::Path, http::StatusCode, Extension}; +use handlers::chats::get_raw_llm_messages_handler::{ + get_raw_llm_messages_handler, GetRawLlmMessagesResponse, +}; +use middleware::AuthenticatedUser; +use uuid::Uuid; pub async fn get_chat_raw_llm_messages( - Extension(user): Extension, + Extension(user): Extension, Path(chat_id): Path, -) -> Result, (StatusCode, &'static str)> { +) -> Result, (StatusCode, &'static str)> { match get_raw_llm_messages_handler(chat_id).await { Ok(response) => Ok(ApiResponse::JsonData(response)), - Err(e) => Err((StatusCode::INTERNAL_SERVER_ERROR, "Failed to get raw LLM messages")), + Err(e) => Err(( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to get raw LLM messages", + )), } }