mirror of https://github.com/buster-so/buster.git
Refactor message transformation with new container types and UUIDs
This commit is contained in:
parent
14d379d942
commit
8b96ec01fb
|
@ -205,6 +205,27 @@ pub enum BusterThreadMessage {
|
|||
File(BusterFileMessage),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct BusterChatMessageContainer {
|
||||
pub response_message: BusterChatMessage,
|
||||
pub chat_id: Uuid,
|
||||
pub message_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ReasoningMessage {
|
||||
Thought(BusterThought),
|
||||
File(BusterFileMessage),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct BusterReasoningMessageContainer {
|
||||
pub reasoning: ReasoningMessage,
|
||||
pub chat_id: Uuid,
|
||||
pub message_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct BusterChatMessage {
|
||||
pub id: String,
|
||||
|
@ -258,9 +279,18 @@ pub struct BusterFileLine {
|
|||
pub text: String,
|
||||
}
|
||||
|
||||
pub fn transform_message(message: Message) -> Result<(Vec<BusterThreadMessage>, ThreadEvent)> {
|
||||
println!("transform_message: {:?}", message);
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum BusterContainer {
|
||||
ChatMessage(BusterChatMessageContainer),
|
||||
ReasoningMessage(BusterReasoningMessageContainer),
|
||||
}
|
||||
|
||||
pub fn transform_message(
|
||||
chat_id: Uuid,
|
||||
message_id: Uuid,
|
||||
message: Message,
|
||||
) -> Result<(Vec<BusterContainer>, ThreadEvent)> {
|
||||
match message {
|
||||
Message::Assistant {
|
||||
id,
|
||||
|
@ -272,30 +302,31 @@ pub fn transform_message(message: Message) -> Result<(Vec<BusterThreadMessage>,
|
|||
} => {
|
||||
if let Some(content) = content {
|
||||
let messages = match transform_text_message(id, content, progress) {
|
||||
Ok(messages) => messages,
|
||||
Ok(messages) => messages
|
||||
.into_iter()
|
||||
.map(BusterContainer::ChatMessage)
|
||||
.collect(),
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
return Ok((
|
||||
messages,
|
||||
ThreadEvent::GeneratingResponseMessage,
|
||||
));
|
||||
return Ok((messages, ThreadEvent::GeneratingResponseMessage));
|
||||
}
|
||||
|
||||
if let Some(tool_calls) = tool_calls {
|
||||
let messages = match transform_assistant_tool_message(id, tool_calls, progress, initial) {
|
||||
Ok(messages) => messages,
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
let messages =
|
||||
match transform_assistant_tool_message(id, tool_calls, progress, initial) {
|
||||
Ok(messages) => messages
|
||||
.into_iter()
|
||||
.map(BusterContainer::ReasoningMessage)
|
||||
.collect(),
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
return Ok((
|
||||
messages,
|
||||
ThreadEvent::GeneratingReasoningMessage,
|
||||
));
|
||||
return Ok((messages, ThreadEvent::GeneratingReasoningMessage));
|
||||
}
|
||||
|
||||
Err(anyhow::anyhow!("Assistant message missing required fields"))
|
||||
|
@ -309,16 +340,16 @@ pub fn transform_message(message: Message) -> Result<(Vec<BusterThreadMessage>,
|
|||
} => {
|
||||
if let Some(name) = name {
|
||||
let messages = match transform_tool_message(id, name, content, progress) {
|
||||
Ok(messages) => messages,
|
||||
Ok(messages) => messages
|
||||
.into_iter()
|
||||
.map(BusterContainer::ReasoningMessage)
|
||||
.collect(),
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
return Ok((
|
||||
messages,
|
||||
ThreadEvent::GeneratingReasoningMessage,
|
||||
));
|
||||
return Ok((messages, ThreadEvent::GeneratingReasoningMessage));
|
||||
}
|
||||
|
||||
Err(anyhow::anyhow!("Tool message missing name field"))
|
||||
|
@ -331,34 +362,42 @@ fn transform_text_message(
|
|||
id: Option<String>,
|
||||
content: String,
|
||||
progress: Option<MessageProgress>,
|
||||
) -> Result<Vec<BusterThreadMessage>> {
|
||||
) -> Result<Vec<BusterChatMessageContainer>> {
|
||||
if let Some(progress) = progress {
|
||||
match progress {
|
||||
MessageProgress::InProgress => {
|
||||
Ok(vec![BusterThreadMessage::ChatMessage(BusterChatMessage {
|
||||
MessageProgress::InProgress => Ok(vec![BusterChatMessageContainer {
|
||||
response_message: BusterChatMessage {
|
||||
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
|
||||
message_type: "text".to_string(),
|
||||
message: None,
|
||||
message_chunk: Some(content),
|
||||
})])
|
||||
}
|
||||
MessageProgress::Complete => {
|
||||
Ok(vec![BusterThreadMessage::ChatMessage(BusterChatMessage {
|
||||
},
|
||||
chat_id: Uuid::new_v4(),
|
||||
message_id: Uuid::new_v4(),
|
||||
}]),
|
||||
MessageProgress::Complete => Ok(vec![BusterChatMessageContainer {
|
||||
response_message: BusterChatMessage {
|
||||
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
|
||||
message_type: "text".to_string(),
|
||||
message: Some(content),
|
||||
message_chunk: None,
|
||||
})])
|
||||
}
|
||||
},
|
||||
chat_id: Uuid::new_v4(),
|
||||
message_id: Uuid::new_v4(),
|
||||
}]),
|
||||
_ => Err(anyhow::anyhow!("Unsupported message progress")),
|
||||
}
|
||||
} else {
|
||||
Ok(vec![BusterThreadMessage::ChatMessage(BusterChatMessage {
|
||||
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
|
||||
message_type: "text".to_string(),
|
||||
message: None,
|
||||
message_chunk: None,
|
||||
})])
|
||||
Ok(vec![BusterChatMessageContainer {
|
||||
response_message: BusterChatMessage {
|
||||
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
|
||||
message_type: "text".to_string(),
|
||||
message: None,
|
||||
message_chunk: None,
|
||||
},
|
||||
chat_id: Uuid::new_v4(),
|
||||
message_id: Uuid::new_v4(),
|
||||
}])
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -367,8 +406,8 @@ fn transform_tool_message(
|
|||
name: String,
|
||||
content: String,
|
||||
progress: Option<MessageProgress>,
|
||||
) -> Result<Vec<BusterThreadMessage>> {
|
||||
match name.as_str() {
|
||||
) -> Result<Vec<BusterReasoningMessageContainer>> {
|
||||
let messages = match name.as_str() {
|
||||
"search_data_catalog" => tool_data_catalog_search(id, content, progress),
|
||||
"stored_values_search" => tool_stored_values_search(id, content, progress),
|
||||
"search_files" => tool_file_search(id, content, progress),
|
||||
|
@ -376,7 +415,20 @@ fn transform_tool_message(
|
|||
"modify_files" => tool_modify_file(id, content, progress),
|
||||
"open_files" => tool_open_files(id, content, progress),
|
||||
_ => Err(anyhow::anyhow!("Unsupported tool name")),
|
||||
}
|
||||
}?;
|
||||
|
||||
Ok(messages
|
||||
.into_iter()
|
||||
.map(|message| BusterReasoningMessageContainer {
|
||||
reasoning: match message {
|
||||
BusterThreadMessage::Thought(thought) => ReasoningMessage::Thought(thought),
|
||||
BusterThreadMessage::File(file) => ReasoningMessage::File(file),
|
||||
_ => unreachable!("Tool messages should only return Thought or File"),
|
||||
},
|
||||
chat_id: Uuid::new_v4(),
|
||||
message_id: Uuid::new_v4(),
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn transform_assistant_tool_message(
|
||||
|
@ -384,9 +436,9 @@ fn transform_assistant_tool_message(
|
|||
tool_calls: Vec<ToolCall>,
|
||||
progress: Option<MessageProgress>,
|
||||
initial: bool,
|
||||
) -> Result<Vec<BusterThreadMessage>> {
|
||||
) -> Result<Vec<BusterReasoningMessageContainer>> {
|
||||
if let Some(tool_call) = tool_calls.first() {
|
||||
match tool_call.function.name.as_str() {
|
||||
let messages = match tool_call.function.name.as_str() {
|
||||
"search_data_catalog" => assistant_data_catalog_search(id, progress, initial),
|
||||
"stored_values_search" => assistant_stored_values_search(id, progress, initial),
|
||||
"search_files" => assistant_file_search(id, progress, initial),
|
||||
|
@ -394,7 +446,20 @@ fn transform_assistant_tool_message(
|
|||
"modify_files" => assistant_modify_file(id, tool_calls, progress),
|
||||
"open_files" => assistant_open_files(id, progress, initial),
|
||||
_ => Err(anyhow::anyhow!("Unsupported tool name")),
|
||||
}
|
||||
}?;
|
||||
|
||||
Ok(messages
|
||||
.into_iter()
|
||||
.map(|message| BusterReasoningMessageContainer {
|
||||
reasoning: match message {
|
||||
BusterThreadMessage::Thought(thought) => ReasoningMessage::Thought(thought),
|
||||
BusterThreadMessage::File(file) => ReasoningMessage::File(file),
|
||||
_ => unreachable!("Assistant tool messages should only return Thought or File"),
|
||||
},
|
||||
chat_id: Uuid::new_v4(),
|
||||
message_id: Uuid::new_v4(),
|
||||
})
|
||||
.collect())
|
||||
} else {
|
||||
Err(anyhow::anyhow!("Assistant tool message missing tool call"))
|
||||
}
|
||||
|
|
|
@ -57,8 +57,8 @@ pub struct TempInitChatMessage {
|
|||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct ChatCreateNewChat {
|
||||
pub prompt: String,
|
||||
pub chat_id: Option<String>,
|
||||
pub message_id: Option<String>,
|
||||
pub chat_id: Option<Uuid>,
|
||||
pub message_id: Option<Uuid>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
|
@ -167,14 +167,17 @@ impl AgentThreadHandler {
|
|||
|
||||
async fn process_stream(
|
||||
mut rx: Receiver<Result<Message, Error>>,
|
||||
chat_id: Option<String>,
|
||||
chat_id: Option<Uuid>,
|
||||
user_id: &Uuid,
|
||||
) {
|
||||
let subscription = user_id.to_string();
|
||||
|
||||
let chat_id = chat_id.unwrap_or_else(|| Uuid::new_v4());
|
||||
let message_id = Uuid::new_v4();
|
||||
|
||||
while let Some(msg_result) = rx.recv().await {
|
||||
if let Ok(msg) = msg_result {
|
||||
match transform_message(msg) {
|
||||
match transform_message(chat_id, message_id, msg) {
|
||||
Ok((transformed_messages, event)) => {
|
||||
for transformed in transformed_messages {
|
||||
let response = WsResponseMessage::new_no_user(
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::utils::clients::ai::litellm::Message;
|
||||
|
||||
|
@ -7,15 +8,15 @@ use crate::utils::clients::ai::litellm::Message;
|
|||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AgentThread {
|
||||
/// Unique identifier for the thread
|
||||
pub id: String,
|
||||
pub id: Uuid,
|
||||
/// Ordered sequence of messages in the conversation
|
||||
pub messages: Vec<Message>,
|
||||
}
|
||||
|
||||
impl AgentThread {
|
||||
pub fn new(id: Option<String>, messages: Vec<Message>) -> Self {
|
||||
pub fn new(id: Option<Uuid>, messages: Vec<Message>) -> Self {
|
||||
Self {
|
||||
id: id.unwrap_or(uuid::Uuid::new_v4().to_string()),
|
||||
id: id.unwrap_or(Uuid::new_v4()),
|
||||
messages,
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue