From b08ab936cf2929b905aa4993ad60b4393c539184 Mon Sep 17 00:00:00 2001 From: dal Date: Fri, 14 Mar 2025 08:56:15 -0600 Subject: [PATCH] added braintrust to agent. need to tweak a few more things --- api/libs/agents/src/agent.rs | 72 ++++++++++++++++++++++++++++++++++-- 1 file changed, 68 insertions(+), 4 deletions(-) diff --git a/api/libs/agents/src/agent.rs b/api/libs/agents/src/agent.rs index 041fa2a65..06f90d503 100644 --- a/api/libs/agents/src/agent.rs +++ b/api/libs/agents/src/agent.rs @@ -389,11 +389,15 @@ impl Agent { *current = Some(thread.clone()); } + // Initialize Braintrust client let client = BraintrustClient::new( None, "c7b996a6-1c7c-482d-b23f-3d39de16f433" )?; - + + // Create a root span for this thread + let root_span_id = thread.id.to_string(); + if recursion_depth >= 30 { let message = AgentMessage::assistant( Some("max_recursion_depth_message".to_string()), @@ -427,10 +431,23 @@ impl Agent { ..Default::default() }; + // Create a span for the LLM call + let llm_span = client.create_span( + "llm_call", + "llm", + Some(&root_span_id), + None + ).with_input(serde_json::to_value(&request)?); + // Get the streaming response from the LLM let mut stream_rx = match self.llm_client.stream_chat_completion(request).await { Ok(rx) => rx, - Err(e) => return Err(anyhow::anyhow!("Error starting stream: {:?}", e)), + Err(e) => { + // Log error in span + let error_message = format!("Error starting stream: {:?}", e); + client.log_span(llm_span.with_output(serde_json::json!({"error": error_message}))).await?; + return Err(anyhow::anyhow!(error_message)); + }, }; // Process the streaming chunks @@ -483,7 +500,12 @@ impl Agent { is_complete = true; } } - Err(e) => return Err(anyhow::anyhow!("Error in stream: {:?}", e)), + Err(e) => { + // Log error in span + let error_message = format!("Error in stream: {:?}", e); + client.log_span(llm_span.with_output(serde_json::json!({"error": error_message}))).await?; + return Err(anyhow::anyhow!(error_message)); + }, } } @@ -508,6 +530,24 @@ impl Agent { Some(self.name.clone()), ); + // Log the LLM response in Braintrust + let llm_output = if let Some(content) = &final_message.get_content() { + serde_json::json!({ + "content": content, + "tool_calls": final_tool_calls + }) + } else { + serde_json::json!({ + "tool_calls": final_tool_calls + }) + }; + + // Clone the span_id before moving llm_span + let llm_span_id = llm_span.clone().span_id().to_string(); + + // Now we can safely move llm_span + client.log_span(llm_span.with_output(llm_output)).await?; + // Broadcast the final assistant message self.get_stream_sender() .await @@ -532,8 +572,32 @@ impl Agent { // Execute each requested tool for tool_call in tool_calls { if let Some(tool) = self.tools.read().await.get(&tool_call.function.name) { + // Create a span for the tool call + let tool_span = client.create_span( + &tool_call.function.name, + "tool", + Some(&root_span_id), + Some(&llm_span_id) + ); + + // Parse the parameters and log them let params: Value = serde_json::from_str(&tool_call.function.arguments)?; - let result = tool.execute(params, tool_call.id.clone()).await?; + let tool_span = tool_span.with_input(params.clone()); + + // Execute the tool + let result = match tool.execute(params, tool_call.id.clone()).await { + Ok(r) => r, + Err(e) => { + // Log error in tool span + let error_message = format!("Tool execution error: {:?}", e); + client.log_span(tool_span.with_output(serde_json::json!({"error": error_message}))).await?; + return Err(anyhow::anyhow!(error_message)); + } + }; + + // Log the tool result + client.log_span(tool_span.with_output(result.clone())).await?; + let result_str = serde_json::to_string(&result)?; let tool_message = AgentMessage::tool( None,