diff --git a/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs b/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs index e73b91251..92b871c45 100644 --- a/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs +++ b/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs @@ -205,6 +205,27 @@ pub enum BusterThreadMessage { File(BusterFileMessage), } +#[derive(Debug, Serialize)] +pub struct BusterChatMessageContainer { + pub response_message: BusterChatMessage, + pub chat_id: Uuid, + pub message_id: Uuid, +} + +#[derive(Debug, Serialize)] +#[serde(untagged)] +pub enum ReasoningMessage { + Thought(BusterThought), + File(BusterFileMessage), +} + +#[derive(Debug, Serialize)] +pub struct BusterReasoningMessageContainer { + pub reasoning: ReasoningMessage, + pub chat_id: Uuid, + pub message_id: Uuid, +} + #[derive(Debug, Serialize)] pub struct BusterChatMessage { pub id: String, @@ -258,9 +279,18 @@ pub struct BusterFileLine { pub text: String, } -pub fn transform_message(message: Message) -> Result<(Vec, ThreadEvent)> { - println!("transform_message: {:?}", message); +#[derive(Debug, Serialize)] +#[serde(untagged)] +pub enum BusterContainer { + ChatMessage(BusterChatMessageContainer), + ReasoningMessage(BusterReasoningMessageContainer), +} +pub fn transform_message( + chat_id: Uuid, + message_id: Uuid, + message: Message, +) -> Result<(Vec, ThreadEvent)> { match message { Message::Assistant { id, @@ -272,30 +302,31 @@ pub fn transform_message(message: Message) -> Result<(Vec, } => { if let Some(content) = content { let messages = match transform_text_message(id, content, progress) { - Ok(messages) => messages, + Ok(messages) => messages + .into_iter() + .map(BusterContainer::ChatMessage) + .collect(), Err(e) => { return Err(e); } }; - return Ok(( - messages, - ThreadEvent::GeneratingResponseMessage, - )); + return Ok((messages, ThreadEvent::GeneratingResponseMessage)); } if let Some(tool_calls) = tool_calls { - let messages = match transform_assistant_tool_message(id, tool_calls, progress, initial) { - Ok(messages) => messages, - Err(e) => { - return Err(e); - } - }; + let messages = + match transform_assistant_tool_message(id, tool_calls, progress, initial) { + Ok(messages) => messages + .into_iter() + .map(BusterContainer::ReasoningMessage) + .collect(), + Err(e) => { + return Err(e); + } + }; - return Ok(( - messages, - ThreadEvent::GeneratingReasoningMessage, - )); + return Ok((messages, ThreadEvent::GeneratingReasoningMessage)); } Err(anyhow::anyhow!("Assistant message missing required fields")) @@ -309,16 +340,16 @@ pub fn transform_message(message: Message) -> Result<(Vec, } => { if let Some(name) = name { let messages = match transform_tool_message(id, name, content, progress) { - Ok(messages) => messages, + Ok(messages) => messages + .into_iter() + .map(BusterContainer::ReasoningMessage) + .collect(), Err(e) => { return Err(e); } }; - return Ok(( - messages, - ThreadEvent::GeneratingReasoningMessage, - )); + return Ok((messages, ThreadEvent::GeneratingReasoningMessage)); } Err(anyhow::anyhow!("Tool message missing name field")) @@ -331,34 +362,42 @@ fn transform_text_message( id: Option, content: String, progress: Option, -) -> Result> { +) -> Result> { if let Some(progress) = progress { match progress { - MessageProgress::InProgress => { - Ok(vec![BusterThreadMessage::ChatMessage(BusterChatMessage { + MessageProgress::InProgress => Ok(vec![BusterChatMessageContainer { + response_message: BusterChatMessage { id: id.unwrap_or_else(|| Uuid::new_v4().to_string()), message_type: "text".to_string(), message: None, message_chunk: Some(content), - })]) - } - MessageProgress::Complete => { - Ok(vec![BusterThreadMessage::ChatMessage(BusterChatMessage { + }, + chat_id: Uuid::new_v4(), + message_id: Uuid::new_v4(), + }]), + MessageProgress::Complete => Ok(vec![BusterChatMessageContainer { + response_message: BusterChatMessage { id: id.unwrap_or_else(|| Uuid::new_v4().to_string()), message_type: "text".to_string(), message: Some(content), message_chunk: None, - })]) - } + }, + chat_id: Uuid::new_v4(), + message_id: Uuid::new_v4(), + }]), _ => Err(anyhow::anyhow!("Unsupported message progress")), } } else { - Ok(vec![BusterThreadMessage::ChatMessage(BusterChatMessage { - id: id.unwrap_or_else(|| Uuid::new_v4().to_string()), - message_type: "text".to_string(), - message: None, - message_chunk: None, - })]) + Ok(vec![BusterChatMessageContainer { + response_message: BusterChatMessage { + id: id.unwrap_or_else(|| Uuid::new_v4().to_string()), + message_type: "text".to_string(), + message: None, + message_chunk: None, + }, + chat_id: Uuid::new_v4(), + message_id: Uuid::new_v4(), + }]) } } @@ -367,8 +406,8 @@ fn transform_tool_message( name: String, content: String, progress: Option, -) -> Result> { - match name.as_str() { +) -> Result> { + let messages = match name.as_str() { "search_data_catalog" => tool_data_catalog_search(id, content, progress), "stored_values_search" => tool_stored_values_search(id, content, progress), "search_files" => tool_file_search(id, content, progress), @@ -376,7 +415,20 @@ fn transform_tool_message( "modify_files" => tool_modify_file(id, content, progress), "open_files" => tool_open_files(id, content, progress), _ => Err(anyhow::anyhow!("Unsupported tool name")), - } + }?; + + Ok(messages + .into_iter() + .map(|message| BusterReasoningMessageContainer { + reasoning: match message { + BusterThreadMessage::Thought(thought) => ReasoningMessage::Thought(thought), + BusterThreadMessage::File(file) => ReasoningMessage::File(file), + _ => unreachable!("Tool messages should only return Thought or File"), + }, + chat_id: Uuid::new_v4(), + message_id: Uuid::new_v4(), + }) + .collect()) } fn transform_assistant_tool_message( @@ -384,9 +436,9 @@ fn transform_assistant_tool_message( tool_calls: Vec, progress: Option, initial: bool, -) -> Result> { +) -> Result> { if let Some(tool_call) = tool_calls.first() { - match tool_call.function.name.as_str() { + let messages = match tool_call.function.name.as_str() { "search_data_catalog" => assistant_data_catalog_search(id, progress, initial), "stored_values_search" => assistant_stored_values_search(id, progress, initial), "search_files" => assistant_file_search(id, progress, initial), @@ -394,7 +446,20 @@ fn transform_assistant_tool_message( "modify_files" => assistant_modify_file(id, tool_calls, progress), "open_files" => assistant_open_files(id, progress, initial), _ => Err(anyhow::anyhow!("Unsupported tool name")), - } + }?; + + Ok(messages + .into_iter() + .map(|message| BusterReasoningMessageContainer { + reasoning: match message { + BusterThreadMessage::Thought(thought) => ReasoningMessage::Thought(thought), + BusterThreadMessage::File(file) => ReasoningMessage::File(file), + _ => unreachable!("Assistant tool messages should only return Thought or File"), + }, + chat_id: Uuid::new_v4(), + message_id: Uuid::new_v4(), + }) + .collect()) } else { Err(anyhow::anyhow!("Assistant tool message missing tool call")) } diff --git a/api/src/routes/ws/threads_and_messages/post_thread/agent_thread.rs b/api/src/routes/ws/threads_and_messages/post_thread/agent_thread.rs index 158422dab..b234a6298 100644 --- a/api/src/routes/ws/threads_and_messages/post_thread/agent_thread.rs +++ b/api/src/routes/ws/threads_and_messages/post_thread/agent_thread.rs @@ -57,8 +57,8 @@ pub struct TempInitChatMessage { #[derive(Debug, Deserialize, Clone)] pub struct ChatCreateNewChat { pub prompt: String, - pub chat_id: Option, - pub message_id: Option, + pub chat_id: Option, + pub message_id: Option, } #[derive(Debug, Serialize)] @@ -167,14 +167,17 @@ impl AgentThreadHandler { async fn process_stream( mut rx: Receiver>, - chat_id: Option, + chat_id: Option, user_id: &Uuid, ) { let subscription = user_id.to_string(); + let chat_id = chat_id.unwrap_or_else(|| Uuid::new_v4()); + let message_id = Uuid::new_v4(); + while let Some(msg_result) = rx.recv().await { if let Ok(msg) = msg_result { - match transform_message(msg) { + match transform_message(chat_id, message_id, msg) { Ok((transformed_messages, event)) => { for transformed in transformed_messages { let response = WsResponseMessage::new_no_user( diff --git a/api/src/utils/agent/types.rs b/api/src/utils/agent/types.rs index 39583c23d..5253770be 100644 --- a/api/src/utils/agent/types.rs +++ b/api/src/utils/agent/types.rs @@ -1,4 +1,5 @@ use serde::{Deserialize, Serialize}; +use uuid::Uuid; use crate::utils::clients::ai::litellm::Message; @@ -7,15 +8,15 @@ use crate::utils::clients::ai::litellm::Message; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AgentThread { /// Unique identifier for the thread - pub id: String, + pub id: Uuid, /// Ordered sequence of messages in the conversation pub messages: Vec, } impl AgentThread { - pub fn new(id: Option, messages: Vec) -> Self { + pub fn new(id: Option, messages: Vec) -> Self { Self { - id: id.unwrap_or(uuid::Uuid::new_v4().to_string()), + id: id.unwrap_or(Uuid::new_v4()), messages, } }