agent with braintrust logs

This commit is contained in:
dal 2025-03-18 12:26:38 -06:00
parent 3a9c9dbf84
commit 0b5ed94770
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
1 changed files with 73 additions and 43 deletions

View File

@ -408,13 +408,13 @@ impl Agent {
}
// Initialize trace and parent span if not provided (first call)
let (trace, parent_span) = if trace_builder.is_none() && parent_span.is_none() {
let (trace_builder, parent_span) = if trace_builder.is_none() && parent_span.is_none() {
if let Some(client) = &*BRAINTRUST_CLIENT {
// Create a new trace for this conversation
let trace = TraceBuilder::new(client.clone(), &format!("Agent Thread {}", thread.id));
// Create the parent span for the entire conversation
let parent_span = trace.add_span("User Conversation", "conversation").await?;
let span = trace.add_span("User Conversation", "conversation").await?;
// Get the most recent user message for logging
let user_message = thread.messages.iter()
@ -422,13 +422,17 @@ impl Agent {
.last()
.cloned();
if let Some(user_msg) = user_message {
let span = if let Some(user_msg) = user_message {
// Log the user message as input to the parent span
let parent_span = parent_span.with_input(serde_json::to_value(&user_msg)?);
(Some(trace), Some(parent_span))
span.with_input(serde_json::to_value(&user_msg)?)
} else {
(Some(trace), Some(parent_span))
}
span
};
// Log the initial span
client.log_span(span.clone()).await?;
(Some(trace), Some(span))
} else {
(None, None)
}
@ -453,8 +457,8 @@ impl Agent {
// Collect all registered tools and their schemas
let tools = self.get_enabled_tools().await;
// Get the most recent user message for logging
let user_message = thread.messages.last()
// Get the most recent user message for logging (used only in error logging)
let _user_message = thread.messages.last()
.filter(|msg| matches!(msg, AgentMessage::User { .. }))
.cloned();
@ -493,13 +497,17 @@ impl Agent {
};
// Create an assistant span to track the assistant's response
let assistant_span = if let (Some(trace), Some(parent)) = (&trace, &parent_span) {
let assistant_span = if let (Some(trace), Some(parent)) = (&trace_builder, &parent_span) {
if let Some(client) = &*BRAINTRUST_CLIENT {
// Create a span for the assistant message
let span = trace.add_child_span("Assistant Response", "llm", parent).await?;
// Add the request as input
let span = span.with_input(serde_json::to_value(&request)?);
// Log the assistant span
client.log_span(span.clone()).await?;
Some(span)
} else {
None
@ -510,7 +518,7 @@ impl Agent {
// Process the streaming chunks
let mut buffer = MessageBuffer::new();
let mut is_complete = false;
let mut _is_complete = false;
while let Some(chunk_result) = stream_rx.recv().await {
match chunk_result {
@ -555,17 +563,21 @@ impl Agent {
// Check if this is the final chunk
if chunk.choices[0].finish_reason.is_some() {
is_complete = true;
_is_complete = true;
}
}
Err(e) => {
// Log error in span
if let Some(assistant_span) = &assistant_span {
if let Some(client) = &*BRAINTRUST_CLIENT {
let span = assistant_span.with_output(serde_json::json!({
// Create error info
let error_info = serde_json::json!({
"error": format!("Error in stream: {:?}", e)
}));
let _ = client.log_span(span).await;
});
// Log error as output to span
let error_span = assistant_span.clone().with_output(error_info);
let _ = client.log_span(error_span).await;
}
}
let error_message = format!("Error in stream: {:?}", e);
@ -604,9 +616,9 @@ impl Agent {
self.update_current_thread(final_message.clone()).await?;
// Log the assistant message
if let Some(assistant_span) = assistant_span {
if let Some(ref assistant_span) = assistant_span {
if let Some(client) = &*BRAINTRUST_CLIENT {
let span = assistant_span.with_output(serde_json::to_value(&final_message)?);
let span = assistant_span.clone().with_output(serde_json::to_value(&final_message)?);
let _ = client.log_span(span).await;
}
}
@ -614,18 +626,17 @@ impl Agent {
// If this is an auto response without tool calls, it means we're done
if final_tool_calls.is_none() {
// Log the final output to the parent span
if let Some(parent_span) = parent_span {
if let Some(parent_span) = &parent_span {
if let Some(client) = &*BRAINTRUST_CLIENT {
let span = parent_span.with_output(serde_json::to_value(&final_message)?);
let _ = client.log_span(span).await;
// If we have a trace, finish it
if let Some(trace) = trace {
let _ = trace.finish().await;
}
// Create a new span with the final message as output
let final_span = parent_span.clone().with_output(serde_json::to_value(&final_message)?);
let _ = client.log_span(final_span).await;
}
}
// Finish the trace without consuming it
self.finish_trace(&trace_builder).await?;
// Send Done message and return
self.get_stream_sender()
.await
@ -641,7 +652,7 @@ impl Agent {
for tool_call in tool_calls {
if let Some(tool) = self.tools.read().await.get(&tool_call.function.name) {
// Create a tool span
let tool_span = if let (Some(trace), Some(assistant)) = (&trace, &assistant_span) {
let tool_span = if let (Some(trace), Some(assistant)) = (&trace_builder, &assistant_span) {
if let Some(client) = &*BRAINTRUST_CLIENT {
// Create a span for the tool execution
let span = trace.add_child_span(
@ -662,6 +673,10 @@ impl Agent {
// Add the tool call as input
let span = span.with_input(tool_input);
// Log the tool span
client.log_span(span.clone()).await?;
Some(span)
} else {
None
@ -672,7 +687,7 @@ impl Agent {
// Parse the parameters
let params: Value = serde_json::from_str(&tool_call.function.arguments)?;
let tool_input = serde_json::json!({
let _tool_input = serde_json::json!({
"function": {
"name": tool_call.function.name,
"arguments": params
@ -685,12 +700,15 @@ impl Agent {
Ok(r) => r,
Err(e) => {
// Log error in tool span
if let Some(tool_span) = tool_span {
if let Some(tool_span) = &tool_span {
if let Some(client) = &*BRAINTRUST_CLIENT {
let span = tool_span.with_output(serde_json::json!({
let error_info = serde_json::json!({
"error": format!("Tool execution error: {:?}", e)
}));
let _ = client.log_span(span).await;
});
// Create a new span with the error output
let error_span = tool_span.clone().with_output(error_info);
let _ = client.log_span(error_span).await;
}
}
let error_message = format!("Tool execution error: {:?}", e);
@ -708,10 +726,11 @@ impl Agent {
);
// Log the tool result
if let Some(tool_span) = tool_span {
if let Some(tool_span) = &tool_span {
if let Some(client) = &*BRAINTRUST_CLIENT {
let span = tool_span.with_output(serde_json::to_value(&tool_message)?);
let _ = client.log_span(span).await;
// Create a new span with the tool message as output
let result_span = tool_span.clone().with_output(serde_json::to_value(&tool_message)?);
let _ = client.log_span(result_span).await;
}
}
@ -733,21 +752,20 @@ impl Agent {
// For recursive calls, we'll continue with the same trace
// We don't finish the trace here to keep all interactions in one trace
Box::pin(self.process_thread_with_depth(&new_thread, recursion_depth + 1, trace, parent_span)).await
Box::pin(self.process_thread_with_depth(&new_thread, recursion_depth + 1, trace_builder, parent_span)).await
} else {
// Log the final output to the parent span
if let Some(parent_span) = parent_span {
if let Some(parent_span) = &parent_span {
if let Some(client) = &*BRAINTRUST_CLIENT {
let span = parent_span.with_output(serde_json::to_value(&final_message)?);
let _ = client.log_span(span).await;
// If we have a trace, finish it
if let Some(trace) = trace {
let _ = trace.finish().await;
}
// Create a new span with the final message as output
let final_span = parent_span.clone().with_output(serde_json::to_value(&final_message)?);
let _ = client.log_span(final_span).await;
}
}
// Finish the trace without consuming it
self.finish_trace(&trace_builder).await?;
// Send Done message and return
self.get_stream_sender()
.await
@ -778,6 +796,18 @@ impl Agent {
self.tools.read().await
}
/// Helper method to finish a trace without consuming the TraceBuilder
async fn finish_trace(&self, trace: &Option<TraceBuilder>) -> Result<()> {
if let Some(trace) = trace {
if let Some(client) = &*BRAINTRUST_CLIENT {
let root_span = trace.root_span();
let finished_root = root_span.clone().with_output(serde_json::json!("Trace completed"));
client.log_span(finished_root).await?;
}
}
Ok(())
}
// Add this new method alongside other channel-related methods
pub async fn close(&self) {
let mut tx = self.stream_tx.write().await;