added logging to agent

This commit is contained in:
dal 2025-03-18 12:16:24 -06:00
parent c1ca69966c
commit 3a9c9dbf84
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
3 changed files with 168 additions and 8 deletions

View File

@ -26,6 +26,7 @@ diesel-async = { workspace = true }
serde_yaml = { workspace = true }
tracing = { workspace = true }
indexmap = { workspace = true }
once_cell = { workspace = true }
# Development dependencies
[dev-dependencies]

View File

@ -5,14 +5,30 @@ use litellm::{
AgentMessage, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient,
MessageProgress, Metadata, Tool, ToolCall, ToolChoice,
};
use once_cell::sync::Lazy;
use serde_json::Value;
use std::{collections::HashMap, env, sync::Arc};
use tokio::sync::{broadcast, RwLock};
use uuid::Uuid;
use std::time::{Duration, Instant};
use crate::models::AgentThread;
// Global BraintrustClient instance
static BRAINTRUST_CLIENT: Lazy<Option<Arc<BraintrustClient>>> = Lazy::new(|| {
match std::env::var("BRAINTRUST_API_KEY") {
Ok(_) => {
match BraintrustClient::new(None, "buster-agent-logs") {
Ok(client) => Some(client),
Err(e) => {
eprintln!("Failed to create Braintrust client: {}", e);
None
}
}
}
Err(_) => None,
}
});
#[derive(Debug, Clone)]
pub struct AgentError(pub String);
@ -349,7 +365,7 @@ impl Agent {
tokio::spawn(async move {
tokio::select! {
result = agent_clone.process_thread_with_depth(&thread_clone, 0) => {
result = agent_clone.process_thread_with_depth(&thread_clone, 0, None, None) => {
if let Err(e) = result {
let err_msg = format!("Error processing thread: {:?}", e);
let _ = agent_clone.get_stream_sender().await.send(Err(AgentError(err_msg)));
@ -382,6 +398,8 @@ impl Agent {
&self,
thread: &AgentThread,
recursion_depth: u32,
trace_builder: Option<TraceBuilder>,
parent_span: Option<braintrust::Span>,
) -> Result<()> {
// Set the initial thread
{
@ -389,6 +407,35 @@ impl Agent {
*current = Some(thread.clone());
}
// Initialize trace and parent span if not provided (first call)
let (trace, parent_span) = if trace_builder.is_none() && parent_span.is_none() {
if let Some(client) = &*BRAINTRUST_CLIENT {
// Create a new trace for this conversation
let trace = TraceBuilder::new(client.clone(), &format!("Agent Thread {}", thread.id));
// Create the parent span for the entire conversation
let parent_span = trace.add_span("User Conversation", "conversation").await?;
// Get the most recent user message for logging
let user_message = thread.messages.iter()
.filter(|msg| matches!(msg, AgentMessage::User { .. }))
.last()
.cloned();
if let Some(user_msg) = user_message {
// Log the user message as input to the parent span
let parent_span = parent_span.with_input(serde_json::to_value(&user_msg)?);
(Some(trace), Some(parent_span))
} else {
(Some(trace), Some(parent_span))
}
} else {
(None, None)
}
} else {
(trace_builder, parent_span)
};
if recursion_depth >= 30 {
let message = AgentMessage::assistant(
Some("max_recursion_depth_message".to_string()),
@ -428,15 +475,39 @@ impl Agent {
};
// Get the streaming response from the LLM
let mut stream_rx = match self.llm_client.stream_chat_completion(request).await {
let mut stream_rx = match self.llm_client.stream_chat_completion(request.clone()).await {
Ok(rx) => rx,
Err(e) => {
// Log error in span
if let Some(parent_span) = parent_span.clone() {
if let Some(client) = &*BRAINTRUST_CLIENT {
let error_span = parent_span.with_output(serde_json::json!({
"error": format!("Error starting stream: {:?}", e)
}));
let _ = client.log_span(error_span).await;
}
}
let error_message = format!("Error starting stream: {:?}", e);
return Err(anyhow::anyhow!(error_message));
},
};
// Create an assistant span to track the assistant's response
let assistant_span = if let (Some(trace), Some(parent)) = (&trace, &parent_span) {
if let Some(client) = &*BRAINTRUST_CLIENT {
// Create a span for the assistant message
let span = trace.add_child_span("Assistant Response", "llm", parent).await?;
// Add the request as input
let span = span.with_input(serde_json::to_value(&request)?);
Some(span)
} else {
None
}
} else {
None
};
// Process the streaming chunks
let mut buffer = MessageBuffer::new();
let mut is_complete = false;
@ -489,6 +560,14 @@ impl Agent {
}
Err(e) => {
// Log error in span
if let Some(assistant_span) = &assistant_span {
if let Some(client) = &*BRAINTRUST_CLIENT {
let span = assistant_span.with_output(serde_json::json!({
"error": format!("Error in stream: {:?}", e)
}));
let _ = client.log_span(span).await;
}
}
let error_message = format!("Error in stream: {:?}", e);
return Err(anyhow::anyhow!(error_message));
},
@ -524,8 +603,29 @@ impl Agent {
// Update thread with assistant message
self.update_current_thread(final_message.clone()).await?;
// Log the assistant message
if let Some(assistant_span) = assistant_span {
if let Some(client) = &*BRAINTRUST_CLIENT {
let span = assistant_span.with_output(serde_json::to_value(&final_message)?);
let _ = client.log_span(span).await;
}
}
// If this is an auto response without tool calls, it means we're done
if final_tool_calls.is_none() {
// Log the final output to the parent span
if let Some(parent_span) = parent_span {
if let Some(client) = &*BRAINTRUST_CLIENT {
let span = parent_span.with_output(serde_json::to_value(&final_message)?);
let _ = client.log_span(span).await;
// If we have a trace, finish it
if let Some(trace) = trace {
let _ = trace.finish().await;
}
}
}
// Send Done message and return
self.get_stream_sender()
.await
@ -540,7 +640,37 @@ impl Agent {
// Execute each requested tool
for tool_call in tool_calls {
if let Some(tool) = self.tools.read().await.get(&tool_call.function.name) {
// Parse the parameters - log only the tool call as input
// Create a tool span
let tool_span = if let (Some(trace), Some(assistant)) = (&trace, &assistant_span) {
if let Some(client) = &*BRAINTRUST_CLIENT {
// Create a span for the tool execution
let span = trace.add_child_span(
&format!("Tool: {}", tool_call.function.name),
"tool",
assistant
).await?;
// Parse the parameters - log only the tool call as input
let params: Value = serde_json::from_str(&tool_call.function.arguments)?;
let tool_input = serde_json::json!({
"function": {
"name": tool_call.function.name,
"arguments": params
},
"id": tool_call.id
});
// Add the tool call as input
let span = span.with_input(tool_input);
Some(span)
} else {
None
}
} else {
None
};
// Parse the parameters
let params: Value = serde_json::from_str(&tool_call.function.arguments)?;
let tool_input = serde_json::json!({
"function": {
@ -555,6 +685,14 @@ impl Agent {
Ok(r) => r,
Err(e) => {
// Log error in tool span
if let Some(tool_span) = tool_span {
if let Some(client) = &*BRAINTRUST_CLIENT {
let span = tool_span.with_output(serde_json::json!({
"error": format!("Tool execution error: {:?}", e)
}));
let _ = client.log_span(span).await;
}
}
let error_message = format!("Tool execution error: {:?}", e);
return Err(anyhow::anyhow!(error_message));
}
@ -563,12 +701,20 @@ impl Agent {
let result_str = serde_json::to_string(&result)?;
let tool_message = AgentMessage::tool(
None,
result_str,
result_str.clone(),
tool_call.id.clone(),
Some(tool_call.function.name.clone()),
MessageProgress::Complete,
);
// Log the tool result
if let Some(tool_span) = tool_span {
if let Some(client) = &*BRAINTRUST_CLIENT {
let span = tool_span.with_output(serde_json::to_value(&tool_message)?);
let _ = client.log_span(span).await;
}
}
// Broadcast the tool message as soon as we receive it
self.get_stream_sender()
.await
@ -587,8 +733,21 @@ impl Agent {
// For recursive calls, we'll continue with the same trace
// We don't finish the trace here to keep all interactions in one trace
Box::pin(self.process_thread_with_depth(&new_thread, recursion_depth + 1)).await
Box::pin(self.process_thread_with_depth(&new_thread, recursion_depth + 1, trace, parent_span)).await
} else {
// Log the final output to the parent span
if let Some(parent_span) = parent_span {
if let Some(client) = &*BRAINTRUST_CLIENT {
let span = parent_span.with_output(serde_json::to_value(&final_message)?);
let _ = client.log_span(span).await;
// If we have a trace, finish it
if let Some(trace) = trace {
let _ = trace.finish().await;
}
}
}
// Send Done message and return
self.get_stream_sender()
.await

View File

@ -85,7 +85,7 @@ async fn test_real_trace_with_spans() -> Result<()> {
}
// Create client (None means use env var)
let client = BraintrustClient::new(None, "172afc4a-16b7-4d59-978e-4c87cade87b6")?;
let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff")?;
// Create a trace
let trace_id = uuid::Uuid::new_v4().to_string();
@ -249,7 +249,7 @@ async fn test_real_get_prompt() -> Result<()> {
}
// Create client (None means use env var)
let client = BraintrustClient::new(None, "c7b996a6-1c7c-482d-b23f-3d39de16f433")?;
let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff")?;
// Attempt to fetch the prompt with ID "7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee"
let prompt_id = "7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee";