mirror of https://github.com/buster-so/buster.git
transforms for events
This commit is contained in:
parent
233b580e1c
commit
d0400b5226
|
@ -1,3 +1,5 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use anyhow::Result;
|
||||
use serde::Serialize;
|
||||
use uuid::Uuid;
|
||||
|
@ -5,6 +7,7 @@ use uuid::Uuid;
|
|||
use crate::utils::clients::ai::litellm::{Message, MessageProgress, ToolCall};
|
||||
|
||||
use crate::utils::tools::file_tools::search_data_catalog::SearchDataCatalogOutput;
|
||||
use crate::utils::tools::file_tools::search_files::SearchFilesOutput;
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(untagged)]
|
||||
|
@ -48,7 +51,7 @@ pub struct BusterThoughtPill {
|
|||
}
|
||||
|
||||
pub fn transform_message(message: Message) -> Result<BusterThreadMessage> {
|
||||
match message {
|
||||
let buster_message = match message {
|
||||
Message::Assistant {
|
||||
id,
|
||||
content,
|
||||
|
@ -63,13 +66,6 @@ pub fn transform_message(message: Message) -> Result<BusterThreadMessage> {
|
|||
if let (Some(name), Some(tool_calls)) = (name, tool_calls) {
|
||||
return transform_assistant_tool_message(id, name, tool_calls, progress);
|
||||
}
|
||||
|
||||
Ok(BusterThreadMessage::ChatMessage(BusterChatMessage {
|
||||
id,
|
||||
message_type: "text".to_string(),
|
||||
message: None,
|
||||
message_chunk: Some(content),
|
||||
}))
|
||||
}
|
||||
Message::Tool {
|
||||
id,
|
||||
|
@ -78,21 +74,12 @@ pub fn transform_message(message: Message) -> Result<BusterThreadMessage> {
|
|||
name,
|
||||
progress,
|
||||
} => {
|
||||
if let (Some(name), Some(content)) = (name, content) {
|
||||
if let Some(name) = name {
|
||||
return transform_tool_message(id, name, content, progress);
|
||||
}
|
||||
|
||||
Ok(BusterThreadMessage::Thought(BusterThought {
|
||||
id: tool_call_id.clone(),
|
||||
thought_type: "text".to_string(),
|
||||
thought_title: "".to_string(),
|
||||
thought_secondary_title: "".to_string(),
|
||||
thought_pills: None,
|
||||
status: "".to_string(),
|
||||
}))
|
||||
}
|
||||
_ => Err(anyhow::anyhow!("Unsupported message type")),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
fn transform_text_message(
|
||||
|
@ -151,11 +138,11 @@ fn transform_assistant_tool_message(
|
|||
progress: Option<MessageProgress>,
|
||||
) -> Result<BusterThreadMessage> {
|
||||
match name.as_str() {
|
||||
"data_catalog_search" => assistant_data_catalog_search(id, content, progress),
|
||||
"stored_values_search" => assistant_stored_values_search(id, content, progress),
|
||||
"file_search" => assistant_file_search(id, content, progress),
|
||||
"create_file" => assistant_create_file(id, content, progress),
|
||||
"modify_file" => assistant_modify_file(id, content, progress),
|
||||
"data_catalog_search" => assistant_data_catalog_search(id, progress),
|
||||
"stored_values_search" => assistant_stored_values_search(id, progress),
|
||||
"file_search" => assistant_file_search(id, progress),
|
||||
"create_file" => assistant_create_file(id, "".to_string(), progress),
|
||||
"modify_file" => assistant_modify_file(id, "".to_string(), progress),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -204,6 +191,12 @@ fn tool_data_catalog_search(
|
|||
}
|
||||
};
|
||||
|
||||
let duration = (data_catalog_result.duration.clone() as f64 / 1000.0 * 10.0).round() / 10.0;
|
||||
|
||||
let result_count = data_catalog_result.results.len();
|
||||
|
||||
let query_params = data_catalog_result.query_params.clone();
|
||||
|
||||
let thought_pill_containters =
|
||||
match proccess_data_catalog_search_results(data_catalog_result) {
|
||||
Ok(object) => object,
|
||||
|
@ -215,10 +208,6 @@ fn tool_data_catalog_search(
|
|||
}
|
||||
};
|
||||
|
||||
let duration = (data_catalog_result.duration as f64 / 1000.0 * 10.0).round() / 10.0;
|
||||
|
||||
let result_count = data_catalog_result.results.len();
|
||||
|
||||
let buster_thought = if result_count > 0 {
|
||||
BusterThreadMessage::Thought(BusterThought {
|
||||
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
|
||||
|
@ -234,7 +223,17 @@ fn tool_data_catalog_search(
|
|||
thought_type: "thought".to_string(),
|
||||
thought_title: "No data catalog items found".to_string(),
|
||||
thought_secondary_title: format!("{} seconds", duration),
|
||||
thought_pills: vec![],
|
||||
thought_pills: Some(vec![BusterThoughtPillContainer {
|
||||
title: "No results found".to_string(),
|
||||
thought_pills: query_params
|
||||
.iter()
|
||||
.map(|param| BusterThoughtPill {
|
||||
id: "".to_string(),
|
||||
text: param.clone(),
|
||||
thought_file_type: "empty".to_string(),
|
||||
})
|
||||
.collect(),
|
||||
}]),
|
||||
status: "completed".to_string(),
|
||||
})
|
||||
};
|
||||
|
@ -257,55 +256,41 @@ fn proccess_data_catalog_search_results(
|
|||
) -> Result<Vec<BusterThoughtPillContainer>> {
|
||||
if results.results.is_empty() {
|
||||
return Ok(vec![BusterThoughtPillContainer {
|
||||
title: "No datasets found".to_string(),
|
||||
title: "No results found".to_string(),
|
||||
thought_pills: vec![],
|
||||
}]);
|
||||
}
|
||||
|
||||
let mut dataset_results = vec![];
|
||||
let mut terms_results = vec![];
|
||||
let mut verified_metrics = vec![];
|
||||
let mut file_results: HashMap<String, Vec<BusterThoughtPill>> = HashMap::new();
|
||||
|
||||
for result in results.results {
|
||||
match result.name.as_str() {
|
||||
"dataset" => dataset_results.push(BusterThoughtPill {
|
||||
file_results
|
||||
.entry(result.name.clone())
|
||||
.or_insert_with(Vec::new)
|
||||
.push(BusterThoughtPill {
|
||||
id: result.id.to_string(),
|
||||
text: result.name,
|
||||
thought_file_type: "dataset".to_string(),
|
||||
}),
|
||||
"term" => terms_results.push(BusterThoughtPill {
|
||||
id: result.id.to_string(),
|
||||
text: result.name,
|
||||
thought_file_type: "term".to_string(),
|
||||
}),
|
||||
"verified_metric" => verified_metrics.push(BusterThoughtPill {
|
||||
id: result.id.to_string(),
|
||||
text: result.name,
|
||||
thought_file_type: "verified_metric".to_string(),
|
||||
}),
|
||||
_ => (),
|
||||
}
|
||||
text: result.name.clone(),
|
||||
thought_file_type: result.name,
|
||||
});
|
||||
}
|
||||
|
||||
let dataset_count = dataset_results.len();
|
||||
let term_count = terms_results.len();
|
||||
let verified_metric_count = verified_metrics.len();
|
||||
let buster_thought_pill_containers = file_results
|
||||
.into_iter()
|
||||
.map(|(title, thought_pills)| {
|
||||
let count = thought_pills.len();
|
||||
BusterThoughtPillContainer {
|
||||
title: format!(
|
||||
"{count} {} found",
|
||||
title.chars().next().unwrap().to_uppercase().to_string() + &title[1..]
|
||||
),
|
||||
thought_pills,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(vec![
|
||||
BusterThoughtPillContainer {
|
||||
title: format!("Datasets ({})", dataset_count),
|
||||
thought_pills: dataset_results,
|
||||
},
|
||||
BusterThoughtPillContainer {
|
||||
title: format!("Terms ({})", term_count),
|
||||
thought_pills: terms_results,
|
||||
},
|
||||
BusterThoughtPillContainer {
|
||||
title: format!("Verified Metrics ({})", verified_metric_count),
|
||||
thought_pills: verified_metrics,
|
||||
},
|
||||
])
|
||||
Ok(buster_thought_pill_containers)
|
||||
}
|
||||
|
||||
fn assistant_stored_values_search(
|
||||
id: Option<String>,
|
||||
progress: Option<MessageProgress>,
|
||||
|
@ -315,7 +300,7 @@ fn assistant_stored_values_search(
|
|||
MessageProgress::InProgress => Ok(BusterThreadMessage::Thought(BusterThought {
|
||||
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
|
||||
thought_type: "thought".to_string(),
|
||||
thought_title: "Searching your stored values...".to_string(),
|
||||
thought_title: "Searching for relevant values...".to_string(),
|
||||
thought_secondary_title: "".to_string(),
|
||||
thought_pills: None,
|
||||
status: "loading".to_string(),
|
||||
|
@ -331,6 +316,7 @@ fn assistant_stored_values_search(
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: Implmentation for stored values search.
|
||||
fn tool_stored_values_search(
|
||||
id: Option<String>,
|
||||
content: String,
|
||||
|
@ -366,7 +352,7 @@ fn assistant_file_search(
|
|||
MessageProgress::InProgress => Ok(BusterThreadMessage::Thought(BusterThought {
|
||||
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
|
||||
thought_type: "thought".to_string(),
|
||||
thought_title: "Searching your files...".to_string(),
|
||||
thought_title: "Searching across your assets...".to_string(),
|
||||
thought_secondary_title: "".to_string(),
|
||||
thought_pills: None,
|
||||
status: "loading".to_string(),
|
||||
|
@ -385,14 +371,99 @@ fn tool_file_search(
|
|||
content: String,
|
||||
progress: Option<MessageProgress>,
|
||||
) -> Result<BusterThreadMessage> {
|
||||
Ok(BusterThreadMessage::Thought(BusterThought {
|
||||
if let Some(progress) = progress {
|
||||
let file_search_result = match serde_json::from_str::<SearchFilesOutput>(&content) {
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
return Err(anyhow::anyhow!("Failed to parse file search result: {}", e));
|
||||
}
|
||||
};
|
||||
|
||||
let query_params = file_search_result.query_params.clone();
|
||||
let duration = (file_search_result.duration.clone() as f64 / 1000.0 * 10.0).round() / 10.0;
|
||||
let result_count = file_search_result.files.len();
|
||||
|
||||
let thought_pill_containers = match process_file_search_results(file_search_result) {
|
||||
Ok(containers) => containers,
|
||||
Err(e) => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Failed to process file search results: {}",
|
||||
e
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let buster_thought = if result_count > 0 {
|
||||
BusterThreadMessage::Thought(BusterThought {
|
||||
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
|
||||
thought_type: "thought".to_string(),
|
||||
thought_title: "".to_string(),
|
||||
thought_secondary_title: "".to_string(),
|
||||
thought_pills: None,
|
||||
thought_title: format!("Found {} assets", result_count),
|
||||
thought_secondary_title: format!("{} seconds", duration),
|
||||
thought_pills: Some(thought_pill_containers),
|
||||
status: "completed".to_string(),
|
||||
}))
|
||||
})
|
||||
} else {
|
||||
BusterThreadMessage::Thought(BusterThought {
|
||||
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
|
||||
thought_type: "thought".to_string(),
|
||||
thought_title: "No assets found".to_string(),
|
||||
thought_secondary_title: format!("{} seconds", duration),
|
||||
thought_pills: Some(vec![BusterThoughtPillContainer {
|
||||
title: "No assets found".to_string(),
|
||||
thought_pills: query_params
|
||||
.iter()
|
||||
.map(|param| BusterThoughtPill {
|
||||
id: "".to_string(),
|
||||
text: param.clone(),
|
||||
thought_file_type: "empty".to_string(),
|
||||
})
|
||||
.collect(),
|
||||
}]),
|
||||
status: "completed".to_string(),
|
||||
})
|
||||
};
|
||||
|
||||
match progress {
|
||||
MessageProgress::Complete => Ok(buster_thought),
|
||||
_ => Err(anyhow::anyhow!("Tool file search only supports complete.")),
|
||||
}
|
||||
} else {
|
||||
Err(anyhow::anyhow!("Tool file search requires progress."))
|
||||
}
|
||||
}
|
||||
|
||||
fn process_file_search_results(
|
||||
results: SearchFilesOutput,
|
||||
) -> Result<Vec<BusterThoughtPillContainer>> {
|
||||
if results.files.is_empty() {
|
||||
return Ok(vec![BusterThoughtPillContainer {
|
||||
title: "No assets found".to_string(),
|
||||
thought_pills: vec![],
|
||||
}]);
|
||||
}
|
||||
|
||||
let mut file_results: HashMap<String, Vec<BusterThoughtPill>> = HashMap::new();
|
||||
|
||||
for result in results.files {
|
||||
file_results
|
||||
.entry(result.file_type.clone())
|
||||
.or_insert_with(Vec::new)
|
||||
.push(BusterThoughtPill {
|
||||
id: result.id.to_string(),
|
||||
text: result.name,
|
||||
thought_file_type: result.file_type,
|
||||
});
|
||||
}
|
||||
|
||||
let buster_thought_pill_containers = file_results
|
||||
.into_iter()
|
||||
.map(|(title, thought_pills)| BusterThoughtPillContainer {
|
||||
title: title.chars().next().unwrap().to_uppercase().to_string() + &title[1..],
|
||||
thought_pills,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(buster_thought_pill_containers)
|
||||
}
|
||||
|
||||
fn assistant_open_file(
|
||||
|
|
|
@ -25,7 +25,7 @@ struct SearchDataCatalogParams {
|
|||
query_params: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct SearchDataCatalogOutput {
|
||||
pub message: String,
|
||||
pub query_params: Vec<String>,
|
||||
|
|
|
@ -29,20 +29,20 @@ struct SearchFilesParams {
|
|||
query_params: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct SearchFilesOutput {
|
||||
message: String,
|
||||
query_params: Vec<String>,
|
||||
duration: i64,
|
||||
files: Vec<FileSearchResult>,
|
||||
pub message: String,
|
||||
pub query_params: Vec<String>,
|
||||
pub duration: i64,
|
||||
pub files: Vec<FileSearchResult>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct FileSearchResult {
|
||||
id: Uuid,
|
||||
name: String,
|
||||
file_type: String,
|
||||
updated_at: DateTime<Utc>,
|
||||
pub struct FileSearchResult {
|
||||
pub id: Uuid,
|
||||
pub name: String,
|
||||
pub file_type: String,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
const FILE_SEARCH_PROMPT: &str = r#"
|
||||
|
|
Loading…
Reference in New Issue