ok everything sending back except create and modify

This commit is contained in:
dal 2025-02-10 12:15:21 -07:00
parent d0400b5226
commit 8c8372b50e
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
3 changed files with 145 additions and 24 deletions

View File

@ -6,6 +6,8 @@ use uuid::Uuid;
use crate::utils::clients::ai::litellm::{Message, MessageProgress, ToolCall};
use crate::utils::tools::file_tools::file_types::file::FileEnum;
use crate::utils::tools::file_tools::open_files::OpenFilesOutput;
use crate::utils::tools::file_tools::search_data_catalog::SearchDataCatalogOutput;
use crate::utils::tools::file_tools::search_files::SearchFilesOutput;
@ -51,7 +53,7 @@ pub struct BusterThoughtPill {
}
pub fn transform_message(message: Message) -> Result<BusterThreadMessage> {
let buster_message = match message {
match message {
Message::Assistant {
id,
content,
@ -66,6 +68,8 @@ pub fn transform_message(message: Message) -> Result<BusterThreadMessage> {
if let (Some(name), Some(tool_calls)) = (name, tool_calls) {
return transform_assistant_tool_message(id, name, tool_calls, progress);
}
Err(anyhow::anyhow!("Assistant message missing required fields"))
}
Message::Tool {
id,
@ -77,9 +81,11 @@ pub fn transform_message(message: Message) -> Result<BusterThreadMessage> {
if let Some(name) = name {
return transform_tool_message(id, name, content, progress);
}
Err(anyhow::anyhow!("Tool message missing name field"))
}
_ => Err(anyhow::anyhow!("Unsupported message type")),
};
}
}
fn transform_text_message(
@ -122,11 +128,12 @@ fn transform_tool_message(
progress: Option<MessageProgress>,
) -> Result<BusterThreadMessage> {
match name.as_str() {
"data_catalog_search" => assistant_data_catalog_search(id, progress),
"stored_values_search" => assistant_stored_values_search(id, progress),
"file_search" => assistant_file_search(id, progress),
"create_file" => assistant_create_file(id, content, progress),
"modify_file" => assistant_modify_file(id, content, progress),
"data_catalog_search" => tool_data_catalog_search(id, content, progress),
"stored_values_search" => tool_stored_values_search(id, content, progress),
"file_search" => tool_file_search(id, content, progress),
"create_file" => tool_create_file(id, content, progress),
"modify_file" => tool_modify_file(id, content, progress),
"open_files" => tool_open_files(id, content, progress),
_ => Err(anyhow::anyhow!("Unsupported tool name")),
}
}
@ -141,8 +148,10 @@ fn transform_assistant_tool_message(
"data_catalog_search" => assistant_data_catalog_search(id, progress),
"stored_values_search" => assistant_stored_values_search(id, progress),
"file_search" => assistant_file_search(id, progress),
"create_file" => assistant_create_file(id, "".to_string(), progress),
"modify_file" => assistant_modify_file(id, "".to_string(), progress),
"create_file" => assistant_create_file(id, progress),
"modify_file" => assistant_modify_file(id, progress),
"open_files" => assistant_open_files(id, progress),
_ => Err(anyhow::anyhow!("Unsupported tool name")),
}
}
@ -466,7 +475,7 @@ fn process_file_search_results(
Ok(buster_thought_pill_containers)
}
fn assistant_open_file(
fn assistant_open_files(
id: Option<String>,
progress: Option<MessageProgress>,
) -> Result<BusterThreadMessage> {
@ -489,22 +498,69 @@ fn assistant_open_file(
}
}
fn tool_open_file(
fn tool_open_files(
id: Option<String>,
content: String,
progress: Option<MessageProgress>,
) -> Result<BusterThreadMessage> {
Ok(BusterThreadMessage::ChatMessage(BusterChatMessage {
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
message_type: "text".to_string(),
message: Some(content),
message_chunk: None,
}))
if let Some(progress) = progress {
let open_files_result = match serde_json::from_str::<OpenFilesOutput>(&content) {
Ok(result) => result,
Err(e) => {
return Err(anyhow::anyhow!("Failed to parse open files result: {}", e));
}
};
let duration = (open_files_result.duration as f64 / 1000.0 * 10.0).round() / 10.0;
let result_count = open_files_result.results.len();
let mut file_results: HashMap<String, Vec<BusterThoughtPill>> = HashMap::new();
for result in open_files_result.results {
let file_type = match result {
FileEnum::Dashboard(_) => "dashboard",
FileEnum::Metric(_) => "metric",
}
.to_string();
file_results
.entry(file_type.clone())
.or_insert_with(Vec::new)
.push(BusterThoughtPill {
id: Uuid::new_v4().to_string(),
text: open_files_result.message.clone(),
thought_file_type: file_type,
});
}
let thought_pill_containers = file_results
.into_iter()
.map(|(title, thought_pills)| BusterThoughtPillContainer {
title: title.chars().next().unwrap().to_uppercase().to_string() + &title[1..],
thought_pills,
})
.collect::<Vec<_>>();
let buster_thought = BusterThreadMessage::Thought(BusterThought {
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
thought_type: "thought".to_string(),
thought_title: format!("Looked through {} assets", result_count),
thought_secondary_title: format!("{} seconds", duration),
thought_pills: Some(thought_pill_containers),
status: "completed".to_string(),
});
match progress {
MessageProgress::Complete => Ok(buster_thought),
_ => Err(anyhow::anyhow!("Tool open file only supports complete.")),
}
} else {
Err(anyhow::anyhow!("Tool open file requires progress."))
}
}
fn assistant_create_file(
id: Option<String>,
content: String,
progress: Option<MessageProgress>,
) -> Result<BusterThreadMessage> {
if let Some(progress) = progress {
@ -528,7 +584,6 @@ fn assistant_create_file(
fn assistant_modify_file(
id: Option<String>,
content: String,
progress: Option<MessageProgress>,
) -> Result<BusterThreadMessage> {
if let Some(progress) = progress {
@ -550,6 +605,72 @@ fn assistant_modify_file(
}
}
fn tool_create_file(
id: Option<String>,
content: String,
progress: Option<MessageProgress>,
) -> Result<BusterThreadMessage> {
if let Some(progress) = progress {
let duration = 0.1; // File creation is typically very fast
let buster_thought = BusterThreadMessage::Thought(BusterThought {
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
thought_type: "thought".to_string(),
thought_title: "Created new file".to_string(),
thought_secondary_title: format!("{} seconds", duration),
thought_pills: Some(vec![BusterThoughtPillContainer {
title: "Created".to_string(),
thought_pills: vec![BusterThoughtPill {
id: Uuid::new_v4().to_string(),
text: content,
thought_file_type: "file".to_string(),
}],
}]),
status: "completed".to_string(),
});
match progress {
MessageProgress::Complete => Ok(buster_thought),
_ => Err(anyhow::anyhow!("Tool create file only supports complete.")),
}
} else {
Err(anyhow::anyhow!("Tool create file requires progress."))
}
}
fn tool_modify_file(
id: Option<String>,
content: String,
progress: Option<MessageProgress>,
) -> Result<BusterThreadMessage> {
if let Some(progress) = progress {
let duration = 0.1; // File modification is typically very fast
let buster_thought = BusterThreadMessage::Thought(BusterThought {
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
thought_type: "thought".to_string(),
thought_title: "Modified file".to_string(),
thought_secondary_title: format!("{} seconds", duration),
thought_pills: Some(vec![BusterThoughtPillContainer {
title: "Modified".to_string(),
thought_pills: vec![BusterThoughtPill {
id: Uuid::new_v4().to_string(),
text: content,
thought_file_type: "file".to_string(),
}],
}]),
status: "completed".to_string(),
});
match progress {
MessageProgress::Complete => Ok(buster_thought),
_ => Err(anyhow::anyhow!("Tool modify file only supports complete.")),
}
} else {
Err(anyhow::anyhow!("Tool modify file requires progress."))
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -35,11 +35,11 @@ struct OpenFilesParams {
files: Vec<FileRequest>,
}
#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Deserialize)]
pub struct OpenFilesOutput {
message: String,
duration: i64,
results: Vec<FileEnum>,
pub message: String,
pub duration: i64,
pub results: Vec<FileEnum>,
}
pub struct OpenFilesTool;

View File

@ -34,7 +34,7 @@ pub struct SearchDataCatalogOutput {
}
#[derive(Debug, Serialize, Deserialize)]
struct DatasetSearchResult {
pub struct DatasetSearchResult {
pub id: Uuid,
pub name: String,
pub yml_content: String,