Refactor message transformation with new container types and UUIDs

This commit is contained in:
dal 2025-02-11 11:53:06 -07:00
parent 14d379d942
commit 8b96ec01fb
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
3 changed files with 119 additions and 50 deletions

View File

@ -205,6 +205,27 @@ pub enum BusterThreadMessage {
File(BusterFileMessage), File(BusterFileMessage),
} }
#[derive(Debug, Serialize)]
pub struct BusterChatMessageContainer {
pub response_message: BusterChatMessage,
pub chat_id: Uuid,
pub message_id: Uuid,
}
#[derive(Debug, Serialize)]
#[serde(untagged)]
pub enum ReasoningMessage {
Thought(BusterThought),
File(BusterFileMessage),
}
#[derive(Debug, Serialize)]
pub struct BusterReasoningMessageContainer {
pub reasoning: ReasoningMessage,
pub chat_id: Uuid,
pub message_id: Uuid,
}
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
pub struct BusterChatMessage { pub struct BusterChatMessage {
pub id: String, pub id: String,
@ -258,9 +279,18 @@ pub struct BusterFileLine {
pub text: String, pub text: String,
} }
pub fn transform_message(message: Message) -> Result<(Vec<BusterThreadMessage>, ThreadEvent)> { #[derive(Debug, Serialize)]
println!("transform_message: {:?}", message); #[serde(untagged)]
pub enum BusterContainer {
ChatMessage(BusterChatMessageContainer),
ReasoningMessage(BusterReasoningMessageContainer),
}
pub fn transform_message(
chat_id: Uuid,
message_id: Uuid,
message: Message,
) -> Result<(Vec<BusterContainer>, ThreadEvent)> {
match message { match message {
Message::Assistant { Message::Assistant {
id, id,
@ -272,30 +302,31 @@ pub fn transform_message(message: Message) -> Result<(Vec<BusterThreadMessage>,
} => { } => {
if let Some(content) = content { if let Some(content) = content {
let messages = match transform_text_message(id, content, progress) { let messages = match transform_text_message(id, content, progress) {
Ok(messages) => messages, Ok(messages) => messages
.into_iter()
.map(BusterContainer::ChatMessage)
.collect(),
Err(e) => { Err(e) => {
return Err(e); return Err(e);
} }
}; };
return Ok(( return Ok((messages, ThreadEvent::GeneratingResponseMessage));
messages,
ThreadEvent::GeneratingResponseMessage,
));
} }
if let Some(tool_calls) = tool_calls { if let Some(tool_calls) = tool_calls {
let messages = match transform_assistant_tool_message(id, tool_calls, progress, initial) { let messages =
Ok(messages) => messages, match transform_assistant_tool_message(id, tool_calls, progress, initial) {
Err(e) => { Ok(messages) => messages
return Err(e); .into_iter()
} .map(BusterContainer::ReasoningMessage)
}; .collect(),
Err(e) => {
return Err(e);
}
};
return Ok(( return Ok((messages, ThreadEvent::GeneratingReasoningMessage));
messages,
ThreadEvent::GeneratingReasoningMessage,
));
} }
Err(anyhow::anyhow!("Assistant message missing required fields")) Err(anyhow::anyhow!("Assistant message missing required fields"))
@ -309,16 +340,16 @@ pub fn transform_message(message: Message) -> Result<(Vec<BusterThreadMessage>,
} => { } => {
if let Some(name) = name { if let Some(name) = name {
let messages = match transform_tool_message(id, name, content, progress) { let messages = match transform_tool_message(id, name, content, progress) {
Ok(messages) => messages, Ok(messages) => messages
.into_iter()
.map(BusterContainer::ReasoningMessage)
.collect(),
Err(e) => { Err(e) => {
return Err(e); return Err(e);
} }
}; };
return Ok(( return Ok((messages, ThreadEvent::GeneratingReasoningMessage));
messages,
ThreadEvent::GeneratingReasoningMessage,
));
} }
Err(anyhow::anyhow!("Tool message missing name field")) Err(anyhow::anyhow!("Tool message missing name field"))
@ -331,34 +362,42 @@ fn transform_text_message(
id: Option<String>, id: Option<String>,
content: String, content: String,
progress: Option<MessageProgress>, progress: Option<MessageProgress>,
) -> Result<Vec<BusterThreadMessage>> { ) -> Result<Vec<BusterChatMessageContainer>> {
if let Some(progress) = progress { if let Some(progress) = progress {
match progress { match progress {
MessageProgress::InProgress => { MessageProgress::InProgress => Ok(vec![BusterChatMessageContainer {
Ok(vec![BusterThreadMessage::ChatMessage(BusterChatMessage { response_message: BusterChatMessage {
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()), id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
message_type: "text".to_string(), message_type: "text".to_string(),
message: None, message: None,
message_chunk: Some(content), message_chunk: Some(content),
})]) },
} chat_id: Uuid::new_v4(),
MessageProgress::Complete => { message_id: Uuid::new_v4(),
Ok(vec![BusterThreadMessage::ChatMessage(BusterChatMessage { }]),
MessageProgress::Complete => Ok(vec![BusterChatMessageContainer {
response_message: BusterChatMessage {
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()), id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
message_type: "text".to_string(), message_type: "text".to_string(),
message: Some(content), message: Some(content),
message_chunk: None, message_chunk: None,
})]) },
} chat_id: Uuid::new_v4(),
message_id: Uuid::new_v4(),
}]),
_ => Err(anyhow::anyhow!("Unsupported message progress")), _ => Err(anyhow::anyhow!("Unsupported message progress")),
} }
} else { } else {
Ok(vec![BusterThreadMessage::ChatMessage(BusterChatMessage { Ok(vec![BusterChatMessageContainer {
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()), response_message: BusterChatMessage {
message_type: "text".to_string(), id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
message: None, message_type: "text".to_string(),
message_chunk: None, message: None,
})]) message_chunk: None,
},
chat_id: Uuid::new_v4(),
message_id: Uuid::new_v4(),
}])
} }
} }
@ -367,8 +406,8 @@ fn transform_tool_message(
name: String, name: String,
content: String, content: String,
progress: Option<MessageProgress>, progress: Option<MessageProgress>,
) -> Result<Vec<BusterThreadMessage>> { ) -> Result<Vec<BusterReasoningMessageContainer>> {
match name.as_str() { let messages = match name.as_str() {
"search_data_catalog" => tool_data_catalog_search(id, content, progress), "search_data_catalog" => tool_data_catalog_search(id, content, progress),
"stored_values_search" => tool_stored_values_search(id, content, progress), "stored_values_search" => tool_stored_values_search(id, content, progress),
"search_files" => tool_file_search(id, content, progress), "search_files" => tool_file_search(id, content, progress),
@ -376,7 +415,20 @@ fn transform_tool_message(
"modify_files" => tool_modify_file(id, content, progress), "modify_files" => tool_modify_file(id, content, progress),
"open_files" => tool_open_files(id, content, progress), "open_files" => tool_open_files(id, content, progress),
_ => Err(anyhow::anyhow!("Unsupported tool name")), _ => Err(anyhow::anyhow!("Unsupported tool name")),
} }?;
Ok(messages
.into_iter()
.map(|message| BusterReasoningMessageContainer {
reasoning: match message {
BusterThreadMessage::Thought(thought) => ReasoningMessage::Thought(thought),
BusterThreadMessage::File(file) => ReasoningMessage::File(file),
_ => unreachable!("Tool messages should only return Thought or File"),
},
chat_id: Uuid::new_v4(),
message_id: Uuid::new_v4(),
})
.collect())
} }
fn transform_assistant_tool_message( fn transform_assistant_tool_message(
@ -384,9 +436,9 @@ fn transform_assistant_tool_message(
tool_calls: Vec<ToolCall>, tool_calls: Vec<ToolCall>,
progress: Option<MessageProgress>, progress: Option<MessageProgress>,
initial: bool, initial: bool,
) -> Result<Vec<BusterThreadMessage>> { ) -> Result<Vec<BusterReasoningMessageContainer>> {
if let Some(tool_call) = tool_calls.first() { if let Some(tool_call) = tool_calls.first() {
match tool_call.function.name.as_str() { let messages = match tool_call.function.name.as_str() {
"search_data_catalog" => assistant_data_catalog_search(id, progress, initial), "search_data_catalog" => assistant_data_catalog_search(id, progress, initial),
"stored_values_search" => assistant_stored_values_search(id, progress, initial), "stored_values_search" => assistant_stored_values_search(id, progress, initial),
"search_files" => assistant_file_search(id, progress, initial), "search_files" => assistant_file_search(id, progress, initial),
@ -394,7 +446,20 @@ fn transform_assistant_tool_message(
"modify_files" => assistant_modify_file(id, tool_calls, progress), "modify_files" => assistant_modify_file(id, tool_calls, progress),
"open_files" => assistant_open_files(id, progress, initial), "open_files" => assistant_open_files(id, progress, initial),
_ => Err(anyhow::anyhow!("Unsupported tool name")), _ => Err(anyhow::anyhow!("Unsupported tool name")),
} }?;
Ok(messages
.into_iter()
.map(|message| BusterReasoningMessageContainer {
reasoning: match message {
BusterThreadMessage::Thought(thought) => ReasoningMessage::Thought(thought),
BusterThreadMessage::File(file) => ReasoningMessage::File(file),
_ => unreachable!("Assistant tool messages should only return Thought or File"),
},
chat_id: Uuid::new_v4(),
message_id: Uuid::new_v4(),
})
.collect())
} else { } else {
Err(anyhow::anyhow!("Assistant tool message missing tool call")) Err(anyhow::anyhow!("Assistant tool message missing tool call"))
} }

View File

@ -57,8 +57,8 @@ pub struct TempInitChatMessage {
#[derive(Debug, Deserialize, Clone)] #[derive(Debug, Deserialize, Clone)]
pub struct ChatCreateNewChat { pub struct ChatCreateNewChat {
pub prompt: String, pub prompt: String,
pub chat_id: Option<String>, pub chat_id: Option<Uuid>,
pub message_id: Option<String>, pub message_id: Option<Uuid>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
@ -167,14 +167,17 @@ impl AgentThreadHandler {
async fn process_stream( async fn process_stream(
mut rx: Receiver<Result<Message, Error>>, mut rx: Receiver<Result<Message, Error>>,
chat_id: Option<String>, chat_id: Option<Uuid>,
user_id: &Uuid, user_id: &Uuid,
) { ) {
let subscription = user_id.to_string(); let subscription = user_id.to_string();
let chat_id = chat_id.unwrap_or_else(|| Uuid::new_v4());
let message_id = Uuid::new_v4();
while let Some(msg_result) = rx.recv().await { while let Some(msg_result) = rx.recv().await {
if let Ok(msg) = msg_result { if let Ok(msg) = msg_result {
match transform_message(msg) { match transform_message(chat_id, message_id, msg) {
Ok((transformed_messages, event)) => { Ok((transformed_messages, event)) => {
for transformed in transformed_messages { for transformed in transformed_messages {
let response = WsResponseMessage::new_no_user( let response = WsResponseMessage::new_no_user(

View File

@ -1,4 +1,5 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::utils::clients::ai::litellm::Message; use crate::utils::clients::ai::litellm::Message;
@ -7,15 +8,15 @@ use crate::utils::clients::ai::litellm::Message;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentThread { pub struct AgentThread {
/// Unique identifier for the thread /// Unique identifier for the thread
pub id: String, pub id: Uuid,
/// Ordered sequence of messages in the conversation /// Ordered sequence of messages in the conversation
pub messages: Vec<Message>, pub messages: Vec<Message>,
} }
impl AgentThread { impl AgentThread {
pub fn new(id: Option<String>, messages: Vec<Message>) -> Self { pub fn new(id: Option<Uuid>, messages: Vec<Message>) -> Self {
Self { Self {
id: id.unwrap_or(uuid::Uuid::new_v4().to_string()), id: id.unwrap_or(Uuid::new_v4()),
messages, messages,
} }
} }