From 0b5ed947705b9efcd46b91c257997883f4ca1bff Mon Sep 17 00:00:00 2001 From: dal Date: Tue, 18 Mar 2025 12:26:38 -0600 Subject: [PATCH] agent with braintrust logs --- api/libs/agents/src/agent.rs | 116 ++++++++++++++++++++++------------- 1 file changed, 73 insertions(+), 43 deletions(-) diff --git a/api/libs/agents/src/agent.rs b/api/libs/agents/src/agent.rs index 4487a64b5..b70a24baf 100644 --- a/api/libs/agents/src/agent.rs +++ b/api/libs/agents/src/agent.rs @@ -408,13 +408,13 @@ impl Agent { } // Initialize trace and parent span if not provided (first call) - let (trace, parent_span) = if trace_builder.is_none() && parent_span.is_none() { + let (trace_builder, parent_span) = if trace_builder.is_none() && parent_span.is_none() { if let Some(client) = &*BRAINTRUST_CLIENT { // Create a new trace for this conversation let trace = TraceBuilder::new(client.clone(), &format!("Agent Thread {}", thread.id)); // Create the parent span for the entire conversation - let parent_span = trace.add_span("User Conversation", "conversation").await?; + let span = trace.add_span("User Conversation", "conversation").await?; // Get the most recent user message for logging let user_message = thread.messages.iter() @@ -422,13 +422,17 @@ impl Agent { .last() .cloned(); - if let Some(user_msg) = user_message { + let span = if let Some(user_msg) = user_message { // Log the user message as input to the parent span - let parent_span = parent_span.with_input(serde_json::to_value(&user_msg)?); - (Some(trace), Some(parent_span)) + span.with_input(serde_json::to_value(&user_msg)?) } else { - (Some(trace), Some(parent_span)) - } + span + }; + + // Log the initial span + client.log_span(span.clone()).await?; + + (Some(trace), Some(span)) } else { (None, None) } @@ -453,8 +457,8 @@ impl Agent { // Collect all registered tools and their schemas let tools = self.get_enabled_tools().await; - // Get the most recent user message for logging - let user_message = thread.messages.last() + // Get the most recent user message for logging (used only in error logging) + let _user_message = thread.messages.last() .filter(|msg| matches!(msg, AgentMessage::User { .. })) .cloned(); @@ -493,13 +497,17 @@ impl Agent { }; // Create an assistant span to track the assistant's response - let assistant_span = if let (Some(trace), Some(parent)) = (&trace, &parent_span) { + let assistant_span = if let (Some(trace), Some(parent)) = (&trace_builder, &parent_span) { if let Some(client) = &*BRAINTRUST_CLIENT { // Create a span for the assistant message let span = trace.add_child_span("Assistant Response", "llm", parent).await?; // Add the request as input let span = span.with_input(serde_json::to_value(&request)?); + + // Log the assistant span + client.log_span(span.clone()).await?; + Some(span) } else { None @@ -510,7 +518,7 @@ impl Agent { // Process the streaming chunks let mut buffer = MessageBuffer::new(); - let mut is_complete = false; + let mut _is_complete = false; while let Some(chunk_result) = stream_rx.recv().await { match chunk_result { @@ -555,17 +563,21 @@ impl Agent { // Check if this is the final chunk if chunk.choices[0].finish_reason.is_some() { - is_complete = true; + _is_complete = true; } } Err(e) => { // Log error in span if let Some(assistant_span) = &assistant_span { if let Some(client) = &*BRAINTRUST_CLIENT { - let span = assistant_span.with_output(serde_json::json!({ + // Create error info + let error_info = serde_json::json!({ "error": format!("Error in stream: {:?}", e) - })); - let _ = client.log_span(span).await; + }); + + // Log error as output to span + let error_span = assistant_span.clone().with_output(error_info); + let _ = client.log_span(error_span).await; } } let error_message = format!("Error in stream: {:?}", e); @@ -604,9 +616,9 @@ impl Agent { self.update_current_thread(final_message.clone()).await?; // Log the assistant message - if let Some(assistant_span) = assistant_span { + if let Some(ref assistant_span) = assistant_span { if let Some(client) = &*BRAINTRUST_CLIENT { - let span = assistant_span.with_output(serde_json::to_value(&final_message)?); + let span = assistant_span.clone().with_output(serde_json::to_value(&final_message)?); let _ = client.log_span(span).await; } } @@ -614,18 +626,17 @@ impl Agent { // If this is an auto response without tool calls, it means we're done if final_tool_calls.is_none() { // Log the final output to the parent span - if let Some(parent_span) = parent_span { + if let Some(parent_span) = &parent_span { if let Some(client) = &*BRAINTRUST_CLIENT { - let span = parent_span.with_output(serde_json::to_value(&final_message)?); - let _ = client.log_span(span).await; - - // If we have a trace, finish it - if let Some(trace) = trace { - let _ = trace.finish().await; - } + // Create a new span with the final message as output + let final_span = parent_span.clone().with_output(serde_json::to_value(&final_message)?); + let _ = client.log_span(final_span).await; } } + // Finish the trace without consuming it + self.finish_trace(&trace_builder).await?; + // Send Done message and return self.get_stream_sender() .await @@ -641,7 +652,7 @@ impl Agent { for tool_call in tool_calls { if let Some(tool) = self.tools.read().await.get(&tool_call.function.name) { // Create a tool span - let tool_span = if let (Some(trace), Some(assistant)) = (&trace, &assistant_span) { + let tool_span = if let (Some(trace), Some(assistant)) = (&trace_builder, &assistant_span) { if let Some(client) = &*BRAINTRUST_CLIENT { // Create a span for the tool execution let span = trace.add_child_span( @@ -662,6 +673,10 @@ impl Agent { // Add the tool call as input let span = span.with_input(tool_input); + + // Log the tool span + client.log_span(span.clone()).await?; + Some(span) } else { None @@ -672,7 +687,7 @@ impl Agent { // Parse the parameters let params: Value = serde_json::from_str(&tool_call.function.arguments)?; - let tool_input = serde_json::json!({ + let _tool_input = serde_json::json!({ "function": { "name": tool_call.function.name, "arguments": params @@ -685,12 +700,15 @@ impl Agent { Ok(r) => r, Err(e) => { // Log error in tool span - if let Some(tool_span) = tool_span { + if let Some(tool_span) = &tool_span { if let Some(client) = &*BRAINTRUST_CLIENT { - let span = tool_span.with_output(serde_json::json!({ + let error_info = serde_json::json!({ "error": format!("Tool execution error: {:?}", e) - })); - let _ = client.log_span(span).await; + }); + + // Create a new span with the error output + let error_span = tool_span.clone().with_output(error_info); + let _ = client.log_span(error_span).await; } } let error_message = format!("Tool execution error: {:?}", e); @@ -708,10 +726,11 @@ impl Agent { ); // Log the tool result - if let Some(tool_span) = tool_span { + if let Some(tool_span) = &tool_span { if let Some(client) = &*BRAINTRUST_CLIENT { - let span = tool_span.with_output(serde_json::to_value(&tool_message)?); - let _ = client.log_span(span).await; + // Create a new span with the tool message as output + let result_span = tool_span.clone().with_output(serde_json::to_value(&tool_message)?); + let _ = client.log_span(result_span).await; } } @@ -733,21 +752,20 @@ impl Agent { // For recursive calls, we'll continue with the same trace // We don't finish the trace here to keep all interactions in one trace - Box::pin(self.process_thread_with_depth(&new_thread, recursion_depth + 1, trace, parent_span)).await + Box::pin(self.process_thread_with_depth(&new_thread, recursion_depth + 1, trace_builder, parent_span)).await } else { // Log the final output to the parent span - if let Some(parent_span) = parent_span { + if let Some(parent_span) = &parent_span { if let Some(client) = &*BRAINTRUST_CLIENT { - let span = parent_span.with_output(serde_json::to_value(&final_message)?); - let _ = client.log_span(span).await; - - // If we have a trace, finish it - if let Some(trace) = trace { - let _ = trace.finish().await; - } + // Create a new span with the final message as output + let final_span = parent_span.clone().with_output(serde_json::to_value(&final_message)?); + let _ = client.log_span(final_span).await; } } + // Finish the trace without consuming it + self.finish_trace(&trace_builder).await?; + // Send Done message and return self.get_stream_sender() .await @@ -778,6 +796,18 @@ impl Agent { self.tools.read().await } + /// Helper method to finish a trace without consuming the TraceBuilder + async fn finish_trace(&self, trace: &Option) -> Result<()> { + if let Some(trace) = trace { + if let Some(client) = &*BRAINTRUST_CLIENT { + let root_span = trace.root_span(); + let finished_root = root_span.clone().with_output(serde_json::json!("Trace completed")); + client.log_span(finished_root).await?; + } + } + Ok(()) + } + // Add this new method alongside other channel-related methods pub async fn close(&self) { let mut tx = self.stream_tx.write().await;