mirror of https://github.com/buster-so/buster.git
Merge branch 'evals' into big-nate/bus-939-create-new-structure-for-chats
This commit is contained in:
commit
e27928f79b
|
@ -329,7 +329,7 @@ impl Agent {
|
|||
model: self.model.clone(),
|
||||
messages: thread.messages.clone(),
|
||||
tools: if tools.is_empty() { None } else { Some(tools) },
|
||||
tool_choice: Some(ToolChoice::Required),
|
||||
tool_choice: Some(ToolChoice::Auto),
|
||||
stream: Some(true), // Enable streaming
|
||||
metadata: Some(Metadata {
|
||||
generation_name: "agent".to_string(),
|
||||
|
@ -337,7 +337,6 @@ impl Agent {
|
|||
session_id: thread.id.to_string(),
|
||||
trace_id: thread.id.to_string(),
|
||||
}),
|
||||
store: Some(true),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
|
|
|
@ -104,7 +104,7 @@ impl BusterSuperAgent {
|
|||
HashMap::new(),
|
||||
user_id,
|
||||
session_id,
|
||||
"manager_agent".to_string(),
|
||||
"buster_super_agent".to_string(),
|
||||
));
|
||||
|
||||
let manager = Self { agent };
|
||||
|
@ -116,7 +116,7 @@ impl BusterSuperAgent {
|
|||
// Create a new agent with the same core properties and shared state/stream
|
||||
let agent = Arc::new(Agent::from_existing(
|
||||
existing_agent,
|
||||
"manager_agent".to_string(),
|
||||
"buster_super_agent".to_string(),
|
||||
));
|
||||
let manager = Self { agent };
|
||||
manager.load_tools().await?;
|
||||
|
@ -127,7 +127,7 @@ impl BusterSuperAgent {
|
|||
&self,
|
||||
thread: &mut AgentThread,
|
||||
) -> Result<broadcast::Receiver<Result<AgentMessage, AgentError>>> {
|
||||
thread.set_developer_message(MANAGER_AGENT_PROMPT.to_string());
|
||||
thread.set_developer_message(BUSTER_SUPER_AGENT_PROMPT.to_string());
|
||||
|
||||
// Get shutdown receiver
|
||||
let rx = self.stream_process_thread(thread).await?;
|
||||
|
@ -141,7 +141,7 @@ impl BusterSuperAgent {
|
|||
}
|
||||
}
|
||||
|
||||
const MANAGER_AGENT_PROMPT: &str = r##"### Role & Task
|
||||
const BUSTER_SUPER_AGENT_PROMPT: &str = r##"### Role & Task
|
||||
You are Buster, an expert analytics and data engineer. Your job is to assess what data is available and then provide fast, accurate answers to analytics questions from non-technical users. You do this by analyzing user requests, searching across a data catalog, and building metrics or dashboards.
|
||||
---
|
||||
### Actions Available (Tools)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use std::sync::Arc;
|
||||
use std::collections::HashSet;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
|
@ -10,6 +11,7 @@ use database::{
|
|||
use diesel::prelude::*;
|
||||
use diesel_async::RunQueryDsl;
|
||||
use agents::{Agent, AgentMessage};
|
||||
use serde_json::Value;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::ContextLoader;
|
||||
|
@ -22,6 +24,33 @@ impl ChatContextLoader {
|
|||
pub fn new(chat_id: Uuid) -> Self {
|
||||
Self { chat_id }
|
||||
}
|
||||
|
||||
// Helper function to check for tool usage and set appropriate context
|
||||
async fn update_context_from_tool_calls(agent: &Arc<Agent>, message: &AgentMessage) {
|
||||
if let AgentMessage::Assistant { tool_calls: Some(tool_calls), .. } = message {
|
||||
for tool_call in tool_calls {
|
||||
match tool_call.function.name.as_str() {
|
||||
"search_data_catalog" => {
|
||||
agent.set_state_value(String::from("data_context"), Value::Bool(true))
|
||||
.await;
|
||||
},
|
||||
"create_metrics" | "update_metrics" => {
|
||||
agent.set_state_value(String::from("metrics_available"), Value::Bool(true))
|
||||
.await;
|
||||
},
|
||||
"create_dashboard" | "update_dashboard" => {
|
||||
agent.set_state_value(String::from("dashboards_available"), Value::Bool(true))
|
||||
.await;
|
||||
},
|
||||
name if name.contains("file") || name.contains("read") || name.contains("write") || name.contains("edit") => {
|
||||
agent.set_state_value(String::from("files_available"), Value::Bool(true))
|
||||
.await;
|
||||
},
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
@ -36,23 +65,33 @@ impl ContextLoader for ChatContextLoader {
|
|||
.first::<database::models::Chat>(&mut conn)
|
||||
.await?;
|
||||
|
||||
// Get all messages for the chat
|
||||
let messages = messages::table
|
||||
// Get only the most recent message for the chat
|
||||
let message = messages::table
|
||||
.filter(messages::chat_id.eq(chat.id))
|
||||
.order_by(messages::created_at.asc())
|
||||
.load::<database::models::Message>(&mut conn)
|
||||
.order_by(messages::created_at.desc())
|
||||
.first::<database::models::Message>(&mut conn)
|
||||
.await?;
|
||||
|
||||
// Track seen message IDs
|
||||
let mut seen_ids = HashSet::new();
|
||||
// Convert messages to AgentMessages
|
||||
let mut agent_messages = Vec::new();
|
||||
for message in messages {
|
||||
// Add user message
|
||||
agent_messages.push(AgentMessage::user(message.request_message));
|
||||
|
||||
// Add assistant messages from response
|
||||
if let Ok(response_messages) = serde_json::from_value::<Vec<AgentMessage>>(message.response_messages)
|
||||
{
|
||||
agent_messages.extend(response_messages);
|
||||
|
||||
// Process only the most recent message's raw LLM messages
|
||||
if let Ok(raw_messages) = serde_json::from_value::<Vec<AgentMessage>>(message.raw_llm_messages) {
|
||||
// Check each message for tool calls and update context
|
||||
for agent_message in &raw_messages {
|
||||
Self::update_context_from_tool_calls(agent, agent_message).await;
|
||||
|
||||
// Only add messages with new IDs
|
||||
if let Some(id) = agent_message.get_id() {
|
||||
if seen_ids.insert(id.to_string()) {
|
||||
agent_messages.push(agent_message.clone());
|
||||
}
|
||||
} else {
|
||||
// Messages without IDs are always included
|
||||
agent_messages.push(agent_message.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ use crate::chats::{
|
|||
chat_context::ChatContextLoader, dashboard_context::DashboardContextLoader,
|
||||
metric_context::MetricContextLoader, validate_context_request, ContextLoader,
|
||||
},
|
||||
get_chat_handler,
|
||||
streaming_parser::StreamingParser,
|
||||
};
|
||||
use crate::messages::types::{ChatMessage, ChatUserMessage};
|
||||
|
@ -42,6 +43,7 @@ pub enum ThreadEvent {
|
|||
GeneratingReasoningMessage,
|
||||
GeneratingTitle,
|
||||
InitializeChat,
|
||||
Completed,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
|
@ -59,12 +61,8 @@ pub async fn post_chat_handler(
|
|||
tx: Option<mpsc::Sender<Result<(BusterContainer, ThreadEvent)>>>,
|
||||
) -> Result<ChatWithMessages> {
|
||||
let reasoning_duration = Instant::now();
|
||||
// Validate context request
|
||||
validate_context_request(request.chat_id, request.metric_id, request.dashboard_id)?;
|
||||
|
||||
let chat_id = request.chat_id.unwrap_or_else(Uuid::new_v4);
|
||||
let message_id = request.message_id.unwrap_or_else(Uuid::new_v4);
|
||||
|
||||
let user_org_id = match user.attributes.get("organization_id") {
|
||||
Some(Value::String(org_id)) => Uuid::parse_str(&org_id).unwrap_or_default(),
|
||||
_ => {
|
||||
|
@ -73,47 +71,15 @@ pub async fn post_chat_handler(
|
|||
}
|
||||
};
|
||||
|
||||
// Initialize chat - either get existing or create new
|
||||
let (chat_id, message_id, mut chat_with_messages) =
|
||||
initialize_chat(&request, &user, user_org_id).await?;
|
||||
|
||||
tracing::info!(
|
||||
"Starting post_chat_handler for chat_id: {}, message_id: {}, organization_id: {}, user_id: {}",
|
||||
chat_id, message_id, user_org_id, user.id
|
||||
);
|
||||
|
||||
// Create chat
|
||||
let chat = Chat {
|
||||
id: chat_id,
|
||||
title: request.prompt.clone(),
|
||||
organization_id: user_org_id,
|
||||
created_by: user.id.clone(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
deleted_at: None,
|
||||
updated_by: user.id.clone(),
|
||||
};
|
||||
|
||||
let mut chat_with_messages = ChatWithMessages {
|
||||
id: chat_id,
|
||||
title: request.prompt.clone(),
|
||||
is_favorited: false,
|
||||
messages: vec![ChatMessage {
|
||||
id: message_id,
|
||||
request_message: ChatUserMessage {
|
||||
request: request.prompt.clone(),
|
||||
sender_id: user.id.clone(),
|
||||
sender_name: user.name.clone().unwrap_or_default(),
|
||||
sender_avatar: None,
|
||||
},
|
||||
response_messages: vec![],
|
||||
reasoning: vec![],
|
||||
created_at: Utc::now().to_string(),
|
||||
}],
|
||||
created_at: Utc::now().to_string(),
|
||||
updated_at: Utc::now().to_string(),
|
||||
created_by: user.id.to_string(),
|
||||
created_by_id: user.id.to_string(),
|
||||
created_by_name: user.name.clone().unwrap_or_default(),
|
||||
created_by_avatar: None,
|
||||
};
|
||||
|
||||
// Send initial chat state to client
|
||||
if let Some(tx) = tx.clone() {
|
||||
tx.send(Ok((
|
||||
|
@ -126,12 +92,6 @@ pub async fn post_chat_handler(
|
|||
// Create database connection
|
||||
let mut conn = get_pg_pool().get().await?;
|
||||
|
||||
// Create chat in database
|
||||
insert_into(chats::table)
|
||||
.values(&chat)
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
|
||||
// Initialize agent with context if provided
|
||||
let mut initial_messages = vec![];
|
||||
|
||||
|
@ -162,6 +122,9 @@ pub async fn post_chat_handler(
|
|||
// Add the new user message
|
||||
initial_messages.push(AgentMessage::user(request.prompt.clone()));
|
||||
|
||||
// Initialize raw_llm_messages with initial_messages
|
||||
let mut raw_llm_messages = initial_messages.clone();
|
||||
|
||||
// Initialize the agent thread
|
||||
let mut chat = AgentThread::new(Some(chat_id), user.id, initial_messages);
|
||||
|
||||
|
@ -187,11 +150,30 @@ pub async fn post_chat_handler(
|
|||
while let Ok(message_result) = rx.recv().await {
|
||||
match message_result {
|
||||
Ok(msg) => {
|
||||
// Store the original message
|
||||
// Store the original message for file processing
|
||||
all_messages.push(msg.clone());
|
||||
|
||||
// Only store completed messages in raw_llm_messages
|
||||
match &msg {
|
||||
AgentMessage::Assistant { progress, .. } => {
|
||||
if matches!(progress, MessageProgress::Complete) {
|
||||
raw_llm_messages.push(msg.clone());
|
||||
}
|
||||
}
|
||||
AgentMessage::Tool { progress, .. } => {
|
||||
if matches!(progress, MessageProgress::Complete) {
|
||||
raw_llm_messages.push(msg.clone());
|
||||
}
|
||||
}
|
||||
// User messages and other types don't have progress, so we store them all
|
||||
AgentMessage::User { .. } => {
|
||||
raw_llm_messages.push(msg.clone());
|
||||
}
|
||||
_ => {} // Ignore other message types
|
||||
}
|
||||
|
||||
// Always transform the message
|
||||
match transform_message(&chat_id, &message_id, msg) {
|
||||
match transform_message(&chat_id, &message_id, msg, tx.as_ref()).await {
|
||||
Ok((containers, event)) => {
|
||||
// Store all transformed containers
|
||||
for container in containers.clone() {
|
||||
|
@ -261,7 +243,7 @@ pub async fn post_chat_handler(
|
|||
reasoning: serde_json::to_value(&reasoning_messages)?,
|
||||
final_reasoning_message,
|
||||
title: title.title.clone().unwrap_or_default(),
|
||||
raw_llm_messages: Value::Array(vec![]),
|
||||
raw_llm_messages: serde_json::to_value(&raw_llm_messages)?,
|
||||
};
|
||||
|
||||
// Insert message into database
|
||||
|
@ -283,6 +265,15 @@ pub async fn post_chat_handler(
|
|||
chat_with_messages.title = title;
|
||||
}
|
||||
|
||||
// Send final completed state
|
||||
if let Some(tx) = &tx {
|
||||
tx.send(Ok((
|
||||
BusterContainer::Chat(chat_with_messages.clone()),
|
||||
ThreadEvent::Completed,
|
||||
)))
|
||||
.await?;
|
||||
}
|
||||
|
||||
tracing::info!("Completed post_chat_handler for chat_id: {}", chat_id);
|
||||
Ok(chat_with_messages)
|
||||
}
|
||||
|
@ -346,15 +337,14 @@ async fn process_completed_files(
|
|||
user_id: &Uuid,
|
||||
) -> Result<()> {
|
||||
// Transform messages to BusterContainer format
|
||||
let transformed_messages: Vec<BusterContainer> = messages
|
||||
.iter()
|
||||
.filter_map(|msg| {
|
||||
transform_message(&message.chat_id, &message.id, msg.clone())
|
||||
.ok()
|
||||
.map(|(containers, _)| containers)
|
||||
})
|
||||
.flatten()
|
||||
.collect();
|
||||
let mut transformed_messages = Vec::new();
|
||||
for msg in messages {
|
||||
if let Ok((containers, _)) =
|
||||
transform_message(&message.chat_id, &message.id, msg.clone(), None).await
|
||||
{
|
||||
transformed_messages.extend(containers);
|
||||
}
|
||||
}
|
||||
|
||||
// Process any completed metric or dashboard files
|
||||
for container in transformed_messages {
|
||||
|
@ -575,10 +565,11 @@ pub enum BusterContainer {
|
|||
GeneratingTitle(BusterGeneratingTitle),
|
||||
}
|
||||
|
||||
pub fn transform_message(
|
||||
pub async fn transform_message(
|
||||
chat_id: &Uuid,
|
||||
message_id: &Uuid,
|
||||
message: AgentMessage,
|
||||
tx: Option<&mpsc::Sender<Result<(BusterContainer, ThreadEvent)>>>,
|
||||
) -> Result<(Vec<BusterContainer>, ThreadEvent)> {
|
||||
println!("MESSAGE_STREAM: Transforming message: {:?}", message);
|
||||
|
||||
|
@ -631,13 +622,25 @@ pub fn transform_message(
|
|||
status: Some("completed".to_string()),
|
||||
});
|
||||
|
||||
containers.push(BusterContainer::ReasoningMessage(
|
||||
BusterReasoningMessageContainer {
|
||||
let reasoning_container =
|
||||
BusterContainer::ReasoningMessage(BusterReasoningMessageContainer {
|
||||
reasoning: reasoning_message,
|
||||
chat_id: *chat_id,
|
||||
message_id: *message_id,
|
||||
},
|
||||
));
|
||||
});
|
||||
|
||||
// Send the finished reasoning message separately
|
||||
if let Some(tx) = tx {
|
||||
if let Err(e) = tx
|
||||
.send(Ok((
|
||||
reasoning_container,
|
||||
ThreadEvent::GeneratingReasoningMessage,
|
||||
)))
|
||||
.await
|
||||
{
|
||||
tracing::warn!("Failed to send finished reasoning message: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Ok((containers, ThreadEvent::GeneratingResponseMessage));
|
||||
|
@ -1393,3 +1396,78 @@ pub async fn generate_conversation_title(
|
|||
|
||||
Ok(title)
|
||||
}
|
||||
|
||||
async fn initialize_chat(
|
||||
request: &ChatCreateNewChat,
|
||||
user: &User,
|
||||
user_org_id: Uuid,
|
||||
) -> Result<(Uuid, Uuid, ChatWithMessages)> {
|
||||
let message_id = request.message_id.unwrap_or_else(Uuid::new_v4);
|
||||
|
||||
if let Some(existing_chat_id) = request.chat_id {
|
||||
// Get existing chat - no need to create new chat in DB
|
||||
let mut existing_chat = get_chat_handler(&existing_chat_id, &user.id).await?;
|
||||
|
||||
// Add new message to existing chat
|
||||
existing_chat.messages.push(ChatMessage {
|
||||
id: message_id,
|
||||
request_message: ChatUserMessage {
|
||||
request: request.prompt.clone(),
|
||||
sender_id: user.id.clone(),
|
||||
sender_name: user.name.clone().unwrap_or_default(),
|
||||
sender_avatar: None,
|
||||
},
|
||||
response_messages: vec![],
|
||||
reasoning: vec![],
|
||||
created_at: Utc::now().to_string(),
|
||||
});
|
||||
|
||||
Ok((existing_chat_id, message_id, existing_chat))
|
||||
} else {
|
||||
// Create new chat since we don't have an existing one
|
||||
let chat_id = Uuid::new_v4();
|
||||
let chat = Chat {
|
||||
id: chat_id,
|
||||
title: request.prompt.clone(),
|
||||
organization_id: user_org_id,
|
||||
created_by: user.id.clone(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
deleted_at: None,
|
||||
updated_by: user.id.clone(),
|
||||
};
|
||||
|
||||
let chat_with_messages = ChatWithMessages {
|
||||
id: chat_id,
|
||||
title: request.prompt.clone(),
|
||||
is_favorited: false,
|
||||
messages: vec![ChatMessage {
|
||||
id: message_id,
|
||||
request_message: ChatUserMessage {
|
||||
request: request.prompt.clone(),
|
||||
sender_id: user.id.clone(),
|
||||
sender_name: user.name.clone().unwrap_or_default(),
|
||||
sender_avatar: None,
|
||||
},
|
||||
response_messages: vec![],
|
||||
reasoning: vec![],
|
||||
created_at: Utc::now().to_string(),
|
||||
}],
|
||||
created_at: Utc::now().to_string(),
|
||||
updated_at: Utc::now().to_string(),
|
||||
created_by: user.id.to_string(),
|
||||
created_by_id: user.id.to_string(),
|
||||
created_by_name: user.name.clone().unwrap_or_default(),
|
||||
created_by_avatar: None,
|
||||
};
|
||||
|
||||
// Only create new chat in DB if this is a new chat
|
||||
let mut conn = get_pg_pool().get().await?;
|
||||
insert_into(chats::table)
|
||||
.values(&chat)
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
|
||||
Ok((chat_id, message_id, chat_with_messages))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -250,6 +250,15 @@ impl AgentMessage {
|
|||
Self::User { id, .. } => *id = Some(new_id),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_id(&self) -> Option<String> {
|
||||
match self {
|
||||
Self::Assistant { id, .. } => id.clone(),
|
||||
Self::Tool { id, .. } => id.clone(),
|
||||
Self::Developer { id, .. } => id.clone(),
|
||||
Self::User { id, .. } => id.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
|
@ -584,7 +593,9 @@ mod tests {
|
|||
async fn test_chat_completion_request_with_tools() {
|
||||
let request = ChatCompletionRequest {
|
||||
model: "o1".to_string(),
|
||||
messages: vec![AgentMessage::user("Hello whats the weather in vineyard ut!")],
|
||||
messages: vec![AgentMessage::user(
|
||||
"Hello whats the weather in vineyard ut!",
|
||||
)],
|
||||
max_completion_tokens: Some(100),
|
||||
tools: Some(vec![Tool {
|
||||
tool_type: "function".to_string(),
|
||||
|
@ -877,7 +888,9 @@ mod tests {
|
|||
// Test request with function tool
|
||||
let request = ChatCompletionRequest {
|
||||
model: "gpt-4o".to_string(),
|
||||
messages: vec![AgentMessage::user("What's the weather like in Boston today?")],
|
||||
messages: vec![AgentMessage::user(
|
||||
"What's the weather like in Boston today?",
|
||||
)],
|
||||
tools: Some(vec![Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: json!({
|
||||
|
|
|
@ -46,6 +46,9 @@ pub async fn post_thread(
|
|||
ThreadEvent::InitializeChat => {
|
||||
WsEvent::Threads(WSThreadEvent::InitializeChat)
|
||||
}
|
||||
ThreadEvent::Completed => {
|
||||
WsEvent::Threads(WSThreadEvent::Complete)
|
||||
}
|
||||
};
|
||||
|
||||
let response = WsResponseMessage::new_no_user(
|
||||
|
|
Loading…
Reference in New Issue