Merge branch 'evals' into big-nate/bus-939-create-new-structure-for-chats

This commit is contained in:
Nate Kelley 2025-03-04 15:23:54 -07:00
commit e27928f79b
No known key found for this signature in database
GPG Key ID: FD90372AB8D98B4F
6 changed files with 215 additions and 83 deletions

View File

@ -329,7 +329,7 @@ impl Agent {
model: self.model.clone(),
messages: thread.messages.clone(),
tools: if tools.is_empty() { None } else { Some(tools) },
tool_choice: Some(ToolChoice::Required),
tool_choice: Some(ToolChoice::Auto),
stream: Some(true), // Enable streaming
metadata: Some(Metadata {
generation_name: "agent".to_string(),
@ -337,7 +337,6 @@ impl Agent {
session_id: thread.id.to_string(),
trace_id: thread.id.to_string(),
}),
store: Some(true),
..Default::default()
};

View File

@ -104,7 +104,7 @@ impl BusterSuperAgent {
HashMap::new(),
user_id,
session_id,
"manager_agent".to_string(),
"buster_super_agent".to_string(),
));
let manager = Self { agent };
@ -116,7 +116,7 @@ impl BusterSuperAgent {
// Create a new agent with the same core properties and shared state/stream
let agent = Arc::new(Agent::from_existing(
existing_agent,
"manager_agent".to_string(),
"buster_super_agent".to_string(),
));
let manager = Self { agent };
manager.load_tools().await?;
@ -127,7 +127,7 @@ impl BusterSuperAgent {
&self,
thread: &mut AgentThread,
) -> Result<broadcast::Receiver<Result<AgentMessage, AgentError>>> {
thread.set_developer_message(MANAGER_AGENT_PROMPT.to_string());
thread.set_developer_message(BUSTER_SUPER_AGENT_PROMPT.to_string());
// Get shutdown receiver
let rx = self.stream_process_thread(thread).await?;
@ -141,7 +141,7 @@ impl BusterSuperAgent {
}
}
const MANAGER_AGENT_PROMPT: &str = r##"### Role & Task
const BUSTER_SUPER_AGENT_PROMPT: &str = r##"### Role & Task
You are Buster, an expert analytics and data engineer. Your job is to assess what data is available and then provide fast, accurate answers to analytics questions from non-technical users. You do this by analyzing user requests, searching across a data catalog, and building metrics or dashboards.
---
### Actions Available (Tools)

View File

@ -1,4 +1,5 @@
use std::sync::Arc;
use std::collections::HashSet;
use anyhow::Result;
use async_trait::async_trait;
@ -10,6 +11,7 @@ use database::{
use diesel::prelude::*;
use diesel_async::RunQueryDsl;
use agents::{Agent, AgentMessage};
use serde_json::Value;
use uuid::Uuid;
use super::ContextLoader;
@ -22,6 +24,33 @@ impl ChatContextLoader {
pub fn new(chat_id: Uuid) -> Self {
Self { chat_id }
}
// Helper function to check for tool usage and set appropriate context
async fn update_context_from_tool_calls(agent: &Arc<Agent>, message: &AgentMessage) {
if let AgentMessage::Assistant { tool_calls: Some(tool_calls), .. } = message {
for tool_call in tool_calls {
match tool_call.function.name.as_str() {
"search_data_catalog" => {
agent.set_state_value(String::from("data_context"), Value::Bool(true))
.await;
},
"create_metrics" | "update_metrics" => {
agent.set_state_value(String::from("metrics_available"), Value::Bool(true))
.await;
},
"create_dashboard" | "update_dashboard" => {
agent.set_state_value(String::from("dashboards_available"), Value::Bool(true))
.await;
},
name if name.contains("file") || name.contains("read") || name.contains("write") || name.contains("edit") => {
agent.set_state_value(String::from("files_available"), Value::Bool(true))
.await;
},
_ => {}
}
}
}
}
}
#[async_trait]
@ -36,23 +65,33 @@ impl ContextLoader for ChatContextLoader {
.first::<database::models::Chat>(&mut conn)
.await?;
// Get all messages for the chat
let messages = messages::table
// Get only the most recent message for the chat
let message = messages::table
.filter(messages::chat_id.eq(chat.id))
.order_by(messages::created_at.asc())
.load::<database::models::Message>(&mut conn)
.order_by(messages::created_at.desc())
.first::<database::models::Message>(&mut conn)
.await?;
// Track seen message IDs
let mut seen_ids = HashSet::new();
// Convert messages to AgentMessages
let mut agent_messages = Vec::new();
for message in messages {
// Add user message
agent_messages.push(AgentMessage::user(message.request_message));
// Add assistant messages from response
if let Ok(response_messages) = serde_json::from_value::<Vec<AgentMessage>>(message.response_messages)
{
agent_messages.extend(response_messages);
// Process only the most recent message's raw LLM messages
if let Ok(raw_messages) = serde_json::from_value::<Vec<AgentMessage>>(message.raw_llm_messages) {
// Check each message for tool calls and update context
for agent_message in &raw_messages {
Self::update_context_from_tool_calls(agent, agent_message).await;
// Only add messages with new IDs
if let Some(id) = agent_message.get_id() {
if seen_ids.insert(id.to_string()) {
agent_messages.push(agent_message.clone());
}
} else {
// Messages without IDs are always included
agent_messages.push(agent_message.clone());
}
}
}

View File

@ -28,6 +28,7 @@ use crate::chats::{
chat_context::ChatContextLoader, dashboard_context::DashboardContextLoader,
metric_context::MetricContextLoader, validate_context_request, ContextLoader,
},
get_chat_handler,
streaming_parser::StreamingParser,
};
use crate::messages::types::{ChatMessage, ChatUserMessage};
@ -42,6 +43,7 @@ pub enum ThreadEvent {
GeneratingReasoningMessage,
GeneratingTitle,
InitializeChat,
Completed,
}
#[derive(Debug, Deserialize, Clone)]
@ -59,12 +61,8 @@ pub async fn post_chat_handler(
tx: Option<mpsc::Sender<Result<(BusterContainer, ThreadEvent)>>>,
) -> Result<ChatWithMessages> {
let reasoning_duration = Instant::now();
// Validate context request
validate_context_request(request.chat_id, request.metric_id, request.dashboard_id)?;
let chat_id = request.chat_id.unwrap_or_else(Uuid::new_v4);
let message_id = request.message_id.unwrap_or_else(Uuid::new_v4);
let user_org_id = match user.attributes.get("organization_id") {
Some(Value::String(org_id)) => Uuid::parse_str(&org_id).unwrap_or_default(),
_ => {
@ -73,47 +71,15 @@ pub async fn post_chat_handler(
}
};
// Initialize chat - either get existing or create new
let (chat_id, message_id, mut chat_with_messages) =
initialize_chat(&request, &user, user_org_id).await?;
tracing::info!(
"Starting post_chat_handler for chat_id: {}, message_id: {}, organization_id: {}, user_id: {}",
chat_id, message_id, user_org_id, user.id
);
// Create chat
let chat = Chat {
id: chat_id,
title: request.prompt.clone(),
organization_id: user_org_id,
created_by: user.id.clone(),
created_at: Utc::now(),
updated_at: Utc::now(),
deleted_at: None,
updated_by: user.id.clone(),
};
let mut chat_with_messages = ChatWithMessages {
id: chat_id,
title: request.prompt.clone(),
is_favorited: false,
messages: vec![ChatMessage {
id: message_id,
request_message: ChatUserMessage {
request: request.prompt.clone(),
sender_id: user.id.clone(),
sender_name: user.name.clone().unwrap_or_default(),
sender_avatar: None,
},
response_messages: vec![],
reasoning: vec![],
created_at: Utc::now().to_string(),
}],
created_at: Utc::now().to_string(),
updated_at: Utc::now().to_string(),
created_by: user.id.to_string(),
created_by_id: user.id.to_string(),
created_by_name: user.name.clone().unwrap_or_default(),
created_by_avatar: None,
};
// Send initial chat state to client
if let Some(tx) = tx.clone() {
tx.send(Ok((
@ -126,12 +92,6 @@ pub async fn post_chat_handler(
// Create database connection
let mut conn = get_pg_pool().get().await?;
// Create chat in database
insert_into(chats::table)
.values(&chat)
.execute(&mut conn)
.await?;
// Initialize agent with context if provided
let mut initial_messages = vec![];
@ -162,6 +122,9 @@ pub async fn post_chat_handler(
// Add the new user message
initial_messages.push(AgentMessage::user(request.prompt.clone()));
// Initialize raw_llm_messages with initial_messages
let mut raw_llm_messages = initial_messages.clone();
// Initialize the agent thread
let mut chat = AgentThread::new(Some(chat_id), user.id, initial_messages);
@ -187,11 +150,30 @@ pub async fn post_chat_handler(
while let Ok(message_result) = rx.recv().await {
match message_result {
Ok(msg) => {
// Store the original message
// Store the original message for file processing
all_messages.push(msg.clone());
// Only store completed messages in raw_llm_messages
match &msg {
AgentMessage::Assistant { progress, .. } => {
if matches!(progress, MessageProgress::Complete) {
raw_llm_messages.push(msg.clone());
}
}
AgentMessage::Tool { progress, .. } => {
if matches!(progress, MessageProgress::Complete) {
raw_llm_messages.push(msg.clone());
}
}
// User messages and other types don't have progress, so we store them all
AgentMessage::User { .. } => {
raw_llm_messages.push(msg.clone());
}
_ => {} // Ignore other message types
}
// Always transform the message
match transform_message(&chat_id, &message_id, msg) {
match transform_message(&chat_id, &message_id, msg, tx.as_ref()).await {
Ok((containers, event)) => {
// Store all transformed containers
for container in containers.clone() {
@ -261,7 +243,7 @@ pub async fn post_chat_handler(
reasoning: serde_json::to_value(&reasoning_messages)?,
final_reasoning_message,
title: title.title.clone().unwrap_or_default(),
raw_llm_messages: Value::Array(vec![]),
raw_llm_messages: serde_json::to_value(&raw_llm_messages)?,
};
// Insert message into database
@ -283,6 +265,15 @@ pub async fn post_chat_handler(
chat_with_messages.title = title;
}
// Send final completed state
if let Some(tx) = &tx {
tx.send(Ok((
BusterContainer::Chat(chat_with_messages.clone()),
ThreadEvent::Completed,
)))
.await?;
}
tracing::info!("Completed post_chat_handler for chat_id: {}", chat_id);
Ok(chat_with_messages)
}
@ -346,15 +337,14 @@ async fn process_completed_files(
user_id: &Uuid,
) -> Result<()> {
// Transform messages to BusterContainer format
let transformed_messages: Vec<BusterContainer> = messages
.iter()
.filter_map(|msg| {
transform_message(&message.chat_id, &message.id, msg.clone())
.ok()
.map(|(containers, _)| containers)
})
.flatten()
.collect();
let mut transformed_messages = Vec::new();
for msg in messages {
if let Ok((containers, _)) =
transform_message(&message.chat_id, &message.id, msg.clone(), None).await
{
transformed_messages.extend(containers);
}
}
// Process any completed metric or dashboard files
for container in transformed_messages {
@ -575,10 +565,11 @@ pub enum BusterContainer {
GeneratingTitle(BusterGeneratingTitle),
}
pub fn transform_message(
pub async fn transform_message(
chat_id: &Uuid,
message_id: &Uuid,
message: AgentMessage,
tx: Option<&mpsc::Sender<Result<(BusterContainer, ThreadEvent)>>>,
) -> Result<(Vec<BusterContainer>, ThreadEvent)> {
println!("MESSAGE_STREAM: Transforming message: {:?}", message);
@ -631,13 +622,25 @@ pub fn transform_message(
status: Some("completed".to_string()),
});
containers.push(BusterContainer::ReasoningMessage(
BusterReasoningMessageContainer {
let reasoning_container =
BusterContainer::ReasoningMessage(BusterReasoningMessageContainer {
reasoning: reasoning_message,
chat_id: *chat_id,
message_id: *message_id,
},
));
});
// Send the finished reasoning message separately
if let Some(tx) = tx {
if let Err(e) = tx
.send(Ok((
reasoning_container,
ThreadEvent::GeneratingReasoningMessage,
)))
.await
{
tracing::warn!("Failed to send finished reasoning message: {:?}", e);
}
}
}
return Ok((containers, ThreadEvent::GeneratingResponseMessage));
@ -1393,3 +1396,78 @@ pub async fn generate_conversation_title(
Ok(title)
}
async fn initialize_chat(
request: &ChatCreateNewChat,
user: &User,
user_org_id: Uuid,
) -> Result<(Uuid, Uuid, ChatWithMessages)> {
let message_id = request.message_id.unwrap_or_else(Uuid::new_v4);
if let Some(existing_chat_id) = request.chat_id {
// Get existing chat - no need to create new chat in DB
let mut existing_chat = get_chat_handler(&existing_chat_id, &user.id).await?;
// Add new message to existing chat
existing_chat.messages.push(ChatMessage {
id: message_id,
request_message: ChatUserMessage {
request: request.prompt.clone(),
sender_id: user.id.clone(),
sender_name: user.name.clone().unwrap_or_default(),
sender_avatar: None,
},
response_messages: vec![],
reasoning: vec![],
created_at: Utc::now().to_string(),
});
Ok((existing_chat_id, message_id, existing_chat))
} else {
// Create new chat since we don't have an existing one
let chat_id = Uuid::new_v4();
let chat = Chat {
id: chat_id,
title: request.prompt.clone(),
organization_id: user_org_id,
created_by: user.id.clone(),
created_at: Utc::now(),
updated_at: Utc::now(),
deleted_at: None,
updated_by: user.id.clone(),
};
let chat_with_messages = ChatWithMessages {
id: chat_id,
title: request.prompt.clone(),
is_favorited: false,
messages: vec![ChatMessage {
id: message_id,
request_message: ChatUserMessage {
request: request.prompt.clone(),
sender_id: user.id.clone(),
sender_name: user.name.clone().unwrap_or_default(),
sender_avatar: None,
},
response_messages: vec![],
reasoning: vec![],
created_at: Utc::now().to_string(),
}],
created_at: Utc::now().to_string(),
updated_at: Utc::now().to_string(),
created_by: user.id.to_string(),
created_by_id: user.id.to_string(),
created_by_name: user.name.clone().unwrap_or_default(),
created_by_avatar: None,
};
// Only create new chat in DB if this is a new chat
let mut conn = get_pg_pool().get().await?;
insert_into(chats::table)
.values(&chat)
.execute(&mut conn)
.await?;
Ok((chat_id, message_id, chat_with_messages))
}
}

View File

@ -250,6 +250,15 @@ impl AgentMessage {
Self::User { id, .. } => *id = Some(new_id),
}
}
pub fn get_id(&self) -> Option<String> {
match self {
Self::Assistant { id, .. } => id.clone(),
Self::Tool { id, .. } => id.clone(),
Self::Developer { id, .. } => id.clone(),
Self::User { id, .. } => id.clone(),
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
@ -584,7 +593,9 @@ mod tests {
async fn test_chat_completion_request_with_tools() {
let request = ChatCompletionRequest {
model: "o1".to_string(),
messages: vec![AgentMessage::user("Hello whats the weather in vineyard ut!")],
messages: vec![AgentMessage::user(
"Hello whats the weather in vineyard ut!",
)],
max_completion_tokens: Some(100),
tools: Some(vec![Tool {
tool_type: "function".to_string(),
@ -877,7 +888,9 @@ mod tests {
// Test request with function tool
let request = ChatCompletionRequest {
model: "gpt-4o".to_string(),
messages: vec![AgentMessage::user("What's the weather like in Boston today?")],
messages: vec![AgentMessage::user(
"What's the weather like in Boston today?",
)],
tools: Some(vec![Tool {
tool_type: "function".to_string(),
function: json!({

View File

@ -46,6 +46,9 @@ pub async fn post_thread(
ThreadEvent::InitializeChat => {
WsEvent::Threads(WSThreadEvent::InitializeChat)
}
ThreadEvent::Completed => {
WsEvent::Threads(WSThreadEvent::Complete)
}
};
let response = WsResponseMessage::new_no_user(