added in progress on tools

This commit is contained in:
dal 2025-02-10 08:24:56 -07:00
parent 8bd14e0ee7
commit c849a22b4f
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
3 changed files with 59 additions and 6 deletions

View File

@ -4,6 +4,13 @@ use uuid::Uuid;
use crate::utils::clients::ai::litellm::Message;
#[derive(Debug, Serialize)]
#[serde(untagged)]
pub enum BusterThreadMessage {
ChatMessage(BusterChatMessage),
Thought(BusterThought),
}
#[derive(Debug, Serialize)]
pub struct BusterChatMessage {
pub id: String,
@ -13,17 +20,53 @@ pub struct BusterChatMessage {
pub message_chunk: Option<String>,
}
pub fn transform_message(message: Message) -> Result<BusterChatMessage> {
#[derive(Debug, Serialize)]
pub struct BusterThought {
pub id: String,
#[serde(rename = "type")]
pub thought_type: String,
pub thought_title: String,
pub thought_secondary_title: String,
pub thought_pills: Option<Vec<BusterThoughtPill>>,
pub status: String,
}
#[derive(Debug, Serialize)]
pub struct BusterThoughtPill {
pub id: String,
pub text: String,
#[serde(rename = "type")]
pub thought_file_type: String,
}
pub fn transform_message(message: Message) -> Result<BusterThreadMessage> {
match message {
Message::Assistant { id, content, .. } => {
let id = id.unwrap_or_else(|| Uuid::new_v4().to_string());
let content = content.ok_or_else(|| anyhow::anyhow!("Missing content"))?;
Ok(BusterChatMessage {
Ok(BusterThreadMessage::ChatMessage(BusterChatMessage {
id,
message_type: "text".to_string(),
message: None,
message_chunk: Some(content),
})
}))
}
Message::Tool {
id,
content,
tool_call_id,
name,
progress,
} => {
tracing::debug!("Tool message: {:?}", message);
Ok(BusterThreadMessage::Thought(BusterThought {
id,
thought_type: "text".to_string(),
thought_title: "".to_string(),
thought_secondary_title: "".to_string(),
thought_pills: None,
status: "".to_string(),
}))
}
_ => Err(anyhow::anyhow!("Unsupported message type")),
}
@ -57,6 +100,7 @@ mod tests {
content: "content".to_string(),
tool_call_id: "test".to_string(),
name: None,
progress: None,
};
let result = transform_message(message);

View File

@ -183,7 +183,7 @@ impl Agent {
if let Some(tool) = self.tools.get(&tool_call.function.name) {
let result = tool.execute(tool_call).await?;
let result_str = serde_json::to_string(&result)?;
results.push(Message::tool(result_str, tool_call.id.clone()));
results.push(Message::tool(result_str, tool_call.id.clone(), None));
}
}
@ -375,6 +375,7 @@ impl Agent {
let tool_result = Message::tool(
result_str,
tool_call.id.clone(),
None,
);
let _ = tx.send(Ok(tool_result.clone())).await;
@ -388,6 +389,7 @@ impl Agent {
let tool_error = Message::tool(
error_msg,
tool_call.id.clone(),
None,
);
let _ = tx.send(Ok(tool_error.clone())).await;

View File

@ -130,6 +130,8 @@ pub enum Message {
tool_call_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(skip)]
progress: Option<MessageProgress>,
},
}
@ -162,16 +164,21 @@ impl Message {
content,
name: None,
tool_calls,
progress: None,
progress,
}
}
pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
pub fn tool(
content: impl Into<String>,
tool_call_id: impl Into<String>,
progress: Option<MessageProgress>,
) -> Self {
Self::Tool {
id: None,
content: content.into(),
tool_call_id: tool_call_id.into(),
name: None,
progress,
}
}