ok thread is being inserted

This commit is contained in:
dal 2025-02-13 16:15:09 -07:00
parent 977b5eb6de
commit 3bd5516486
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
5 changed files with 142 additions and 32 deletions

View File

@ -1,6 +1,6 @@
---
description: These are global rules and recommendations for the rust server.
globs:
globs: *
---
# Global Rules and Project Structure
@ -17,6 +17,9 @@ This is a Rust web server project built with Axum, focusing on high performance,
- `database/` - Database models, schema, and connection management
- `main.rs` - Application entry point and server setup
## Implementation
When working with prds, you should always mark your progress off in them as you build.
## Database Connectivity
- The primary database connection is managed through `get_pg_pool()`, which returns a lazy static `PgPool`
- Always use this pool for database connections to ensure proper connection management
@ -141,3 +144,4 @@ Remember to always consider:
2. Transaction boundaries for data consistency
3. Error propagation and cleanup
4. Memory usage and ownership
5. Please use comments to help document your code and make it more readable.

4
api/.gitignore vendored
View File

@ -62,4 +62,6 @@ Cargo.lock
node_modules/
prds/
docs/
docs/
.cargo/

View File

@ -22,7 +22,6 @@ diesel = { version = "2", features = [
"postgres",
] }
diesel-async = { version = "0.5.2", features = ["postgres", "bb8"] }
diesel_full_text_search = "2.2.0"
dotenv = "0.15.0"
futures = "0.3.30"
gcp-bigquery-client = "0.24.1"
@ -31,7 +30,6 @@ jsonwebtoken = "9.3.0"
lazy_static = "1.4.0"
num-traits = "0.2.19"
once_cell = "1.20.2"
pgvector = { version = "0.4.0", features = ["diesel", "serde"] }
rand = "0.8.5"
redis = { version = "0.27.5", features = [
"tokio-comp",
@ -57,9 +55,6 @@ sqlx = { version = "0.8", features = [
"chrono",
"bigdecimal",
] }
stop-words = { version = "0.8.0", default-features = false, features = [
"nltk",
] }
tempfile = "3.10.1"
tiberius = { version = "0.12.2", default-features = false, features = [
"chrono",
@ -99,10 +94,4 @@ async-trait = "0.1.77"
tokio = { version = "1.0", features = ["full", "test-util"] }
[profile.release]
debug = false
[profile.dev]
opt-level = 0
incremental = true
debug = 1
debug = false

View File

@ -205,28 +205,28 @@ pub enum BusterThreadMessage {
File(BusterFileMessage),
}
#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Clone)]
pub struct BusterChatMessageContainer {
pub response_message: BusterChatMessage,
pub chat_id: Uuid,
pub message_id: Uuid,
}
#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Clone)]
#[serde(untagged)]
pub enum ReasoningMessage {
Thought(BusterThought),
File(BusterFileMessage),
}
#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Clone)]
pub struct BusterReasoningMessageContainer {
pub reasoning: ReasoningMessage,
pub chat_id: Uuid,
pub message_id: Uuid,
}
#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Clone)]
pub struct BusterChatMessage {
pub id: String,
#[serde(rename = "type")]
@ -235,7 +235,7 @@ pub struct BusterChatMessage {
pub message_chunk: Option<String>,
}
#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Clone)]
pub struct BusterThought {
pub id: String,
#[serde(rename = "type")]
@ -246,13 +246,13 @@ pub struct BusterThought {
pub status: String,
}
#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Clone)]
pub struct BusterThoughtPillContainer {
pub title: String,
pub thought_pills: Vec<BusterThoughtPill>,
}
#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Clone)]
pub struct BusterThoughtPill {
pub id: String,
pub text: String,
@ -260,7 +260,7 @@ pub struct BusterThoughtPill {
pub thought_file_type: String,
}
#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Clone)]
pub struct BusterFileMessage {
pub id: String,
#[serde(rename = "type")]
@ -279,7 +279,7 @@ pub struct BusterFileLine {
pub text: String,
}
#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Clone)]
#[serde(untagged)]
pub enum BusterContainer {
ChatMessage(BusterChatMessageContainer),

View File

@ -1,5 +1,7 @@
use anyhow::{Error, Result};
use chrono::Utc;
use diesel::{insert_into, ExpressionMethods, QueryDsl};
use diesel_async::RunQueryDsl;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
@ -8,10 +10,16 @@ use tracing;
use uuid::Uuid;
use crate::{
database::models::User,
database::{
lib::get_pg_pool,
models::{Message, MessageToFile, Thread, User},
schema::{messages, messages_to_files, threads},
},
routes::ws::{
threads_and_messages::{
post_thread::agent_message_transformer::transform_message,
post_thread::agent_message_transformer::{
transform_message, BusterContainer, ReasoningMessage,
},
threads_router::{ThreadEvent, ThreadRoute},
},
ws::{WsEvent, WsResponseMessage, WsSendMethod},
@ -20,7 +28,7 @@ use crate::{
},
utils::{
agent::{Agent, AgentThread},
clients::ai::litellm::Message,
clients::ai::litellm::Message as AgentMessage,
tools::{
file_tools::{
CreateFilesTool, ModifyFilesTool, OpenFilesTool, SearchDataCatalogTool,
@ -124,9 +132,17 @@ impl AgentThreadHandler {
let chat_id = request.chat_id.unwrap_or_else(|| Uuid::new_v4());
let message_id = request.message_id.unwrap_or_else(|| Uuid::new_v4());
let user_org_id = match user.attributes.get("organization_id") {
Some(Value::String(org_id)) => Uuid::parse_str(&org_id).unwrap_or_default(),
_ => {
tracing::error!("User has no organization ID");
return Err(anyhow::anyhow!("User has no organization ID"));
}
};
let init_response = TempInitChat {
id: chat_id.clone(),
title: "New Chat".to_string(),
title: request.prompt.clone(), // Use prompt as title
is_favorited: false,
messages: vec![TempInitChatMessage {
id: message_id.clone(),
@ -162,7 +178,15 @@ impl AgentThreadHandler {
let rx = self.process_chat_request(request.clone()).await?;
tokio::spawn(async move {
Self::process_stream(rx, &user.id, &chat_id, &message_id).await;
Self::process_stream(
rx,
&user.id,
&user_org_id,
&chat_id,
&message_id,
request.prompt,
)
.await;
});
Ok(())
}
@ -170,29 +194,61 @@ impl AgentThreadHandler {
async fn process_chat_request(
&self,
request: ChatCreateNewChat,
) -> Result<Receiver<Result<Message, Error>>> {
) -> Result<Receiver<Result<AgentMessage, Error>>> {
let thread = AgentThread::new(
request.chat_id,
vec![
Message::developer(AGENT_PROMPT.to_string()),
Message::user(request.prompt),
AgentMessage::developer(AGENT_PROMPT.to_string()),
AgentMessage::user(request.prompt),
],
);
self.agent.stream_process_thread(&thread).await
}
async fn process_stream(
mut rx: Receiver<Result<Message, Error>>,
mut rx: Receiver<Result<AgentMessage, Error>>,
user_id: &Uuid,
organization_id: &Uuid,
chat_id: &Uuid,
message_id: &Uuid,
request: String,
) {
let subscription = user_id.to_string();
let mut all_transformed_messages = Vec::new();
// Create thread first
let thread = Thread {
id: chat_id.clone(),
title: request.clone(), // Use request as title
organization_id: organization_id.clone(),
created_by: user_id.clone(),
created_at: Utc::now(),
updated_at: Utc::now(),
deleted_at: None,
};
// Insert thread into database
if let Err(e) = async {
let mut conn = get_pg_pool().get().await?;
insert_into(threads::table)
.values(&thread)
.execute(&mut conn)
.await?;
Ok::<_, Error>(())
}
.await
{
tracing::error!("Failed to create thread: {}", e);
}
while let Some(msg_result) = rx.recv().await {
if let Ok(msg) = msg_result {
match transform_message(chat_id, message_id, msg) {
Ok((transformed_messages, event)) => {
// Store transformed messages for later database insertion
all_transformed_messages.extend(transformed_messages.clone());
// Send websocket messages as before
for transformed in transformed_messages {
let response = WsResponseMessage::new_no_user(
WsRoutes::Threads(ThreadRoute::Post),
@ -214,6 +270,65 @@ impl AgentThreadHandler {
}
}
}
// After all messages are received, store them in the database
if !all_transformed_messages.is_empty() {
// Create message record
let message = Message {
id: message_id.clone(),
request: request.clone(),
response: serde_json::to_value(&all_transformed_messages).unwrap_or_default(),
thread_id: thread.id,
created_by: user_id.clone(),
created_at: Utc::now(),
updated_at: Utc::now(),
deleted_at: None,
};
// Insert message into database
if let Err(e) = async {
let mut conn = get_pg_pool().get().await?;
insert_into(messages::table)
.values(&message)
.execute(&mut conn)
.await?;
Ok::<_, Error>(())
}
.await
{
tracing::error!("Failed to create message: {}", e);
}
// Process file messages and create MessageToFile records
for container in all_transformed_messages {
if let BusterContainer::ReasoningMessage(reasoning) = container {
if let ReasoningMessage::File(file_msg) = reasoning.reasoning {
let message_to_file = MessageToFile {
id: Uuid::new_v4(),
message_id: message.id,
file_id: Uuid::parse_str(&file_msg.id).unwrap_or_default(),
created_at: Utc::now(),
updated_at: Utc::now(),
deleted_at: None,
};
// Insert message_to_file into database
if let Err(e) = async {
let mut conn = get_pg_pool().get().await?;
insert_into(messages_to_files::table)
.values(&message_to_file)
.execute(&mut conn)
.await?;
Ok::<_, Error>(())
}
.await
{
tracing::error!("Failed to create message_to_file: {}", e);
}
}
}
}
}
}
}