added dashboard check and started on braintrust prompt injection

This commit is contained in:
dal 2025-03-18 09:46:54 -06:00
parent bd2cbf781c
commit 54ef8971af
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
9 changed files with 491 additions and 16 deletions

View File

@ -1,6 +1,7 @@
use anyhow::Result; use anyhow::Result;
use braintrust::{get_prompt_system_message, BraintrustClient};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::{collections::HashMap, env};
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::broadcast; use tokio::sync::broadcast;
use uuid::Uuid; use uuid::Uuid;
@ -9,17 +10,17 @@ use crate::{
tools::{ tools::{
categories::{ categories::{
file_tools::{ file_tools::{
SearchDataCatalogTool, CreateDashboardFilesTool, CreateMetricFilesTool, CreateDashboardFilesTool, CreateMetricFilesTool, ModifyDashboardFilesTool,
ModifyDashboardFilesTool, ModifyMetricFilesTool, ModifyMetricFilesTool, SearchDataCatalogTool,
}, },
planning_tools::CreatePlan, planning_tools::CreatePlan,
}, },
ToolExecutor, IntoToolCallExecutor, IntoToolCallExecutor, ToolExecutor,
}, },
Agent, AgentError, AgentExt, AgentThread, Agent, AgentError, AgentExt, AgentThread,
}; };
use litellm::AgentMessage as AgentMessage; use litellm::AgentMessage;
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct BusterSuperAgentOutput { pub struct BusterSuperAgentOutput {
@ -127,7 +128,7 @@ impl BusterSuperAgent {
&self, &self,
thread: &mut AgentThread, thread: &mut AgentThread,
) -> Result<broadcast::Receiver<Result<AgentMessage, AgentError>>> { ) -> Result<broadcast::Receiver<Result<AgentMessage, AgentError>>> {
thread.set_developer_message(BUSTER_SUPER_AGENT_PROMPT.to_string()); thread.set_developer_message(get_system_message().await);
// Get shutdown receiver // Get shutdown receiver
let rx = self.stream_process_thread(thread).await?; let rx = self.stream_process_thread(thread).await?;
@ -141,6 +142,21 @@ impl BusterSuperAgent {
} }
} }
async fn get_system_message() -> String {
if env::var("USE_BRAINTRUST_PROMPTS").is_err() {
return BUSTER_SUPER_AGENT_PROMPT.to_string();
}
let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap();
match get_prompt_system_message(&client, "12e4cf21-0b49-4de7-9c3f-a73c3e233dad").await {
Ok(message) => message,
Err(e) => {
eprintln!("Failed to get prompt system message: {}", e);
BUSTER_SUPER_AGENT_PROMPT.to_string()
}
}
}
const BUSTER_SUPER_AGENT_PROMPT: &str = r##"### Role & Task const BUSTER_SUPER_AGENT_PROMPT: &str = r##"### Role & Task
You are Buster, an expert analytics and data engineer. Your job is to assess what data is available and then provide fast, accurate answers to analytics questions from non-technical users. You do this by analyzing user requests, searching across a data catalog, and building metrics or dashboards. You are Buster, an expert analytics and data engineer. Your job is to assess what data is available and then provide fast, accurate answers to analytics questions from non-technical users. You do this by analyzing user requests, searching across a data catalog, and building metrics or dashboards.
--- ---

View File

@ -6,7 +6,7 @@ use tracing::{debug, error};
use std::env; use std::env;
use uuid::Uuid; use uuid::Uuid;
use crate::types::{Span, EventPayload}; use crate::types::{Span, EventPayload, Prompt};
use crate::API_BASE; use crate::API_BASE;
/// Environment variable name for Braintrust API key /// Environment variable name for Braintrust API key
@ -157,4 +157,41 @@ impl BraintrustClient {
pub fn project_id(&self) -> &str { pub fn project_id(&self) -> &str {
&self.project_id &self.project_id
} }
/// Fetch a prompt by its ID
///
/// # Arguments
/// * `prompt_id` - ID of the prompt to fetch
///
/// # Returns
/// Result containing the Prompt if successful, or an error if the request fails
///
/// # Errors
/// Returns an error if the API request fails or if the response cannot be parsed
pub async fn get_prompt(&self, prompt_id: &str) -> Result<Prompt> {
let url = format!("{}/prompt/{}", API_BASE, prompt_id);
debug!("Fetching prompt: {}", prompt_id);
let response = self.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.send()
.await
.map_err(|e| anyhow!("Failed to fetch prompt: {}", e))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow!("Failed to fetch prompt: HTTP {}, error: {}", status, error_text));
}
let prompt = response.json::<Prompt>().await
.map_err(|e| anyhow!("Failed to parse prompt response: {}", e))?;
debug!("Successfully fetched prompt: {}", prompt_id);
Ok(prompt)
}
} }

View File

@ -0,0 +1,68 @@
use anyhow::{Result, anyhow};
use crate::BraintrustClient;
/// Fetch a prompt from Braintrust and extract its system message
///
/// # Returns
/// The system message content from the prompt's messages
///
/// # Errors
/// Returns an error if:
/// - The prompt cannot be fetched
/// - The prompt has no prompt_data
/// - The prompt has no messages
/// - No system message is found in the messages
pub async fn get_prompt_system_message(client: &BraintrustClient, prompt_id: &str) -> Result<String> {
// Fetch the prompt
let prompt = client.get_prompt(prompt_id).await?;
// Extract the prompt data
let prompt_data = prompt.prompt_data
.ok_or_else(|| anyhow!("Prompt has no prompt_data"))?;
// Get the prompt content
let prompt_content = prompt_data.prompt
.ok_or_else(|| anyhow!("Prompt has no content"))?;
// Get the messages
let messages = prompt_content.messages
.ok_or_else(|| anyhow!("Prompt has no messages"))?;
// Find the system message
let system_message = messages.iter()
.find(|msg| msg.role == "system")
.ok_or_else(|| anyhow!("No system message found in prompt"))?;
Ok(system_message.content.clone())
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
use dotenv::dotenv;
#[tokio::test]
async fn test_get_prompt_system_message() -> Result<()> {
// Load environment variables
dotenv().ok();
// Skip test if no API key is available
if env::var("BRAINTRUST_API_KEY").is_err() {
println!("Skipping test_get_prompt_system_message: No API key available");
return Ok(());
}
// Create client
let client = BraintrustClient::new(None, "c7b996a6-1c7c-482d-b23f-3d39de16f433")?;
// Test with known prompt ID
let prompt_id = "7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee";
let system_message = get_prompt_system_message(&client, prompt_id).await?;
// Verify the message content
assert_eq!(system_message, "this is just a test {{input}}\n\n{{other_variable}}");
Ok(())
}
}

View File

@ -6,11 +6,13 @@
mod client; mod client;
mod types; mod types;
mod trace; mod trace;
mod helpers;
// Re-export public API // Re-export public API
pub use client::BraintrustClient; pub use client::BraintrustClient;
pub use trace::TraceBuilder; pub use trace::TraceBuilder;
pub use types::{Span, Metrics, EventPayload}; pub use types::{Span, Metrics, EventPayload, Prompt, PromptData, PromptContent, PromptOptions};
pub use helpers::*;
// Constants // Constants
pub const API_BASE: &str = "https://api.braintrust.dev/v1"; pub const API_BASE: &str = "https://api.braintrust.dev/v1";

View File

@ -124,6 +124,157 @@ impl Span {
} }
} }
/// Prompt data structure from Braintrust API
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct Prompt {
pub id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub _xact_id: Option<String>,
pub project_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub log_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub org_id: Option<String>,
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub slug: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub created: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_data: Option<PromptData>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tags: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_data: Option<FunctionData>,
}
/// Function data containing type information
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct FunctionData {
#[serde(rename = "type")]
pub function_type: String,
}
/// Prompt data containing the actual prompt content and configuration
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct PromptData {
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt: Option<PromptContent>,
#[serde(skip_serializing_if = "Option::is_none")]
pub options: Option<PromptOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parser: Option<PromptParser>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_functions: Option<Vec<ToolFunction>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub origin: Option<PromptOrigin>,
}
/// Content of the prompt - can be either completion or chat type
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct PromptContent {
#[serde(rename = "type")]
pub content_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub messages: Option<Vec<ChatMessage>>,
}
/// Chat message for chat-type prompts
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
/// Options for prompt execution
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct PromptOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<PromptParams>,
#[serde(skip_serializing_if = "Option::is_none")]
pub position: Option<String>,
}
/// Parameters for prompt execution
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct PromptParams {
#[serde(skip_serializing_if = "Option::is_none")]
pub use_cache: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_completion_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_effort: Option<String>,
}
/// Response format specification
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct ResponseFormat {
#[serde(rename = "type")]
pub format_type: String,
}
/// Parser configuration for prompt outputs
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct PromptParser {
#[serde(rename = "type")]
pub parser_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub use_cot: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub choice_scores: Option<HashMap<String, f32>>,
}
/// Tool function definition
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct ToolFunction {
#[serde(rename = "type")]
pub function_type: String,
pub id: String,
}
/// Origin information for the prompt
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct PromptOrigin {
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub project_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_version: Option<String>,
}
// Custom serializer for DateTime<Utc> to convert to timestamp // Custom serializer for DateTime<Utc> to convert to timestamp
fn serialize_datetime_as_timestamp<S>(date: &DateTime<Utc>, serializer: S) -> Result<S::Ok, S::Error> fn serialize_datetime_as_timestamp<S>(date: &DateTime<Utc>, serializer: S) -> Result<S::Ok, S::Error>
where where

View File

@ -170,3 +170,128 @@ async fn test_env_var_api_key() -> Result<()> {
Ok(()) Ok(())
} }
#[tokio::test]
async fn test_get_prompt() -> Result<()> {
// Create a mock server
let mut server = mockito::Server::new_async().await;
// Override the API_BASE constant for testing
env::set_var("BRAINTRUST_API_BASE", &server.url());
// Sample prompt response data based on the actual API response
let prompt_response = r#"{
"id": "7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee",
"_xact_id": "1000194776110029234",
"project_id": "96af8b2b-cf3c-494f-9092-44eb3d5b96ff",
"log_id": "p",
"org_id": "a931db1f-4a90-480c-a915-a66f143b79ab",
"name": "Testing",
"slug": "testing",
"description": null,
"created": "2025-03-18T14:16:28.539Z",
"prompt_data": {
"prompt": {
"type": "chat",
"tools": "",
"messages": [
{
"role": "system",
"content": "this is just a test {{input}}\n\n{{other_variable}}"
}
]
},
"options": {
"model": "gpt-4o"
}
},
"tags": null,
"metadata": null,
"function_type": null,
"function_data": {
"type": "prompt"
}
}"#;
// Create a mock for the API endpoint
let m = server.mock("GET", "/prompt/7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee")
.match_header("Authorization", "Bearer test_api_key")
.match_header("Content-Type", "application/json")
.with_status(200)
.with_body(prompt_response)
.create_async()
.await;
// Create a test client
let client = BraintrustClient::new(Some("test_api_key"), "test_project")?;
// Fetch the prompt
let prompt = client.get_prompt("7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee").await?;
// Verify the mock was called
m.assert_async().await;
// Verify the prompt data
assert_eq!(prompt.id, "7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee");
assert_eq!(prompt.project_id, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff");
assert_eq!(prompt.name, "Testing");
assert_eq!(prompt.slug.as_ref().unwrap(), "testing");
// Verify prompt content
let prompt_content = prompt.prompt_data.as_ref().unwrap().prompt.as_ref().unwrap();
assert_eq!(prompt_content.content_type, "chat");
// Verify messages
let messages = prompt_content.messages.as_ref().unwrap();
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].role, "system");
assert_eq!(messages[0].content, "this is just a test {{input}}\n\n{{other_variable}}");
// Verify model options
let options = prompt.prompt_data.as_ref().unwrap().options.as_ref().unwrap();
assert_eq!(options.model.as_ref().unwrap(), "gpt-4o");
// Verify function data
let function_data = prompt.function_data.as_ref().unwrap();
assert_eq!(function_data.function_type, "prompt");
// Reset the environment variable
env::remove_var("BRAINTRUST_API_BASE");
Ok(())
}
#[tokio::test]
async fn test_get_prompt_error() -> Result<()> {
// Create a mock server
let mut server = mockito::Server::new_async().await;
// Override the API_BASE constant for testing
env::set_var("BRAINTRUST_API_BASE", &server.url());
// Create a mock for the API endpoint with an error response
let m = server.mock("GET", "/prompt/nonexistent")
.match_header("Authorization", "Bearer test_api_key")
.match_header("Content-Type", "application/json")
.with_status(404)
.with_body(r#"{"error": "Prompt not found"}"#)
.create_async()
.await;
// Create a test client
let client = BraintrustClient::new(Some("test_api_key"), "test_project")?;
// Attempt to fetch a nonexistent prompt
let result = client.get_prompt("nonexistent").await;
// Verify the mock was called
m.assert_async().await;
// Verify that an error was returned
assert!(result.is_err());
// Reset the environment variable
env::remove_var("BRAINTRUST_API_BASE");
Ok(())
}

View File

@ -236,3 +236,66 @@ async fn test_real_error_handling() -> Result<()> {
Ok(()) Ok(())
} }
#[tokio::test]
async fn test_real_get_prompt() -> Result<()> {
// Initialize environment
init_env()?;
// Skip test if no API key is available
if env::var("BRAINTRUST_API_KEY").is_err() {
println!("Skipping test_real_get_prompt: No API key available");
return Ok(());
}
// Create client (None means use env var)
let client = BraintrustClient::new(None, "c7b996a6-1c7c-482d-b23f-3d39de16f433")?;
// Attempt to fetch the prompt with ID "7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee"
let prompt_id = "7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee";
println!("Fetching prompt with ID: {}", prompt_id);
match client.get_prompt(prompt_id).await {
Ok(prompt) => {
println!("Successfully fetched prompt: {}", prompt.name);
println!("Prompt ID: {}", prompt.id);
println!("Project ID: {}", prompt.project_id);
// Verify the prompt ID matches what we requested
assert_eq!(prompt.id, prompt_id, "Prompt ID should match the requested ID");
if let Some(description) = &prompt.description {
println!("Description: {}", description);
}
if let Some(prompt_data) = &prompt.prompt_data {
if let Some(content) = &prompt_data.prompt {
println!("Prompt type: {}", content.content_type);
println!("Prompt content: {:?}", content.content);
println!("Prompt messages: {:?}", content.messages);
}
if let Some(options) = &prompt_data.options {
if let Some(model) = &options.model {
println!("Model: {}", model);
}
}
}
if let Some(tags) = &prompt.tags {
println!("Tags: {:?}", tags);
}
},
Err(e) => {
println!("Failed to fetch prompt '{}': {}", prompt_id, e);
println!("This is expected if the prompt doesn't exist in your Braintrust project");
println!("You can create a prompt with this ID in your Braintrust project for this test to pass");
// Fail the test if we can't fetch the prompt
panic!("Could not fetch prompt with ID: {}", prompt_id);
}
}
Ok(())
}

View File

@ -236,6 +236,7 @@ impl FromSql<sql_types::StoredValuesStatusEnum, Pg> for StoredValuesStatus {
#[diesel(sql_type = sql_types::AssetTypeEnum)] #[diesel(sql_type = sql_types::AssetTypeEnum)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum AssetType { pub enum AssetType {
#[serde(rename = "dashboard_deprecated")]
Dashboard, Dashboard,
Thread, Thread,
Collection, Collection,

View File

@ -16,8 +16,7 @@ use crate::utils::user::user_info::get_user_organization_id;
use database::enums::{AssetPermissionRole, AssetType, UserOrganizationRole}; use database::enums::{AssetPermissionRole, AssetType, UserOrganizationRole};
use database::pool::{get_pg_pool, PgPool}; use database::pool::{get_pg_pool, PgPool};
use database::schema::{ use database::schema::{
asset_permissions, collections_to_assets, dashboards, metric_files, teams_to_users, asset_permissions, collections_to_assets, dashboard_files, dashboards, metric_files, teams_to_users, threads_deprecated, threads_to_dashboards, users_to_organizations
threads_deprecated, threads_to_dashboards, users_to_organizations,
}; };
pub async fn get_asset_access( pub async fn get_asset_access(
@ -135,15 +134,28 @@ async fn get_asset_access_handler(
.first::<(Uuid, bool, Option<DateTime<Utc>>)>(&mut conn) .first::<(Uuid, bool, Option<DateTime<Utc>>)>(&mut conn)
.await?; .await?;
let metric_info = ( let metric_info = (metric_info.0, metric_info.1, false, metric_info.2);
metric_info.0,
metric_info.1,
false,
metric_info.2,
);
(metric_info, Some(AssetPermissionRole::Owner)) (metric_info, Some(AssetPermissionRole::Owner))
} }
AssetType::DashboardFile => {
let mut conn = pg_pool.get().await?;
let dashboard_info = dashboard_files::table
.select((
dashboard_files::id,
dashboard_files::publicly_accessible,
dashboard_files::public_expiry_date,
))
.filter(dashboard_files::id.eq(&asset_id))
.filter(dashboard_files::deleted_at.is_null())
.first::<(Uuid, bool, Option<DateTime<Utc>>)>(&mut conn)
.await?;
let dashboard_info = (dashboard_info.0, dashboard_info.1, false, dashboard_info.2);
(dashboard_info, Some(AssetPermissionRole::Owner))
}
_ => { _ => {
return Err(anyhow!("Public access is not supported for chats yet")); return Err(anyhow!("Public access is not supported for chats yet"));
} }