mirror of https://github.com/buster-so/buster.git
added dashboard check and started on braintrust prompt injection
This commit is contained in:
parent
bd2cbf781c
commit
54ef8971af
|
@ -1,6 +1,7 @@
|
|||
use anyhow::Result;
|
||||
use braintrust::{get_prompt_system_message, BraintrustClient};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::{collections::HashMap, env};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::broadcast;
|
||||
use uuid::Uuid;
|
||||
|
@ -9,17 +10,17 @@ use crate::{
|
|||
tools::{
|
||||
categories::{
|
||||
file_tools::{
|
||||
SearchDataCatalogTool, CreateDashboardFilesTool, CreateMetricFilesTool,
|
||||
ModifyDashboardFilesTool, ModifyMetricFilesTool,
|
||||
CreateDashboardFilesTool, CreateMetricFilesTool, ModifyDashboardFilesTool,
|
||||
ModifyMetricFilesTool, SearchDataCatalogTool,
|
||||
},
|
||||
planning_tools::CreatePlan,
|
||||
},
|
||||
ToolExecutor, IntoToolCallExecutor,
|
||||
IntoToolCallExecutor, ToolExecutor,
|
||||
},
|
||||
Agent, AgentError, AgentExt, AgentThread,
|
||||
};
|
||||
|
||||
use litellm::AgentMessage as AgentMessage;
|
||||
use litellm::AgentMessage;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct BusterSuperAgentOutput {
|
||||
|
@ -127,7 +128,7 @@ impl BusterSuperAgent {
|
|||
&self,
|
||||
thread: &mut AgentThread,
|
||||
) -> Result<broadcast::Receiver<Result<AgentMessage, AgentError>>> {
|
||||
thread.set_developer_message(BUSTER_SUPER_AGENT_PROMPT.to_string());
|
||||
thread.set_developer_message(get_system_message().await);
|
||||
|
||||
// Get shutdown receiver
|
||||
let rx = self.stream_process_thread(thread).await?;
|
||||
|
@ -141,6 +142,21 @@ impl BusterSuperAgent {
|
|||
}
|
||||
}
|
||||
|
||||
async fn get_system_message() -> String {
|
||||
if env::var("USE_BRAINTRUST_PROMPTS").is_err() {
|
||||
return BUSTER_SUPER_AGENT_PROMPT.to_string();
|
||||
}
|
||||
|
||||
let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap();
|
||||
match get_prompt_system_message(&client, "12e4cf21-0b49-4de7-9c3f-a73c3e233dad").await {
|
||||
Ok(message) => message,
|
||||
Err(e) => {
|
||||
eprintln!("Failed to get prompt system message: {}", e);
|
||||
BUSTER_SUPER_AGENT_PROMPT.to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const BUSTER_SUPER_AGENT_PROMPT: &str = r##"### Role & Task
|
||||
You are Buster, an expert analytics and data engineer. Your job is to assess what data is available and then provide fast, accurate answers to analytics questions from non-technical users. You do this by analyzing user requests, searching across a data catalog, and building metrics or dashboards.
|
||||
---
|
||||
|
|
|
@ -6,7 +6,7 @@ use tracing::{debug, error};
|
|||
use std::env;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::types::{Span, EventPayload};
|
||||
use crate::types::{Span, EventPayload, Prompt};
|
||||
use crate::API_BASE;
|
||||
|
||||
/// Environment variable name for Braintrust API key
|
||||
|
@ -157,4 +157,41 @@ impl BraintrustClient {
|
|||
pub fn project_id(&self) -> &str {
|
||||
&self.project_id
|
||||
}
|
||||
|
||||
/// Fetch a prompt by its ID
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `prompt_id` - ID of the prompt to fetch
|
||||
///
|
||||
/// # Returns
|
||||
/// Result containing the Prompt if successful, or an error if the request fails
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the API request fails or if the response cannot be parsed
|
||||
pub async fn get_prompt(&self, prompt_id: &str) -> Result<Prompt> {
|
||||
let url = format!("{}/prompt/{}", API_BASE, prompt_id);
|
||||
|
||||
debug!("Fetching prompt: {}", prompt_id);
|
||||
|
||||
let response = self.client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to fetch prompt: {}", e))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(anyhow!("Failed to fetch prompt: HTTP {}, error: {}", status, error_text));
|
||||
}
|
||||
|
||||
let prompt = response.json::<Prompt>().await
|
||||
.map_err(|e| anyhow!("Failed to parse prompt response: {}", e))?;
|
||||
|
||||
debug!("Successfully fetched prompt: {}", prompt_id);
|
||||
|
||||
Ok(prompt)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
use anyhow::{Result, anyhow};
|
||||
use crate::BraintrustClient;
|
||||
|
||||
/// Fetch a prompt from Braintrust and extract its system message
|
||||
///
|
||||
/// # Returns
|
||||
/// The system message content from the prompt's messages
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The prompt cannot be fetched
|
||||
/// - The prompt has no prompt_data
|
||||
/// - The prompt has no messages
|
||||
/// - No system message is found in the messages
|
||||
pub async fn get_prompt_system_message(client: &BraintrustClient, prompt_id: &str) -> Result<String> {
|
||||
// Fetch the prompt
|
||||
let prompt = client.get_prompt(prompt_id).await?;
|
||||
|
||||
// Extract the prompt data
|
||||
let prompt_data = prompt.prompt_data
|
||||
.ok_or_else(|| anyhow!("Prompt has no prompt_data"))?;
|
||||
|
||||
// Get the prompt content
|
||||
let prompt_content = prompt_data.prompt
|
||||
.ok_or_else(|| anyhow!("Prompt has no content"))?;
|
||||
|
||||
// Get the messages
|
||||
let messages = prompt_content.messages
|
||||
.ok_or_else(|| anyhow!("Prompt has no messages"))?;
|
||||
|
||||
// Find the system message
|
||||
let system_message = messages.iter()
|
||||
.find(|msg| msg.role == "system")
|
||||
.ok_or_else(|| anyhow!("No system message found in prompt"))?;
|
||||
|
||||
Ok(system_message.content.clone())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::env;
|
||||
use dotenv::dotenv;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_prompt_system_message() -> Result<()> {
|
||||
// Load environment variables
|
||||
dotenv().ok();
|
||||
|
||||
// Skip test if no API key is available
|
||||
if env::var("BRAINTRUST_API_KEY").is_err() {
|
||||
println!("Skipping test_get_prompt_system_message: No API key available");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Create client
|
||||
let client = BraintrustClient::new(None, "c7b996a6-1c7c-482d-b23f-3d39de16f433")?;
|
||||
|
||||
// Test with known prompt ID
|
||||
let prompt_id = "7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee";
|
||||
let system_message = get_prompt_system_message(&client, prompt_id).await?;
|
||||
|
||||
// Verify the message content
|
||||
assert_eq!(system_message, "this is just a test {{input}}\n\n{{other_variable}}");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -6,11 +6,13 @@
|
|||
mod client;
|
||||
mod types;
|
||||
mod trace;
|
||||
mod helpers;
|
||||
|
||||
// Re-export public API
|
||||
pub use client::BraintrustClient;
|
||||
pub use trace::TraceBuilder;
|
||||
pub use types::{Span, Metrics, EventPayload};
|
||||
pub use types::{Span, Metrics, EventPayload, Prompt, PromptData, PromptContent, PromptOptions};
|
||||
pub use helpers::*;
|
||||
|
||||
// Constants
|
||||
pub const API_BASE: &str = "https://api.braintrust.dev/v1";
|
||||
|
|
|
@ -124,6 +124,157 @@ impl Span {
|
|||
}
|
||||
}
|
||||
|
||||
/// Prompt data structure from Braintrust API
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct Prompt {
|
||||
pub id: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub _xact_id: Option<String>,
|
||||
pub project_id: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub log_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub org_id: Option<String>,
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub slug: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub created: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompt_data: Option<PromptData>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tags: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub function_type: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub function_data: Option<FunctionData>,
|
||||
}
|
||||
|
||||
/// Function data containing type information
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct FunctionData {
|
||||
#[serde(rename = "type")]
|
||||
pub function_type: String,
|
||||
}
|
||||
|
||||
/// Prompt data containing the actual prompt content and configuration
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct PromptData {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompt: Option<PromptContent>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub options: Option<PromptOptions>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub parser: Option<PromptParser>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_functions: Option<Vec<ToolFunction>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub origin: Option<PromptOrigin>,
|
||||
}
|
||||
|
||||
/// Content of the prompt - can be either completion or chat type
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct PromptContent {
|
||||
#[serde(rename = "type")]
|
||||
pub content_type: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub messages: Option<Vec<ChatMessage>>,
|
||||
}
|
||||
|
||||
/// Chat message for chat-type prompts
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
/// Options for prompt execution
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct PromptOptions {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<PromptParams>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub position: Option<String>,
|
||||
}
|
||||
|
||||
/// Parameters for prompt execution
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct PromptParams {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub use_cache: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_completion_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub frequency_penalty: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub presence_penalty: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub response_format: Option<ResponseFormat>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub function_call: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub n: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_effort: Option<String>,
|
||||
}
|
||||
|
||||
/// Response format specification
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct ResponseFormat {
|
||||
#[serde(rename = "type")]
|
||||
pub format_type: String,
|
||||
}
|
||||
|
||||
/// Parser configuration for prompt outputs
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct PromptParser {
|
||||
#[serde(rename = "type")]
|
||||
pub parser_type: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub use_cot: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub choice_scores: Option<HashMap<String, f32>>,
|
||||
}
|
||||
|
||||
/// Tool function definition
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct ToolFunction {
|
||||
#[serde(rename = "type")]
|
||||
pub function_type: String,
|
||||
pub id: String,
|
||||
}
|
||||
|
||||
/// Origin information for the prompt
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct PromptOrigin {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompt_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub project_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompt_version: Option<String>,
|
||||
}
|
||||
|
||||
// Custom serializer for DateTime<Utc> to convert to timestamp
|
||||
fn serialize_datetime_as_timestamp<S>(date: &DateTime<Utc>, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
|
|
|
@ -170,3 +170,128 @@ async fn test_env_var_api_key() -> Result<()> {
|
|||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_prompt() -> Result<()> {
|
||||
// Create a mock server
|
||||
let mut server = mockito::Server::new_async().await;
|
||||
|
||||
// Override the API_BASE constant for testing
|
||||
env::set_var("BRAINTRUST_API_BASE", &server.url());
|
||||
|
||||
// Sample prompt response data based on the actual API response
|
||||
let prompt_response = r#"{
|
||||
"id": "7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee",
|
||||
"_xact_id": "1000194776110029234",
|
||||
"project_id": "96af8b2b-cf3c-494f-9092-44eb3d5b96ff",
|
||||
"log_id": "p",
|
||||
"org_id": "a931db1f-4a90-480c-a915-a66f143b79ab",
|
||||
"name": "Testing",
|
||||
"slug": "testing",
|
||||
"description": null,
|
||||
"created": "2025-03-18T14:16:28.539Z",
|
||||
"prompt_data": {
|
||||
"prompt": {
|
||||
"type": "chat",
|
||||
"tools": "",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "this is just a test {{input}}\n\n{{other_variable}}"
|
||||
}
|
||||
]
|
||||
},
|
||||
"options": {
|
||||
"model": "gpt-4o"
|
||||
}
|
||||
},
|
||||
"tags": null,
|
||||
"metadata": null,
|
||||
"function_type": null,
|
||||
"function_data": {
|
||||
"type": "prompt"
|
||||
}
|
||||
}"#;
|
||||
|
||||
// Create a mock for the API endpoint
|
||||
let m = server.mock("GET", "/prompt/7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee")
|
||||
.match_header("Authorization", "Bearer test_api_key")
|
||||
.match_header("Content-Type", "application/json")
|
||||
.with_status(200)
|
||||
.with_body(prompt_response)
|
||||
.create_async()
|
||||
.await;
|
||||
|
||||
// Create a test client
|
||||
let client = BraintrustClient::new(Some("test_api_key"), "test_project")?;
|
||||
|
||||
// Fetch the prompt
|
||||
let prompt = client.get_prompt("7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee").await?;
|
||||
|
||||
// Verify the mock was called
|
||||
m.assert_async().await;
|
||||
|
||||
// Verify the prompt data
|
||||
assert_eq!(prompt.id, "7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee");
|
||||
assert_eq!(prompt.project_id, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff");
|
||||
assert_eq!(prompt.name, "Testing");
|
||||
assert_eq!(prompt.slug.as_ref().unwrap(), "testing");
|
||||
|
||||
// Verify prompt content
|
||||
let prompt_content = prompt.prompt_data.as_ref().unwrap().prompt.as_ref().unwrap();
|
||||
assert_eq!(prompt_content.content_type, "chat");
|
||||
|
||||
// Verify messages
|
||||
let messages = prompt_content.messages.as_ref().unwrap();
|
||||
assert_eq!(messages.len(), 1);
|
||||
assert_eq!(messages[0].role, "system");
|
||||
assert_eq!(messages[0].content, "this is just a test {{input}}\n\n{{other_variable}}");
|
||||
|
||||
// Verify model options
|
||||
let options = prompt.prompt_data.as_ref().unwrap().options.as_ref().unwrap();
|
||||
assert_eq!(options.model.as_ref().unwrap(), "gpt-4o");
|
||||
|
||||
// Verify function data
|
||||
let function_data = prompt.function_data.as_ref().unwrap();
|
||||
assert_eq!(function_data.function_type, "prompt");
|
||||
|
||||
// Reset the environment variable
|
||||
env::remove_var("BRAINTRUST_API_BASE");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_prompt_error() -> Result<()> {
|
||||
// Create a mock server
|
||||
let mut server = mockito::Server::new_async().await;
|
||||
|
||||
// Override the API_BASE constant for testing
|
||||
env::set_var("BRAINTRUST_API_BASE", &server.url());
|
||||
|
||||
// Create a mock for the API endpoint with an error response
|
||||
let m = server.mock("GET", "/prompt/nonexistent")
|
||||
.match_header("Authorization", "Bearer test_api_key")
|
||||
.match_header("Content-Type", "application/json")
|
||||
.with_status(404)
|
||||
.with_body(r#"{"error": "Prompt not found"}"#)
|
||||
.create_async()
|
||||
.await;
|
||||
|
||||
// Create a test client
|
||||
let client = BraintrustClient::new(Some("test_api_key"), "test_project")?;
|
||||
|
||||
// Attempt to fetch a nonexistent prompt
|
||||
let result = client.get_prompt("nonexistent").await;
|
||||
|
||||
// Verify the mock was called
|
||||
m.assert_async().await;
|
||||
|
||||
// Verify that an error was returned
|
||||
assert!(result.is_err());
|
||||
|
||||
// Reset the environment variable
|
||||
env::remove_var("BRAINTRUST_API_BASE");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -236,3 +236,66 @@ async fn test_real_error_handling() -> Result<()> {
|
|||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_real_get_prompt() -> Result<()> {
|
||||
// Initialize environment
|
||||
init_env()?;
|
||||
|
||||
// Skip test if no API key is available
|
||||
if env::var("BRAINTRUST_API_KEY").is_err() {
|
||||
println!("Skipping test_real_get_prompt: No API key available");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Create client (None means use env var)
|
||||
let client = BraintrustClient::new(None, "c7b996a6-1c7c-482d-b23f-3d39de16f433")?;
|
||||
|
||||
// Attempt to fetch the prompt with ID "7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee"
|
||||
let prompt_id = "7f6fbd7a-d03a-42e7-a115-b87f5e9f86ee";
|
||||
|
||||
println!("Fetching prompt with ID: {}", prompt_id);
|
||||
|
||||
match client.get_prompt(prompt_id).await {
|
||||
Ok(prompt) => {
|
||||
println!("Successfully fetched prompt: {}", prompt.name);
|
||||
println!("Prompt ID: {}", prompt.id);
|
||||
println!("Project ID: {}", prompt.project_id);
|
||||
|
||||
// Verify the prompt ID matches what we requested
|
||||
assert_eq!(prompt.id, prompt_id, "Prompt ID should match the requested ID");
|
||||
|
||||
if let Some(description) = &prompt.description {
|
||||
println!("Description: {}", description);
|
||||
}
|
||||
|
||||
if let Some(prompt_data) = &prompt.prompt_data {
|
||||
if let Some(content) = &prompt_data.prompt {
|
||||
println!("Prompt type: {}", content.content_type);
|
||||
println!("Prompt content: {:?}", content.content);
|
||||
println!("Prompt messages: {:?}", content.messages);
|
||||
}
|
||||
|
||||
if let Some(options) = &prompt_data.options {
|
||||
if let Some(model) = &options.model {
|
||||
println!("Model: {}", model);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(tags) = &prompt.tags {
|
||||
println!("Tags: {:?}", tags);
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
println!("Failed to fetch prompt '{}': {}", prompt_id, e);
|
||||
println!("This is expected if the prompt doesn't exist in your Braintrust project");
|
||||
println!("You can create a prompt with this ID in your Braintrust project for this test to pass");
|
||||
|
||||
// Fail the test if we can't fetch the prompt
|
||||
panic!("Could not fetch prompt with ID: {}", prompt_id);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -236,6 +236,7 @@ impl FromSql<sql_types::StoredValuesStatusEnum, Pg> for StoredValuesStatus {
|
|||
#[diesel(sql_type = sql_types::AssetTypeEnum)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum AssetType {
|
||||
#[serde(rename = "dashboard_deprecated")]
|
||||
Dashboard,
|
||||
Thread,
|
||||
Collection,
|
||||
|
|
|
@ -16,8 +16,7 @@ use crate::utils::user::user_info::get_user_organization_id;
|
|||
use database::enums::{AssetPermissionRole, AssetType, UserOrganizationRole};
|
||||
use database::pool::{get_pg_pool, PgPool};
|
||||
use database::schema::{
|
||||
asset_permissions, collections_to_assets, dashboards, metric_files, teams_to_users,
|
||||
threads_deprecated, threads_to_dashboards, users_to_organizations,
|
||||
asset_permissions, collections_to_assets, dashboard_files, dashboards, metric_files, teams_to_users, threads_deprecated, threads_to_dashboards, users_to_organizations
|
||||
};
|
||||
|
||||
pub async fn get_asset_access(
|
||||
|
@ -135,15 +134,28 @@ async fn get_asset_access_handler(
|
|||
.first::<(Uuid, bool, Option<DateTime<Utc>>)>(&mut conn)
|
||||
.await?;
|
||||
|
||||
let metric_info = (
|
||||
metric_info.0,
|
||||
metric_info.1,
|
||||
false,
|
||||
metric_info.2,
|
||||
);
|
||||
let metric_info = (metric_info.0, metric_info.1, false, metric_info.2);
|
||||
|
||||
(metric_info, Some(AssetPermissionRole::Owner))
|
||||
}
|
||||
AssetType::DashboardFile => {
|
||||
let mut conn = pg_pool.get().await?;
|
||||
|
||||
let dashboard_info = dashboard_files::table
|
||||
.select((
|
||||
dashboard_files::id,
|
||||
dashboard_files::publicly_accessible,
|
||||
dashboard_files::public_expiry_date,
|
||||
))
|
||||
.filter(dashboard_files::id.eq(&asset_id))
|
||||
.filter(dashboard_files::deleted_at.is_null())
|
||||
.first::<(Uuid, bool, Option<DateTime<Utc>>)>(&mut conn)
|
||||
.await?;
|
||||
|
||||
let dashboard_info = (dashboard_info.0, dashboard_info.1, false, dashboard_info.2);
|
||||
|
||||
(dashboard_info, Some(AssetPermissionRole::Owner))
|
||||
}
|
||||
_ => {
|
||||
return Err(anyhow!("Public access is not supported for chats yet"));
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue