mirror of https://github.com/buster-so/buster.git
ok everything sending back except create and modify
This commit is contained in:
parent
d0400b5226
commit
8c8372b50e
|
@ -6,6 +6,8 @@ use uuid::Uuid;
|
|||
|
||||
use crate::utils::clients::ai::litellm::{Message, MessageProgress, ToolCall};
|
||||
|
||||
use crate::utils::tools::file_tools::file_types::file::FileEnum;
|
||||
use crate::utils::tools::file_tools::open_files::OpenFilesOutput;
|
||||
use crate::utils::tools::file_tools::search_data_catalog::SearchDataCatalogOutput;
|
||||
use crate::utils::tools::file_tools::search_files::SearchFilesOutput;
|
||||
|
||||
|
@ -51,7 +53,7 @@ pub struct BusterThoughtPill {
|
|||
}
|
||||
|
||||
pub fn transform_message(message: Message) -> Result<BusterThreadMessage> {
|
||||
let buster_message = match message {
|
||||
match message {
|
||||
Message::Assistant {
|
||||
id,
|
||||
content,
|
||||
|
@ -66,6 +68,8 @@ pub fn transform_message(message: Message) -> Result<BusterThreadMessage> {
|
|||
if let (Some(name), Some(tool_calls)) = (name, tool_calls) {
|
||||
return transform_assistant_tool_message(id, name, tool_calls, progress);
|
||||
}
|
||||
|
||||
Err(anyhow::anyhow!("Assistant message missing required fields"))
|
||||
}
|
||||
Message::Tool {
|
||||
id,
|
||||
|
@ -77,9 +81,11 @@ pub fn transform_message(message: Message) -> Result<BusterThreadMessage> {
|
|||
if let Some(name) = name {
|
||||
return transform_tool_message(id, name, content, progress);
|
||||
}
|
||||
|
||||
Err(anyhow::anyhow!("Tool message missing name field"))
|
||||
}
|
||||
_ => Err(anyhow::anyhow!("Unsupported message type")),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn transform_text_message(
|
||||
|
@ -122,11 +128,12 @@ fn transform_tool_message(
|
|||
progress: Option<MessageProgress>,
|
||||
) -> Result<BusterThreadMessage> {
|
||||
match name.as_str() {
|
||||
"data_catalog_search" => assistant_data_catalog_search(id, progress),
|
||||
"stored_values_search" => assistant_stored_values_search(id, progress),
|
||||
"file_search" => assistant_file_search(id, progress),
|
||||
"create_file" => assistant_create_file(id, content, progress),
|
||||
"modify_file" => assistant_modify_file(id, content, progress),
|
||||
"data_catalog_search" => tool_data_catalog_search(id, content, progress),
|
||||
"stored_values_search" => tool_stored_values_search(id, content, progress),
|
||||
"file_search" => tool_file_search(id, content, progress),
|
||||
"create_file" => tool_create_file(id, content, progress),
|
||||
"modify_file" => tool_modify_file(id, content, progress),
|
||||
"open_files" => tool_open_files(id, content, progress),
|
||||
_ => Err(anyhow::anyhow!("Unsupported tool name")),
|
||||
}
|
||||
}
|
||||
|
@ -141,8 +148,10 @@ fn transform_assistant_tool_message(
|
|||
"data_catalog_search" => assistant_data_catalog_search(id, progress),
|
||||
"stored_values_search" => assistant_stored_values_search(id, progress),
|
||||
"file_search" => assistant_file_search(id, progress),
|
||||
"create_file" => assistant_create_file(id, "".to_string(), progress),
|
||||
"modify_file" => assistant_modify_file(id, "".to_string(), progress),
|
||||
"create_file" => assistant_create_file(id, progress),
|
||||
"modify_file" => assistant_modify_file(id, progress),
|
||||
"open_files" => assistant_open_files(id, progress),
|
||||
_ => Err(anyhow::anyhow!("Unsupported tool name")),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -466,7 +475,7 @@ fn process_file_search_results(
|
|||
Ok(buster_thought_pill_containers)
|
||||
}
|
||||
|
||||
fn assistant_open_file(
|
||||
fn assistant_open_files(
|
||||
id: Option<String>,
|
||||
progress: Option<MessageProgress>,
|
||||
) -> Result<BusterThreadMessage> {
|
||||
|
@ -489,22 +498,69 @@ fn assistant_open_file(
|
|||
}
|
||||
}
|
||||
|
||||
fn tool_open_file(
|
||||
fn tool_open_files(
|
||||
id: Option<String>,
|
||||
content: String,
|
||||
progress: Option<MessageProgress>,
|
||||
) -> Result<BusterThreadMessage> {
|
||||
Ok(BusterThreadMessage::ChatMessage(BusterChatMessage {
|
||||
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
|
||||
message_type: "text".to_string(),
|
||||
message: Some(content),
|
||||
message_chunk: None,
|
||||
}))
|
||||
if let Some(progress) = progress {
|
||||
let open_files_result = match serde_json::from_str::<OpenFilesOutput>(&content) {
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
return Err(anyhow::anyhow!("Failed to parse open files result: {}", e));
|
||||
}
|
||||
};
|
||||
|
||||
let duration = (open_files_result.duration as f64 / 1000.0 * 10.0).round() / 10.0;
|
||||
let result_count = open_files_result.results.len();
|
||||
|
||||
let mut file_results: HashMap<String, Vec<BusterThoughtPill>> = HashMap::new();
|
||||
|
||||
for result in open_files_result.results {
|
||||
let file_type = match result {
|
||||
FileEnum::Dashboard(_) => "dashboard",
|
||||
FileEnum::Metric(_) => "metric",
|
||||
}
|
||||
.to_string();
|
||||
|
||||
file_results
|
||||
.entry(file_type.clone())
|
||||
.or_insert_with(Vec::new)
|
||||
.push(BusterThoughtPill {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
text: open_files_result.message.clone(),
|
||||
thought_file_type: file_type,
|
||||
});
|
||||
}
|
||||
|
||||
let thought_pill_containers = file_results
|
||||
.into_iter()
|
||||
.map(|(title, thought_pills)| BusterThoughtPillContainer {
|
||||
title: title.chars().next().unwrap().to_uppercase().to_string() + &title[1..],
|
||||
thought_pills,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let buster_thought = BusterThreadMessage::Thought(BusterThought {
|
||||
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
|
||||
thought_type: "thought".to_string(),
|
||||
thought_title: format!("Looked through {} assets", result_count),
|
||||
thought_secondary_title: format!("{} seconds", duration),
|
||||
thought_pills: Some(thought_pill_containers),
|
||||
status: "completed".to_string(),
|
||||
});
|
||||
|
||||
match progress {
|
||||
MessageProgress::Complete => Ok(buster_thought),
|
||||
_ => Err(anyhow::anyhow!("Tool open file only supports complete.")),
|
||||
}
|
||||
} else {
|
||||
Err(anyhow::anyhow!("Tool open file requires progress."))
|
||||
}
|
||||
}
|
||||
|
||||
fn assistant_create_file(
|
||||
id: Option<String>,
|
||||
content: String,
|
||||
progress: Option<MessageProgress>,
|
||||
) -> Result<BusterThreadMessage> {
|
||||
if let Some(progress) = progress {
|
||||
|
@ -528,7 +584,6 @@ fn assistant_create_file(
|
|||
|
||||
fn assistant_modify_file(
|
||||
id: Option<String>,
|
||||
content: String,
|
||||
progress: Option<MessageProgress>,
|
||||
) -> Result<BusterThreadMessage> {
|
||||
if let Some(progress) = progress {
|
||||
|
@ -550,6 +605,72 @@ fn assistant_modify_file(
|
|||
}
|
||||
}
|
||||
|
||||
fn tool_create_file(
|
||||
id: Option<String>,
|
||||
content: String,
|
||||
progress: Option<MessageProgress>,
|
||||
) -> Result<BusterThreadMessage> {
|
||||
if let Some(progress) = progress {
|
||||
let duration = 0.1; // File creation is typically very fast
|
||||
|
||||
let buster_thought = BusterThreadMessage::Thought(BusterThought {
|
||||
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
|
||||
thought_type: "thought".to_string(),
|
||||
thought_title: "Created new file".to_string(),
|
||||
thought_secondary_title: format!("{} seconds", duration),
|
||||
thought_pills: Some(vec![BusterThoughtPillContainer {
|
||||
title: "Created".to_string(),
|
||||
thought_pills: vec![BusterThoughtPill {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
text: content,
|
||||
thought_file_type: "file".to_string(),
|
||||
}],
|
||||
}]),
|
||||
status: "completed".to_string(),
|
||||
});
|
||||
|
||||
match progress {
|
||||
MessageProgress::Complete => Ok(buster_thought),
|
||||
_ => Err(anyhow::anyhow!("Tool create file only supports complete.")),
|
||||
}
|
||||
} else {
|
||||
Err(anyhow::anyhow!("Tool create file requires progress."))
|
||||
}
|
||||
}
|
||||
|
||||
fn tool_modify_file(
|
||||
id: Option<String>,
|
||||
content: String,
|
||||
progress: Option<MessageProgress>,
|
||||
) -> Result<BusterThreadMessage> {
|
||||
if let Some(progress) = progress {
|
||||
let duration = 0.1; // File modification is typically very fast
|
||||
|
||||
let buster_thought = BusterThreadMessage::Thought(BusterThought {
|
||||
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
|
||||
thought_type: "thought".to_string(),
|
||||
thought_title: "Modified file".to_string(),
|
||||
thought_secondary_title: format!("{} seconds", duration),
|
||||
thought_pills: Some(vec![BusterThoughtPillContainer {
|
||||
title: "Modified".to_string(),
|
||||
thought_pills: vec![BusterThoughtPill {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
text: content,
|
||||
thought_file_type: "file".to_string(),
|
||||
}],
|
||||
}]),
|
||||
status: "completed".to_string(),
|
||||
});
|
||||
|
||||
match progress {
|
||||
MessageProgress::Complete => Ok(buster_thought),
|
||||
_ => Err(anyhow::anyhow!("Tool modify file only supports complete.")),
|
||||
}
|
||||
} else {
|
||||
Err(anyhow::anyhow!("Tool modify file requires progress."))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
@ -35,11 +35,11 @@ struct OpenFilesParams {
|
|||
files: Vec<FileRequest>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct OpenFilesOutput {
|
||||
message: String,
|
||||
duration: i64,
|
||||
results: Vec<FileEnum>,
|
||||
pub message: String,
|
||||
pub duration: i64,
|
||||
pub results: Vec<FileEnum>,
|
||||
}
|
||||
|
||||
pub struct OpenFilesTool;
|
||||
|
|
|
@ -34,7 +34,7 @@ pub struct SearchDataCatalogOutput {
|
|||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct DatasetSearchResult {
|
||||
pub struct DatasetSearchResult {
|
||||
pub id: Uuid,
|
||||
pub name: String,
|
||||
pub yml_content: String,
|
||||
|
|
Loading…
Reference in New Issue