diff --git a/api/libs/agents/Cargo.toml b/api/libs/agents/Cargo.toml index 599698b31..a9ffeba5e 100644 --- a/api/libs/agents/Cargo.toml +++ b/api/libs/agents/Cargo.toml @@ -26,6 +26,7 @@ diesel-async = { workspace = true } serde_yaml = { workspace = true } tracing = { workspace = true } indexmap = { workspace = true } +once_cell = { workspace = true } # Development dependencies [dev-dependencies] diff --git a/api/libs/agents/src/agent.rs b/api/libs/agents/src/agent.rs index 48006817a..4487a64b5 100644 --- a/api/libs/agents/src/agent.rs +++ b/api/libs/agents/src/agent.rs @@ -5,14 +5,30 @@ use litellm::{ AgentMessage, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient, MessageProgress, Metadata, Tool, ToolCall, ToolChoice, }; +use once_cell::sync::Lazy; use serde_json::Value; use std::{collections::HashMap, env, sync::Arc}; use tokio::sync::{broadcast, RwLock}; use uuid::Uuid; use std::time::{Duration, Instant}; - use crate::models::AgentThread; +// Global BraintrustClient instance +static BRAINTRUST_CLIENT: Lazy>> = Lazy::new(|| { + match std::env::var("BRAINTRUST_API_KEY") { + Ok(_) => { + match BraintrustClient::new(None, "buster-agent-logs") { + Ok(client) => Some(client), + Err(e) => { + eprintln!("Failed to create Braintrust client: {}", e); + None + } + } + } + Err(_) => None, + } +}); + #[derive(Debug, Clone)] pub struct AgentError(pub String); @@ -349,7 +365,7 @@ impl Agent { tokio::spawn(async move { tokio::select! { - result = agent_clone.process_thread_with_depth(&thread_clone, 0) => { + result = agent_clone.process_thread_with_depth(&thread_clone, 0, None, None) => { if let Err(e) = result { let err_msg = format!("Error processing thread: {:?}", e); let _ = agent_clone.get_stream_sender().await.send(Err(AgentError(err_msg))); @@ -382,6 +398,8 @@ impl Agent { &self, thread: &AgentThread, recursion_depth: u32, + trace_builder: Option, + parent_span: Option, ) -> Result<()> { // Set the initial thread { @@ -389,6 +407,35 @@ impl Agent { *current = Some(thread.clone()); } + // Initialize trace and parent span if not provided (first call) + let (trace, parent_span) = if trace_builder.is_none() && parent_span.is_none() { + if let Some(client) = &*BRAINTRUST_CLIENT { + // Create a new trace for this conversation + let trace = TraceBuilder::new(client.clone(), &format!("Agent Thread {}", thread.id)); + + // Create the parent span for the entire conversation + let parent_span = trace.add_span("User Conversation", "conversation").await?; + + // Get the most recent user message for logging + let user_message = thread.messages.iter() + .filter(|msg| matches!(msg, AgentMessage::User { .. })) + .last() + .cloned(); + + if let Some(user_msg) = user_message { + // Log the user message as input to the parent span + let parent_span = parent_span.with_input(serde_json::to_value(&user_msg)?); + (Some(trace), Some(parent_span)) + } else { + (Some(trace), Some(parent_span)) + } + } else { + (None, None) + } + } else { + (trace_builder, parent_span) + }; + if recursion_depth >= 30 { let message = AgentMessage::assistant( Some("max_recursion_depth_message".to_string()), @@ -428,15 +475,39 @@ impl Agent { }; // Get the streaming response from the LLM - let mut stream_rx = match self.llm_client.stream_chat_completion(request).await { + let mut stream_rx = match self.llm_client.stream_chat_completion(request.clone()).await { Ok(rx) => rx, Err(e) => { // Log error in span + if let Some(parent_span) = parent_span.clone() { + if let Some(client) = &*BRAINTRUST_CLIENT { + let error_span = parent_span.with_output(serde_json::json!({ + "error": format!("Error starting stream: {:?}", e) + })); + let _ = client.log_span(error_span).await; + } + } let error_message = format!("Error starting stream: {:?}", e); return Err(anyhow::anyhow!(error_message)); }, }; + // Create an assistant span to track the assistant's response + let assistant_span = if let (Some(trace), Some(parent)) = (&trace, &parent_span) { + if let Some(client) = &*BRAINTRUST_CLIENT { + // Create a span for the assistant message + let span = trace.add_child_span("Assistant Response", "llm", parent).await?; + + // Add the request as input + let span = span.with_input(serde_json::to_value(&request)?); + Some(span) + } else { + None + } + } else { + None + }; + // Process the streaming chunks let mut buffer = MessageBuffer::new(); let mut is_complete = false; @@ -489,6 +560,14 @@ impl Agent { } Err(e) => { // Log error in span + if let Some(assistant_span) = &assistant_span { + if let Some(client) = &*BRAINTRUST_CLIENT { + let span = assistant_span.with_output(serde_json::json!({ + "error": format!("Error in stream: {:?}", e) + })); + let _ = client.log_span(span).await; + } + } let error_message = format!("Error in stream: {:?}", e); return Err(anyhow::anyhow!(error_message)); }, @@ -524,8 +603,29 @@ impl Agent { // Update thread with assistant message self.update_current_thread(final_message.clone()).await?; + // Log the assistant message + if let Some(assistant_span) = assistant_span { + if let Some(client) = &*BRAINTRUST_CLIENT { + let span = assistant_span.with_output(serde_json::to_value(&final_message)?); + let _ = client.log_span(span).await; + } + } + // If this is an auto response without tool calls, it means we're done if final_tool_calls.is_none() { + // Log the final output to the parent span + if let Some(parent_span) = parent_span { + if let Some(client) = &*BRAINTRUST_CLIENT { + let span = parent_span.with_output(serde_json::to_value(&final_message)?); + let _ = client.log_span(span).await; + + // If we have a trace, finish it + if let Some(trace) = trace { + let _ = trace.finish().await; + } + } + } + // Send Done message and return self.get_stream_sender() .await @@ -540,7 +640,37 @@ impl Agent { // Execute each requested tool for tool_call in tool_calls { if let Some(tool) = self.tools.read().await.get(&tool_call.function.name) { - // Parse the parameters - log only the tool call as input + // Create a tool span + let tool_span = if let (Some(trace), Some(assistant)) = (&trace, &assistant_span) { + if let Some(client) = &*BRAINTRUST_CLIENT { + // Create a span for the tool execution + let span = trace.add_child_span( + &format!("Tool: {}", tool_call.function.name), + "tool", + assistant + ).await?; + + // Parse the parameters - log only the tool call as input + let params: Value = serde_json::from_str(&tool_call.function.arguments)?; + let tool_input = serde_json::json!({ + "function": { + "name": tool_call.function.name, + "arguments": params + }, + "id": tool_call.id + }); + + // Add the tool call as input + let span = span.with_input(tool_input); + Some(span) + } else { + None + } + } else { + None + }; + + // Parse the parameters let params: Value = serde_json::from_str(&tool_call.function.arguments)?; let tool_input = serde_json::json!({ "function": { @@ -555,6 +685,14 @@ impl Agent { Ok(r) => r, Err(e) => { // Log error in tool span + if let Some(tool_span) = tool_span { + if let Some(client) = &*BRAINTRUST_CLIENT { + let span = tool_span.with_output(serde_json::json!({ + "error": format!("Tool execution error: {:?}", e) + })); + let _ = client.log_span(span).await; + } + } let error_message = format!("Tool execution error: {:?}", e); return Err(anyhow::anyhow!(error_message)); } @@ -563,12 +701,20 @@ impl Agent { let result_str = serde_json::to_string(&result)?; let tool_message = AgentMessage::tool( None, - result_str, + result_str.clone(), tool_call.id.clone(), Some(tool_call.function.name.clone()), MessageProgress::Complete, ); + // Log the tool result + if let Some(tool_span) = tool_span { + if let Some(client) = &*BRAINTRUST_CLIENT { + let span = tool_span.with_output(serde_json::to_value(&tool_message)?); + let _ = client.log_span(span).await; + } + } + // Broadcast the tool message as soon as we receive it self.get_stream_sender() .await @@ -587,8 +733,21 @@ impl Agent { // For recursive calls, we'll continue with the same trace // We don't finish the trace here to keep all interactions in one trace - Box::pin(self.process_thread_with_depth(&new_thread, recursion_depth + 1)).await + Box::pin(self.process_thread_with_depth(&new_thread, recursion_depth + 1, trace, parent_span)).await } else { + // Log the final output to the parent span + if let Some(parent_span) = parent_span { + if let Some(client) = &*BRAINTRUST_CLIENT { + let span = parent_span.with_output(serde_json::to_value(&final_message)?); + let _ = client.log_span(span).await; + + // If we have a trace, finish it + if let Some(trace) = trace { + let _ = trace.finish().await; + } + } + } + // Send Done message and return self.get_stream_sender() .await diff --git a/api/libs/braintrust/tests/integration_tests.rs b/api/libs/braintrust/tests/integration_tests.rs index 1bc43cf26..b9381772f 100644 --- a/api/libs/braintrust/tests/integration_tests.rs +++ b/api/libs/braintrust/tests/integration_tests.rs @@ -85,7 +85,7 @@ async fn test_real_trace_with_spans() -> Result<()> { } // Create client (None means use env var) - let client = BraintrustClient::new(None, "172afc4a-16b7-4d59-978e-4c87cade87b6")?; + let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff")?; // Create a trace let trace_id = uuid::Uuid::new_v4().to_string(); @@ -249,7 +249,7 @@ async fn test_real_get_prompt() -> Result<()> { } // Create client (None means use env var) - let client = BraintrustClient::new(None, "c7b996a6-1c7c-482d-b23f-3d39de16f433")?; + let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff")?; // Attempt to fetch the prompt with ID "7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee" let prompt_id = "7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee";