added dashboard check and started on braintrust prompt injection

This commit is contained in:
dal 2025-03-18 09:46:54 -06:00
parent bd2cbf781c
commit 54ef8971af
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
9 changed files with 491 additions and 16 deletions

View File

@ -1,6 +1,7 @@
use anyhow::Result;
use braintrust::{get_prompt_system_message, BraintrustClient};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::{collections::HashMap, env};
use std::sync::Arc;
use tokio::sync::broadcast;
use uuid::Uuid;
@ -9,17 +10,17 @@ use crate::{
tools::{
categories::{
file_tools::{
SearchDataCatalogTool, CreateDashboardFilesTool, CreateMetricFilesTool,
ModifyDashboardFilesTool, ModifyMetricFilesTool,
CreateDashboardFilesTool, CreateMetricFilesTool, ModifyDashboardFilesTool,
ModifyMetricFilesTool, SearchDataCatalogTool,
},
planning_tools::CreatePlan,
},
ToolExecutor, IntoToolCallExecutor,
IntoToolCallExecutor, ToolExecutor,
},
Agent, AgentError, AgentExt, AgentThread,
};
use litellm::AgentMessage as AgentMessage;
use litellm::AgentMessage;
#[derive(Debug, Serialize, Deserialize)]
pub struct BusterSuperAgentOutput {
@ -127,7 +128,7 @@ impl BusterSuperAgent {
&self,
thread: &mut AgentThread,
) -> Result<broadcast::Receiver<Result<AgentMessage, AgentError>>> {
thread.set_developer_message(BUSTER_SUPER_AGENT_PROMPT.to_string());
thread.set_developer_message(get_system_message().await);
// Get shutdown receiver
let rx = self.stream_process_thread(thread).await?;
@ -141,6 +142,21 @@ impl BusterSuperAgent {
}
}
async fn get_system_message() -> String {
if env::var("USE_BRAINTRUST_PROMPTS").is_err() {
return BUSTER_SUPER_AGENT_PROMPT.to_string();
}
let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap();
match get_prompt_system_message(&client, "12e4cf21-0b49-4de7-9c3f-a73c3e233dad").await {
Ok(message) => message,
Err(e) => {
eprintln!("Failed to get prompt system message: {}", e);
BUSTER_SUPER_AGENT_PROMPT.to_string()
}
}
}
const BUSTER_SUPER_AGENT_PROMPT: &str = r##"### Role & Task
You are Buster, an expert analytics and data engineer. Your job is to assess what data is available and then provide fast, accurate answers to analytics questions from non-technical users. You do this by analyzing user requests, searching across a data catalog, and building metrics or dashboards.
---

View File

@ -6,7 +6,7 @@ use tracing::{debug, error};
use std::env;
use uuid::Uuid;
use crate::types::{Span, EventPayload};
use crate::types::{Span, EventPayload, Prompt};
use crate::API_BASE;
/// Environment variable name for Braintrust API key
@ -157,4 +157,41 @@ impl BraintrustClient {
pub fn project_id(&self) -> &str {
&self.project_id
}
/// Fetch a prompt by its ID
///
/// # Arguments
/// * `prompt_id` - ID of the prompt to fetch
///
/// # Returns
/// Result containing the Prompt if successful, or an error if the request fails
///
/// # Errors
/// Returns an error if the API request fails or if the response cannot be parsed
pub async fn get_prompt(&self, prompt_id: &str) -> Result<Prompt> {
let url = format!("{}/prompt/{}", API_BASE, prompt_id);
debug!("Fetching prompt: {}", prompt_id);
let response = self.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.send()
.await
.map_err(|e| anyhow!("Failed to fetch prompt: {}", e))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow!("Failed to fetch prompt: HTTP {}, error: {}", status, error_text));
}
let prompt = response.json::<Prompt>().await
.map_err(|e| anyhow!("Failed to parse prompt response: {}", e))?;
debug!("Successfully fetched prompt: {}", prompt_id);
Ok(prompt)
}
}

View File

@ -0,0 +1,68 @@
use anyhow::{Result, anyhow};
use crate::BraintrustClient;
/// Fetch a prompt from Braintrust and extract its system message
///
/// # Returns
/// The system message content from the prompt's messages
///
/// # Errors
/// Returns an error if:
/// - The prompt cannot be fetched
/// - The prompt has no prompt_data
/// - The prompt has no messages
/// - No system message is found in the messages
pub async fn get_prompt_system_message(client: &BraintrustClient, prompt_id: &str) -> Result<String> {
// Fetch the prompt
let prompt = client.get_prompt(prompt_id).await?;
// Extract the prompt data
let prompt_data = prompt.prompt_data
.ok_or_else(|| anyhow!("Prompt has no prompt_data"))?;
// Get the prompt content
let prompt_content = prompt_data.prompt
.ok_or_else(|| anyhow!("Prompt has no content"))?;
// Get the messages
let messages = prompt_content.messages
.ok_or_else(|| anyhow!("Prompt has no messages"))?;
// Find the system message
let system_message = messages.iter()
.find(|msg| msg.role == "system")
.ok_or_else(|| anyhow!("No system message found in prompt"))?;
Ok(system_message.content.clone())
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
use dotenv::dotenv;
#[tokio::test]
async fn test_get_prompt_system_message() -> Result<()> {
// Load environment variables
dotenv().ok();
// Skip test if no API key is available
if env::var("BRAINTRUST_API_KEY").is_err() {
println!("Skipping test_get_prompt_system_message: No API key available");
return Ok(());
}
// Create client
let client = BraintrustClient::new(None, "c7b996a6-1c7c-482d-b23f-3d39de16f433")?;
// Test with known prompt ID
let prompt_id = "7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee";
let system_message = get_prompt_system_message(&client, prompt_id).await?;
// Verify the message content
assert_eq!(system_message, "this is just a test {{input}}\n\n{{other_variable}}");
Ok(())
}
}

View File

@ -6,11 +6,13 @@
mod client;
mod types;
mod trace;
mod helpers;
// Re-export public API
pub use client::BraintrustClient;
pub use trace::TraceBuilder;
pub use types::{Span, Metrics, EventPayload};
pub use types::{Span, Metrics, EventPayload, Prompt, PromptData, PromptContent, PromptOptions};
pub use helpers::*;
// Constants
pub const API_BASE: &str = "https://api.braintrust.dev/v1";

View File

@ -124,6 +124,157 @@ impl Span {
}
}
/// Prompt data structure from Braintrust API
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct Prompt {
pub id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub _xact_id: Option<String>,
pub project_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub log_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub org_id: Option<String>,
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub slug: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub created: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_data: Option<PromptData>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tags: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_data: Option<FunctionData>,
}
/// Function data containing type information
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct FunctionData {
#[serde(rename = "type")]
pub function_type: String,
}
/// Prompt data containing the actual prompt content and configuration
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct PromptData {
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt: Option<PromptContent>,
#[serde(skip_serializing_if = "Option::is_none")]
pub options: Option<PromptOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parser: Option<PromptParser>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_functions: Option<Vec<ToolFunction>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub origin: Option<PromptOrigin>,
}
/// Content of the prompt - can be either completion or chat type
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct PromptContent {
#[serde(rename = "type")]
pub content_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub messages: Option<Vec<ChatMessage>>,
}
/// Chat message for chat-type prompts
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
/// Options for prompt execution
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct PromptOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<PromptParams>,
#[serde(skip_serializing_if = "Option::is_none")]
pub position: Option<String>,
}
/// Parameters for prompt execution
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct PromptParams {
#[serde(skip_serializing_if = "Option::is_none")]
pub use_cache: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_completion_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_effort: Option<String>,
}
/// Response format specification
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct ResponseFormat {
#[serde(rename = "type")]
pub format_type: String,
}
/// Parser configuration for prompt outputs
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct PromptParser {
#[serde(rename = "type")]
pub parser_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub use_cot: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub choice_scores: Option<HashMap<String, f32>>,
}
/// Tool function definition
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct ToolFunction {
#[serde(rename = "type")]
pub function_type: String,
pub id: String,
}
/// Origin information for the prompt
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct PromptOrigin {
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub project_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_version: Option<String>,
}
// Custom serializer for DateTime<Utc> to convert to timestamp
fn serialize_datetime_as_timestamp<S>(date: &DateTime<Utc>, serializer: S) -> Result<S::Ok, S::Error>
where

View File

@ -170,3 +170,128 @@ async fn test_env_var_api_key() -> Result<()> {
Ok(())
}
#[tokio::test]
async fn test_get_prompt() -> Result<()> {
// Create a mock server
let mut server = mockito::Server::new_async().await;
// Override the API_BASE constant for testing
env::set_var("BRAINTRUST_API_BASE", &server.url());
// Sample prompt response data based on the actual API response
let prompt_response = r#"{
"id": "7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee",
"_xact_id": "1000194776110029234",
"project_id": "96af8b2b-cf3c-494f-9092-44eb3d5b96ff",
"log_id": "p",
"org_id": "a931db1f-4a90-480c-a915-a66f143b79ab",
"name": "Testing",
"slug": "testing",
"description": null,
"created": "2025-03-18T14:16:28.539Z",
"prompt_data": {
"prompt": {
"type": "chat",
"tools": "",
"messages": [
{
"role": "system",
"content": "this is just a test {{input}}\n\n{{other_variable}}"
}
]
},
"options": {
"model": "gpt-4o"
}
},
"tags": null,
"metadata": null,
"function_type": null,
"function_data": {
"type": "prompt"
}
}"#;
// Create a mock for the API endpoint
let m = server.mock("GET", "/prompt/7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee")
.match_header("Authorization", "Bearer test_api_key")
.match_header("Content-Type", "application/json")
.with_status(200)
.with_body(prompt_response)
.create_async()
.await;
// Create a test client
let client = BraintrustClient::new(Some("test_api_key"), "test_project")?;
// Fetch the prompt
let prompt = client.get_prompt("7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee").await?;
// Verify the mock was called
m.assert_async().await;
// Verify the prompt data
assert_eq!(prompt.id, "7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee");
assert_eq!(prompt.project_id, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff");
assert_eq!(prompt.name, "Testing");
assert_eq!(prompt.slug.as_ref().unwrap(), "testing");
// Verify prompt content
let prompt_content = prompt.prompt_data.as_ref().unwrap().prompt.as_ref().unwrap();
assert_eq!(prompt_content.content_type, "chat");
// Verify messages
let messages = prompt_content.messages.as_ref().unwrap();
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].role, "system");
assert_eq!(messages[0].content, "this is just a test {{input}}\n\n{{other_variable}}");
// Verify model options
let options = prompt.prompt_data.as_ref().unwrap().options.as_ref().unwrap();
assert_eq!(options.model.as_ref().unwrap(), "gpt-4o");
// Verify function data
let function_data = prompt.function_data.as_ref().unwrap();
assert_eq!(function_data.function_type, "prompt");
// Reset the environment variable
env::remove_var("BRAINTRUST_API_BASE");
Ok(())
}
#[tokio::test]
async fn test_get_prompt_error() -> Result<()> {
// Create a mock server
let mut server = mockito::Server::new_async().await;
// Override the API_BASE constant for testing
env::set_var("BRAINTRUST_API_BASE", &server.url());
// Create a mock for the API endpoint with an error response
let m = server.mock("GET", "/prompt/nonexistent")
.match_header("Authorization", "Bearer test_api_key")
.match_header("Content-Type", "application/json")
.with_status(404)
.with_body(r#"{"error": "Prompt not found"}"#)
.create_async()
.await;
// Create a test client
let client = BraintrustClient::new(Some("test_api_key"), "test_project")?;
// Attempt to fetch a nonexistent prompt
let result = client.get_prompt("nonexistent").await;
// Verify the mock was called
m.assert_async().await;
// Verify that an error was returned
assert!(result.is_err());
// Reset the environment variable
env::remove_var("BRAINTRUST_API_BASE");
Ok(())
}

View File

@ -236,3 +236,66 @@ async fn test_real_error_handling() -> Result<()> {
Ok(())
}
#[tokio::test]
async fn test_real_get_prompt() -> Result<()> {
// Initialize environment
init_env()?;
// Skip test if no API key is available
if env::var("BRAINTRUST_API_KEY").is_err() {
println!("Skipping test_real_get_prompt: No API key available");
return Ok(());
}
// Create client (None means use env var)
let client = BraintrustClient::new(None, "c7b996a6-1c7c-482d-b23f-3d39de16f433")?;
// Attempt to fetch the prompt with ID "7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee"
let prompt_id = "7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee";
println!("Fetching prompt with ID: {}", prompt_id);
match client.get_prompt(prompt_id).await {
Ok(prompt) => {
println!("Successfully fetched prompt: {}", prompt.name);
println!("Prompt ID: {}", prompt.id);
println!("Project ID: {}", prompt.project_id);
// Verify the prompt ID matches what we requested
assert_eq!(prompt.id, prompt_id, "Prompt ID should match the requested ID");
if let Some(description) = &prompt.description {
println!("Description: {}", description);
}
if let Some(prompt_data) = &prompt.prompt_data {
if let Some(content) = &prompt_data.prompt {
println!("Prompt type: {}", content.content_type);
println!("Prompt content: {:?}", content.content);
println!("Prompt messages: {:?}", content.messages);
}
if let Some(options) = &prompt_data.options {
if let Some(model) = &options.model {
println!("Model: {}", model);
}
}
}
if let Some(tags) = &prompt.tags {
println!("Tags: {:?}", tags);
}
},
Err(e) => {
println!("Failed to fetch prompt '{}': {}", prompt_id, e);
println!("This is expected if the prompt doesn't exist in your Braintrust project");
println!("You can create a prompt with this ID in your Braintrust project for this test to pass");
// Fail the test if we can't fetch the prompt
panic!("Could not fetch prompt with ID: {}", prompt_id);
}
}
Ok(())
}

View File

@ -236,6 +236,7 @@ impl FromSql<sql_types::StoredValuesStatusEnum, Pg> for StoredValuesStatus {
#[diesel(sql_type = sql_types::AssetTypeEnum)]
#[serde(rename_all = "snake_case")]
pub enum AssetType {
#[serde(rename = "dashboard_deprecated")]
Dashboard,
Thread,
Collection,

View File

@ -16,8 +16,7 @@ use crate::utils::user::user_info::get_user_organization_id;
use database::enums::{AssetPermissionRole, AssetType, UserOrganizationRole};
use database::pool::{get_pg_pool, PgPool};
use database::schema::{
asset_permissions, collections_to_assets, dashboards, metric_files, teams_to_users,
threads_deprecated, threads_to_dashboards, users_to_organizations,
asset_permissions, collections_to_assets, dashboard_files, dashboards, metric_files, teams_to_users, threads_deprecated, threads_to_dashboards, users_to_organizations
};
pub async fn get_asset_access(
@ -135,15 +134,28 @@ async fn get_asset_access_handler(
.first::<(Uuid, bool, Option<DateTime<Utc>>)>(&mut conn)
.await?;
let metric_info = (
metric_info.0,
metric_info.1,
false,
metric_info.2,
);
let metric_info = (metric_info.0, metric_info.1, false, metric_info.2);
(metric_info, Some(AssetPermissionRole::Owner))
}
AssetType::DashboardFile => {
let mut conn = pg_pool.get().await?;
let dashboard_info = dashboard_files::table
.select((
dashboard_files::id,
dashboard_files::publicly_accessible,
dashboard_files::public_expiry_date,
))
.filter(dashboard_files::id.eq(&asset_id))
.filter(dashboard_files::deleted_at.is_null())
.first::<(Uuid, bool, Option<DateTime<Utc>>)>(&mut conn)
.await?;
let dashboard_info = (dashboard_info.0, dashboard_info.1, false, dashboard_info.2);
(dashboard_info, Some(AssetPermissionRole::Owner))
}
_ => {
return Err(anyhow!("Public access is not supported for chats yet"));
}