transforms for events

This commit is contained in:
dal 2025-02-10 11:53:19 -07:00
parent 233b580e1c
commit d0400b5226
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
3 changed files with 159 additions and 88 deletions

View File

@ -1,3 +1,5 @@
use std::collections::HashMap;
use anyhow::Result;
use serde::Serialize;
use uuid::Uuid;
@ -5,6 +7,7 @@ use uuid::Uuid;
use crate::utils::clients::ai::litellm::{Message, MessageProgress, ToolCall};
use crate::utils::tools::file_tools::search_data_catalog::SearchDataCatalogOutput;
use crate::utils::tools::file_tools::search_files::SearchFilesOutput;
#[derive(Debug, Serialize)]
#[serde(untagged)]
@ -48,7 +51,7 @@ pub struct BusterThoughtPill {
}
pub fn transform_message(message: Message) -> Result<BusterThreadMessage> {
match message {
let buster_message = match message {
Message::Assistant {
id,
content,
@ -63,13 +66,6 @@ pub fn transform_message(message: Message) -> Result<BusterThreadMessage> {
if let (Some(name), Some(tool_calls)) = (name, tool_calls) {
return transform_assistant_tool_message(id, name, tool_calls, progress);
}
Ok(BusterThreadMessage::ChatMessage(BusterChatMessage {
id,
message_type: "text".to_string(),
message: None,
message_chunk: Some(content),
}))
}
Message::Tool {
id,
@ -78,21 +74,12 @@ pub fn transform_message(message: Message) -> Result<BusterThreadMessage> {
name,
progress,
} => {
if let (Some(name), Some(content)) = (name, content) {
if let Some(name) = name {
return transform_tool_message(id, name, content, progress);
}
Ok(BusterThreadMessage::Thought(BusterThought {
id: tool_call_id.clone(),
thought_type: "text".to_string(),
thought_title: "".to_string(),
thought_secondary_title: "".to_string(),
thought_pills: None,
status: "".to_string(),
}))
}
_ => Err(anyhow::anyhow!("Unsupported message type")),
}
};
}
fn transform_text_message(
@ -151,11 +138,11 @@ fn transform_assistant_tool_message(
progress: Option<MessageProgress>,
) -> Result<BusterThreadMessage> {
match name.as_str() {
"data_catalog_search" => assistant_data_catalog_search(id, content, progress),
"stored_values_search" => assistant_stored_values_search(id, content, progress),
"file_search" => assistant_file_search(id, content, progress),
"create_file" => assistant_create_file(id, content, progress),
"modify_file" => assistant_modify_file(id, content, progress),
"data_catalog_search" => assistant_data_catalog_search(id, progress),
"stored_values_search" => assistant_stored_values_search(id, progress),
"file_search" => assistant_file_search(id, progress),
"create_file" => assistant_create_file(id, "".to_string(), progress),
"modify_file" => assistant_modify_file(id, "".to_string(), progress),
}
}
@ -204,6 +191,12 @@ fn tool_data_catalog_search(
}
};
let duration = (data_catalog_result.duration.clone() as f64 / 1000.0 * 10.0).round() / 10.0;
let result_count = data_catalog_result.results.len();
let query_params = data_catalog_result.query_params.clone();
let thought_pill_containters =
match proccess_data_catalog_search_results(data_catalog_result) {
Ok(object) => object,
@ -215,10 +208,6 @@ fn tool_data_catalog_search(
}
};
let duration = (data_catalog_result.duration as f64 / 1000.0 * 10.0).round() / 10.0;
let result_count = data_catalog_result.results.len();
let buster_thought = if result_count > 0 {
BusterThreadMessage::Thought(BusterThought {
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
@ -234,7 +223,17 @@ fn tool_data_catalog_search(
thought_type: "thought".to_string(),
thought_title: "No data catalog items found".to_string(),
thought_secondary_title: format!("{} seconds", duration),
thought_pills: vec![],
thought_pills: Some(vec![BusterThoughtPillContainer {
title: "No results found".to_string(),
thought_pills: query_params
.iter()
.map(|param| BusterThoughtPill {
id: "".to_string(),
text: param.clone(),
thought_file_type: "empty".to_string(),
})
.collect(),
}]),
status: "completed".to_string(),
})
};
@ -257,55 +256,41 @@ fn proccess_data_catalog_search_results(
) -> Result<Vec<BusterThoughtPillContainer>> {
if results.results.is_empty() {
return Ok(vec![BusterThoughtPillContainer {
title: "No datasets found".to_string(),
title: "No results found".to_string(),
thought_pills: vec![],
}]);
}
let mut dataset_results = vec![];
let mut terms_results = vec![];
let mut verified_metrics = vec![];
let mut file_results: HashMap<String, Vec<BusterThoughtPill>> = HashMap::new();
for result in results.results {
match result.name.as_str() {
"dataset" => dataset_results.push(BusterThoughtPill {
file_results
.entry(result.name.clone())
.or_insert_with(Vec::new)
.push(BusterThoughtPill {
id: result.id.to_string(),
text: result.name,
thought_file_type: "dataset".to_string(),
}),
"term" => terms_results.push(BusterThoughtPill {
id: result.id.to_string(),
text: result.name,
thought_file_type: "term".to_string(),
}),
"verified_metric" => verified_metrics.push(BusterThoughtPill {
id: result.id.to_string(),
text: result.name,
thought_file_type: "verified_metric".to_string(),
}),
_ => (),
}
text: result.name.clone(),
thought_file_type: result.name,
});
}
let dataset_count = dataset_results.len();
let term_count = terms_results.len();
let verified_metric_count = verified_metrics.len();
let buster_thought_pill_containers = file_results
.into_iter()
.map(|(title, thought_pills)| {
let count = thought_pills.len();
BusterThoughtPillContainer {
title: format!(
"{count} {} found",
title.chars().next().unwrap().to_uppercase().to_string() + &title[1..]
),
thought_pills,
}
})
.collect();
Ok(vec![
BusterThoughtPillContainer {
title: format!("Datasets ({})", dataset_count),
thought_pills: dataset_results,
},
BusterThoughtPillContainer {
title: format!("Terms ({})", term_count),
thought_pills: terms_results,
},
BusterThoughtPillContainer {
title: format!("Verified Metrics ({})", verified_metric_count),
thought_pills: verified_metrics,
},
])
Ok(buster_thought_pill_containers)
}
fn assistant_stored_values_search(
id: Option<String>,
progress: Option<MessageProgress>,
@ -315,7 +300,7 @@ fn assistant_stored_values_search(
MessageProgress::InProgress => Ok(BusterThreadMessage::Thought(BusterThought {
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
thought_type: "thought".to_string(),
thought_title: "Searching your stored values...".to_string(),
thought_title: "Searching for relevant values...".to_string(),
thought_secondary_title: "".to_string(),
thought_pills: None,
status: "loading".to_string(),
@ -331,6 +316,7 @@ fn assistant_stored_values_search(
}
}
// TODO: Implmentation for stored values search.
fn tool_stored_values_search(
id: Option<String>,
content: String,
@ -366,7 +352,7 @@ fn assistant_file_search(
MessageProgress::InProgress => Ok(BusterThreadMessage::Thought(BusterThought {
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
thought_type: "thought".to_string(),
thought_title: "Searching your files...".to_string(),
thought_title: "Searching across your assets...".to_string(),
thought_secondary_title: "".to_string(),
thought_pills: None,
status: "loading".to_string(),
@ -385,14 +371,99 @@ fn tool_file_search(
content: String,
progress: Option<MessageProgress>,
) -> Result<BusterThreadMessage> {
Ok(BusterThreadMessage::Thought(BusterThought {
if let Some(progress) = progress {
let file_search_result = match serde_json::from_str::<SearchFilesOutput>(&content) {
Ok(result) => result,
Err(e) => {
return Err(anyhow::anyhow!("Failed to parse file search result: {}", e));
}
};
let query_params = file_search_result.query_params.clone();
let duration = (file_search_result.duration.clone() as f64 / 1000.0 * 10.0).round() / 10.0;
let result_count = file_search_result.files.len();
let thought_pill_containers = match process_file_search_results(file_search_result) {
Ok(containers) => containers,
Err(e) => {
return Err(anyhow::anyhow!(
"Failed to process file search results: {}",
e
));
}
};
let buster_thought = if result_count > 0 {
BusterThreadMessage::Thought(BusterThought {
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
thought_type: "thought".to_string(),
thought_title: "".to_string(),
thought_secondary_title: "".to_string(),
thought_pills: None,
thought_title: format!("Found {} assets", result_count),
thought_secondary_title: format!("{} seconds", duration),
thought_pills: Some(thought_pill_containers),
status: "completed".to_string(),
}))
})
} else {
BusterThreadMessage::Thought(BusterThought {
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
thought_type: "thought".to_string(),
thought_title: "No assets found".to_string(),
thought_secondary_title: format!("{} seconds", duration),
thought_pills: Some(vec![BusterThoughtPillContainer {
title: "No assets found".to_string(),
thought_pills: query_params
.iter()
.map(|param| BusterThoughtPill {
id: "".to_string(),
text: param.clone(),
thought_file_type: "empty".to_string(),
})
.collect(),
}]),
status: "completed".to_string(),
})
};
match progress {
MessageProgress::Complete => Ok(buster_thought),
_ => Err(anyhow::anyhow!("Tool file search only supports complete.")),
}
} else {
Err(anyhow::anyhow!("Tool file search requires progress."))
}
}
fn process_file_search_results(
results: SearchFilesOutput,
) -> Result<Vec<BusterThoughtPillContainer>> {
if results.files.is_empty() {
return Ok(vec![BusterThoughtPillContainer {
title: "No assets found".to_string(),
thought_pills: vec![],
}]);
}
let mut file_results: HashMap<String, Vec<BusterThoughtPill>> = HashMap::new();
for result in results.files {
file_results
.entry(result.file_type.clone())
.or_insert_with(Vec::new)
.push(BusterThoughtPill {
id: result.id.to_string(),
text: result.name,
thought_file_type: result.file_type,
});
}
let buster_thought_pill_containers = file_results
.into_iter()
.map(|(title, thought_pills)| BusterThoughtPillContainer {
title: title.chars().next().unwrap().to_uppercase().to_string() + &title[1..],
thought_pills,
})
.collect();
Ok(buster_thought_pill_containers)
}
fn assistant_open_file(

View File

@ -25,7 +25,7 @@ struct SearchDataCatalogParams {
query_params: Vec<String>,
}
#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Deserialize)]
pub struct SearchDataCatalogOutput {
pub message: String,
pub query_params: Vec<String>,

View File

@ -29,20 +29,20 @@ struct SearchFilesParams {
query_params: Vec<String>,
}
#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Deserialize)]
pub struct SearchFilesOutput {
message: String,
query_params: Vec<String>,
duration: i64,
files: Vec<FileSearchResult>,
pub message: String,
pub query_params: Vec<String>,
pub duration: i64,
pub files: Vec<FileSearchResult>,
}
#[derive(Debug, Serialize, Deserialize)]
struct FileSearchResult {
id: Uuid,
name: String,
file_type: String,
updated_at: DateTime<Utc>,
pub struct FileSearchResult {
pub id: Uuid,
pub name: String,
pub file_type: String,
pub updated_at: DateTime<Utc>,
}
const FILE_SEARCH_PROMPT: &str = r#"