Add dataset_security dependency and implement permission checks for dataset access

- Added `dataset_security` as a dependency in `Cargo.toml`.
- Enhanced SQL validation in `common.rs` to check user permissions before executing queries.
- Updated metric file processing functions to include user ID for permission validation.
- Modified dataset retrieval in `search_data_catalog.rs` to return permissioned datasets based on user access.
- Updated planning tools to include guidelines for modifying visualizations in bulk.
This commit is contained in:
dal 2025-04-18 10:36:57 -06:00
parent 53b967ca4a
commit 115d525d96
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
10 changed files with 202 additions and 111 deletions

View File

@ -27,6 +27,7 @@ once_cell = { workspace = true }
regex = "1"
glob = "0.3"
cohere-rust = { workspace = true }
dataset_security = { path = "../dataset_security" }
# Development dependencies
[dev-dependencies]

View File

@ -2,6 +2,8 @@ use anyhow::Result;
use chrono::Local;
use database::helpers::datasets::get_dataset_names_for_organization;
use database::organization::get_user_organization_id;
use database::pool::get_pg_pool;
use dataset_security::get_permissioned_datasets;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
@ -95,15 +97,16 @@ impl AgentExt for BusterMultiAgent {
impl BusterMultiAgent {
pub async fn new(user_id: Uuid, session_id: Uuid, is_follow_up: bool) -> Result<Self> {
let organization_id = match get_user_organization_id(&user_id).await {
Ok(Some(org_id)) => org_id,
Ok(None) => return Err(anyhow::anyhow!("User does not belong to any organization")),
Err(e) => return Err(e),
};
// Prepare data for modes
let todays_date = Arc::new(Local::now().format("%Y-%m-%d").to_string());
let dataset_names = Arc::new(get_dataset_names_for_organization(organization_id).await?);
// Get permissioned datasets and extract names
let permissioned_datasets = get_permissioned_datasets(&user_id, 0, 10000).await?;
let dataset_names: Vec<String> = permissioned_datasets
.into_iter()
.map(|ds| ds.name)
.collect();
let dataset_names = Arc::new(dataset_names);
let agent_data = ModeAgentData {
dataset_names,
@ -119,8 +122,8 @@ impl BusterMultiAgent {
user_id,
session_id,
"buster_multi_agent".to_string(),
None, // api_key
None, // base_url
None, // api_key
None, // base_url
mode_provider, // Pass the provider
));
@ -129,16 +132,13 @@ impl BusterMultiAgent {
.set_state_value("is_follow_up".to_string(), Value::Bool(is_follow_up))
.await;
let buster_agent = Self {
agent,
};
let buster_agent = Self { agent };
Ok(buster_agent)
}
pub async fn run(
self: &Arc<Self>,
self: &Arc<Self>,
thread: &mut AgentThread,
) -> Result<broadcast::Receiver<Result<AgentMessage, AgentError>>> {
if let Some(user_prompt) = self.get_latest_user_message(thread) {

View File

@ -320,6 +320,7 @@ By following these guidelines, you can ensure that the visualizations you create
- If the user asks for something that hasn't been created yetlike a different chart or a metric you haven't made yet create a new metric.
- If the user wants to change something you've already built like switching a chart from monthly to weekly data or adding a filter just update the existing metric, don't create a new one.
- **Grouping Modifications**: Just like creating multiple new visualizations is done in a single bulk step, if the user asks to modify multiple existing visualizations in one request, group all these modifications under a single "**Modify existing visualization(s)**" step in the plan.
### Responses With the `finish_and_respond` Tool

View File

@ -1,4 +1,4 @@
use anyhow::{anyhow, Result};
use anyhow::{anyhow, bail, Result};
use chrono::Utc;
use database::{
enums::Verification,
@ -22,6 +22,9 @@ use serde::{Deserialize, Serialize};
use super::file_types::file::FileWithId;
// Import dataset_security for permission check
use dataset_security::has_dataset_access;
// Import the types needed for the modification function
/// Validates SQL query using existing query engine by attempting to run it
@ -29,6 +32,7 @@ use super::file_types::file::FileWithId;
pub async fn validate_sql(
sql: &str,
dataset_id: &Uuid,
user_id: &Uuid,
) -> Result<(
String,
Vec<IndexMap<String, DataType>>,
@ -40,6 +44,15 @@ pub async fn validate_sql(
return Err(anyhow!("SQL query cannot be empty"));
}
// Check dataset access before proceeding
if !has_dataset_access(user_id, dataset_id).await? {
bail!(
"Permission denied: User {} does not have access to dataset {}",
user_id,
dataset_id
);
}
let mut conn = get_pg_pool().get().await?;
let data_source_id = match datasets::table
@ -738,7 +751,7 @@ pub async fn process_metric_file(
let dataset_id = dataset_ids[0];
// Validate SQL with the selected dataset_id and get results
let (message, results, metadata) = match validate_sql(&metric_yml.sql, &dataset_id).await {
let (message, results, metadata) = match validate_sql(&metric_yml.sql, &dataset_id, user_id).await {
Ok(results) => results,
Err(e) => return Err(format!("Invalid SQL query: {}", e)),
};
@ -860,7 +873,7 @@ pub async fn process_metric_file_modification(
}
let dataset_id = new_yml.dataset_ids[0];
match validate_sql(&new_yml.sql, &dataset_id).await {
match validate_sql(&new_yml.sql, &dataset_id, &file.created_by).await {
Ok((message, validation_results, _metadata)) => {
// Update file record
file.content = new_yml.clone();
@ -1274,7 +1287,7 @@ mod tests {
#[tokio::test]
async fn test_validate_sql_empty() {
let dataset_id = Uuid::new_v4();
let result = validate_sql("", &dataset_id).await;
let result = validate_sql("", &dataset_id, &Uuid::new_v4()).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("cannot be empty"));
}

View File

@ -71,6 +71,7 @@ async fn process_metric_file_update(
mut file: MetricFile,
yml_content: String,
duration: i64,
user_id: &Uuid,
) -> Result<(
MetricFile,
MetricYml,
@ -164,8 +165,8 @@ async fn process_metric_file_update(
"Metadata missing, performing validation"
);
}
match validate_sql(&new_yml.sql, &dataset_id).await {
match validate_sql(&new_yml.sql, &dataset_id, user_id).await {
Ok((message, validation_results, metadata)) => {
// Update file record
file.content = new_yml.clone();
@ -295,6 +296,7 @@ impl ToolExecutor for ModifyMetricFilesTool {
file.clone(),
file_update.yml_content.clone(),
start_time_elapsed,
&self.agent.get_user_id(),
).await;
match result {

View File

@ -18,6 +18,7 @@ use serde::{Deserialize, Serialize};
use serde_json::Value;
use tracing::{debug, error, info, warn};
use uuid::Uuid;
use dataset_security::{get_permissioned_datasets, PermissionedDataset};
use crate::{agent::Agent, tools::ToolExecutor};
@ -41,7 +42,7 @@ pub struct DatasetSearchResult {
pub yml_content: Option<String>,
}
#[derive(Debug, Serialize, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Hash)]
struct DatasetResult {
id: Uuid,
name: Option<String>,
@ -50,7 +51,7 @@ struct DatasetResult {
#[derive(Debug, Clone)]
struct RankedDataset {
dataset: Dataset,
dataset: PermissionedDataset,
relevance_score: f64,
}
@ -112,35 +113,27 @@ impl SearchDataCatalogTool {
true
}
async fn get_datasets() -> Result<Vec<Dataset>> {
debug!("Fetching datasets for agent tool");
let mut conn = get_pg_pool().get().await?;
let datasets_result = datasets::table
.select((
datasets::id,
datasets::name,
datasets::yml_file,
datasets::created_at,
datasets::updated_at,
datasets::deleted_at,
))
.filter(datasets::deleted_at.is_null())
.filter(datasets::yml_file.is_not_null())
.load::<Dataset>(&mut conn)
.await;
async fn get_datasets(user_id: &Uuid) -> Result<Vec<PermissionedDataset>> {
debug!("Fetching permissioned datasets for agent tool for user {}", user_id);
let datasets_result = get_permissioned_datasets(user_id, 0, 10000).await;
match datasets_result {
Ok(datasets) => {
let filtered_datasets: Vec<PermissionedDataset> = datasets
.into_iter()
.filter(|d| d.yml_content.is_some())
.collect();
debug!(
count = datasets.len(),
"Successfully loaded datasets for agent tool"
count = filtered_datasets.len(),
user_id = %user_id,
"Successfully loaded and filtered permissioned datasets for agent tool"
);
Ok(datasets)
Ok(filtered_datasets)
}
Err(e) => {
error!("Failed to load datasets for agent tool: {}", e);
Err(anyhow::anyhow!("Database error fetching datasets: {}", e))
error!(user_id = %user_id, "Failed to load permissioned datasets for agent tool: {}", e);
Err(anyhow::anyhow!("Error fetching permissioned datasets: {}", e))
}
}
}
@ -166,10 +159,10 @@ impl ToolExecutor for SearchDataCatalogTool {
});
}
let all_datasets = match Self::get_datasets().await {
let all_datasets = match Self::get_datasets(&user_id).await {
Ok(datasets) => datasets,
Err(e) => {
error!("Failed to retrieve datasets for tool execution: {}", e);
error!(user_id=%user_id, "Failed to retrieve permissioned datasets for tool execution: {}", e);
return Ok(SearchDataCatalogOutput {
message: format!("Error fetching datasets: {}", e),
queries: params.queries,
@ -344,7 +337,7 @@ async fn get_search_data_catalog_description() -> String {
async fn rerank_datasets(
query: &str,
all_datasets: &[Dataset],
all_datasets: &[PermissionedDataset],
documents: &[String],
) -> Result<Vec<RankedDataset>, anyhow::Error> {
if documents.is_empty() || all_datasets.is_empty() {
@ -485,7 +478,7 @@ async fn filter_datasets_with_llm(
}
};
let dataset_map: HashMap<Uuid, &Dataset> = ranked_datasets
let dataset_map: HashMap<Uuid, &PermissionedDataset> = ranked_datasets
.iter()
.map(|ranked| (ranked.dataset.id, &ranked.dataset))
.collect();
@ -530,19 +523,3 @@ async fn filter_datasets_with_llm(
);
Ok(filtered_datasets)
}
#[derive(Queryable, Selectable, Clone, Debug)]
#[diesel(table_name = datasets)]
#[diesel(check_for_backend(diesel::pg::Pg))]
struct Dataset {
id: Uuid,
name: String,
#[diesel(column_name = "yml_file")]
yml_content: Option<String>,
#[allow(dead_code)]
created_at: DateTime<Utc>,
#[allow(dead_code)]
updated_at: DateTime<Utc>,
#[allow(dead_code)]
deleted_at: Option<DateTime<Utc>>,
}

View File

@ -166,6 +166,7 @@ Add any assumptions, limitations, or clarifications about the analysis and findi
- **For Multi-Line Charts**: Explicitly state it's a `multi-line chart`. Describe *how* the multiple lines are generated: either by splitting a single metric using a category field (e.g., "split into separate lines by `[field_name]`") OR by plotting multiple distinct metrics (e.g., "plotting separate lines for `[metric1]` and `[metric2]`").
- **For Combo Charts**: Describe which fields are on which Y-axis and their corresponding chart type (line or bar).
- **Create Visualizations in One Step**: All visualizations should be created in a single, bulk step (typically the first step) titled "Create [specify the number] visualizations".
- **Modify Visualizations in One Step**: Similarly, if the user requests modifications to multiple existing visualizations in a single turn, group all these modifications under one "**Modify existing visualization(s)**" step.
- **Review**: Always include a review step to ensure accuracy and relevance.
- **Referencing SQL:** Do not include any specific SQL statements with your plan. The details of the SQL statement will be decided during the workflow. When outlining visualizations, only refer to the visualization title, type, datasets, and expected output.
- **Use Names instead of IDs**: When visualizations or tables include things like people, customers, vendors, products, categories, etc, you should display names instead of IDs (if names are included in the available datasets). IDs are not meaningful to users. For people, you should combine first and last names if they are available. State this clearly in the `Expected Output` (e.g., "...split into separate lines by sales rep full names").

View File

@ -162,6 +162,7 @@ Add context like assumptions, limitations, or acknowledge unsupported aspects of
- **For Multi-Line Charts**: Explicitly state it's a `multi-line chart`. Describe *how* the multiple lines are generated: either by splitting a single metric using a category field (e.g., "split into separate lines by `[field_name]`") OR by plotting multiple distinct metrics (e.g., "plotting separate lines for `[metric1]` and `[metric2]`").
- **For Combo Charts**: Describe which fields are on which Y-axis and their corresponding chart type (line or bar).
- **Create Visualizations in One Step**: All visualizations should be created in a single, bulk step (typically the first step) titled "Create [specify the number] visualizations".
- **Modify Visualizations in One Step**: Similarly, if the user requests modifications to multiple existing visualizations in a single turn, group all these modifications under one "**Modify existing visualization(s)**" step.
- **Broad Requests**: For broad or summary requests (e.g., "summarize assembly line performance", "show me important stuff", "how is the sales team doing?"), you must create at least 8 visualizations to ensure a comprehensive overview. Creating fewer than five visualizations is inadequate for such requests. Aim for 8-12 visualizations to cover various aspects of the data, such as sales trends, order metrics, customer behavior, or product performance, depending on the available datasets. Include lots of trends (time-series data), groupings, segments, etc. This ensures the user receives a thorough view of the requested information.
- **Review**: Always include a review step to ensure accuracy and relevance.
- **Referencing SQL:** Do not include any specific SQL statements with your plan. The details of the SQL statement will be decided during the workflow. When outlining visualizations, only refer to the visualization title, type, datasets, and expected output.

View File

@ -10,11 +10,14 @@ diesel = { workspace = true }
diesel-async = { workspace = true }
uuid = { workspace = true }
tracing = { workspace = true }
serde = { workspace = true }
tokio = { workspace = true }
chrono = { workspace = true }
# Internal workspace dependencies
database = { path = "../database" }
# Development dependencies
[dev-dependencies]
tokio = { workspace = true }
dotenv = { workspace = true }
# Add other workspace dev dependencies as needed

View File

@ -1,69 +1,134 @@
//! Library for handling dataset security and permissions.
use anyhow::{anyhow, Result};
use diesel::{BoolExpressionMethods, ExpressionMethods, JoinOnDsl, QueryDsl};
use chrono::{DateTime, Utc};
use database::enums::UserOrganizationRole;
use diesel::prelude::Queryable;
use diesel::{
BoolExpressionMethods, ExpressionMethods, JoinOnDsl, QueryDsl, Selectable, SelectableHelper,
};
use diesel_async::RunQueryDsl;
use uuid::Uuid;
use database::{
pool::{get_pg_pool, PgPool},
models::Dataset,
schema::{
datasets,
datasets_to_permission_groups,
permission_groups,
permission_groups_to_identities,
teams_to_users,
datasets, datasets_to_permission_groups, permission_groups,
permission_groups_to_identities, teams_to_users, users_to_organizations,
},
};
// Define the new struct mirroring the one in search_data_catalog.rs
#[derive(Queryable, Selectable, Clone, Debug)]
#[diesel(table_name = datasets)]
#[diesel(check_for_backend(diesel::pg::Pg))]
pub struct PermissionedDataset {
pub id: Uuid,
pub name: String,
#[diesel(column_name = "yml_file")]
pub yml_content: Option<String>, // Matches the local struct field name
#[allow(dead_code)]
pub created_at: DateTime<Utc>,
#[allow(dead_code)]
pub updated_at: DateTime<Utc>,
#[allow(dead_code)]
pub deleted_at: Option<DateTime<Utc>>,
}
pub async fn get_permissioned_datasets(
pool: &PgPool,
user_id: &Uuid,
page: i64,
page_size: i64,
) -> Result<Vec<Dataset>> {
let mut conn = match pool.get().await {
) -> Result<Vec<PermissionedDataset>> {
let mut conn = match get_pg_pool().get().await {
Ok(conn) => conn,
Err(e) => return Err(anyhow!("Unable to get connection from pool: {}", e)),
};
// TODO: Add logic to check if user is admin, if so, return all datasets
// Fetch user's organization and role
let user_org_info = users_to_organizations::table
.filter(users_to_organizations::user_id.eq(user_id))
.select((
users_to_organizations::organization_id,
users_to_organizations::role,
))
.first::<(Uuid, UserOrganizationRole)>(&mut conn)
.await;
let datasets = match datasets::table
.select(datasets::all_columns)
.inner_join(
datasets_to_permission_groups::table
.on(datasets::id.eq(datasets_to_permission_groups::dataset_id)),
)
.inner_join(
permission_groups::table
.on(datasets_to_permission_groups::permission_group_id.eq(permission_groups::id)),
)
.inner_join(
permission_groups_to_identities::table
.on(permission_groups::id.eq(permission_groups_to_identities::permission_group_id)),
)
.inner_join(
teams_to_users::table
.on(teams_to_users::team_id.eq(permission_groups_to_identities::identity_id)),
)
.filter(
teams_to_users::user_id
.eq(&user_id)
.or(permission_groups_to_identities::identity_id.eq(&user_id)),
)
.filter(datasets::deleted_at.is_null())
.limit(page_size)
.offset(page * page_size)
.load::<Dataset>(&mut conn)
.await
{
Ok(datasets) => datasets,
Err(e) => return Err(anyhow!("Unable to get team datasets from database: {}", e)),
let datasets_query = match user_org_info {
Ok((organization_id, role)) => {
// Check if user has admin/querier role
if matches!(
role,
UserOrganizationRole::WorkspaceAdmin
| UserOrganizationRole::DataAdmin
| UserOrganizationRole::Querier
) {
// User is admin/querier, return all org datasets
datasets::table
.filter(datasets::organization_id.eq(organization_id))
.filter(datasets::deleted_at.is_null())
.select(PermissionedDataset::as_select())
.limit(page_size)
.offset(page * page_size)
.load::<PermissionedDataset>(&mut conn)
.await
} else {
// User is not admin/querier, use permission group logic
datasets::table
.select(PermissionedDataset::as_select())
.inner_join(
datasets_to_permission_groups::table
.on(datasets::id.eq(datasets_to_permission_groups::dataset_id)),
)
.inner_join(
permission_groups::table
.on(datasets_to_permission_groups::permission_group_id
.eq(permission_groups::id)),
)
.inner_join(
permission_groups_to_identities::table.on(permission_groups::id
.eq(permission_groups_to_identities::permission_group_id)),
)
.inner_join(teams_to_users::table.on(
teams_to_users::team_id.eq(permission_groups_to_identities::identity_id),
))
.filter(
teams_to_users::user_id
.eq(user_id)
.or(permission_groups_to_identities::identity_id.eq(user_id)),
)
.filter(datasets::deleted_at.is_null())
// Ensure related permission records are not deleted (important for non-admins)
.filter(
datasets_to_permission_groups::deleted_at
.is_null()
.and(permission_groups::deleted_at.is_null())
.and(permission_groups_to_identities::deleted_at.is_null())
.and(teams_to_users::deleted_at.is_null()),
)
.distinct() // Ensure unique datasets if multiple paths grant access
.limit(page_size)
.offset(page * page_size)
.load::<PermissionedDataset>(&mut conn)
.await
}
}
Err(diesel::NotFound) => {
// User not found in any organization, return empty vec or error?
// Returning empty for now, indicating no datasets accessible.
Ok(Vec::new())
}
Err(e) => {
// Other database error fetching user role
return Err(anyhow!("Error fetching user organization role: {}", e));
}
};
Ok(datasets)
match datasets_query {
Ok(datasets) => Ok(datasets),
Err(e) => Err(anyhow!("Unable to get datasets from database: {}", e)),
}
}
pub async fn has_dataset_access(user_id: &Uuid, dataset_id: &Uuid) -> Result<bool> {
@ -72,7 +137,34 @@ pub async fn has_dataset_access(user_id: &Uuid, dataset_id: &Uuid) -> Result<boo
Err(e) => return Err(anyhow!("Unable to get connection from pool: {}", e)),
};
let has_dataset_access = match datasets::table
// First, check if the user is an admin/querier for the dataset's organization
let admin_access = match datasets::table
.filter(datasets::id.eq(dataset_id))
.inner_join(
users_to_organizations::table
.on(datasets::organization_id.eq(users_to_organizations::organization_id)),
)
.filter(users_to_organizations::user_id.eq(user_id))
.select(users_to_organizations::role)
.first::<UserOrganizationRole>(&mut conn)
.await
{
Ok(role) => matches!(
role,
UserOrganizationRole::WorkspaceAdmin
| UserOrganizationRole::DataAdmin
| UserOrganizationRole::Querier
),
Err(diesel::NotFound) => false, // User not in the dataset's organization or dataset doesn't exist
Err(e) => return Err(anyhow!("Error checking admin access for dataset: {}", e)),
};
if admin_access {
return Ok(true);
}
// If not admin, check permission group access (existing logic)
let group_access = match datasets::table
.select(datasets::id)
.inner_join(
datasets_to_permission_groups::table
@ -112,5 +204,5 @@ pub async fn has_dataset_access(user_id: &Uuid, dataset_id: &Uuid) -> Result<boo
Err(e) => return Err(anyhow!("Unable to get team datasets from database: {}", e)),
};
Ok(has_dataset_access)
}
Ok(group_access)
}