Remove redundant foreign key constraints and joinables for messages_to_files

This commit is contained in:
dal 2025-02-14 13:32:50 -07:00
parent 743c256dbc
commit a59e9a26c2
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
5 changed files with 185 additions and 145 deletions

View File

@ -11,16 +11,5 @@ CREATE TABLE messages_to_files (
-- Index for faster lookups by message_id
CREATE INDEX messages_files_message_id_idx ON messages_to_files(message_id);
-- Add foreign key constraints for file_id to both metric_files and dashboard_files
ALTER TABLE messages_to_files
ADD CONSTRAINT fk_metric_files
FOREIGN KEY (file_id)
REFERENCES metric_files(id);
ALTER TABLE messages_to_files
ADD CONSTRAINT fk_dashboard_files
FOREIGN KEY (file_id)
REFERENCES dashboard_files(id);
-- Index for faster lookups by file_id
CREATE INDEX messages_files_file_id_idx ON messages_to_files(file_id);

View File

@ -607,9 +607,7 @@ diesel::joinable!(messages -> users (created_by));
diesel::joinable!(messages_deprecated -> datasets (dataset_id));
diesel::joinable!(messages_deprecated -> threads_deprecated (thread_id));
diesel::joinable!(messages_deprecated -> users (sent_by));
diesel::joinable!(messages_to_files -> dashboard_files (file_id));
diesel::joinable!(messages_to_files -> messages (message_id));
diesel::joinable!(messages_to_files -> metric_files (file_id));
diesel::joinable!(permission_groups -> organizations (organization_id));
diesel::joinable!(permission_groups_to_users -> permission_groups (permission_group_id));
diesel::joinable!(permission_groups_to_users -> users (user_id));

View File

@ -19,7 +19,8 @@ use crate::{
routes::ws::{
threads_and_messages::{
post_thread::agent_message_transformer::{
transform_message, BusterContainer, ReasoningMessage,
BusterChatMessage, BusterChatMessageContainer, BusterContainer, BusterFileLine,
BusterFileMessage, BusterReasoningMessageContainer, ReasoningMessage,
},
threads_router::{ThreadEvent, ThreadRoute},
},
@ -40,6 +41,8 @@ use crate::{
},
};
use super::agent_message_transformer::transform_message;
#[derive(Debug, Serialize, Deserialize)]
pub struct TempInitChat {
pub id: Uuid,
@ -247,7 +250,7 @@ impl AgentThreadHandler {
user_id: &Uuid,
) -> Result<(), Error> {
let mut conn = get_pg_pool().get().await?;
// Update final message state
diesel::update(messages::table)
.filter(messages::id.eq(message.id))
@ -257,7 +260,7 @@ impl AgentThreadHandler {
))
.execute(&mut conn)
.await?;
// Process any completed metric or dashboard files
for container in all_transformed_messages {
match container {
@ -348,7 +351,7 @@ impl AgentThreadHandler {
_ => (), // Skip non-reasoning messages
}
}
Ok(())
}
@ -420,7 +423,8 @@ impl AgentThreadHandler {
BusterContainer::ReasoningMessage(reasoning) => {
match &reasoning.reasoning {
ReasoningMessage::Thought(thought) => {
thought.status == "completed" && thought.thoughts.is_some()
thought.status == "completed"
&& thought.thoughts.is_some()
}
ReasoningMessage::File(file) => {
file.status == "completed" && file.file.is_some()
@ -435,7 +439,8 @@ impl AgentThreadHandler {
all_transformed_messages.extend(storage_messages);
// Update message in memory with latest messages
message.response = serde_json::to_value(&all_transformed_messages).unwrap_or_default();
message.response =
serde_json::to_value(&all_transformed_messages).unwrap_or_default();
message.updated_at = Utc::now();
// Send websocket messages for real-time updates
@ -467,7 +472,9 @@ impl AgentThreadHandler {
all_transformed_messages.clone(),
organization_id,
user_id,
).await {
)
.await
{
tracing::error!("Failed to store final message state: {}", store_err);
}
break;
@ -481,7 +488,9 @@ impl AgentThreadHandler {
all_transformed_messages,
organization_id,
user_id,
).await {
)
.await
{
tracing::error!("Failed to store final message state: {}", e);
}
}
@ -524,8 +533,6 @@ const AGENT_PROMPT: &str = r##"
You are an expert analytics/data engineer helping non-technical users get answers to their analytics questions quickly and accurately. You primarily do this by creating or returning metrics and dashboards that already exist or can be built from available datasets.
Before you begin your work and after the user message, respond acknowledging the user request and explaining simply what you are going to do. Do it in a friendly way.
## Core Responsibilities
- Only open (and show) files that clearly fulfill the user's request
- Search data catalog if you can't find solutions to verify you can build what's needed

View File

@ -108,39 +108,10 @@ impl Agent {
})
.collect();
// First, make request with tool_choice set to none
let initial_request = ChatCompletionRequest {
model: self.model.clone(),
messages: thread.messages.clone(),
tools: if tools.is_empty() {
None
} else {
Some(tools.clone())
},
tool_choice: Some(ToolChoice::None("none".to_string())),
..Default::default()
};
// Get initial response
let initial_response = self.llm_client.chat_completion(initial_request).await?;
let initial_message = &initial_response.choices[0].message;
// Ensure we have content from the initial message
let initial_content = match initial_message {
Message::Assistant { content, .. } => content.clone().unwrap_or_default(),
_ => return Err(anyhow::anyhow!("Expected assistant message from LLM")),
};
// Create a new thread with the initial response
let mut tool_thread = thread.clone();
tool_thread
.messages
.push(Message::assistant(None, Some(initial_content), None, None, None));
// Create the tool-enabled request
let request = ChatCompletionRequest {
model: self.model.clone(),
messages: tool_thread.messages.clone(),
messages: thread.messages.clone(),
tools: if tools.is_empty() { None } else { Some(tools) },
tool_choice: Some(ToolChoice::Auto("auto".to_string())),
..Default::default()
@ -255,98 +226,10 @@ impl Agent {
})
.collect();
let mut tool_thread = thread.clone();
// Only do initial message phase if this is the first call (recursion_depth = 0)
if recursion_depth == 0 {
// First, make request with tool_choice set to none
let initial_request = ChatCompletionRequest {
model: model.to_string(),
messages: thread.messages.clone(),
tools: if tools.is_empty() {
None
} else {
Some(tools.clone())
},
tool_choice: Some(ToolChoice::None("none".to_string())),
stream: Some(true),
..Default::default()
};
// Get streaming response for initial thoughts
let mut initial_stream = llm_client.stream_chat_completion(initial_request).await?;
let mut initial_message = Message::assistant(
None,
Some(String::new()),
None,
None,
None,
);
// Process initial stream chunks
while let Some(chunk_result) = initial_stream.recv().await {
match chunk_result {
Ok(chunk) => {
initial_message.set_id(chunk.id.clone());
let delta = &chunk.choices[0].delta;
// Handle content updates - send delta directly
if let Some(content) = &delta.content {
// Send the delta chunk immediately with InProgress
let _ = tx
.send(Ok(Message::assistant(
Some("initial_message".to_string()),
Some(content.clone()),
None,
Some(MessageProgress::InProgress),
None,
)))
.await;
// Also accumulate for our thread history
if let Message::Assistant {
content: msg_content,
..
} = &mut initial_message
{
if let Some(existing) = msg_content {
existing.push_str(content);
}
}
}
}
Err(e) => {
let _ = tx.send(Err(anyhow::Error::from(e))).await;
return Ok(());
}
}
}
// Ensure we have content in the initial message
let initial_content = match &initial_message {
Message::Assistant { content, .. } => content.clone().unwrap_or_default(),
_ => String::new(),
};
// Send the complete message with accumulated content
if !initial_content.trim().is_empty() {
let _ = tx
.send(Ok(Message::assistant(
Some("initial_message".to_string()),
Some(initial_content.clone()),
None,
Some(MessageProgress::Complete),
None,
)))
.await;
}
}
// Create the tool-enabled request
let request = ChatCompletionRequest {
model: model.to_string(),
messages: tool_thread.messages.clone(),
messages: thread.messages.clone(),
tools: if tools.is_empty() { None } else { Some(tools) },
tool_choice: Some(ToolChoice::Auto("auto".to_string())),
stream: Some(true),
@ -355,7 +238,8 @@ impl Agent {
// Get streaming response
let mut stream = llm_client.stream_chat_completion(request).await?;
let mut current_message = Message::assistant(None, Some(String::new()), None, None, None);
let mut current_message =
Message::assistant(None, Some(String::new()), None, None, None);
let mut current_pending_tool: Option<PendingToolCall> = None;
let mut has_tool_calls = false;
let mut tool_results = Vec::new();
@ -380,7 +264,7 @@ impl Agent {
// Create and preserve the assistant message with the tool call
let is_first = !first_tool_message_sent;
first_tool_message_sent = true;
let assistant_tool_message = Message::assistant(
Some(chunk.id.clone()),
None,
@ -753,4 +637,3 @@ mod tests {
println!("Response: {:?}", response);
}
}

View File

@ -0,0 +1,163 @@
use anyhow::Result;
use chrono::Utc;
use serde_json::json;
use uuid::Uuid;
use crate::database::{
models::{Message, Thread, User},
schema::{messages, messages_to_files, metric_files, threads},
};
use crate::routes::ws::threads_and_messages::post_thread::{
agent_message_transformer::{BusterContainer, ReasoningMessage},
agent_thread::AgentThreadHandler,
};
use crate::tests::common::{db::TestDb, env::setup_test_env};
use crate::utils::clients::ai::litellm::Message as AgentMessage;
async fn setup_test_thread(test_db: &TestDb, user: &User) -> Result<(Thread, Message)> {
let thread_id = Uuid::new_v4();
let message_id = Uuid::new_v4();
// Create thread
let thread = Thread {
id: thread_id,
title: "Test Thread".to_string(),
organization_id: Uuid::parse_str(&user.attributes["organization_id"].as_str().unwrap())?,
created_by: user.id,
created_at: Utc::now(),
updated_at: Utc::now(),
deleted_at: None,
};
diesel::insert_into(threads::table)
.values(&thread)
.execute(&mut test_db.pool.get().await?)
.await?;
// Create initial message
let message = Message {
id: message_id,
request: "test request".to_string(),
response: json!({}),
thread_id,
created_by: user.id,
created_at: Utc::now(),
updated_at: Utc::now(),
deleted_at: None,
};
diesel::insert_into(messages::table)
.values(&message)
.execute(&mut test_db.pool.get().await?)
.await?;
Ok((thread, message))
}
#[tokio::test]
async fn test_end_to_end_agent_thread_flow() -> Result<()> {
// Setup test environment
setup_test_env();
let test_db = TestDb::new().await?;
let user = test_db.create_test_user().await?;
// Setup test thread and message
let (thread, message) = setup_test_thread(&test_db, &user).await?;
// Create agent handler
let handler = AgentThreadHandler::new()?;
// Create test request
let request = ChatCreateNewChat {
prompt: "Test prompt".to_string(),
chat_id: Some(thread.id),
message_id: Some(message.id),
};
// Process request
handler.handle_request(request, user.clone()).await?;
// Verify final state
let stored_message = messages::table
.filter(messages::id.eq(message.id))
.first::<Message>(&mut test_db.pool.get().await?)
.await?;
// Message should be updated with final state
assert!(!stored_message.response.as_array().unwrap().is_empty());
Ok(())
}
#[tokio::test]
async fn test_file_creation_and_linking() -> Result<()> {
// Setup test environment
setup_test_env();
let test_db = TestDb::new().await?;
let user = test_db.create_test_user().await?;
// Setup test thread and message
let (thread, message) = setup_test_thread(&test_db, &user).await?;
// Create test messages with file creation
let transformed_messages = vec![
BusterContainer::ReasoningMessage(ReasoningMessage::File(/* create test file message */)),
];
// Store final state
AgentThreadHandler::store_final_message_state(
&message,
transformed_messages,
&thread.organization_id,
&user.id,
)
.await?;
// Verify file was created and linked
let file_links = messages_to_files::table
.filter(messages_to_files::message_id.eq(message.id))
.count()
.get_result::<i64>(&mut test_db.pool.get().await?)
.await?;
assert!(file_links > 0);
Ok(())
}
#[tokio::test]
async fn test_concurrent_agent_threads() -> Result<()> {
// Setup test environment
setup_test_env();
let test_db = TestDb::new().await?;
let user = test_db.create_test_user().await?;
// Create multiple threads
let mut handles = vec![];
let handler = AgentThreadHandler::new()?;
for i in 0..3 {
let (thread, message) = setup_test_thread(&test_db, &user).await?;
let handler = handler.clone();
let user = user.clone();
let request = ChatCreateNewChat {
prompt: format!("Test prompt {}", i),
chat_id: Some(thread.id),
message_id: Some(message.id),
};
let handle = tokio::spawn(async move {
handler.handle_request(request, user).await
});
handles.push(handle);
}
// Wait for all threads to complete
for handle in handles {
handle.await??;
}
Ok(())
}