diff --git a/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs b/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs index aa5cf0a30..e659dcec2 100644 --- a/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs +++ b/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs @@ -4,6 +4,13 @@ use uuid::Uuid; use crate::utils::clients::ai::litellm::Message; +#[derive(Debug, Serialize)] +#[serde(untagged)] +pub enum BusterThreadMessage { + ChatMessage(BusterChatMessage), + Thought(BusterThought), +} + #[derive(Debug, Serialize)] pub struct BusterChatMessage { pub id: String, @@ -13,17 +20,53 @@ pub struct BusterChatMessage { pub message_chunk: Option, } -pub fn transform_message(message: Message) -> Result { +#[derive(Debug, Serialize)] +pub struct BusterThought { + pub id: String, + #[serde(rename = "type")] + pub thought_type: String, + pub thought_title: String, + pub thought_secondary_title: String, + pub thought_pills: Option>, + pub status: String, +} + +#[derive(Debug, Serialize)] +pub struct BusterThoughtPill { + pub id: String, + pub text: String, + #[serde(rename = "type")] + pub thought_file_type: String, +} + +pub fn transform_message(message: Message) -> Result { match message { Message::Assistant { id, content, .. } => { let id = id.unwrap_or_else(|| Uuid::new_v4().to_string()); let content = content.ok_or_else(|| anyhow::anyhow!("Missing content"))?; - Ok(BusterChatMessage { + Ok(BusterThreadMessage::ChatMessage(BusterChatMessage { id, message_type: "text".to_string(), message: None, message_chunk: Some(content), - }) + })) + } + Message::Tool { + id, + content, + tool_call_id, + name, + progress, + } => { + tracing::debug!("Tool message: {:?}", message); + Ok(BusterThreadMessage::Thought(BusterThought { + id, + thought_type: "text".to_string(), + thought_title: "".to_string(), + thought_secondary_title: "".to_string(), + thought_pills: None, + status: "".to_string(), + })) } _ => Err(anyhow::anyhow!("Unsupported message type")), } @@ -57,6 +100,7 @@ mod tests { content: "content".to_string(), tool_call_id: "test".to_string(), name: None, + progress: None, }; let result = transform_message(message); diff --git a/api/src/utils/agent/agent.rs b/api/src/utils/agent/agent.rs index 47f084436..981542e67 100644 --- a/api/src/utils/agent/agent.rs +++ b/api/src/utils/agent/agent.rs @@ -183,7 +183,7 @@ impl Agent { if let Some(tool) = self.tools.get(&tool_call.function.name) { let result = tool.execute(tool_call).await?; let result_str = serde_json::to_string(&result)?; - results.push(Message::tool(result_str, tool_call.id.clone())); + results.push(Message::tool(result_str, tool_call.id.clone(), None)); } } @@ -375,6 +375,7 @@ impl Agent { let tool_result = Message::tool( result_str, tool_call.id.clone(), + None, ); let _ = tx.send(Ok(tool_result.clone())).await; @@ -388,6 +389,7 @@ impl Agent { let tool_error = Message::tool( error_msg, tool_call.id.clone(), + None, ); let _ = tx.send(Ok(tool_error.clone())).await; diff --git a/api/src/utils/clients/ai/litellm/types.rs b/api/src/utils/clients/ai/litellm/types.rs index 30e53d220..2b69b629c 100644 --- a/api/src/utils/clients/ai/litellm/types.rs +++ b/api/src/utils/clients/ai/litellm/types.rs @@ -130,6 +130,8 @@ pub enum Message { tool_call_id: String, #[serde(skip_serializing_if = "Option::is_none")] name: Option, + #[serde(skip)] + progress: Option, }, } @@ -162,16 +164,21 @@ impl Message { content, name: None, tool_calls, - progress: None, + progress, } } - pub fn tool(content: impl Into, tool_call_id: impl Into) -> Self { + pub fn tool( + content: impl Into, + tool_call_id: impl Into, + progress: Option, + ) -> Self { Self::Tool { id: None, content: content.into(), tool_call_id: tool_call_id.into(), name: None, + progress, } }