diff --git a/api/libs/agents/src/agents/buster_super_agent.rs b/api/libs/agents/src/agents/buster_super_agent.rs index 25059268a..bc763607f 100644 --- a/api/libs/agents/src/agents/buster_super_agent.rs +++ b/api/libs/agents/src/agents/buster_super_agent.rs @@ -1,6 +1,7 @@ use anyhow::Result; +use braintrust::{get_prompt_system_message, BraintrustClient}; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; +use std::{collections::HashMap, env}; use std::sync::Arc; use tokio::sync::broadcast; use uuid::Uuid; @@ -9,17 +10,17 @@ use crate::{ tools::{ categories::{ file_tools::{ - SearchDataCatalogTool, CreateDashboardFilesTool, CreateMetricFilesTool, - ModifyDashboardFilesTool, ModifyMetricFilesTool, + CreateDashboardFilesTool, CreateMetricFilesTool, ModifyDashboardFilesTool, + ModifyMetricFilesTool, SearchDataCatalogTool, }, planning_tools::CreatePlan, }, - ToolExecutor, IntoToolCallExecutor, + IntoToolCallExecutor, ToolExecutor, }, Agent, AgentError, AgentExt, AgentThread, }; -use litellm::AgentMessage as AgentMessage; +use litellm::AgentMessage; #[derive(Debug, Serialize, Deserialize)] pub struct BusterSuperAgentOutput { @@ -127,7 +128,7 @@ impl BusterSuperAgent { &self, thread: &mut AgentThread, ) -> Result>> { - thread.set_developer_message(BUSTER_SUPER_AGENT_PROMPT.to_string()); + thread.set_developer_message(get_system_message().await); // Get shutdown receiver let rx = self.stream_process_thread(thread).await?; @@ -141,6 +142,21 @@ impl BusterSuperAgent { } } +async fn get_system_message() -> String { + if env::var("USE_BRAINTRUST_PROMPTS").is_err() { + return BUSTER_SUPER_AGENT_PROMPT.to_string(); + } + + let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap(); + match get_prompt_system_message(&client, "12e4cf21-0b49-4de7-9c3f-a73c3e233dad").await { + Ok(message) => message, + Err(e) => { + eprintln!("Failed to get prompt system message: {}", e); + BUSTER_SUPER_AGENT_PROMPT.to_string() + } + } +} + const BUSTER_SUPER_AGENT_PROMPT: &str = r##"### Role & Task You are Buster, an expert analytics and data engineer. Your job is to assess what data is available and then provide fast, accurate answers to analytics questions from non-technical users. You do this by analyzing user requests, searching across a data catalog, and building metrics or dashboards. --- diff --git a/api/libs/braintrust/src/client.rs b/api/libs/braintrust/src/client.rs index 3a4905a27..9eb0ffc6d 100644 --- a/api/libs/braintrust/src/client.rs +++ b/api/libs/braintrust/src/client.rs @@ -6,7 +6,7 @@ use tracing::{debug, error}; use std::env; use uuid::Uuid; -use crate::types::{Span, EventPayload}; +use crate::types::{Span, EventPayload, Prompt}; use crate::API_BASE; /// Environment variable name for Braintrust API key @@ -157,4 +157,41 @@ impl BraintrustClient { pub fn project_id(&self) -> &str { &self.project_id } + + /// Fetch a prompt by its ID + /// + /// # Arguments + /// * `prompt_id` - ID of the prompt to fetch + /// + /// # Returns + /// Result containing the Prompt if successful, or an error if the request fails + /// + /// # Errors + /// Returns an error if the API request fails or if the response cannot be parsed + pub async fn get_prompt(&self, prompt_id: &str) -> Result { + let url = format!("{}/prompt/{}", API_BASE, prompt_id); + + debug!("Fetching prompt: {}", prompt_id); + + let response = self.client + .get(&url) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .send() + .await + .map_err(|e| anyhow!("Failed to fetch prompt: {}", e))?; + + if !response.status().is_success() { + let status = response.status(); + let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string()); + return Err(anyhow!("Failed to fetch prompt: HTTP {}, error: {}", status, error_text)); + } + + let prompt = response.json::().await + .map_err(|e| anyhow!("Failed to parse prompt response: {}", e))?; + + debug!("Successfully fetched prompt: {}", prompt_id); + + Ok(prompt) + } } diff --git a/api/libs/braintrust/src/helpers.rs b/api/libs/braintrust/src/helpers.rs new file mode 100644 index 000000000..e3e0dd4d3 --- /dev/null +++ b/api/libs/braintrust/src/helpers.rs @@ -0,0 +1,68 @@ +use anyhow::{Result, anyhow}; +use crate::BraintrustClient; + +/// Fetch a prompt from Braintrust and extract its system message +/// +/// # Returns +/// The system message content from the prompt's messages +/// +/// # Errors +/// Returns an error if: +/// - The prompt cannot be fetched +/// - The prompt has no prompt_data +/// - The prompt has no messages +/// - No system message is found in the messages +pub async fn get_prompt_system_message(client: &BraintrustClient, prompt_id: &str) -> Result { + // Fetch the prompt + let prompt = client.get_prompt(prompt_id).await?; + + // Extract the prompt data + let prompt_data = prompt.prompt_data + .ok_or_else(|| anyhow!("Prompt has no prompt_data"))?; + + // Get the prompt content + let prompt_content = prompt_data.prompt + .ok_or_else(|| anyhow!("Prompt has no content"))?; + + // Get the messages + let messages = prompt_content.messages + .ok_or_else(|| anyhow!("Prompt has no messages"))?; + + // Find the system message + let system_message = messages.iter() + .find(|msg| msg.role == "system") + .ok_or_else(|| anyhow!("No system message found in prompt"))?; + + Ok(system_message.content.clone()) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::env; + use dotenv::dotenv; + + #[tokio::test] + async fn test_get_prompt_system_message() -> Result<()> { + // Load environment variables + dotenv().ok(); + + // Skip test if no API key is available + if env::var("BRAINTRUST_API_KEY").is_err() { + println!("Skipping test_get_prompt_system_message: No API key available"); + return Ok(()); + } + + // Create client + let client = BraintrustClient::new(None, "c7b996a6-1c7c-482d-b23f-3d39de16f433")?; + + // Test with known prompt ID + let prompt_id = "7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee"; + let system_message = get_prompt_system_message(&client, prompt_id).await?; + + // Verify the message content + assert_eq!(system_message, "this is just a test {{input}}\n\n{{other_variable}}"); + + Ok(()) + } +} \ No newline at end of file diff --git a/api/libs/braintrust/src/lib.rs b/api/libs/braintrust/src/lib.rs index 61ed893fc..61e14db56 100644 --- a/api/libs/braintrust/src/lib.rs +++ b/api/libs/braintrust/src/lib.rs @@ -6,11 +6,13 @@ mod client; mod types; mod trace; +mod helpers; // Re-export public API pub use client::BraintrustClient; pub use trace::TraceBuilder; -pub use types::{Span, Metrics, EventPayload}; +pub use types::{Span, Metrics, EventPayload, Prompt, PromptData, PromptContent, PromptOptions}; +pub use helpers::*; // Constants pub const API_BASE: &str = "https://api.braintrust.dev/v1"; diff --git a/api/libs/braintrust/src/types.rs b/api/libs/braintrust/src/types.rs index 172b47e79..c2c73236d 100644 --- a/api/libs/braintrust/src/types.rs +++ b/api/libs/braintrust/src/types.rs @@ -124,6 +124,157 @@ impl Span { } } +/// Prompt data structure from Braintrust API +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct Prompt { + pub id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub _xact_id: Option, + pub project_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub log_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub org_id: Option, + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub slug: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub created: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_data: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tags: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub function_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub function_data: Option, +} + +/// Function data containing type information +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct FunctionData { + #[serde(rename = "type")] + pub function_type: String, +} + +/// Prompt data containing the actual prompt content and configuration +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct PromptData { + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub options: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub parser: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_functions: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub origin: Option, +} + +/// Content of the prompt - can be either completion or chat type +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct PromptContent { + #[serde(rename = "type")] + pub content_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub messages: Option>, +} + +/// Chat message for chat-type prompts +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct ChatMessage { + pub role: String, + pub content: String, +} + +/// Options for prompt execution +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct PromptOptions { + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub params: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub position: Option, +} + +/// Parameters for prompt execution +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct PromptParams { + #[serde(skip_serializing_if = "Option::is_none")] + pub use_cache: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_completion_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub function_call: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, +} + +/// Response format specification +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct ResponseFormat { + #[serde(rename = "type")] + pub format_type: String, +} + +/// Parser configuration for prompt outputs +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct PromptParser { + #[serde(rename = "type")] + pub parser_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub use_cot: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub choice_scores: Option>, +} + +/// Tool function definition +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct ToolFunction { + #[serde(rename = "type")] + pub function_type: String, + pub id: String, +} + +/// Origin information for the prompt +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct PromptOrigin { + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub project_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_version: Option, +} + // Custom serializer for DateTime to convert to timestamp fn serialize_datetime_as_timestamp(date: &DateTime, serializer: S) -> Result where diff --git a/api/libs/braintrust/tests/client_tests.rs b/api/libs/braintrust/tests/client_tests.rs index c99d5686e..d09619ba2 100644 --- a/api/libs/braintrust/tests/client_tests.rs +++ b/api/libs/braintrust/tests/client_tests.rs @@ -170,3 +170,128 @@ async fn test_env_var_api_key() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_get_prompt() -> Result<()> { + // Create a mock server + let mut server = mockito::Server::new_async().await; + + // Override the API_BASE constant for testing + env::set_var("BRAINTRUST_API_BASE", &server.url()); + + // Sample prompt response data based on the actual API response + let prompt_response = r#"{ + "id": "7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee", + "_xact_id": "1000194776110029234", + "project_id": "96af8b2b-cf3c-494f-9092-44eb3d5b96ff", + "log_id": "p", + "org_id": "a931db1f-4a90-480c-a915-a66f143b79ab", + "name": "Testing", + "slug": "testing", + "description": null, + "created": "2025-03-18T14:16:28.539Z", + "prompt_data": { + "prompt": { + "type": "chat", + "tools": "", + "messages": [ + { + "role": "system", + "content": "this is just a test {{input}}\n\n{{other_variable}}" + } + ] + }, + "options": { + "model": "gpt-4o" + } + }, + "tags": null, + "metadata": null, + "function_type": null, + "function_data": { + "type": "prompt" + } + }"#; + + // Create a mock for the API endpoint + let m = server.mock("GET", "/prompt/7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee") + .match_header("Authorization", "Bearer test_api_key") + .match_header("Content-Type", "application/json") + .with_status(200) + .with_body(prompt_response) + .create_async() + .await; + + // Create a test client + let client = BraintrustClient::new(Some("test_api_key"), "test_project")?; + + // Fetch the prompt + let prompt = client.get_prompt("7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee").await?; + + // Verify the mock was called + m.assert_async().await; + + // Verify the prompt data + assert_eq!(prompt.id, "7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee"); + assert_eq!(prompt.project_id, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff"); + assert_eq!(prompt.name, "Testing"); + assert_eq!(prompt.slug.as_ref().unwrap(), "testing"); + + // Verify prompt content + let prompt_content = prompt.prompt_data.as_ref().unwrap().prompt.as_ref().unwrap(); + assert_eq!(prompt_content.content_type, "chat"); + + // Verify messages + let messages = prompt_content.messages.as_ref().unwrap(); + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].role, "system"); + assert_eq!(messages[0].content, "this is just a test {{input}}\n\n{{other_variable}}"); + + // Verify model options + let options = prompt.prompt_data.as_ref().unwrap().options.as_ref().unwrap(); + assert_eq!(options.model.as_ref().unwrap(), "gpt-4o"); + + // Verify function data + let function_data = prompt.function_data.as_ref().unwrap(); + assert_eq!(function_data.function_type, "prompt"); + + // Reset the environment variable + env::remove_var("BRAINTRUST_API_BASE"); + + Ok(()) +} + +#[tokio::test] +async fn test_get_prompt_error() -> Result<()> { + // Create a mock server + let mut server = mockito::Server::new_async().await; + + // Override the API_BASE constant for testing + env::set_var("BRAINTRUST_API_BASE", &server.url()); + + // Create a mock for the API endpoint with an error response + let m = server.mock("GET", "/prompt/nonexistent") + .match_header("Authorization", "Bearer test_api_key") + .match_header("Content-Type", "application/json") + .with_status(404) + .with_body(r#"{"error": "Prompt not found"}"#) + .create_async() + .await; + + // Create a test client + let client = BraintrustClient::new(Some("test_api_key"), "test_project")?; + + // Attempt to fetch a nonexistent prompt + let result = client.get_prompt("nonexistent").await; + + // Verify the mock was called + m.assert_async().await; + + // Verify that an error was returned + assert!(result.is_err()); + + // Reset the environment variable + env::remove_var("BRAINTRUST_API_BASE"); + + Ok(()) +} diff --git a/api/libs/braintrust/tests/integration_tests.rs b/api/libs/braintrust/tests/integration_tests.rs index 9100b7736..1bc43cf26 100644 --- a/api/libs/braintrust/tests/integration_tests.rs +++ b/api/libs/braintrust/tests/integration_tests.rs @@ -236,3 +236,66 @@ async fn test_real_error_handling() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_real_get_prompt() -> Result<()> { + // Initialize environment + init_env()?; + + // Skip test if no API key is available + if env::var("BRAINTRUST_API_KEY").is_err() { + println!("Skipping test_real_get_prompt: No API key available"); + return Ok(()); + } + + // Create client (None means use env var) + let client = BraintrustClient::new(None, "c7b996a6-1c7c-482d-b23f-3d39de16f433")?; + + // Attempt to fetch the prompt with ID "7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee" + let prompt_id = "7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee"; + + println!("Fetching prompt with ID: {}", prompt_id); + + match client.get_prompt(prompt_id).await { + Ok(prompt) => { + println!("Successfully fetched prompt: {}", prompt.name); + println!("Prompt ID: {}", prompt.id); + println!("Project ID: {}", prompt.project_id); + + // Verify the prompt ID matches what we requested + assert_eq!(prompt.id, prompt_id, "Prompt ID should match the requested ID"); + + if let Some(description) = &prompt.description { + println!("Description: {}", description); + } + + if let Some(prompt_data) = &prompt.prompt_data { + if let Some(content) = &prompt_data.prompt { + println!("Prompt type: {}", content.content_type); + println!("Prompt content: {:?}", content.content); + println!("Prompt messages: {:?}", content.messages); + } + + if let Some(options) = &prompt_data.options { + if let Some(model) = &options.model { + println!("Model: {}", model); + } + } + } + + if let Some(tags) = &prompt.tags { + println!("Tags: {:?}", tags); + } + }, + Err(e) => { + println!("Failed to fetch prompt '{}': {}", prompt_id, e); + println!("This is expected if the prompt doesn't exist in your Braintrust project"); + println!("You can create a prompt with this ID in your Braintrust project for this test to pass"); + + // Fail the test if we can't fetch the prompt + panic!("Could not fetch prompt with ID: {}", prompt_id); + } + } + + Ok(()) +} diff --git a/api/libs/database/src/enums.rs b/api/libs/database/src/enums.rs index 33d7740fd..3d57095a6 100644 --- a/api/libs/database/src/enums.rs +++ b/api/libs/database/src/enums.rs @@ -236,6 +236,7 @@ impl FromSql for StoredValuesStatus { #[diesel(sql_type = sql_types::AssetTypeEnum)] #[serde(rename_all = "snake_case")] pub enum AssetType { + #[serde(rename = "dashboard_deprecated")] Dashboard, Thread, Collection, diff --git a/api/src/routes/rest/routes/assets/get_asset_access.rs b/api/src/routes/rest/routes/assets/get_asset_access.rs index 45df97ce7..d1d741a20 100644 --- a/api/src/routes/rest/routes/assets/get_asset_access.rs +++ b/api/src/routes/rest/routes/assets/get_asset_access.rs @@ -16,8 +16,7 @@ use crate::utils::user::user_info::get_user_organization_id; use database::enums::{AssetPermissionRole, AssetType, UserOrganizationRole}; use database::pool::{get_pg_pool, PgPool}; use database::schema::{ - asset_permissions, collections_to_assets, dashboards, metric_files, teams_to_users, - threads_deprecated, threads_to_dashboards, users_to_organizations, + asset_permissions, collections_to_assets, dashboard_files, dashboards, metric_files, teams_to_users, threads_deprecated, threads_to_dashboards, users_to_organizations }; pub async fn get_asset_access( @@ -135,15 +134,28 @@ async fn get_asset_access_handler( .first::<(Uuid, bool, Option>)>(&mut conn) .await?; - let metric_info = ( - metric_info.0, - metric_info.1, - false, - metric_info.2, - ); + let metric_info = (metric_info.0, metric_info.1, false, metric_info.2); (metric_info, Some(AssetPermissionRole::Owner)) } + AssetType::DashboardFile => { + let mut conn = pg_pool.get().await?; + + let dashboard_info = dashboard_files::table + .select(( + dashboard_files::id, + dashboard_files::publicly_accessible, + dashboard_files::public_expiry_date, + )) + .filter(dashboard_files::id.eq(&asset_id)) + .filter(dashboard_files::deleted_at.is_null()) + .first::<(Uuid, bool, Option>)>(&mut conn) + .await?; + + let dashboard_info = (dashboard_info.0, dashboard_info.1, false, dashboard_info.2); + + (dashboard_info, Some(AssetPermissionRole::Owner)) + } _ => { return Err(anyhow!("Public access is not supported for chats yet")); }