mirror of https://github.com/buster-so/buster.git
ok thread is being inserted
This commit is contained in:
parent
977b5eb6de
commit
3bd5516486
|
@ -1,6 +1,6 @@
|
|||
---
|
||||
description: These are global rules and recommendations for the rust server.
|
||||
globs:
|
||||
globs: *
|
||||
---
|
||||
|
||||
# Global Rules and Project Structure
|
||||
|
@ -17,6 +17,9 @@ This is a Rust web server project built with Axum, focusing on high performance,
|
|||
- `database/` - Database models, schema, and connection management
|
||||
- `main.rs` - Application entry point and server setup
|
||||
|
||||
## Implementation
|
||||
When working with prds, you should always mark your progress off in them as you build.
|
||||
|
||||
## Database Connectivity
|
||||
- The primary database connection is managed through `get_pg_pool()`, which returns a lazy static `PgPool`
|
||||
- Always use this pool for database connections to ensure proper connection management
|
||||
|
@ -141,3 +144,4 @@ Remember to always consider:
|
|||
2. Transaction boundaries for data consistency
|
||||
3. Error propagation and cleanup
|
||||
4. Memory usage and ownership
|
||||
5. Please use comments to help document your code and make it more readable.
|
||||
|
|
|
@ -63,3 +63,5 @@ node_modules/
|
|||
|
||||
prds/
|
||||
docs/
|
||||
|
||||
.cargo/
|
|
@ -22,7 +22,6 @@ diesel = { version = "2", features = [
|
|||
"postgres",
|
||||
] }
|
||||
diesel-async = { version = "0.5.2", features = ["postgres", "bb8"] }
|
||||
diesel_full_text_search = "2.2.0"
|
||||
dotenv = "0.15.0"
|
||||
futures = "0.3.30"
|
||||
gcp-bigquery-client = "0.24.1"
|
||||
|
@ -31,7 +30,6 @@ jsonwebtoken = "9.3.0"
|
|||
lazy_static = "1.4.0"
|
||||
num-traits = "0.2.19"
|
||||
once_cell = "1.20.2"
|
||||
pgvector = { version = "0.4.0", features = ["diesel", "serde"] }
|
||||
rand = "0.8.5"
|
||||
redis = { version = "0.27.5", features = [
|
||||
"tokio-comp",
|
||||
|
@ -57,9 +55,6 @@ sqlx = { version = "0.8", features = [
|
|||
"chrono",
|
||||
"bigdecimal",
|
||||
] }
|
||||
stop-words = { version = "0.8.0", default-features = false, features = [
|
||||
"nltk",
|
||||
] }
|
||||
tempfile = "3.10.1"
|
||||
tiberius = { version = "0.12.2", default-features = false, features = [
|
||||
"chrono",
|
||||
|
@ -100,9 +95,3 @@ tokio = { version = "1.0", features = ["full", "test-util"] }
|
|||
|
||||
[profile.release]
|
||||
debug = false
|
||||
|
||||
[profile.dev]
|
||||
opt-level = 0
|
||||
incremental = true
|
||||
debug = 1
|
||||
|
||||
|
|
|
@ -205,28 +205,28 @@ pub enum BusterThreadMessage {
|
|||
File(BusterFileMessage),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
pub struct BusterChatMessageContainer {
|
||||
pub response_message: BusterChatMessage,
|
||||
pub chat_id: Uuid,
|
||||
pub message_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
#[serde(untagged)]
|
||||
pub enum ReasoningMessage {
|
||||
Thought(BusterThought),
|
||||
File(BusterFileMessage),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
pub struct BusterReasoningMessageContainer {
|
||||
pub reasoning: ReasoningMessage,
|
||||
pub chat_id: Uuid,
|
||||
pub message_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
pub struct BusterChatMessage {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
|
@ -235,7 +235,7 @@ pub struct BusterChatMessage {
|
|||
pub message_chunk: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
pub struct BusterThought {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
|
@ -246,13 +246,13 @@ pub struct BusterThought {
|
|||
pub status: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
pub struct BusterThoughtPillContainer {
|
||||
pub title: String,
|
||||
pub thought_pills: Vec<BusterThoughtPill>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
pub struct BusterThoughtPill {
|
||||
pub id: String,
|
||||
pub text: String,
|
||||
|
@ -260,7 +260,7 @@ pub struct BusterThoughtPill {
|
|||
pub thought_file_type: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
pub struct BusterFileMessage {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
|
@ -279,7 +279,7 @@ pub struct BusterFileLine {
|
|||
pub text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
#[serde(untagged)]
|
||||
pub enum BusterContainer {
|
||||
ChatMessage(BusterChatMessageContainer),
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
use anyhow::{Error, Result};
|
||||
use chrono::Utc;
|
||||
use diesel::{insert_into, ExpressionMethods, QueryDsl};
|
||||
use diesel_async::RunQueryDsl;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
|
@ -8,10 +10,16 @@ use tracing;
|
|||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
database::models::User,
|
||||
database::{
|
||||
lib::get_pg_pool,
|
||||
models::{Message, MessageToFile, Thread, User},
|
||||
schema::{messages, messages_to_files, threads},
|
||||
},
|
||||
routes::ws::{
|
||||
threads_and_messages::{
|
||||
post_thread::agent_message_transformer::transform_message,
|
||||
post_thread::agent_message_transformer::{
|
||||
transform_message, BusterContainer, ReasoningMessage,
|
||||
},
|
||||
threads_router::{ThreadEvent, ThreadRoute},
|
||||
},
|
||||
ws::{WsEvent, WsResponseMessage, WsSendMethod},
|
||||
|
@ -20,7 +28,7 @@ use crate::{
|
|||
},
|
||||
utils::{
|
||||
agent::{Agent, AgentThread},
|
||||
clients::ai::litellm::Message,
|
||||
clients::ai::litellm::Message as AgentMessage,
|
||||
tools::{
|
||||
file_tools::{
|
||||
CreateFilesTool, ModifyFilesTool, OpenFilesTool, SearchDataCatalogTool,
|
||||
|
@ -124,9 +132,17 @@ impl AgentThreadHandler {
|
|||
let chat_id = request.chat_id.unwrap_or_else(|| Uuid::new_v4());
|
||||
let message_id = request.message_id.unwrap_or_else(|| Uuid::new_v4());
|
||||
|
||||
let user_org_id = match user.attributes.get("organization_id") {
|
||||
Some(Value::String(org_id)) => Uuid::parse_str(&org_id).unwrap_or_default(),
|
||||
_ => {
|
||||
tracing::error!("User has no organization ID");
|
||||
return Err(anyhow::anyhow!("User has no organization ID"));
|
||||
}
|
||||
};
|
||||
|
||||
let init_response = TempInitChat {
|
||||
id: chat_id.clone(),
|
||||
title: "New Chat".to_string(),
|
||||
title: request.prompt.clone(), // Use prompt as title
|
||||
is_favorited: false,
|
||||
messages: vec![TempInitChatMessage {
|
||||
id: message_id.clone(),
|
||||
|
@ -162,7 +178,15 @@ impl AgentThreadHandler {
|
|||
|
||||
let rx = self.process_chat_request(request.clone()).await?;
|
||||
tokio::spawn(async move {
|
||||
Self::process_stream(rx, &user.id, &chat_id, &message_id).await;
|
||||
Self::process_stream(
|
||||
rx,
|
||||
&user.id,
|
||||
&user_org_id,
|
||||
&chat_id,
|
||||
&message_id,
|
||||
request.prompt,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
@ -170,29 +194,61 @@ impl AgentThreadHandler {
|
|||
async fn process_chat_request(
|
||||
&self,
|
||||
request: ChatCreateNewChat,
|
||||
) -> Result<Receiver<Result<Message, Error>>> {
|
||||
) -> Result<Receiver<Result<AgentMessage, Error>>> {
|
||||
let thread = AgentThread::new(
|
||||
request.chat_id,
|
||||
vec![
|
||||
Message::developer(AGENT_PROMPT.to_string()),
|
||||
Message::user(request.prompt),
|
||||
AgentMessage::developer(AGENT_PROMPT.to_string()),
|
||||
AgentMessage::user(request.prompt),
|
||||
],
|
||||
);
|
||||
self.agent.stream_process_thread(&thread).await
|
||||
}
|
||||
|
||||
async fn process_stream(
|
||||
mut rx: Receiver<Result<Message, Error>>,
|
||||
mut rx: Receiver<Result<AgentMessage, Error>>,
|
||||
user_id: &Uuid,
|
||||
organization_id: &Uuid,
|
||||
chat_id: &Uuid,
|
||||
message_id: &Uuid,
|
||||
request: String,
|
||||
) {
|
||||
let subscription = user_id.to_string();
|
||||
let mut all_transformed_messages = Vec::new();
|
||||
|
||||
// Create thread first
|
||||
let thread = Thread {
|
||||
id: chat_id.clone(),
|
||||
title: request.clone(), // Use request as title
|
||||
organization_id: organization_id.clone(),
|
||||
created_by: user_id.clone(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
deleted_at: None,
|
||||
};
|
||||
|
||||
// Insert thread into database
|
||||
if let Err(e) = async {
|
||||
let mut conn = get_pg_pool().get().await?;
|
||||
insert_into(threads::table)
|
||||
.values(&thread)
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
Ok::<_, Error>(())
|
||||
}
|
||||
.await
|
||||
{
|
||||
tracing::error!("Failed to create thread: {}", e);
|
||||
}
|
||||
|
||||
while let Some(msg_result) = rx.recv().await {
|
||||
if let Ok(msg) = msg_result {
|
||||
match transform_message(chat_id, message_id, msg) {
|
||||
Ok((transformed_messages, event)) => {
|
||||
// Store transformed messages for later database insertion
|
||||
all_transformed_messages.extend(transformed_messages.clone());
|
||||
|
||||
// Send websocket messages as before
|
||||
for transformed in transformed_messages {
|
||||
let response = WsResponseMessage::new_no_user(
|
||||
WsRoutes::Threads(ThreadRoute::Post),
|
||||
|
@ -214,6 +270,65 @@ impl AgentThreadHandler {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// After all messages are received, store them in the database
|
||||
if !all_transformed_messages.is_empty() {
|
||||
// Create message record
|
||||
let message = Message {
|
||||
id: message_id.clone(),
|
||||
request: request.clone(),
|
||||
response: serde_json::to_value(&all_transformed_messages).unwrap_or_default(),
|
||||
thread_id: thread.id,
|
||||
created_by: user_id.clone(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
deleted_at: None,
|
||||
};
|
||||
|
||||
// Insert message into database
|
||||
if let Err(e) = async {
|
||||
let mut conn = get_pg_pool().get().await?;
|
||||
insert_into(messages::table)
|
||||
.values(&message)
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
Ok::<_, Error>(())
|
||||
}
|
||||
.await
|
||||
{
|
||||
tracing::error!("Failed to create message: {}", e);
|
||||
}
|
||||
|
||||
// Process file messages and create MessageToFile records
|
||||
for container in all_transformed_messages {
|
||||
if let BusterContainer::ReasoningMessage(reasoning) = container {
|
||||
if let ReasoningMessage::File(file_msg) = reasoning.reasoning {
|
||||
let message_to_file = MessageToFile {
|
||||
id: Uuid::new_v4(),
|
||||
message_id: message.id,
|
||||
file_id: Uuid::parse_str(&file_msg.id).unwrap_or_default(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
deleted_at: None,
|
||||
};
|
||||
|
||||
// Insert message_to_file into database
|
||||
if let Err(e) = async {
|
||||
let mut conn = get_pg_pool().get().await?;
|
||||
insert_into(messages_to_files::table)
|
||||
.values(&message_to_file)
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
Ok::<_, Error>(())
|
||||
}
|
||||
.await
|
||||
{
|
||||
tracing::error!("Failed to create message_to_file: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue