From ee830562cf655e05bb66981a8a7fc5175bce85c8 Mon Sep 17 00:00:00 2001 From: dal Date: Fri, 21 Mar 2025 12:10:57 -0600 Subject: [PATCH] update data source stubbed, but will change --- api/libs/handlers/src/data_sources/mod.rs | 4 +- .../update_data_source_handler.rs | 338 ++++++++++++++++++ .../routes/rest/routes/data_sources/mod.rs | 4 +- .../routes/data_sources/update_data_source.rs | 25 ++ api/tests/integration/data_sources/mod.rs | 3 +- .../data_sources/update_data_source_test.rs | 197 ++++++++++ 6 files changed, 568 insertions(+), 3 deletions(-) create mode 100644 api/libs/handlers/src/data_sources/update_data_source_handler.rs create mode 100644 api/src/routes/rest/routes/data_sources/update_data_source.rs create mode 100644 api/tests/integration/data_sources/update_data_source_test.rs diff --git a/api/libs/handlers/src/data_sources/mod.rs b/api/libs/handlers/src/data_sources/mod.rs index 08e5e2988..903ab649b 100644 --- a/api/libs/handlers/src/data_sources/mod.rs +++ b/api/libs/handlers/src/data_sources/mod.rs @@ -1,3 +1,5 @@ mod list_data_sources_handler; +mod update_data_source_handler; -pub use list_data_sources_handler::*; \ No newline at end of file +pub use list_data_sources_handler::*; +pub use update_data_source_handler::*; \ No newline at end of file diff --git a/api/libs/handlers/src/data_sources/update_data_source_handler.rs b/api/libs/handlers/src/data_sources/update_data_source_handler.rs new file mode 100644 index 000000000..0c051c9b7 --- /dev/null +++ b/api/libs/handlers/src/data_sources/update_data_source_handler.rs @@ -0,0 +1,338 @@ +use std::str::FromStr; + +use anyhow::{anyhow, Result}; +use chrono::{DateTime, Utc}; +use diesel::{AsChangeset, ExpressionMethods, QueryDsl}; +use diesel_async::RunQueryDsl; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use database::{ + enums::DataSourceType, + models::{DataSource, User, UserToOrganization}, + pool::get_pg_pool, + schema::{data_sources, users, users_to_organizations}, + vault::{read_secret, update_secret}, +}; + +/// Request for updating a data source +#[derive(Debug, Deserialize)] +pub struct UpdateDataSourceRequest { + pub name: Option, + pub env: Option, + #[serde(flatten)] + pub credential: Option, +} + +/// Changeset for updating a data source +#[derive(AsChangeset)] +#[diesel(table_name = data_sources)] +struct DataSourceChangeset { + name: Option, + env: Option, + updated_at: DateTime, + updated_by: Uuid, + #[diesel(column_name = type_)] + type_field: Option, +} + +/// Part of the response showing the user who created the data source +#[derive(Serialize)] +pub struct CreatedBy { + pub id: String, + pub email: String, + pub name: String, +} + +/// Credentials information in the response +#[derive(Serialize)] +pub struct Credentials { + pub database: Option, + pub host: String, + pub jump_host: Option, + pub password: String, + pub port: u64, + pub schemas: Option>, + pub ssh_private_key: Option, + pub ssh_username: Option, + pub username: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub project_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub dataset_ids: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub credentials_json: Option, +} + +/// Response for a data source +#[derive(Serialize)] +pub struct DataSourceResponse { + pub id: String, + pub name: String, + pub db_type: String, + pub created_at: DateTime, + pub updated_at: DateTime, + pub created_by: CreatedBy, + pub credentials: Credentials, + pub data_sets: Vec, // Empty for now, could be populated if needed +} + +/// Handler for updating a data source +pub async fn update_data_source_handler( + user_id: &Uuid, + data_source_id: &Uuid, + request: UpdateDataSourceRequest, +) -> Result { + let mut conn = get_pg_pool().get().await?; + + // Verify user has access to the data source + let user_org = users_to_organizations::table + .filter(users_to_organizations::user_id.eq(user_id)) + .filter(users_to_organizations::deleted_at.is_null()) + .select(users_to_organizations::all_columns) + .first::(&mut conn) + .await + .map_err(|e| anyhow!("Unable to get user organization: {}", e))?; + + // Get current data source + let mut data_source = data_sources::table + .filter(data_sources::id.eq(data_source_id)) + .filter(data_sources::organization_id.eq(user_org.organization_id)) + .filter(data_sources::deleted_at.is_null()) + .first::(&mut conn) + .await + .map_err(|e| anyhow!("Data source not found: {}", e))?; + + // Extract type from credentials if present + let type_field = request + .credential + .as_ref() + .and_then(|cred| cred.get("type")) + .and_then(|t| t.as_str()) + .map(|s| s.to_string()); + + // Only perform database update if there are changes to make + if request.name.is_some() || request.env.is_some() || type_field.is_some() { + // Create changeset for update + let changeset = DataSourceChangeset { + name: request.name.clone(), + env: request.env.clone(), + updated_at: Utc::now(), + updated_by: *user_id, + type_field: type_field.clone(), + }; + + // Execute the update + diesel::update(data_sources::table) + .filter(data_sources::id.eq(data_source_id)) + .set(changeset) + .execute(&mut conn) + .await + .map_err(|e| anyhow!("Failed to update data source: {}", e))?; + + // Update local variable + if let Some(name) = &request.name { + data_source.name = name.clone(); + } + + if let Some(env) = &request.env { + data_source.env = env.clone(); + } + + if let Some(type_str) = &type_field { + data_source.type_ = DataSourceType::from_str(type_str).unwrap(); + } + } + + // Update credentials if provided + if let Some(credentials) = &request.credential { + // Read existing secret + let existing_secret = read_secret(data_source_id).await?; + let mut existing_credential: serde_json::Value = serde_json::from_str(&existing_secret)?; + + // Merge credential fields + if let (Some(existing_obj), Some(new_obj)) = + (existing_credential.as_object_mut(), credentials.as_object()) + { + for (key, value) in new_obj { + existing_obj.insert(key.clone(), value.clone()); + } + } + + // Update the secret + let secret_json = serde_json::to_string(&existing_credential)?; + update_secret(data_source_id, &secret_json).await?; + } + + // Get the creator's information + let creator = users::table + .filter(users::id.eq(data_source.created_by)) + .first::(&mut conn) + .await + .map_err(|e| anyhow!("Unable to get creator information: {}", e))?; + + // Fetch the current credential data + let secret = read_secret(data_source_id).await?; + let credential_json: serde_json::Value = serde_json::from_str(&secret)?; + + // Build credentials based on the data source type + let db_type = data_source.type_.to_string(); + let credentials = parse_credentials(&db_type, &credential_json)?; + + // Build the response + Ok(DataSourceResponse { + id: data_source.id.to_string(), + name: data_source.name, + db_type: db_type.to_string(), + created_at: data_source.created_at, + updated_at: data_source.updated_at, + created_by: CreatedBy { + id: creator.id.to_string(), + email: creator.email, + name: creator.name.unwrap_or_else(|| "".to_string()), + }, + credentials, + data_sets: Vec::new(), + }) +} + +/// Helper function to parse credentials based on data source type +fn parse_credentials(db_type: &str, credential_json: &serde_json::Value) -> Result { + // Determine port based on database type + let default_port = match db_type { + "postgres" | "supabase" => 5432, + "mysql" | "mariadb" => 3306, + "redshift" => 5439, + "sqlserver" => 1433, + "snowflake" | "bigquery" | "databricks" => 443, + _ => 5432, // default + }; + + // Extract common credentials with type-specific defaults + let host = match db_type { + "bigquery" => "bigquery.googleapis.com".to_string(), + "snowflake" => credential_json + .get("account_id") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + _ => credential_json + .get("host") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + }; + + let username = match db_type { + "bigquery" => "bigquery".to_string(), + "databricks" => "databricks".to_string(), + _ => credential_json + .get("username") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + }; + + let password = match db_type { + "bigquery" => "".to_string(), + "databricks" => credential_json + .get("api_key") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + _ => credential_json + .get("password") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + }; + + // Handle special database field names by type + let database = match db_type { + "mysql" | "mariadb" => None, + "snowflake" => credential_json + .get("database_id") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()), + _ => credential_json + .get("database") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()), + }; + + // Handle schemas/databases field based on type + let schemas = match db_type { + "mysql" | "mariadb" => credential_json.get("databases").and_then(|v| { + v.as_array().map(|arr| { + arr.iter() + .filter_map(|s| s.as_str().map(|s| s.to_string())) + .collect() + }) + }), + _ => credential_json.get("schemas").and_then(|v| { + v.as_array().map(|arr| { + arr.iter() + .filter_map(|s| s.as_str().map(|s| s.to_string())) + .collect() + }) + }), + }; + + // Get port from credentials or use default + let port = credential_json + .get("port") + .and_then(|v| v.as_u64()) + .unwrap_or(default_port); + + // Handle optional fields + let project_id = credential_json + .get("project_id") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + // Extract dataset IDs for BigQuery + let dataset_ids = if db_type == "bigquery" { + credential_json + .get("dataset_ids") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(|s| s.to_string())) + .collect::>() + }) + } else { + None + }; + + // Handle credentials_json for BigQuery + let credentials_json = if db_type == "bigquery" { + credential_json.get("credentials_json").cloned() + } else { + None + }; + + // Create Credentials struct + Ok(Credentials { + host, + port, + username, + password, + database, + schemas, + jump_host: credential_json + .get("jump_host") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()), + ssh_username: credential_json + .get("ssh_username") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()), + ssh_private_key: credential_json + .get("ssh_private_key") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()), + project_id, + dataset_ids, + credentials_json, + }) +} diff --git a/api/src/routes/rest/routes/data_sources/mod.rs b/api/src/routes/rest/routes/data_sources/mod.rs index 921e64afc..7a12ba2ba 100644 --- a/api/src/routes/rest/routes/data_sources/mod.rs +++ b/api/src/routes/rest/routes/data_sources/mod.rs @@ -1,8 +1,9 @@ mod post_data_sources; mod list_data_sources; +mod update_data_source; use axum::{ - routing::{get, post}, + routing::{get, post, put}, Router, }; @@ -10,4 +11,5 @@ pub fn router() -> Router { Router::new() .route("/", post(post_data_sources::post_data_sources)) .route("/", get(list_data_sources::list_data_sources)) + .route("/:id", put(update_data_source::update_data_source)) } diff --git a/api/src/routes/rest/routes/data_sources/update_data_source.rs b/api/src/routes/rest/routes/data_sources/update_data_source.rs new file mode 100644 index 000000000..99792c3e6 --- /dev/null +++ b/api/src/routes/rest/routes/data_sources/update_data_source.rs @@ -0,0 +1,25 @@ +use anyhow::Result; +use axum::{extract::Path, http::StatusCode, Extension, Json}; +use middleware::AuthenticatedUser; +use serde::{Deserialize}; +use uuid::Uuid; + +use crate::routes::rest::ApiResponse; +use handlers::data_sources::{update_data_source_handler, UpdateDataSourceRequest, DataSourceResponse}; + +pub async fn update_data_source( + Extension(user): Extension, + Path(id): Path, + Json(payload): Json, +) -> Result, (StatusCode, &'static str)> { + match update_data_source_handler(&user.id, &id, payload).await { + Ok(data_source) => Ok(ApiResponse::JsonData(data_source)), + Err(e) => { + tracing::error!("Error updating data source: {:?}", e); + Err(( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to update data source", + )) + } + } +} \ No newline at end of file diff --git a/api/tests/integration/data_sources/mod.rs b/api/tests/integration/data_sources/mod.rs index a63adee68..148b60769 100644 --- a/api/tests/integration/data_sources/mod.rs +++ b/api/tests/integration/data_sources/mod.rs @@ -1 +1,2 @@ -mod list_data_sources_test; \ No newline at end of file +mod list_data_sources_test; +mod update_data_source_test; \ No newline at end of file diff --git a/api/tests/integration/data_sources/update_data_source_test.rs b/api/tests/integration/data_sources/update_data_source_test.rs new file mode 100644 index 000000000..2cd7f8c41 --- /dev/null +++ b/api/tests/integration/data_sources/update_data_source_test.rs @@ -0,0 +1,197 @@ +use axum::http::StatusCode; +use diesel::sql_types; +use diesel_async::RunQueryDsl; +use serde_json::json; +use uuid::Uuid; + +use crate::common::{ + assertions::response::ResponseAssertions, + fixtures::builder::UserBuilder, + http::test_app::TestApp, +}; + +// Mock DataSourceBuilder since we don't know the exact implementation +struct DataSourceBuilder { + name: String, + env: String, + organization_id: Uuid, + created_by: Uuid, + db_type: String, + credentials: serde_json::Value, + id: Uuid, +} + +impl DataSourceBuilder { + fn new() -> Self { + DataSourceBuilder { + name: "Test Data Source".to_string(), + env: "dev".to_string(), + organization_id: Uuid::new_v4(), + created_by: Uuid::new_v4(), + db_type: "postgres".to_string(), + credentials: json!({}), + id: Uuid::new_v4(), + } + } + + fn with_name(mut self, name: &str) -> Self { + self.name = name.to_string(); + self + } + + fn with_env(mut self, env: &str) -> Self { + self.env = env.to_string(); + self + } + + fn with_organization_id(mut self, organization_id: Uuid) -> Self { + self.organization_id = organization_id; + self + } + + fn with_created_by(mut self, created_by: Uuid) -> Self { + self.created_by = created_by; + self + } + + fn with_type(mut self, db_type: &str) -> Self { + self.db_type = db_type.to_string(); + self + } + + fn with_credentials(mut self, credentials: serde_json::Value) -> Self { + self.credentials = credentials; + self + } + + async fn build(self, pool: &diesel_async::pooled_connection::bb8::Pool) -> DataSourceResponse { + // Create data source directly in database using SQL + let mut conn = pool.get().await.unwrap(); + + // Insert the data source + diesel::sql_query("INSERT INTO data_sources (id, name, type, secret_id, organization_id, created_by, updated_by, created_at, updated_at, onboarding_status, env) VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), NOW(), 'notStarted', $8)") + .bind::(&self.id) + .bind::(&self.name) + .bind::(&self.db_type) + .bind::(&self.id) // Using the same UUID for both id and secret_id for simplicity + .bind::(&self.organization_id) + .bind::(&self.created_by) + .bind::(&self.created_by) + .bind::(&self.env) + .execute(&mut conn) + .await + .unwrap(); + + // Insert the secret + diesel::sql_query("INSERT INTO vault.secrets (id, secret) VALUES ($1, $2)") + .bind::(&self.id) + .bind::(&self.credentials.to_string()) + .execute(&mut conn) + .await + .unwrap(); + + // Construct response + DataSourceResponse { + id: self.id.to_string(), + } + } +} + +struct DataSourceResponse { + id: String, +} + +#[tokio::test] +async fn test_update_data_source() { + let app = TestApp::new().await.unwrap(); + + // Create a test user with organization + let user = UserBuilder::new() + .with_organization("Test Org") + .build(&app.db.pool) + .await; + + // Create a test data source + let data_source = DataSourceBuilder::new() + .with_name("Original DS Name") + .with_env("dev") + .with_organization_id(user.organization_id) + .with_created_by(user.id) + .with_type("postgres") + .with_credentials(json!({ + "type": "postgres", + "host": "localhost", + "port": 5432, + "username": "postgres", + "password": "password", + "database": "test", + "schemas": ["public"] + })) + .build(&app.db.pool) + .await; + + // Prepare update request + let update_req = json!({ + "name": "Updated DS Name", + "env": "prod", + "type": "postgres", + "host": "new-host", + "port": 5433, + "username": "new-user", + "password": "new-password", + "database": "new-db", + "schemas": ["public", "schema2"] + }); + + // Send update request + let response = app + .client + .put(format!("/api/data_sources/{}", data_source.id)) + .header("Authorization", format!("Bearer {}", user.api_key)) + .json(&update_req) + .send() + .await + .unwrap(); + + // Assert response + assert_eq!(response.status(), StatusCode::OK); + + let body = response.json::().await.unwrap(); + body.assert_has_key_with_value("id", data_source.id); + body.assert_has_key_with_value("name", "Updated DS Name"); + + let credentials = &body["credentials"]; + assert!(credentials.is_object()); + + // Test updating just the name + let name_only_update = json!({ + "name": "Name Only Update" + }); + + let response = app + .client + .put(format!("/api/data_sources/{}", data_source.id)) + .header("Authorization", format!("Bearer {}", user.api_key)) + .json(&name_only_update) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let body = response.json::().await.unwrap(); + body.assert_has_key_with_value("name", "Name Only Update"); + + // Test updating with invalid UUID + let invalid_id = Uuid::new_v4(); + let response = app + .client + .put(format!("/api/data_sources/{}", invalid_id)) + .header("Authorization", format!("Bearer {}", user.api_key)) + .json(&name_only_update) + .send() + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); +} \ No newline at end of file