From c1ca69966c2bd1b96a7b5da57cd1cbed49f0e786 Mon Sep 17 00:00:00 2001 From: dal Date: Tue, 18 Mar 2025 10:57:04 -0600 Subject: [PATCH] prompts --- .../file_tools/modify_dashboards.rs | 82 ++++++++++-- .../categories/file_tools/modify_metrics.rs | 126 +++++++++++++----- .../file_tools/search_data_catalog.rs | 21 ++- 3 files changed, 183 insertions(+), 46 deletions(-) diff --git a/api/libs/agents/src/tools/categories/file_tools/modify_dashboards.rs b/api/libs/agents/src/tools/categories/file_tools/modify_dashboards.rs index 43c5bc3c4..1a83e2b69 100644 --- a/api/libs/agents/src/tools/categories/file_tools/modify_dashboards.rs +++ b/api/libs/agents/src/tools/categories/file_tools/modify_dashboards.rs @@ -203,6 +203,7 @@ impl ToolExecutor for ModifyDashboardFilesTool { "properties": { "files": { "type": "array", + "description": get_modify_dashboards_yml_description().await, "items": { "type": "object", "required": ["id", "file_name", "modifications"], @@ -213,35 +214,34 @@ impl ToolExecutor for ModifyDashboardFilesTool { }, "file_name": { "type": "string", - "description": "The name of the dashboard file being modified" + "description": get_modify_dashboards_file_name_description().await }, "modifications": { "type": "array", + "description": get_modify_dashboards_modifications_description().await, "items": { "type": "object", "required": ["content_to_replace", "new_content"], "properties": { "content_to_replace": { "type": "string", - "description": "The exact content in the file that should be replaced. Must match exactly." + "description": get_modify_dashboards_content_to_replace_description().await }, "new_content": { "type": "string", - "description": "The new content that will replace the matched content. Make sure to include proper indentation and formatting." + "description": get_modify_dashboards_new_content_description().await } }, "additionalProperties": false - }, - "description": "List of content replacements to apply to the file." + } } }, "additionalProperties": false - }, - "description": get_modify_dashboards_yml_description().await + } } }, "additionalProperties": false - }, + } }) } } @@ -282,7 +282,7 @@ async fn get_dashboard_modification_id_description() -> String { } let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap(); - match get_prompt_system_message(&client, "modify-dashboards-id-description").await { + match get_prompt_system_message(&client, "1d9cda62-53eb-4c5c-9c33-d3f81667b249").await { Ok(message) => message, Err(e) => { eprintln!("Failed to get prompt system message: {}", e); @@ -291,6 +291,66 @@ async fn get_dashboard_modification_id_description() -> String { } } +async fn get_modify_dashboards_file_name_description() -> String { + if env::var("USE_BRAINTRUST_PROMPTS").is_err() { + return "Name of the dashboard file to modify".to_string(); + } + + let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap(); + match get_prompt_system_message(&client, "5e0761df-2668-40f3-874e-84eb54c66e4d").await { + Ok(message) => message, + Err(e) => { + eprintln!("Failed to get prompt system message: {}", e); + "Name of the dashboard file to modify".to_string() + } + } +} + +async fn get_modify_dashboards_modifications_description() -> String { + if env::var("USE_BRAINTRUST_PROMPTS").is_err() { + return "List of content modifications to make to the dashboard file".to_string(); + } + + let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap(); + match get_prompt_system_message(&client, "ee5789a9-fd99-4afd-a5c4-88f2ebe58fe9").await { + Ok(message) => message, + Err(e) => { + eprintln!("Failed to get prompt system message: {}", e); + "List of content modifications to make to the dashboard file".to_string() + } + } +} + +async fn get_modify_dashboards_new_content_description() -> String { + if env::var("USE_BRAINTRUST_PROMPTS").is_err() { + return "The new content to replace the existing content with".to_string(); + } + + let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap(); + match get_prompt_system_message(&client, "258a84a6-ec1a-4f45-b586-04853272deeb").await { + Ok(message) => message, + Err(e) => { + eprintln!("Failed to get prompt system message: {}", e); + "The new content to replace the existing content with".to_string() + } + } +} + +async fn get_modify_dashboards_content_to_replace_description() -> String { + if env::var("USE_BRAINTRUST_PROMPTS").is_err() { + return "The exact content in the file that should be replaced. Must match exactly.".to_string(); + } + + let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap(); + match get_prompt_system_message(&client, "7e89b1f9-30ed-4f0c-b4da-32ce03f31635").await { + Ok(message) => message, + Err(e) => { + eprintln!("Failed to get prompt system message: {}", e); + "The exact content in the file that should be replaced. Must match exactly.".to_string() + } + } +} + #[cfg(test)] mod tests { use std::collections::HashMap; @@ -392,8 +452,8 @@ mod tests { "id": Uuid::new_v4().to_string(), "file_name": "test.yml", "modifications": [{ - "content_to_replace": "old content", - "new_content": "new content" + "new_content": "new content", + "line_numbers": [1, 2] }] }] }); diff --git a/api/libs/agents/src/tools/categories/file_tools/modify_metrics.rs b/api/libs/agents/src/tools/categories/file_tools/modify_metrics.rs index ef98a9394..8e65c8596 100644 --- a/api/libs/agents/src/tools/categories/file_tools/modify_metrics.rs +++ b/api/libs/agents/src/tools/categories/file_tools/modify_metrics.rs @@ -4,19 +4,26 @@ use anyhow::Result; use async_trait::async_trait; use braintrust::{get_prompt_system_message, BraintrustClient}; use chrono::Utc; -use database::{enums::Verification, models::MetricFile, pool::get_pg_pool, schema::metric_files, types::MetricYml}; +use database::{ + enums::Verification, models::MetricFile, pool::get_pg_pool, schema::metric_files, + types::MetricYml, +}; use diesel::{upsert::excluded, ExpressionMethods, QueryDsl}; use diesel_async::RunQueryDsl; use indexmap::IndexMap; use query_engine::data_types::DataType; use serde::{Deserialize, Serialize}; use serde_json::Value; -use tracing::{debug, info, error}; +use tracing::{debug, error, info}; use uuid::Uuid; use super::{ - common::{FileModification, Modification, ModificationResult, process_metric_file_modification, ModifyFilesParams, ModifyFilesOutput, FileModificationBatch, apply_modifications_to_content}, - file_types::{file::FileWithId}, + common::{ + apply_modifications_to_content, process_metric_file_modification, FileModification, + FileModificationBatch, Modification, ModificationResult, ModifyFilesOutput, + ModifyFilesParams, + }, + file_types::file::FileWithId, FileModificationTool, }; use crate::{ @@ -195,19 +202,21 @@ impl ToolExecutor for ModifyMetricFilesTool { }; // Add files to output - output.files.extend(batch.files.iter().enumerate().map(|(i, file)| { - let yml = &batch.ymls[i]; - FileWithId { - id: file.id, - name: file.name.clone(), - file_type: "metric".to_string(), - yml_content: serde_yaml::to_string(&yml).unwrap_or_default(), - result_message: Some(batch.validation_messages[i].clone()), - results: Some(batch.validation_results[i].clone()), - created_at: file.created_at, - updated_at: file.updated_at, - } - })); + output + .files + .extend(batch.files.iter().enumerate().map(|(i, file)| { + let yml = &batch.ymls[i]; + FileWithId { + id: file.id, + name: file.name.clone(), + file_type: "metric".to_string(), + yml_content: serde_yaml::to_string(&yml).unwrap_or_default(), + result_message: Some(batch.validation_messages[i].clone()), + results: Some(batch.validation_results[i].clone()), + created_at: file.created_at, + updated_at: file.updated_at, + } + })); Ok(output) } @@ -219,20 +228,14 @@ impl ToolExecutor for ModifyMetricFilesTool { "strict": true, "parameters": { "type": "object", - "required": [ - "files" - ], + "required": ["files"], "properties": { "files": { "type": "array", "description": get_modify_metrics_yml_description().await, "items": { "type": "object", - "required": [ - "id", - "file_name", - "modifications" - ], + "required": ["id", "file_name", "modifications"], "properties": { "id": { "type": "string", @@ -240,11 +243,11 @@ impl ToolExecutor for ModifyMetricFilesTool { }, "file_name": { "type": "string", - "description": "Name of the file to modify" + "description": get_modify_metrics_file_name_description().await }, "modifications": { "type": "array", - "description": "List of content replacements to apply to the file", + "description": get_modify_metrics_modifications_description().await, "items": { "type": "object", "required": [ @@ -254,11 +257,11 @@ impl ToolExecutor for ModifyMetricFilesTool { "properties": { "content_to_replace": { "type": "string", - "description": "The exact content in the file that should be replaced. Must match exactly." + "description": get_modify_metrics_content_to_replace_description().await }, "new_content": { "type": "string", - "description": "The new content that will replace the matched content. Make sure to include proper indentation and formatting." + "description": get_modify_metrics_new_content_description().await } }, "additionalProperties": false @@ -305,6 +308,65 @@ async fn get_modify_metrics_yml_description() -> String { } } +async fn get_modify_metrics_file_name_description() -> String { + if env::var("USE_BRAINTRUST_PROMPTS").is_err() { + return "Name of the metric file to modify".to_string(); + } + + let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap(); + match get_prompt_system_message(&client, "5e9e0a31-760a-483f-8876-41f2027bf731").await { + Ok(message) => message, + Err(e) => { + eprintln!("Failed to get prompt system message: {}", e); + "Name of the metric file to modify".to_string() + } + } +} + +async fn get_modify_metrics_modifications_description() -> String { + if env::var("USE_BRAINTRUST_PROMPTS").is_err() { + return "List of content modifications to make to the metric file".to_string(); + } + + let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap(); + match get_prompt_system_message(&client, "c56d3034-e527-45b6-aa2e-18fb5e3240de").await { + Ok(message) => message, + Err(e) => { + eprintln!("Failed to get prompt system message: {}", e); + "List of content modifications to make to the metric file".to_string() + } + } +} + +async fn get_modify_metrics_new_content_description() -> String { + if env::var("USE_BRAINTRUST_PROMPTS").is_err() { + return "The new content to replace the existing content with".to_string(); + } + + let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap(); + match get_prompt_system_message(&client, "28467bdb-6cab-49ce-bca5-193d26c620b2").await { + Ok(message) => message, + Err(e) => { + eprintln!("Failed to get prompt system message: {}", e); + "The new content to replace the existing content with".to_string() + } + } +} + +async fn get_modify_metrics_content_to_replace_description() -> String { + if env::var("USE_BRAINTRUST_PROMPTS").is_err() { + return "The exact content in the file that should be replaced. Must match exactly.".to_string(); + } + + let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap(); + match get_prompt_system_message(&client, "ad7e79f0-dd3a-4239-9548-ee7f4ef3be5a").await { + Ok(message) => message, + Err(e) => { + eprintln!("Failed to get prompt system message: {}", e); + "The exact content in the file that should be replaced. Must match exactly.".to_string() + } + } +} async fn get_metric_id_description() -> String { if env::var("USE_BRAINTRUST_PROMPTS").is_err() { @@ -312,7 +374,7 @@ async fn get_metric_id_description() -> String { } let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap(); - match get_prompt_system_message(&client, "modify-metrics-id-description").await { + match get_prompt_system_message(&client, "471a0880-72f9-4989-bf47-397884a944fd").await { Ok(message) => message, Err(e) => { eprintln!("Failed to get prompt system message: {}", e); @@ -417,8 +479,8 @@ mod tests { "id": Uuid::new_v4().to_string(), "file_name": "test.yml", "modifications": [{ - "content_to_replace": "old content", - "new_content": "new content" + "new_content": "new content", + "line_numbers": [1, 2] }] }] }); diff --git a/api/libs/agents/src/tools/categories/file_tools/search_data_catalog.rs b/api/libs/agents/src/tools/categories/file_tools/search_data_catalog.rs index b740c5e2d..731812084 100644 --- a/api/libs/agents/src/tools/categories/file_tools/search_data_catalog.rs +++ b/api/libs/agents/src/tools/categories/file_tools/search_data_catalog.rs @@ -94,13 +94,13 @@ impl SearchDataCatalogTool { true } - fn format_search_prompt(query_params: &[String], datasets: &[DatasetRecord]) -> Result { + async fn format_search_prompt(query_params: &[String], datasets: &[DatasetRecord]) -> Result { let datasets_json = datasets .iter() .map(|d| d.to_llm_format()) .collect::>(); - Ok(CATALOG_SEARCH_PROMPT + Ok(SearchDataCatalogTool::get_search_prompt().await .replace("{queries_joined_with_newlines}", &query_params.join("\n")) .replace( "{datasets_array_as_json}", @@ -108,6 +108,21 @@ impl SearchDataCatalogTool { )) } + async fn get_search_prompt() -> String { + if env::var("USE_BRAINTRUST_PROMPTS").is_err() { + return CATALOG_SEARCH_PROMPT.to_string(); + } + + let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap(); + match get_prompt_system_message(&client, "812b3f76-20d5-49e3-884c-2c8084800b43").await { + Ok(message) => message, + Err(e) => { + eprintln!("Failed to get prompt system message: {}", e); + CATALOG_SEARCH_PROMPT.to_string() + } + } + } + async fn perform_llm_search( prompt: String, user_id: &Uuid, @@ -277,7 +292,7 @@ impl ToolExecutor for SearchDataCatalogTool { } // Format prompt and perform search - let prompt = Self::format_search_prompt(&[params.search_requirements.clone()], &datasets)?; + let prompt = Self::format_search_prompt(&[params.search_requirements.clone()], &datasets).await?; let search_results = match Self::perform_llm_search( prompt, &self.agent.get_user_id(),