diff --git a/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs b/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs index 2cb8a49fc..ec003db3a 100644 --- a/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs +++ b/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs @@ -287,8 +287,8 @@ pub enum BusterContainer { } pub fn transform_message( - chat_id: Uuid, - message_id: Uuid, + chat_id: &Uuid, + message_id: &Uuid, message: Message, ) -> Result<(Vec, ThreadEvent)> { match message { @@ -301,23 +301,33 @@ pub fn transform_message( initial, } => { if let Some(content) = content { - let messages = - match transform_text_message(id, content, progress, chat_id, message_id) { - Ok(messages) => messages - .into_iter() - .map(BusterContainer::ChatMessage) - .collect(), - Err(e) => { - return Err(e); - } - }; + let messages = match transform_text_message( + id, + content, + progress, + chat_id.clone(), + message_id.clone(), + ) { + Ok(messages) => messages + .into_iter() + .map(BusterContainer::ChatMessage) + .collect(), + Err(e) => { + return Err(e); + } + }; return Ok((messages, ThreadEvent::GeneratingResponseMessage)); } if let Some(tool_calls) = tool_calls { let messages = match transform_assistant_tool_message( - id, tool_calls, progress, initial, chat_id, message_id, + id, + tool_calls, + progress, + initial, + chat_id.clone(), + message_id.clone(), ) { Ok(messages) => messages .into_iter() @@ -342,7 +352,12 @@ pub fn transform_message( } => { if let Some(name) = name { let messages = match transform_tool_message( - id, name, content, progress, chat_id, message_id, + id, + name, + content, + progress, + chat_id.clone(), + message_id.clone(), ) { Ok(messages) => messages .into_iter() diff --git a/api/src/routes/ws/threads_and_messages/post_thread/agent_thread.rs b/api/src/routes/ws/threads_and_messages/post_thread/agent_thread.rs index f3534c39a..5f11c9bdb 100644 --- a/api/src/routes/ws/threads_and_messages/post_thread/agent_thread.rs +++ b/api/src/routes/ws/threads_and_messages/post_thread/agent_thread.rs @@ -33,7 +33,7 @@ use crate::{ #[derive(Debug, Serialize, Deserialize)] pub struct TempInitChat { - pub id: String, + pub id: Uuid, pub title: String, pub is_favorited: bool, pub messages: Vec, @@ -47,7 +47,7 @@ pub struct TempInitChat { #[derive(Debug, Serialize, Deserialize)] pub struct TempInitChatMessage { - pub id: String, + pub id: Uuid, pub request_message: TempRequestMessage, pub response_messages: Vec, pub reasoning: Vec, @@ -121,12 +121,15 @@ impl AgentThreadHandler { pub async fn handle_request(&self, request: ChatCreateNewChat, user: User) -> Result<()> { let subscription = &user.id.to_string(); + let chat_id = request.chat_id.unwrap_or_else(|| Uuid::new_v4()); + let message_id = request.message_id.unwrap_or_else(|| Uuid::new_v4()); + let init_response = TempInitChat { - id: Uuid::new_v4().to_string(), + id: chat_id.clone(), title: "New Chat".to_string(), is_favorited: false, messages: vec![TempInitChatMessage { - id: Uuid::new_v4().to_string(), + id: message_id.clone(), request_message: TempRequestMessage { request: request.prompt.clone(), sender_id: user.id, @@ -159,7 +162,7 @@ impl AgentThreadHandler { let rx = self.process_chat_request(request.clone()).await?; tokio::spawn(async move { - Self::process_stream(rx, request.chat_id, &user.id).await; + Self::process_stream(rx, &user.id, &chat_id, &message_id).await; }); Ok(()) } @@ -180,14 +183,12 @@ impl AgentThreadHandler { async fn process_stream( mut rx: Receiver>, - chat_id: Option, user_id: &Uuid, + chat_id: &Uuid, + message_id: &Uuid, ) { let subscription = user_id.to_string(); - let chat_id = chat_id.unwrap_or_else(|| Uuid::new_v4()); - let message_id = Uuid::new_v4(); - while let Some(msg_result) = rx.recv().await { if let Ok(msg) = msg_result { match transform_message(chat_id, message_id, msg) {