This commit is contained in:
dal 2025-03-18 10:57:04 -06:00
parent a02c05cb58
commit c1ca69966c
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
3 changed files with 183 additions and 46 deletions

View File

@ -203,6 +203,7 @@ impl ToolExecutor for ModifyDashboardFilesTool {
"properties": {
"files": {
"type": "array",
"description": get_modify_dashboards_yml_description().await,
"items": {
"type": "object",
"required": ["id", "file_name", "modifications"],
@ -213,35 +214,34 @@ impl ToolExecutor for ModifyDashboardFilesTool {
},
"file_name": {
"type": "string",
"description": "The name of the dashboard file being modified"
"description": get_modify_dashboards_file_name_description().await
},
"modifications": {
"type": "array",
"description": get_modify_dashboards_modifications_description().await,
"items": {
"type": "object",
"required": ["content_to_replace", "new_content"],
"properties": {
"content_to_replace": {
"type": "string",
"description": "The exact content in the file that should be replaced. Must match exactly."
"description": get_modify_dashboards_content_to_replace_description().await
},
"new_content": {
"type": "string",
"description": "The new content that will replace the matched content. Make sure to include proper indentation and formatting."
"description": get_modify_dashboards_new_content_description().await
}
},
"additionalProperties": false
},
"description": "List of content replacements to apply to the file."
}
}
},
"additionalProperties": false
},
"description": get_modify_dashboards_yml_description().await
}
}
},
"additionalProperties": false
},
}
})
}
}
@ -282,7 +282,7 @@ async fn get_dashboard_modification_id_description() -> String {
}
let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap();
match get_prompt_system_message(&client, "modify-dashboards-id-description").await {
match get_prompt_system_message(&client, "1d9cda62-53eb-4c5c-9c33-d3f81667b249").await {
Ok(message) => message,
Err(e) => {
eprintln!("Failed to get prompt system message: {}", e);
@ -291,6 +291,66 @@ async fn get_dashboard_modification_id_description() -> String {
}
}
async fn get_modify_dashboards_file_name_description() -> String {
if env::var("USE_BRAINTRUST_PROMPTS").is_err() {
return "Name of the dashboard file to modify".to_string();
}
let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap();
match get_prompt_system_message(&client, "5e0761df-2668-40f3-874e-84eb54c66e4d").await {
Ok(message) => message,
Err(e) => {
eprintln!("Failed to get prompt system message: {}", e);
"Name of the dashboard file to modify".to_string()
}
}
}
async fn get_modify_dashboards_modifications_description() -> String {
if env::var("USE_BRAINTRUST_PROMPTS").is_err() {
return "List of content modifications to make to the dashboard file".to_string();
}
let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap();
match get_prompt_system_message(&client, "ee5789a9-fd99-4afd-a5c4-88f2ebe58fe9").await {
Ok(message) => message,
Err(e) => {
eprintln!("Failed to get prompt system message: {}", e);
"List of content modifications to make to the dashboard file".to_string()
}
}
}
async fn get_modify_dashboards_new_content_description() -> String {
if env::var("USE_BRAINTRUST_PROMPTS").is_err() {
return "The new content to replace the existing content with".to_string();
}
let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap();
match get_prompt_system_message(&client, "258a84a6-ec1a-4f45-b586-04853272deeb").await {
Ok(message) => message,
Err(e) => {
eprintln!("Failed to get prompt system message: {}", e);
"The new content to replace the existing content with".to_string()
}
}
}
async fn get_modify_dashboards_content_to_replace_description() -> String {
if env::var("USE_BRAINTRUST_PROMPTS").is_err() {
return "The exact content in the file that should be replaced. Must match exactly.".to_string();
}
let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap();
match get_prompt_system_message(&client, "7e89b1f9-30ed-4f0c-b4da-32ce03f31635").await {
Ok(message) => message,
Err(e) => {
eprintln!("Failed to get prompt system message: {}", e);
"The exact content in the file that should be replaced. Must match exactly.".to_string()
}
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
@ -392,8 +452,8 @@ mod tests {
"id": Uuid::new_v4().to_string(),
"file_name": "test.yml",
"modifications": [{
"content_to_replace": "old content",
"new_content": "new content"
"new_content": "new content",
"line_numbers": [1, 2]
}]
}]
});

View File

@ -4,19 +4,26 @@ use anyhow::Result;
use async_trait::async_trait;
use braintrust::{get_prompt_system_message, BraintrustClient};
use chrono::Utc;
use database::{enums::Verification, models::MetricFile, pool::get_pg_pool, schema::metric_files, types::MetricYml};
use database::{
enums::Verification, models::MetricFile, pool::get_pg_pool, schema::metric_files,
types::MetricYml,
};
use diesel::{upsert::excluded, ExpressionMethods, QueryDsl};
use diesel_async::RunQueryDsl;
use indexmap::IndexMap;
use query_engine::data_types::DataType;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tracing::{debug, info, error};
use tracing::{debug, error, info};
use uuid::Uuid;
use super::{
common::{FileModification, Modification, ModificationResult, process_metric_file_modification, ModifyFilesParams, ModifyFilesOutput, FileModificationBatch, apply_modifications_to_content},
file_types::{file::FileWithId},
common::{
apply_modifications_to_content, process_metric_file_modification, FileModification,
FileModificationBatch, Modification, ModificationResult, ModifyFilesOutput,
ModifyFilesParams,
},
file_types::file::FileWithId,
FileModificationTool,
};
use crate::{
@ -195,19 +202,21 @@ impl ToolExecutor for ModifyMetricFilesTool {
};
// Add files to output
output.files.extend(batch.files.iter().enumerate().map(|(i, file)| {
let yml = &batch.ymls[i];
FileWithId {
id: file.id,
name: file.name.clone(),
file_type: "metric".to_string(),
yml_content: serde_yaml::to_string(&yml).unwrap_or_default(),
result_message: Some(batch.validation_messages[i].clone()),
results: Some(batch.validation_results[i].clone()),
created_at: file.created_at,
updated_at: file.updated_at,
}
}));
output
.files
.extend(batch.files.iter().enumerate().map(|(i, file)| {
let yml = &batch.ymls[i];
FileWithId {
id: file.id,
name: file.name.clone(),
file_type: "metric".to_string(),
yml_content: serde_yaml::to_string(&yml).unwrap_or_default(),
result_message: Some(batch.validation_messages[i].clone()),
results: Some(batch.validation_results[i].clone()),
created_at: file.created_at,
updated_at: file.updated_at,
}
}));
Ok(output)
}
@ -219,20 +228,14 @@ impl ToolExecutor for ModifyMetricFilesTool {
"strict": true,
"parameters": {
"type": "object",
"required": [
"files"
],
"required": ["files"],
"properties": {
"files": {
"type": "array",
"description": get_modify_metrics_yml_description().await,
"items": {
"type": "object",
"required": [
"id",
"file_name",
"modifications"
],
"required": ["id", "file_name", "modifications"],
"properties": {
"id": {
"type": "string",
@ -240,11 +243,11 @@ impl ToolExecutor for ModifyMetricFilesTool {
},
"file_name": {
"type": "string",
"description": "Name of the file to modify"
"description": get_modify_metrics_file_name_description().await
},
"modifications": {
"type": "array",
"description": "List of content replacements to apply to the file",
"description": get_modify_metrics_modifications_description().await,
"items": {
"type": "object",
"required": [
@ -254,11 +257,11 @@ impl ToolExecutor for ModifyMetricFilesTool {
"properties": {
"content_to_replace": {
"type": "string",
"description": "The exact content in the file that should be replaced. Must match exactly."
"description": get_modify_metrics_content_to_replace_description().await
},
"new_content": {
"type": "string",
"description": "The new content that will replace the matched content. Make sure to include proper indentation and formatting."
"description": get_modify_metrics_new_content_description().await
}
},
"additionalProperties": false
@ -305,6 +308,65 @@ async fn get_modify_metrics_yml_description() -> String {
}
}
async fn get_modify_metrics_file_name_description() -> String {
if env::var("USE_BRAINTRUST_PROMPTS").is_err() {
return "Name of the metric file to modify".to_string();
}
let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap();
match get_prompt_system_message(&client, "5e9e0a31-760a-483f-8876-41f2027bf731").await {
Ok(message) => message,
Err(e) => {
eprintln!("Failed to get prompt system message: {}", e);
"Name of the metric file to modify".to_string()
}
}
}
async fn get_modify_metrics_modifications_description() -> String {
if env::var("USE_BRAINTRUST_PROMPTS").is_err() {
return "List of content modifications to make to the metric file".to_string();
}
let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap();
match get_prompt_system_message(&client, "c56d3034-e527-45b6-aa2e-18fb5e3240de").await {
Ok(message) => message,
Err(e) => {
eprintln!("Failed to get prompt system message: {}", e);
"List of content modifications to make to the metric file".to_string()
}
}
}
async fn get_modify_metrics_new_content_description() -> String {
if env::var("USE_BRAINTRUST_PROMPTS").is_err() {
return "The new content to replace the existing content with".to_string();
}
let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap();
match get_prompt_system_message(&client, "28467bdb-6cab-49ce-bca5-193d26c620b2").await {
Ok(message) => message,
Err(e) => {
eprintln!("Failed to get prompt system message: {}", e);
"The new content to replace the existing content with".to_string()
}
}
}
async fn get_modify_metrics_content_to_replace_description() -> String {
if env::var("USE_BRAINTRUST_PROMPTS").is_err() {
return "The exact content in the file that should be replaced. Must match exactly.".to_string();
}
let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap();
match get_prompt_system_message(&client, "ad7e79f0-dd3a-4239-9548-ee7f4ef3be5a").await {
Ok(message) => message,
Err(e) => {
eprintln!("Failed to get prompt system message: {}", e);
"The exact content in the file that should be replaced. Must match exactly.".to_string()
}
}
}
async fn get_metric_id_description() -> String {
if env::var("USE_BRAINTRUST_PROMPTS").is_err() {
@ -312,7 +374,7 @@ async fn get_metric_id_description() -> String {
}
let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap();
match get_prompt_system_message(&client, "modify-metrics-id-description").await {
match get_prompt_system_message(&client, "471a0880-72f9-4989-bf47-397884a944fd").await {
Ok(message) => message,
Err(e) => {
eprintln!("Failed to get prompt system message: {}", e);
@ -417,8 +479,8 @@ mod tests {
"id": Uuid::new_v4().to_string(),
"file_name": "test.yml",
"modifications": [{
"content_to_replace": "old content",
"new_content": "new content"
"new_content": "new content",
"line_numbers": [1, 2]
}]
}]
});

View File

@ -94,13 +94,13 @@ impl SearchDataCatalogTool {
true
}
fn format_search_prompt(query_params: &[String], datasets: &[DatasetRecord]) -> Result<String> {
async fn format_search_prompt(query_params: &[String], datasets: &[DatasetRecord]) -> Result<String> {
let datasets_json = datasets
.iter()
.map(|d| d.to_llm_format())
.collect::<Vec<_>>();
Ok(CATALOG_SEARCH_PROMPT
Ok(SearchDataCatalogTool::get_search_prompt().await
.replace("{queries_joined_with_newlines}", &query_params.join("\n"))
.replace(
"{datasets_array_as_json}",
@ -108,6 +108,21 @@ impl SearchDataCatalogTool {
))
}
async fn get_search_prompt() -> String {
if env::var("USE_BRAINTRUST_PROMPTS").is_err() {
return CATALOG_SEARCH_PROMPT.to_string();
}
let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap();
match get_prompt_system_message(&client, "812b3f76-20d5-49e3-884c-2c8084800b43").await {
Ok(message) => message,
Err(e) => {
eprintln!("Failed to get prompt system message: {}", e);
CATALOG_SEARCH_PROMPT.to_string()
}
}
}
async fn perform_llm_search(
prompt: String,
user_id: &Uuid,
@ -277,7 +292,7 @@ impl ToolExecutor for SearchDataCatalogTool {
}
// Format prompt and perform search
let prompt = Self::format_search_prompt(&[params.search_requirements.clone()], &datasets)?;
let prompt = Self::format_search_prompt(&[params.search_requirements.clone()], &datasets).await?;
let search_results = match Self::perform_llm_search(
prompt,
&self.agent.get_user_id(),