mirror of https://github.com/buster-so/buster.git
Refactor message transformation with new container types and UUIDs
This commit is contained in:
parent
14d379d942
commit
8b96ec01fb
|
@ -205,6 +205,27 @@ pub enum BusterThreadMessage {
|
||||||
File(BusterFileMessage),
|
File(BusterFileMessage),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct BusterChatMessageContainer {
|
||||||
|
pub response_message: BusterChatMessage,
|
||||||
|
pub chat_id: Uuid,
|
||||||
|
pub message_id: Uuid,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum ReasoningMessage {
|
||||||
|
Thought(BusterThought),
|
||||||
|
File(BusterFileMessage),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct BusterReasoningMessageContainer {
|
||||||
|
pub reasoning: ReasoningMessage,
|
||||||
|
pub chat_id: Uuid,
|
||||||
|
pub message_id: Uuid,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
pub struct BusterChatMessage {
|
pub struct BusterChatMessage {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
|
@ -258,9 +279,18 @@ pub struct BusterFileLine {
|
||||||
pub text: String,
|
pub text: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn transform_message(message: Message) -> Result<(Vec<BusterThreadMessage>, ThreadEvent)> {
|
#[derive(Debug, Serialize)]
|
||||||
println!("transform_message: {:?}", message);
|
#[serde(untagged)]
|
||||||
|
pub enum BusterContainer {
|
||||||
|
ChatMessage(BusterChatMessageContainer),
|
||||||
|
ReasoningMessage(BusterReasoningMessageContainer),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn transform_message(
|
||||||
|
chat_id: Uuid,
|
||||||
|
message_id: Uuid,
|
||||||
|
message: Message,
|
||||||
|
) -> Result<(Vec<BusterContainer>, ThreadEvent)> {
|
||||||
match message {
|
match message {
|
||||||
Message::Assistant {
|
Message::Assistant {
|
||||||
id,
|
id,
|
||||||
|
@ -272,30 +302,31 @@ pub fn transform_message(message: Message) -> Result<(Vec<BusterThreadMessage>,
|
||||||
} => {
|
} => {
|
||||||
if let Some(content) = content {
|
if let Some(content) = content {
|
||||||
let messages = match transform_text_message(id, content, progress) {
|
let messages = match transform_text_message(id, content, progress) {
|
||||||
Ok(messages) => messages,
|
Ok(messages) => messages
|
||||||
|
.into_iter()
|
||||||
|
.map(BusterContainer::ChatMessage)
|
||||||
|
.collect(),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return Err(e);
|
return Err(e);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
return Ok((
|
return Ok((messages, ThreadEvent::GeneratingResponseMessage));
|
||||||
messages,
|
|
||||||
ThreadEvent::GeneratingResponseMessage,
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(tool_calls) = tool_calls {
|
if let Some(tool_calls) = tool_calls {
|
||||||
let messages = match transform_assistant_tool_message(id, tool_calls, progress, initial) {
|
let messages =
|
||||||
Ok(messages) => messages,
|
match transform_assistant_tool_message(id, tool_calls, progress, initial) {
|
||||||
|
Ok(messages) => messages
|
||||||
|
.into_iter()
|
||||||
|
.map(BusterContainer::ReasoningMessage)
|
||||||
|
.collect(),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return Err(e);
|
return Err(e);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
return Ok((
|
return Ok((messages, ThreadEvent::GeneratingReasoningMessage));
|
||||||
messages,
|
|
||||||
ThreadEvent::GeneratingReasoningMessage,
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Err(anyhow::anyhow!("Assistant message missing required fields"))
|
Err(anyhow::anyhow!("Assistant message missing required fields"))
|
||||||
|
@ -309,16 +340,16 @@ pub fn transform_message(message: Message) -> Result<(Vec<BusterThreadMessage>,
|
||||||
} => {
|
} => {
|
||||||
if let Some(name) = name {
|
if let Some(name) = name {
|
||||||
let messages = match transform_tool_message(id, name, content, progress) {
|
let messages = match transform_tool_message(id, name, content, progress) {
|
||||||
Ok(messages) => messages,
|
Ok(messages) => messages
|
||||||
|
.into_iter()
|
||||||
|
.map(BusterContainer::ReasoningMessage)
|
||||||
|
.collect(),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return Err(e);
|
return Err(e);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
return Ok((
|
return Ok((messages, ThreadEvent::GeneratingReasoningMessage));
|
||||||
messages,
|
|
||||||
ThreadEvent::GeneratingReasoningMessage,
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Err(anyhow::anyhow!("Tool message missing name field"))
|
Err(anyhow::anyhow!("Tool message missing name field"))
|
||||||
|
@ -331,34 +362,42 @@ fn transform_text_message(
|
||||||
id: Option<String>,
|
id: Option<String>,
|
||||||
content: String,
|
content: String,
|
||||||
progress: Option<MessageProgress>,
|
progress: Option<MessageProgress>,
|
||||||
) -> Result<Vec<BusterThreadMessage>> {
|
) -> Result<Vec<BusterChatMessageContainer>> {
|
||||||
if let Some(progress) = progress {
|
if let Some(progress) = progress {
|
||||||
match progress {
|
match progress {
|
||||||
MessageProgress::InProgress => {
|
MessageProgress::InProgress => Ok(vec![BusterChatMessageContainer {
|
||||||
Ok(vec![BusterThreadMessage::ChatMessage(BusterChatMessage {
|
response_message: BusterChatMessage {
|
||||||
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
|
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
|
||||||
message_type: "text".to_string(),
|
message_type: "text".to_string(),
|
||||||
message: None,
|
message: None,
|
||||||
message_chunk: Some(content),
|
message_chunk: Some(content),
|
||||||
})])
|
},
|
||||||
}
|
chat_id: Uuid::new_v4(),
|
||||||
MessageProgress::Complete => {
|
message_id: Uuid::new_v4(),
|
||||||
Ok(vec![BusterThreadMessage::ChatMessage(BusterChatMessage {
|
}]),
|
||||||
|
MessageProgress::Complete => Ok(vec![BusterChatMessageContainer {
|
||||||
|
response_message: BusterChatMessage {
|
||||||
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
|
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
|
||||||
message_type: "text".to_string(),
|
message_type: "text".to_string(),
|
||||||
message: Some(content),
|
message: Some(content),
|
||||||
message_chunk: None,
|
message_chunk: None,
|
||||||
})])
|
},
|
||||||
}
|
chat_id: Uuid::new_v4(),
|
||||||
|
message_id: Uuid::new_v4(),
|
||||||
|
}]),
|
||||||
_ => Err(anyhow::anyhow!("Unsupported message progress")),
|
_ => Err(anyhow::anyhow!("Unsupported message progress")),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Ok(vec![BusterThreadMessage::ChatMessage(BusterChatMessage {
|
Ok(vec![BusterChatMessageContainer {
|
||||||
|
response_message: BusterChatMessage {
|
||||||
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
|
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
|
||||||
message_type: "text".to_string(),
|
message_type: "text".to_string(),
|
||||||
message: None,
|
message: None,
|
||||||
message_chunk: None,
|
message_chunk: None,
|
||||||
})])
|
},
|
||||||
|
chat_id: Uuid::new_v4(),
|
||||||
|
message_id: Uuid::new_v4(),
|
||||||
|
}])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -367,8 +406,8 @@ fn transform_tool_message(
|
||||||
name: String,
|
name: String,
|
||||||
content: String,
|
content: String,
|
||||||
progress: Option<MessageProgress>,
|
progress: Option<MessageProgress>,
|
||||||
) -> Result<Vec<BusterThreadMessage>> {
|
) -> Result<Vec<BusterReasoningMessageContainer>> {
|
||||||
match name.as_str() {
|
let messages = match name.as_str() {
|
||||||
"search_data_catalog" => tool_data_catalog_search(id, content, progress),
|
"search_data_catalog" => tool_data_catalog_search(id, content, progress),
|
||||||
"stored_values_search" => tool_stored_values_search(id, content, progress),
|
"stored_values_search" => tool_stored_values_search(id, content, progress),
|
||||||
"search_files" => tool_file_search(id, content, progress),
|
"search_files" => tool_file_search(id, content, progress),
|
||||||
|
@ -376,7 +415,20 @@ fn transform_tool_message(
|
||||||
"modify_files" => tool_modify_file(id, content, progress),
|
"modify_files" => tool_modify_file(id, content, progress),
|
||||||
"open_files" => tool_open_files(id, content, progress),
|
"open_files" => tool_open_files(id, content, progress),
|
||||||
_ => Err(anyhow::anyhow!("Unsupported tool name")),
|
_ => Err(anyhow::anyhow!("Unsupported tool name")),
|
||||||
}
|
}?;
|
||||||
|
|
||||||
|
Ok(messages
|
||||||
|
.into_iter()
|
||||||
|
.map(|message| BusterReasoningMessageContainer {
|
||||||
|
reasoning: match message {
|
||||||
|
BusterThreadMessage::Thought(thought) => ReasoningMessage::Thought(thought),
|
||||||
|
BusterThreadMessage::File(file) => ReasoningMessage::File(file),
|
||||||
|
_ => unreachable!("Tool messages should only return Thought or File"),
|
||||||
|
},
|
||||||
|
chat_id: Uuid::new_v4(),
|
||||||
|
message_id: Uuid::new_v4(),
|
||||||
|
})
|
||||||
|
.collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn transform_assistant_tool_message(
|
fn transform_assistant_tool_message(
|
||||||
|
@ -384,9 +436,9 @@ fn transform_assistant_tool_message(
|
||||||
tool_calls: Vec<ToolCall>,
|
tool_calls: Vec<ToolCall>,
|
||||||
progress: Option<MessageProgress>,
|
progress: Option<MessageProgress>,
|
||||||
initial: bool,
|
initial: bool,
|
||||||
) -> Result<Vec<BusterThreadMessage>> {
|
) -> Result<Vec<BusterReasoningMessageContainer>> {
|
||||||
if let Some(tool_call) = tool_calls.first() {
|
if let Some(tool_call) = tool_calls.first() {
|
||||||
match tool_call.function.name.as_str() {
|
let messages = match tool_call.function.name.as_str() {
|
||||||
"search_data_catalog" => assistant_data_catalog_search(id, progress, initial),
|
"search_data_catalog" => assistant_data_catalog_search(id, progress, initial),
|
||||||
"stored_values_search" => assistant_stored_values_search(id, progress, initial),
|
"stored_values_search" => assistant_stored_values_search(id, progress, initial),
|
||||||
"search_files" => assistant_file_search(id, progress, initial),
|
"search_files" => assistant_file_search(id, progress, initial),
|
||||||
|
@ -394,7 +446,20 @@ fn transform_assistant_tool_message(
|
||||||
"modify_files" => assistant_modify_file(id, tool_calls, progress),
|
"modify_files" => assistant_modify_file(id, tool_calls, progress),
|
||||||
"open_files" => assistant_open_files(id, progress, initial),
|
"open_files" => assistant_open_files(id, progress, initial),
|
||||||
_ => Err(anyhow::anyhow!("Unsupported tool name")),
|
_ => Err(anyhow::anyhow!("Unsupported tool name")),
|
||||||
}
|
}?;
|
||||||
|
|
||||||
|
Ok(messages
|
||||||
|
.into_iter()
|
||||||
|
.map(|message| BusterReasoningMessageContainer {
|
||||||
|
reasoning: match message {
|
||||||
|
BusterThreadMessage::Thought(thought) => ReasoningMessage::Thought(thought),
|
||||||
|
BusterThreadMessage::File(file) => ReasoningMessage::File(file),
|
||||||
|
_ => unreachable!("Assistant tool messages should only return Thought or File"),
|
||||||
|
},
|
||||||
|
chat_id: Uuid::new_v4(),
|
||||||
|
message_id: Uuid::new_v4(),
|
||||||
|
})
|
||||||
|
.collect())
|
||||||
} else {
|
} else {
|
||||||
Err(anyhow::anyhow!("Assistant tool message missing tool call"))
|
Err(anyhow::anyhow!("Assistant tool message missing tool call"))
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,8 +57,8 @@ pub struct TempInitChatMessage {
|
||||||
#[derive(Debug, Deserialize, Clone)]
|
#[derive(Debug, Deserialize, Clone)]
|
||||||
pub struct ChatCreateNewChat {
|
pub struct ChatCreateNewChat {
|
||||||
pub prompt: String,
|
pub prompt: String,
|
||||||
pub chat_id: Option<String>,
|
pub chat_id: Option<Uuid>,
|
||||||
pub message_id: Option<String>,
|
pub message_id: Option<Uuid>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
|
@ -167,14 +167,17 @@ impl AgentThreadHandler {
|
||||||
|
|
||||||
async fn process_stream(
|
async fn process_stream(
|
||||||
mut rx: Receiver<Result<Message, Error>>,
|
mut rx: Receiver<Result<Message, Error>>,
|
||||||
chat_id: Option<String>,
|
chat_id: Option<Uuid>,
|
||||||
user_id: &Uuid,
|
user_id: &Uuid,
|
||||||
) {
|
) {
|
||||||
let subscription = user_id.to_string();
|
let subscription = user_id.to_string();
|
||||||
|
|
||||||
|
let chat_id = chat_id.unwrap_or_else(|| Uuid::new_v4());
|
||||||
|
let message_id = Uuid::new_v4();
|
||||||
|
|
||||||
while let Some(msg_result) = rx.recv().await {
|
while let Some(msg_result) = rx.recv().await {
|
||||||
if let Ok(msg) = msg_result {
|
if let Ok(msg) = msg_result {
|
||||||
match transform_message(msg) {
|
match transform_message(chat_id, message_id, msg) {
|
||||||
Ok((transformed_messages, event)) => {
|
Ok((transformed_messages, event)) => {
|
||||||
for transformed in transformed_messages {
|
for transformed in transformed_messages {
|
||||||
let response = WsResponseMessage::new_no_user(
|
let response = WsResponseMessage::new_no_user(
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::utils::clients::ai::litellm::Message;
|
use crate::utils::clients::ai::litellm::Message;
|
||||||
|
|
||||||
|
@ -7,15 +8,15 @@ use crate::utils::clients::ai::litellm::Message;
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct AgentThread {
|
pub struct AgentThread {
|
||||||
/// Unique identifier for the thread
|
/// Unique identifier for the thread
|
||||||
pub id: String,
|
pub id: Uuid,
|
||||||
/// Ordered sequence of messages in the conversation
|
/// Ordered sequence of messages in the conversation
|
||||||
pub messages: Vec<Message>,
|
pub messages: Vec<Message>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AgentThread {
|
impl AgentThread {
|
||||||
pub fn new(id: Option<String>, messages: Vec<Message>) -> Self {
|
pub fn new(id: Option<Uuid>, messages: Vec<Message>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
id: id.unwrap_or(uuid::Uuid::new_v4().to_string()),
|
id: id.unwrap_or(Uuid::new_v4()),
|
||||||
messages,
|
messages,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue