Refactor message transformation with new container types and UUIDs

This commit is contained in:
dal 2025-02-11 11:53:06 -07:00
parent 14d379d942
commit 8b96ec01fb
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
3 changed files with 119 additions and 50 deletions

View File

@ -205,6 +205,27 @@ pub enum BusterThreadMessage {
File(BusterFileMessage),
}
#[derive(Debug, Serialize)]
pub struct BusterChatMessageContainer {
pub response_message: BusterChatMessage,
pub chat_id: Uuid,
pub message_id: Uuid,
}
#[derive(Debug, Serialize)]
#[serde(untagged)]
pub enum ReasoningMessage {
Thought(BusterThought),
File(BusterFileMessage),
}
#[derive(Debug, Serialize)]
pub struct BusterReasoningMessageContainer {
pub reasoning: ReasoningMessage,
pub chat_id: Uuid,
pub message_id: Uuid,
}
#[derive(Debug, Serialize)]
pub struct BusterChatMessage {
pub id: String,
@ -258,9 +279,18 @@ pub struct BusterFileLine {
pub text: String,
}
pub fn transform_message(message: Message) -> Result<(Vec<BusterThreadMessage>, ThreadEvent)> {
println!("transform_message: {:?}", message);
#[derive(Debug, Serialize)]
#[serde(untagged)]
pub enum BusterContainer {
ChatMessage(BusterChatMessageContainer),
ReasoningMessage(BusterReasoningMessageContainer),
}
pub fn transform_message(
chat_id: Uuid,
message_id: Uuid,
message: Message,
) -> Result<(Vec<BusterContainer>, ThreadEvent)> {
match message {
Message::Assistant {
id,
@ -272,30 +302,31 @@ pub fn transform_message(message: Message) -> Result<(Vec<BusterThreadMessage>,
} => {
if let Some(content) = content {
let messages = match transform_text_message(id, content, progress) {
Ok(messages) => messages,
Ok(messages) => messages
.into_iter()
.map(BusterContainer::ChatMessage)
.collect(),
Err(e) => {
return Err(e);
}
};
return Ok((
messages,
ThreadEvent::GeneratingResponseMessage,
));
return Ok((messages, ThreadEvent::GeneratingResponseMessage));
}
if let Some(tool_calls) = tool_calls {
let messages = match transform_assistant_tool_message(id, tool_calls, progress, initial) {
Ok(messages) => messages,
Err(e) => {
return Err(e);
}
};
let messages =
match transform_assistant_tool_message(id, tool_calls, progress, initial) {
Ok(messages) => messages
.into_iter()
.map(BusterContainer::ReasoningMessage)
.collect(),
Err(e) => {
return Err(e);
}
};
return Ok((
messages,
ThreadEvent::GeneratingReasoningMessage,
));
return Ok((messages, ThreadEvent::GeneratingReasoningMessage));
}
Err(anyhow::anyhow!("Assistant message missing required fields"))
@ -309,16 +340,16 @@ pub fn transform_message(message: Message) -> Result<(Vec<BusterThreadMessage>,
} => {
if let Some(name) = name {
let messages = match transform_tool_message(id, name, content, progress) {
Ok(messages) => messages,
Ok(messages) => messages
.into_iter()
.map(BusterContainer::ReasoningMessage)
.collect(),
Err(e) => {
return Err(e);
}
};
return Ok((
messages,
ThreadEvent::GeneratingReasoningMessage,
));
return Ok((messages, ThreadEvent::GeneratingReasoningMessage));
}
Err(anyhow::anyhow!("Tool message missing name field"))
@ -331,34 +362,42 @@ fn transform_text_message(
id: Option<String>,
content: String,
progress: Option<MessageProgress>,
) -> Result<Vec<BusterThreadMessage>> {
) -> Result<Vec<BusterChatMessageContainer>> {
if let Some(progress) = progress {
match progress {
MessageProgress::InProgress => {
Ok(vec![BusterThreadMessage::ChatMessage(BusterChatMessage {
MessageProgress::InProgress => Ok(vec![BusterChatMessageContainer {
response_message: BusterChatMessage {
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
message_type: "text".to_string(),
message: None,
message_chunk: Some(content),
})])
}
MessageProgress::Complete => {
Ok(vec![BusterThreadMessage::ChatMessage(BusterChatMessage {
},
chat_id: Uuid::new_v4(),
message_id: Uuid::new_v4(),
}]),
MessageProgress::Complete => Ok(vec![BusterChatMessageContainer {
response_message: BusterChatMessage {
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
message_type: "text".to_string(),
message: Some(content),
message_chunk: None,
})])
}
},
chat_id: Uuid::new_v4(),
message_id: Uuid::new_v4(),
}]),
_ => Err(anyhow::anyhow!("Unsupported message progress")),
}
} else {
Ok(vec![BusterThreadMessage::ChatMessage(BusterChatMessage {
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
message_type: "text".to_string(),
message: None,
message_chunk: None,
})])
Ok(vec![BusterChatMessageContainer {
response_message: BusterChatMessage {
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
message_type: "text".to_string(),
message: None,
message_chunk: None,
},
chat_id: Uuid::new_v4(),
message_id: Uuid::new_v4(),
}])
}
}
@ -367,8 +406,8 @@ fn transform_tool_message(
name: String,
content: String,
progress: Option<MessageProgress>,
) -> Result<Vec<BusterThreadMessage>> {
match name.as_str() {
) -> Result<Vec<BusterReasoningMessageContainer>> {
let messages = match name.as_str() {
"search_data_catalog" => tool_data_catalog_search(id, content, progress),
"stored_values_search" => tool_stored_values_search(id, content, progress),
"search_files" => tool_file_search(id, content, progress),
@ -376,7 +415,20 @@ fn transform_tool_message(
"modify_files" => tool_modify_file(id, content, progress),
"open_files" => tool_open_files(id, content, progress),
_ => Err(anyhow::anyhow!("Unsupported tool name")),
}
}?;
Ok(messages
.into_iter()
.map(|message| BusterReasoningMessageContainer {
reasoning: match message {
BusterThreadMessage::Thought(thought) => ReasoningMessage::Thought(thought),
BusterThreadMessage::File(file) => ReasoningMessage::File(file),
_ => unreachable!("Tool messages should only return Thought or File"),
},
chat_id: Uuid::new_v4(),
message_id: Uuid::new_v4(),
})
.collect())
}
fn transform_assistant_tool_message(
@ -384,9 +436,9 @@ fn transform_assistant_tool_message(
tool_calls: Vec<ToolCall>,
progress: Option<MessageProgress>,
initial: bool,
) -> Result<Vec<BusterThreadMessage>> {
) -> Result<Vec<BusterReasoningMessageContainer>> {
if let Some(tool_call) = tool_calls.first() {
match tool_call.function.name.as_str() {
let messages = match tool_call.function.name.as_str() {
"search_data_catalog" => assistant_data_catalog_search(id, progress, initial),
"stored_values_search" => assistant_stored_values_search(id, progress, initial),
"search_files" => assistant_file_search(id, progress, initial),
@ -394,7 +446,20 @@ fn transform_assistant_tool_message(
"modify_files" => assistant_modify_file(id, tool_calls, progress),
"open_files" => assistant_open_files(id, progress, initial),
_ => Err(anyhow::anyhow!("Unsupported tool name")),
}
}?;
Ok(messages
.into_iter()
.map(|message| BusterReasoningMessageContainer {
reasoning: match message {
BusterThreadMessage::Thought(thought) => ReasoningMessage::Thought(thought),
BusterThreadMessage::File(file) => ReasoningMessage::File(file),
_ => unreachable!("Assistant tool messages should only return Thought or File"),
},
chat_id: Uuid::new_v4(),
message_id: Uuid::new_v4(),
})
.collect())
} else {
Err(anyhow::anyhow!("Assistant tool message missing tool call"))
}

View File

@ -57,8 +57,8 @@ pub struct TempInitChatMessage {
#[derive(Debug, Deserialize, Clone)]
pub struct ChatCreateNewChat {
pub prompt: String,
pub chat_id: Option<String>,
pub message_id: Option<String>,
pub chat_id: Option<Uuid>,
pub message_id: Option<Uuid>,
}
#[derive(Debug, Serialize)]
@ -167,14 +167,17 @@ impl AgentThreadHandler {
async fn process_stream(
mut rx: Receiver<Result<Message, Error>>,
chat_id: Option<String>,
chat_id: Option<Uuid>,
user_id: &Uuid,
) {
let subscription = user_id.to_string();
let chat_id = chat_id.unwrap_or_else(|| Uuid::new_v4());
let message_id = Uuid::new_v4();
while let Some(msg_result) = rx.recv().await {
if let Ok(msg) = msg_result {
match transform_message(msg) {
match transform_message(chat_id, message_id, msg) {
Ok((transformed_messages, event)) => {
for transformed in transformed_messages {
let response = WsResponseMessage::new_no_user(

View File

@ -1,4 +1,5 @@
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::utils::clients::ai::litellm::Message;
@ -7,15 +8,15 @@ use crate::utils::clients::ai::litellm::Message;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentThread {
/// Unique identifier for the thread
pub id: String,
pub id: Uuid,
/// Ordered sequence of messages in the conversation
pub messages: Vec<Message>,
}
impl AgentThread {
pub fn new(id: Option<String>, messages: Vec<Message>) -> Self {
pub fn new(id: Option<Uuid>, messages: Vec<Message>) -> Self {
Self {
id: id.unwrap_or(uuid::Uuid::new_v4().to_string()),
id: id.unwrap_or(Uuid::new_v4()),
messages,
}
}