added braintrust to agent. need to tweak a few more things

This commit is contained in:
dal 2025-03-14 08:56:15 -06:00
parent a9cd975a0b
commit b08ab936cf
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
1 changed files with 68 additions and 4 deletions

View File

@ -389,11 +389,15 @@ impl Agent {
*current = Some(thread.clone()); *current = Some(thread.clone());
} }
// Initialize Braintrust client
let client = BraintrustClient::new( let client = BraintrustClient::new(
None, None,
"c7b996a6-1c7c-482d-b23f-3d39de16f433" "c7b996a6-1c7c-482d-b23f-3d39de16f433"
)?; )?;
// Create a root span for this thread
let root_span_id = thread.id.to_string();
if recursion_depth >= 30 { if recursion_depth >= 30 {
let message = AgentMessage::assistant( let message = AgentMessage::assistant(
Some("max_recursion_depth_message".to_string()), Some("max_recursion_depth_message".to_string()),
@ -427,10 +431,23 @@ impl Agent {
..Default::default() ..Default::default()
}; };
// Create a span for the LLM call
let llm_span = client.create_span(
"llm_call",
"llm",
Some(&root_span_id),
None
).with_input(serde_json::to_value(&request)?);
// Get the streaming response from the LLM // Get the streaming response from the LLM
let mut stream_rx = match self.llm_client.stream_chat_completion(request).await { let mut stream_rx = match self.llm_client.stream_chat_completion(request).await {
Ok(rx) => rx, Ok(rx) => rx,
Err(e) => return Err(anyhow::anyhow!("Error starting stream: {:?}", e)), Err(e) => {
// Log error in span
let error_message = format!("Error starting stream: {:?}", e);
client.log_span(llm_span.with_output(serde_json::json!({"error": error_message}))).await?;
return Err(anyhow::anyhow!(error_message));
},
}; };
// Process the streaming chunks // Process the streaming chunks
@ -483,7 +500,12 @@ impl Agent {
is_complete = true; is_complete = true;
} }
} }
Err(e) => return Err(anyhow::anyhow!("Error in stream: {:?}", e)), Err(e) => {
// Log error in span
let error_message = format!("Error in stream: {:?}", e);
client.log_span(llm_span.with_output(serde_json::json!({"error": error_message}))).await?;
return Err(anyhow::anyhow!(error_message));
},
} }
} }
@ -508,6 +530,24 @@ impl Agent {
Some(self.name.clone()), Some(self.name.clone()),
); );
// Log the LLM response in Braintrust
let llm_output = if let Some(content) = &final_message.get_content() {
serde_json::json!({
"content": content,
"tool_calls": final_tool_calls
})
} else {
serde_json::json!({
"tool_calls": final_tool_calls
})
};
// Clone the span_id before moving llm_span
let llm_span_id = llm_span.clone().span_id().to_string();
// Now we can safely move llm_span
client.log_span(llm_span.with_output(llm_output)).await?;
// Broadcast the final assistant message // Broadcast the final assistant message
self.get_stream_sender() self.get_stream_sender()
.await .await
@ -532,8 +572,32 @@ impl Agent {
// Execute each requested tool // Execute each requested tool
for tool_call in tool_calls { for tool_call in tool_calls {
if let Some(tool) = self.tools.read().await.get(&tool_call.function.name) { if let Some(tool) = self.tools.read().await.get(&tool_call.function.name) {
// Create a span for the tool call
let tool_span = client.create_span(
&tool_call.function.name,
"tool",
Some(&root_span_id),
Some(&llm_span_id)
);
// Parse the parameters and log them
let params: Value = serde_json::from_str(&tool_call.function.arguments)?; let params: Value = serde_json::from_str(&tool_call.function.arguments)?;
let result = tool.execute(params, tool_call.id.clone()).await?; let tool_span = tool_span.with_input(params.clone());
// Execute the tool
let result = match tool.execute(params, tool_call.id.clone()).await {
Ok(r) => r,
Err(e) => {
// Log error in tool span
let error_message = format!("Tool execution error: {:?}", e);
client.log_span(tool_span.with_output(serde_json::json!({"error": error_message}))).await?;
return Err(anyhow::anyhow!(error_message));
}
};
// Log the tool result
client.log_span(tool_span.with_output(result.clone())).await?;
let result_str = serde_json::to_string(&result)?; let result_str = serde_json::to_string(&result)?;
let tool_message = AgentMessage::tool( let tool_message = AgentMessage::tool(
None, None,