mirror of https://github.com/buster-so/buster.git
added braintrust to agent. need to tweak a few more things
This commit is contained in:
parent
a9cd975a0b
commit
b08ab936cf
|
@ -389,11 +389,15 @@ impl Agent {
|
||||||
*current = Some(thread.clone());
|
*current = Some(thread.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize Braintrust client
|
||||||
let client = BraintrustClient::new(
|
let client = BraintrustClient::new(
|
||||||
None,
|
None,
|
||||||
"c7b996a6-1c7c-482d-b23f-3d39de16f433"
|
"c7b996a6-1c7c-482d-b23f-3d39de16f433"
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
// Create a root span for this thread
|
||||||
|
let root_span_id = thread.id.to_string();
|
||||||
|
|
||||||
if recursion_depth >= 30 {
|
if recursion_depth >= 30 {
|
||||||
let message = AgentMessage::assistant(
|
let message = AgentMessage::assistant(
|
||||||
Some("max_recursion_depth_message".to_string()),
|
Some("max_recursion_depth_message".to_string()),
|
||||||
|
@ -427,10 +431,23 @@ impl Agent {
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Create a span for the LLM call
|
||||||
|
let llm_span = client.create_span(
|
||||||
|
"llm_call",
|
||||||
|
"llm",
|
||||||
|
Some(&root_span_id),
|
||||||
|
None
|
||||||
|
).with_input(serde_json::to_value(&request)?);
|
||||||
|
|
||||||
// Get the streaming response from the LLM
|
// Get the streaming response from the LLM
|
||||||
let mut stream_rx = match self.llm_client.stream_chat_completion(request).await {
|
let mut stream_rx = match self.llm_client.stream_chat_completion(request).await {
|
||||||
Ok(rx) => rx,
|
Ok(rx) => rx,
|
||||||
Err(e) => return Err(anyhow::anyhow!("Error starting stream: {:?}", e)),
|
Err(e) => {
|
||||||
|
// Log error in span
|
||||||
|
let error_message = format!("Error starting stream: {:?}", e);
|
||||||
|
client.log_span(llm_span.with_output(serde_json::json!({"error": error_message}))).await?;
|
||||||
|
return Err(anyhow::anyhow!(error_message));
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
// Process the streaming chunks
|
// Process the streaming chunks
|
||||||
|
@ -483,7 +500,12 @@ impl Agent {
|
||||||
is_complete = true;
|
is_complete = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => return Err(anyhow::anyhow!("Error in stream: {:?}", e)),
|
Err(e) => {
|
||||||
|
// Log error in span
|
||||||
|
let error_message = format!("Error in stream: {:?}", e);
|
||||||
|
client.log_span(llm_span.with_output(serde_json::json!({"error": error_message}))).await?;
|
||||||
|
return Err(anyhow::anyhow!(error_message));
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -508,6 +530,24 @@ impl Agent {
|
||||||
Some(self.name.clone()),
|
Some(self.name.clone()),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Log the LLM response in Braintrust
|
||||||
|
let llm_output = if let Some(content) = &final_message.get_content() {
|
||||||
|
serde_json::json!({
|
||||||
|
"content": content,
|
||||||
|
"tool_calls": final_tool_calls
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
serde_json::json!({
|
||||||
|
"tool_calls": final_tool_calls
|
||||||
|
})
|
||||||
|
};
|
||||||
|
|
||||||
|
// Clone the span_id before moving llm_span
|
||||||
|
let llm_span_id = llm_span.clone().span_id().to_string();
|
||||||
|
|
||||||
|
// Now we can safely move llm_span
|
||||||
|
client.log_span(llm_span.with_output(llm_output)).await?;
|
||||||
|
|
||||||
// Broadcast the final assistant message
|
// Broadcast the final assistant message
|
||||||
self.get_stream_sender()
|
self.get_stream_sender()
|
||||||
.await
|
.await
|
||||||
|
@ -532,8 +572,32 @@ impl Agent {
|
||||||
// Execute each requested tool
|
// Execute each requested tool
|
||||||
for tool_call in tool_calls {
|
for tool_call in tool_calls {
|
||||||
if let Some(tool) = self.tools.read().await.get(&tool_call.function.name) {
|
if let Some(tool) = self.tools.read().await.get(&tool_call.function.name) {
|
||||||
|
// Create a span for the tool call
|
||||||
|
let tool_span = client.create_span(
|
||||||
|
&tool_call.function.name,
|
||||||
|
"tool",
|
||||||
|
Some(&root_span_id),
|
||||||
|
Some(&llm_span_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
// Parse the parameters and log them
|
||||||
let params: Value = serde_json::from_str(&tool_call.function.arguments)?;
|
let params: Value = serde_json::from_str(&tool_call.function.arguments)?;
|
||||||
let result = tool.execute(params, tool_call.id.clone()).await?;
|
let tool_span = tool_span.with_input(params.clone());
|
||||||
|
|
||||||
|
// Execute the tool
|
||||||
|
let result = match tool.execute(params, tool_call.id.clone()).await {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
// Log error in tool span
|
||||||
|
let error_message = format!("Tool execution error: {:?}", e);
|
||||||
|
client.log_span(tool_span.with_output(serde_json::json!({"error": error_message}))).await?;
|
||||||
|
return Err(anyhow::anyhow!(error_message));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Log the tool result
|
||||||
|
client.log_span(tool_span.with_output(result.clone())).await?;
|
||||||
|
|
||||||
let result_str = serde_json::to_string(&result)?;
|
let result_str = serde_json::to_string(&result)?;
|
||||||
let tool_message = AgentMessage::tool(
|
let tool_message = AgentMessage::tool(
|
||||||
None,
|
None,
|
||||||
|
|
Loading…
Reference in New Issue