From 3bd5516486064424c4c20d5e4472a5b7e4f2558a Mon Sep 17 00:00:00 2001 From: dal Date: Thu, 13 Feb 2025 16:15:09 -0700 Subject: [PATCH] ok thread is being inserted --- api/.cursor/rules/global.mdc | 6 +- api/.gitignore | 4 +- api/Cargo.toml | 13 +- .../post_thread/agent_message_transformer.rs | 18 +-- .../post_thread/agent_thread.rs | 133 ++++++++++++++++-- 5 files changed, 142 insertions(+), 32 deletions(-) diff --git a/api/.cursor/rules/global.mdc b/api/.cursor/rules/global.mdc index 5965aa3c8..890f3aece 100644 --- a/api/.cursor/rules/global.mdc +++ b/api/.cursor/rules/global.mdc @@ -1,6 +1,6 @@ --- description: These are global rules and recommendations for the rust server. -globs: +globs: * --- # Global Rules and Project Structure @@ -17,6 +17,9 @@ This is a Rust web server project built with Axum, focusing on high performance, - `database/` - Database models, schema, and connection management - `main.rs` - Application entry point and server setup +## Implementation +When working with prds, you should always mark your progress off in them as you build. + ## Database Connectivity - The primary database connection is managed through `get_pg_pool()`, which returns a lazy static `PgPool` - Always use this pool for database connections to ensure proper connection management @@ -141,3 +144,4 @@ Remember to always consider: 2. Transaction boundaries for data consistency 3. Error propagation and cleanup 4. Memory usage and ownership +5. Please use comments to help document your code and make it more readable. diff --git a/api/.gitignore b/api/.gitignore index 4e1d55cb1..cf50ef242 100644 --- a/api/.gitignore +++ b/api/.gitignore @@ -62,4 +62,6 @@ Cargo.lock node_modules/ prds/ -docs/ \ No newline at end of file +docs/ + +.cargo/ \ No newline at end of file diff --git a/api/Cargo.toml b/api/Cargo.toml index 5ed843937..ef8e51298 100644 --- a/api/Cargo.toml +++ b/api/Cargo.toml @@ -22,7 +22,6 @@ diesel = { version = "2", features = [ "postgres", ] } diesel-async = { version = "0.5.2", features = ["postgres", "bb8"] } -diesel_full_text_search = "2.2.0" dotenv = "0.15.0" futures = "0.3.30" gcp-bigquery-client = "0.24.1" @@ -31,7 +30,6 @@ jsonwebtoken = "9.3.0" lazy_static = "1.4.0" num-traits = "0.2.19" once_cell = "1.20.2" -pgvector = { version = "0.4.0", features = ["diesel", "serde"] } rand = "0.8.5" redis = { version = "0.27.5", features = [ "tokio-comp", @@ -57,9 +55,6 @@ sqlx = { version = "0.8", features = [ "chrono", "bigdecimal", ] } -stop-words = { version = "0.8.0", default-features = false, features = [ - "nltk", -] } tempfile = "3.10.1" tiberius = { version = "0.12.2", default-features = false, features = [ "chrono", @@ -99,10 +94,4 @@ async-trait = "0.1.77" tokio = { version = "1.0", features = ["full", "test-util"] } [profile.release] -debug = false - -[profile.dev] -opt-level = 0 -incremental = true -debug = 1 - +debug = false \ No newline at end of file diff --git a/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs b/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs index 5e94ab89e..2654a823e 100644 --- a/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs +++ b/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs @@ -205,28 +205,28 @@ pub enum BusterThreadMessage { File(BusterFileMessage), } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Clone)] pub struct BusterChatMessageContainer { pub response_message: BusterChatMessage, pub chat_id: Uuid, pub message_id: Uuid, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Clone)] #[serde(untagged)] pub enum ReasoningMessage { Thought(BusterThought), File(BusterFileMessage), } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Clone)] pub struct BusterReasoningMessageContainer { pub reasoning: ReasoningMessage, pub chat_id: Uuid, pub message_id: Uuid, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Clone)] pub struct BusterChatMessage { pub id: String, #[serde(rename = "type")] @@ -235,7 +235,7 @@ pub struct BusterChatMessage { pub message_chunk: Option, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Clone)] pub struct BusterThought { pub id: String, #[serde(rename = "type")] @@ -246,13 +246,13 @@ pub struct BusterThought { pub status: String, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Clone)] pub struct BusterThoughtPillContainer { pub title: String, pub thought_pills: Vec, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Clone)] pub struct BusterThoughtPill { pub id: String, pub text: String, @@ -260,7 +260,7 @@ pub struct BusterThoughtPill { pub thought_file_type: String, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Clone)] pub struct BusterFileMessage { pub id: String, #[serde(rename = "type")] @@ -279,7 +279,7 @@ pub struct BusterFileLine { pub text: String, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Clone)] #[serde(untagged)] pub enum BusterContainer { ChatMessage(BusterChatMessageContainer), diff --git a/api/src/routes/ws/threads_and_messages/post_thread/agent_thread.rs b/api/src/routes/ws/threads_and_messages/post_thread/agent_thread.rs index 5f11c9bdb..efe071442 100644 --- a/api/src/routes/ws/threads_and_messages/post_thread/agent_thread.rs +++ b/api/src/routes/ws/threads_and_messages/post_thread/agent_thread.rs @@ -1,5 +1,7 @@ use anyhow::{Error, Result}; use chrono::Utc; +use diesel::{insert_into, ExpressionMethods, QueryDsl}; +use diesel_async::RunQueryDsl; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::collections::HashMap; @@ -8,10 +10,16 @@ use tracing; use uuid::Uuid; use crate::{ - database::models::User, + database::{ + lib::get_pg_pool, + models::{Message, MessageToFile, Thread, User}, + schema::{messages, messages_to_files, threads}, + }, routes::ws::{ threads_and_messages::{ - post_thread::agent_message_transformer::transform_message, + post_thread::agent_message_transformer::{ + transform_message, BusterContainer, ReasoningMessage, + }, threads_router::{ThreadEvent, ThreadRoute}, }, ws::{WsEvent, WsResponseMessage, WsSendMethod}, @@ -20,7 +28,7 @@ use crate::{ }, utils::{ agent::{Agent, AgentThread}, - clients::ai::litellm::Message, + clients::ai::litellm::Message as AgentMessage, tools::{ file_tools::{ CreateFilesTool, ModifyFilesTool, OpenFilesTool, SearchDataCatalogTool, @@ -124,9 +132,17 @@ impl AgentThreadHandler { let chat_id = request.chat_id.unwrap_or_else(|| Uuid::new_v4()); let message_id = request.message_id.unwrap_or_else(|| Uuid::new_v4()); + let user_org_id = match user.attributes.get("organization_id") { + Some(Value::String(org_id)) => Uuid::parse_str(&org_id).unwrap_or_default(), + _ => { + tracing::error!("User has no organization ID"); + return Err(anyhow::anyhow!("User has no organization ID")); + } + }; + let init_response = TempInitChat { id: chat_id.clone(), - title: "New Chat".to_string(), + title: request.prompt.clone(), // Use prompt as title is_favorited: false, messages: vec![TempInitChatMessage { id: message_id.clone(), @@ -162,7 +178,15 @@ impl AgentThreadHandler { let rx = self.process_chat_request(request.clone()).await?; tokio::spawn(async move { - Self::process_stream(rx, &user.id, &chat_id, &message_id).await; + Self::process_stream( + rx, + &user.id, + &user_org_id, + &chat_id, + &message_id, + request.prompt, + ) + .await; }); Ok(()) } @@ -170,29 +194,61 @@ impl AgentThreadHandler { async fn process_chat_request( &self, request: ChatCreateNewChat, - ) -> Result>> { + ) -> Result>> { let thread = AgentThread::new( request.chat_id, vec![ - Message::developer(AGENT_PROMPT.to_string()), - Message::user(request.prompt), + AgentMessage::developer(AGENT_PROMPT.to_string()), + AgentMessage::user(request.prompt), ], ); self.agent.stream_process_thread(&thread).await } async fn process_stream( - mut rx: Receiver>, + mut rx: Receiver>, user_id: &Uuid, + organization_id: &Uuid, chat_id: &Uuid, message_id: &Uuid, + request: String, ) { let subscription = user_id.to_string(); + let mut all_transformed_messages = Vec::new(); + + // Create thread first + let thread = Thread { + id: chat_id.clone(), + title: request.clone(), // Use request as title + organization_id: organization_id.clone(), + created_by: user_id.clone(), + created_at: Utc::now(), + updated_at: Utc::now(), + deleted_at: None, + }; + + // Insert thread into database + if let Err(e) = async { + let mut conn = get_pg_pool().get().await?; + insert_into(threads::table) + .values(&thread) + .execute(&mut conn) + .await?; + Ok::<_, Error>(()) + } + .await + { + tracing::error!("Failed to create thread: {}", e); + } while let Some(msg_result) = rx.recv().await { if let Ok(msg) = msg_result { match transform_message(chat_id, message_id, msg) { Ok((transformed_messages, event)) => { + // Store transformed messages for later database insertion + all_transformed_messages.extend(transformed_messages.clone()); + + // Send websocket messages as before for transformed in transformed_messages { let response = WsResponseMessage::new_no_user( WsRoutes::Threads(ThreadRoute::Post), @@ -214,6 +270,65 @@ impl AgentThreadHandler { } } } + + // After all messages are received, store them in the database + if !all_transformed_messages.is_empty() { + // Create message record + let message = Message { + id: message_id.clone(), + request: request.clone(), + response: serde_json::to_value(&all_transformed_messages).unwrap_or_default(), + thread_id: thread.id, + created_by: user_id.clone(), + created_at: Utc::now(), + updated_at: Utc::now(), + deleted_at: None, + }; + + // Insert message into database + if let Err(e) = async { + let mut conn = get_pg_pool().get().await?; + insert_into(messages::table) + .values(&message) + .execute(&mut conn) + .await?; + Ok::<_, Error>(()) + } + .await + { + tracing::error!("Failed to create message: {}", e); + } + + // Process file messages and create MessageToFile records + for container in all_transformed_messages { + if let BusterContainer::ReasoningMessage(reasoning) = container { + if let ReasoningMessage::File(file_msg) = reasoning.reasoning { + let message_to_file = MessageToFile { + id: Uuid::new_v4(), + message_id: message.id, + file_id: Uuid::parse_str(&file_msg.id).unwrap_or_default(), + created_at: Utc::now(), + updated_at: Utc::now(), + deleted_at: None, + }; + + // Insert message_to_file into database + if let Err(e) = async { + let mut conn = get_pg_pool().get().await?; + insert_into(messages_to_files::table) + .values(&message_to_file) + .execute(&mut conn) + .await?; + Ok::<_, Error>(()) + } + .await + { + tracing::error!("Failed to create message_to_file: {}", e); + } + } + } + } + } } }