transforms for events

This commit is contained in:
dal 2025-02-10 11:53:19 -07:00
parent 233b580e1c
commit d0400b5226
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
3 changed files with 159 additions and 88 deletions

View File

@ -1,3 +1,5 @@
use std::collections::HashMap;
use anyhow::Result; use anyhow::Result;
use serde::Serialize; use serde::Serialize;
use uuid::Uuid; use uuid::Uuid;
@ -5,6 +7,7 @@ use uuid::Uuid;
use crate::utils::clients::ai::litellm::{Message, MessageProgress, ToolCall}; use crate::utils::clients::ai::litellm::{Message, MessageProgress, ToolCall};
use crate::utils::tools::file_tools::search_data_catalog::SearchDataCatalogOutput; use crate::utils::tools::file_tools::search_data_catalog::SearchDataCatalogOutput;
use crate::utils::tools::file_tools::search_files::SearchFilesOutput;
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
#[serde(untagged)] #[serde(untagged)]
@ -48,7 +51,7 @@ pub struct BusterThoughtPill {
} }
pub fn transform_message(message: Message) -> Result<BusterThreadMessage> { pub fn transform_message(message: Message) -> Result<BusterThreadMessage> {
match message { let buster_message = match message {
Message::Assistant { Message::Assistant {
id, id,
content, content,
@ -63,13 +66,6 @@ pub fn transform_message(message: Message) -> Result<BusterThreadMessage> {
if let (Some(name), Some(tool_calls)) = (name, tool_calls) { if let (Some(name), Some(tool_calls)) = (name, tool_calls) {
return transform_assistant_tool_message(id, name, tool_calls, progress); return transform_assistant_tool_message(id, name, tool_calls, progress);
} }
Ok(BusterThreadMessage::ChatMessage(BusterChatMessage {
id,
message_type: "text".to_string(),
message: None,
message_chunk: Some(content),
}))
} }
Message::Tool { Message::Tool {
id, id,
@ -78,21 +74,12 @@ pub fn transform_message(message: Message) -> Result<BusterThreadMessage> {
name, name,
progress, progress,
} => { } => {
if let (Some(name), Some(content)) = (name, content) { if let Some(name) = name {
return transform_tool_message(id, name, content, progress); return transform_tool_message(id, name, content, progress);
} }
Ok(BusterThreadMessage::Thought(BusterThought {
id: tool_call_id.clone(),
thought_type: "text".to_string(),
thought_title: "".to_string(),
thought_secondary_title: "".to_string(),
thought_pills: None,
status: "".to_string(),
}))
} }
_ => Err(anyhow::anyhow!("Unsupported message type")), _ => Err(anyhow::anyhow!("Unsupported message type")),
} };
} }
fn transform_text_message( fn transform_text_message(
@ -151,11 +138,11 @@ fn transform_assistant_tool_message(
progress: Option<MessageProgress>, progress: Option<MessageProgress>,
) -> Result<BusterThreadMessage> { ) -> Result<BusterThreadMessage> {
match name.as_str() { match name.as_str() {
"data_catalog_search" => assistant_data_catalog_search(id, content, progress), "data_catalog_search" => assistant_data_catalog_search(id, progress),
"stored_values_search" => assistant_stored_values_search(id, content, progress), "stored_values_search" => assistant_stored_values_search(id, progress),
"file_search" => assistant_file_search(id, content, progress), "file_search" => assistant_file_search(id, progress),
"create_file" => assistant_create_file(id, content, progress), "create_file" => assistant_create_file(id, "".to_string(), progress),
"modify_file" => assistant_modify_file(id, content, progress), "modify_file" => assistant_modify_file(id, "".to_string(), progress),
} }
} }
@ -204,6 +191,12 @@ fn tool_data_catalog_search(
} }
}; };
let duration = (data_catalog_result.duration.clone() as f64 / 1000.0 * 10.0).round() / 10.0;
let result_count = data_catalog_result.results.len();
let query_params = data_catalog_result.query_params.clone();
let thought_pill_containters = let thought_pill_containters =
match proccess_data_catalog_search_results(data_catalog_result) { match proccess_data_catalog_search_results(data_catalog_result) {
Ok(object) => object, Ok(object) => object,
@ -215,10 +208,6 @@ fn tool_data_catalog_search(
} }
}; };
let duration = (data_catalog_result.duration as f64 / 1000.0 * 10.0).round() / 10.0;
let result_count = data_catalog_result.results.len();
let buster_thought = if result_count > 0 { let buster_thought = if result_count > 0 {
BusterThreadMessage::Thought(BusterThought { BusterThreadMessage::Thought(BusterThought {
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()), id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
@ -234,7 +223,17 @@ fn tool_data_catalog_search(
thought_type: "thought".to_string(), thought_type: "thought".to_string(),
thought_title: "No data catalog items found".to_string(), thought_title: "No data catalog items found".to_string(),
thought_secondary_title: format!("{} seconds", duration), thought_secondary_title: format!("{} seconds", duration),
thought_pills: vec![], thought_pills: Some(vec![BusterThoughtPillContainer {
title: "No results found".to_string(),
thought_pills: query_params
.iter()
.map(|param| BusterThoughtPill {
id: "".to_string(),
text: param.clone(),
thought_file_type: "empty".to_string(),
})
.collect(),
}]),
status: "completed".to_string(), status: "completed".to_string(),
}) })
}; };
@ -257,55 +256,41 @@ fn proccess_data_catalog_search_results(
) -> Result<Vec<BusterThoughtPillContainer>> { ) -> Result<Vec<BusterThoughtPillContainer>> {
if results.results.is_empty() { if results.results.is_empty() {
return Ok(vec![BusterThoughtPillContainer { return Ok(vec![BusterThoughtPillContainer {
title: "No datasets found".to_string(), title: "No results found".to_string(),
thought_pills: vec![], thought_pills: vec![],
}]); }]);
} }
let mut dataset_results = vec![]; let mut file_results: HashMap<String, Vec<BusterThoughtPill>> = HashMap::new();
let mut terms_results = vec![];
let mut verified_metrics = vec![];
for result in results.results { for result in results.results {
match result.name.as_str() { file_results
"dataset" => dataset_results.push(BusterThoughtPill { .entry(result.name.clone())
.or_insert_with(Vec::new)
.push(BusterThoughtPill {
id: result.id.to_string(), id: result.id.to_string(),
text: result.name, text: result.name.clone(),
thought_file_type: "dataset".to_string(), thought_file_type: result.name,
}), });
"term" => terms_results.push(BusterThoughtPill {
id: result.id.to_string(),
text: result.name,
thought_file_type: "term".to_string(),
}),
"verified_metric" => verified_metrics.push(BusterThoughtPill {
id: result.id.to_string(),
text: result.name,
thought_file_type: "verified_metric".to_string(),
}),
_ => (),
}
} }
let dataset_count = dataset_results.len(); let buster_thought_pill_containers = file_results
let term_count = terms_results.len(); .into_iter()
let verified_metric_count = verified_metrics.len(); .map(|(title, thought_pills)| {
let count = thought_pills.len();
BusterThoughtPillContainer {
title: format!(
"{count} {} found",
title.chars().next().unwrap().to_uppercase().to_string() + &title[1..]
),
thought_pills,
}
})
.collect();
Ok(vec![ Ok(buster_thought_pill_containers)
BusterThoughtPillContainer {
title: format!("Datasets ({})", dataset_count),
thought_pills: dataset_results,
},
BusterThoughtPillContainer {
title: format!("Terms ({})", term_count),
thought_pills: terms_results,
},
BusterThoughtPillContainer {
title: format!("Verified Metrics ({})", verified_metric_count),
thought_pills: verified_metrics,
},
])
} }
fn assistant_stored_values_search( fn assistant_stored_values_search(
id: Option<String>, id: Option<String>,
progress: Option<MessageProgress>, progress: Option<MessageProgress>,
@ -315,7 +300,7 @@ fn assistant_stored_values_search(
MessageProgress::InProgress => Ok(BusterThreadMessage::Thought(BusterThought { MessageProgress::InProgress => Ok(BusterThreadMessage::Thought(BusterThought {
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()), id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
thought_type: "thought".to_string(), thought_type: "thought".to_string(),
thought_title: "Searching your stored values...".to_string(), thought_title: "Searching for relevant values...".to_string(),
thought_secondary_title: "".to_string(), thought_secondary_title: "".to_string(),
thought_pills: None, thought_pills: None,
status: "loading".to_string(), status: "loading".to_string(),
@ -331,6 +316,7 @@ fn assistant_stored_values_search(
} }
} }
// TODO: Implmentation for stored values search.
fn tool_stored_values_search( fn tool_stored_values_search(
id: Option<String>, id: Option<String>,
content: String, content: String,
@ -366,7 +352,7 @@ fn assistant_file_search(
MessageProgress::InProgress => Ok(BusterThreadMessage::Thought(BusterThought { MessageProgress::InProgress => Ok(BusterThreadMessage::Thought(BusterThought {
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()), id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
thought_type: "thought".to_string(), thought_type: "thought".to_string(),
thought_title: "Searching your files...".to_string(), thought_title: "Searching across your assets...".to_string(),
thought_secondary_title: "".to_string(), thought_secondary_title: "".to_string(),
thought_pills: None, thought_pills: None,
status: "loading".to_string(), status: "loading".to_string(),
@ -385,14 +371,99 @@ fn tool_file_search(
content: String, content: String,
progress: Option<MessageProgress>, progress: Option<MessageProgress>,
) -> Result<BusterThreadMessage> { ) -> Result<BusterThreadMessage> {
Ok(BusterThreadMessage::Thought(BusterThought { if let Some(progress) = progress {
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()), let file_search_result = match serde_json::from_str::<SearchFilesOutput>(&content) {
thought_type: "thought".to_string(), Ok(result) => result,
thought_title: "".to_string(), Err(e) => {
thought_secondary_title: "".to_string(), return Err(anyhow::anyhow!("Failed to parse file search result: {}", e));
thought_pills: None, }
status: "completed".to_string(), };
}))
let query_params = file_search_result.query_params.clone();
let duration = (file_search_result.duration.clone() as f64 / 1000.0 * 10.0).round() / 10.0;
let result_count = file_search_result.files.len();
let thought_pill_containers = match process_file_search_results(file_search_result) {
Ok(containers) => containers,
Err(e) => {
return Err(anyhow::anyhow!(
"Failed to process file search results: {}",
e
));
}
};
let buster_thought = if result_count > 0 {
BusterThreadMessage::Thought(BusterThought {
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
thought_type: "thought".to_string(),
thought_title: format!("Found {} assets", result_count),
thought_secondary_title: format!("{} seconds", duration),
thought_pills: Some(thought_pill_containers),
status: "completed".to_string(),
})
} else {
BusterThreadMessage::Thought(BusterThought {
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
thought_type: "thought".to_string(),
thought_title: "No assets found".to_string(),
thought_secondary_title: format!("{} seconds", duration),
thought_pills: Some(vec![BusterThoughtPillContainer {
title: "No assets found".to_string(),
thought_pills: query_params
.iter()
.map(|param| BusterThoughtPill {
id: "".to_string(),
text: param.clone(),
thought_file_type: "empty".to_string(),
})
.collect(),
}]),
status: "completed".to_string(),
})
};
match progress {
MessageProgress::Complete => Ok(buster_thought),
_ => Err(anyhow::anyhow!("Tool file search only supports complete.")),
}
} else {
Err(anyhow::anyhow!("Tool file search requires progress."))
}
}
fn process_file_search_results(
results: SearchFilesOutput,
) -> Result<Vec<BusterThoughtPillContainer>> {
if results.files.is_empty() {
return Ok(vec![BusterThoughtPillContainer {
title: "No assets found".to_string(),
thought_pills: vec![],
}]);
}
let mut file_results: HashMap<String, Vec<BusterThoughtPill>> = HashMap::new();
for result in results.files {
file_results
.entry(result.file_type.clone())
.or_insert_with(Vec::new)
.push(BusterThoughtPill {
id: result.id.to_string(),
text: result.name,
thought_file_type: result.file_type,
});
}
let buster_thought_pill_containers = file_results
.into_iter()
.map(|(title, thought_pills)| BusterThoughtPillContainer {
title: title.chars().next().unwrap().to_uppercase().to_string() + &title[1..],
thought_pills,
})
.collect();
Ok(buster_thought_pill_containers)
} }
fn assistant_open_file( fn assistant_open_file(

View File

@ -25,7 +25,7 @@ struct SearchDataCatalogParams {
query_params: Vec<String>, query_params: Vec<String>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct SearchDataCatalogOutput { pub struct SearchDataCatalogOutput {
pub message: String, pub message: String,
pub query_params: Vec<String>, pub query_params: Vec<String>,

View File

@ -29,20 +29,20 @@ struct SearchFilesParams {
query_params: Vec<String>, query_params: Vec<String>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct SearchFilesOutput { pub struct SearchFilesOutput {
message: String, pub message: String,
query_params: Vec<String>, pub query_params: Vec<String>,
duration: i64, pub duration: i64,
files: Vec<FileSearchResult>, pub files: Vec<FileSearchResult>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
struct FileSearchResult { pub struct FileSearchResult {
id: Uuid, pub id: Uuid,
name: String, pub name: String,
file_type: String, pub file_type: String,
updated_at: DateTime<Utc>, pub updated_at: DateTime<Utc>,
} }
const FILE_SEARCH_PROMPT: &str = r#" const FILE_SEARCH_PROMPT: &str = r#"