From 4e2246e91df38f1e6dacfc73ed8d635e621d3abf Mon Sep 17 00:00:00 2001 From: dal Date: Wed, 9 Apr 2025 16:48:40 -0600 Subject: [PATCH] auth improvements for cli --- cli/Cargo.toml | 1 + cli/cli/Cargo.toml | 2 + cli/cli/src/commands/auth.rs | 356 +++++++++++++++++++++++++++++++---- cli/cli/src/main.rs | 13 +- 4 files changed, 334 insertions(+), 38 deletions(-) diff --git a/cli/Cargo.toml b/cli/Cargo.toml index 1969f9f1d..898f022c8 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -14,6 +14,7 @@ panic = "abort" [workspace.dependencies] anyhow = "1.0.79" +async-trait = "0.1.80" clap = { version = "4.4.18", features = ["derive", "env"] } confy = "0.6.0" dirs = "6.0.0" diff --git a/cli/cli/Cargo.toml b/cli/cli/Cargo.toml index 204cf5c30..574900334 100644 --- a/cli/cli/Cargo.toml +++ b/cli/cli/Cargo.toml @@ -12,6 +12,7 @@ path = "src/lib.rs" [dependencies] anyhow = { workspace = true } +async-trait = { workspace = true } clap = { workspace = true, features = ["derive", "env"] } confy = { workspace = true } dirs = { workspace = true } @@ -47,3 +48,4 @@ chrono = { workspace = true } [dev-dependencies] tempfile = "3.16.0" +mockall = "0.13.1" diff --git a/cli/cli/src/commands/auth.rs b/cli/cli/src/commands/auth.rs index 64c3a36eb..49d231549 100644 --- a/cli/cli/src/commands/auth.rs +++ b/cli/cli/src/commands/auth.rs @@ -1,6 +1,7 @@ use anyhow::{Context, Result}; +use async_trait::async_trait; use clap::Parser; -use inquire::{Password, Text}; +use inquire::{Confirm, Password, Text}; use thiserror::Error; use crate::utils::{ @@ -8,20 +9,16 @@ use crate::utils::{ file::buster_credentials::{get_buster_credentials, set_buster_credentials, BusterCredentials}, }; -const DEFAULT_HOST: &str = "https://api2.buster.so"; +const DEFAULT_HOST: &str = "https://api.buster.so"; #[derive(Error, Debug)] pub enum AuthError { - #[error("URL is required")] - MissingUrl, #[error("API key is required")] MissingApiKey, #[error("Invalid API key")] InvalidApiKey, #[error("Failed to validate credentials: {0}")] ValidationError(String), - #[error("Failed to save credentials: {0}")] - StorageError(String), } #[derive(Parser, Debug)] @@ -40,21 +37,99 @@ pub struct AuthArgs { pub no_save: bool, } -async fn validate_credentials(url: &str, api_key: &str) -> Result<(), AuthError> { - let buster_client = BusterClient::new(url.to_string(), api_key.to_string()) - .map_err(|e| AuthError::ValidationError(e.to_string()))?; - - if !buster_client.validate_api_key().await - .map_err(|e| AuthError::ValidationError(e.to_string()))? { - return Err(AuthError::InvalidApiKey); - } - - Ok(()) +// --- Credentials Validation Trait --- +#[cfg_attr(test, mockall::automock)] +#[async_trait] +pub trait CredentialsValidator { + async fn validate(&self, url: &str, api_key: &str) -> Result<(), AuthError>; } -pub async fn auth() -> Result<()> { - let args = AuthArgs::parse(); - auth_with_args(args).await +pub struct RealCredentialsValidator; + +#[async_trait] +impl CredentialsValidator for RealCredentialsValidator { + async fn validate(&self, url: &str, api_key: &str) -> Result<(), AuthError> { + let buster_client = BusterClient::new(url.to_string(), api_key.to_string()) + .map_err(|e| AuthError::ValidationError(e.to_string()))?; + + if !buster_client.validate_api_key().await + .map_err(|e| AuthError::ValidationError(e.to_string()))? { + return Err(AuthError::InvalidApiKey); + } + + Ok(()) + } +} + +/// Inner logic for checking authentication, testable without file system access. +async fn check_authentication_inner( + cached_credentials_result: Result, + validator: &dyn CredentialsValidator, +) -> Result<()> { + let host_env = std::env::var("BUSTER_HOST"); + let api_key_env = std::env::var("BUSTER_API_KEY"); + + let credentials = match cached_credentials_result { + Ok(mut creds) => { + // Override with env vars if they exist + if let Ok(host) = host_env { + creds.url = host; + } + if let Ok(api_key) = api_key_env { + creds.api_key = api_key; + } + // Use default host if still empty after checking cache and env var + if creds.url.is_empty() { + creds.url = DEFAULT_HOST.to_string(); + } + Some(creds) + } + Err(_) => { + // No cached creds, rely solely on env vars or defaults + match (host_env, api_key_env) { + (Ok(host), Ok(api_key)) => Some(BusterCredentials { url: host, api_key }), + (Err(_), Ok(api_key)) => Some(BusterCredentials { + url: DEFAULT_HOST.to_string(), + api_key, + }), + _ => None, // Can't proceed without at least an API key + } + } + }; + + match credentials { + Some(creds) => { + if creds.api_key.is_empty() { + Err(anyhow::anyhow!( + "Authentication required. Please run `buster auth` or set BUSTER_API_KEY." + )) + } else { + // Use the validator trait + validator.validate(&creds.url, &creds.api_key) + .await + .map_err(|e| { + anyhow::anyhow!( + "Authentication failed ({}). Please run `buster auth` to configure credentials.", + e + ) + }) + } + } + None => Err(anyhow::anyhow!( + "Authentication required. Please run `buster auth` or set BUSTER_API_KEY." + )), + } +} + +/// Checks if the user is authenticated by loading credentials and validating them. +/// Prioritizes environment variables (BUSTER_HOST, BUSTER_API_KEY) over cached credentials. +/// Returns Ok(()) if authenticated, otherwise returns an Err prompting the user to run `buster auth`. +pub async fn check_authentication() -> Result<()> { + let cached_credentials_result = get_buster_credentials() + .await + .map_err(anyhow::Error::from); + let validator = RealCredentialsValidator; + check_authentication_inner(cached_credentials_result, &validator).await } pub async fn auth_with_args(args: AuthArgs) -> Result<()> { @@ -66,6 +141,26 @@ pub async fn auth_with_args(args: AuthArgs) -> Result<()> { api_key: String::new(), }, }; + let existing_creds_present = !buster_creds.url.is_empty() && !buster_creds.api_key.is_empty(); + + let host_provided = args.host.is_some(); + let api_key_provided = args.api_key.is_some(); + let fully_provided_via_args = host_provided && api_key_provided; + + // If existing credentials are found and the user hasn't provided everything via args/env, + // prompt for confirmation before proceeding with potential overwrites. + if existing_creds_present && !fully_provided_via_args { + let confirm = Confirm::new("Existing credentials found. Do you want to overwrite them?") + .with_default(false) + .with_help_message("Select 'y' to proceed with entering new credentials, or 'n' to cancel.") + .prompt()?; + + if !confirm { + println!("Authentication cancelled."); + return Ok(()); + } + // If confirmed, we will proceed, potentially overwriting existing values below. + } // Apply host from args or use default if let Some(host) = args.host { @@ -73,15 +168,16 @@ pub async fn auth_with_args(args: AuthArgs) -> Result<()> { } // Check if API key was provided via args or environment - let api_key_from_env = args.api_key.is_some(); - + let api_key_from_env_or_args = args.api_key.is_some(); + // Apply API key from args or environment if let Some(api_key) = args.api_key { buster_creds.api_key = api_key; } // Interactive mode for missing values - if buster_creds.url.is_empty() { + // Only prompt if the value wasn't provided via args/env + if !host_provided && buster_creds.url.is_empty() { let url_input = Text::new("Enter the URL of your Buster API") .with_default(DEFAULT_HOST) .with_help_message("Press Enter to use the default URL") @@ -95,41 +191,233 @@ pub async fn auth_with_args(args: AuthArgs) -> Result<()> { } } - // Always prompt for API key if it wasn't found in environment variables - if !api_key_from_env || buster_creds.api_key.is_empty() { + // Always prompt for API key if it wasn't found in environment variables or args + // unless it's already present from the loaded credentials + if !api_key_from_env_or_args { let obfuscated_api_key = if buster_creds.api_key.is_empty() { String::from("None") } else { - format!("{}...", &buster_creds.api_key[0..4]) + format!("{}...", &buster_creds.api_key[0..std::cmp::min(4, buster_creds.api_key.len())]) // Ensure safe slicing }; - let api_key_input = Password::new(&format!("Enter your API key [{obfuscated_api_key}]:")) + let prompt_message = if existing_creds_present && !fully_provided_via_args { + format!("Enter new API key (current: [{obfuscated_api_key}]):") + } else { + format!("Enter your API key [{obfuscated_api_key}]:") + }; + + let api_key_input = Password::new(&prompt_message) .without_confirmation() - .with_help_message("Your API key can be found in your Buster dashboard") + .with_help_message("Your API key can be found in your Buster dashboard. Leave blank to keep current key.") .prompt() .context("Failed to get API key input")?; if api_key_input.is_empty() && buster_creds.api_key.is_empty() { + // Only error if no key exists *and* none was entered return Err(AuthError::MissingApiKey.into()); } else if !api_key_input.is_empty() { + // Update only if new input was provided buster_creds.api_key = api_key_input; } } - // Validate credentials - validate_credentials(&buster_creds.url, &buster_creds.api_key).await?; + // Validate credentials using the trait + let validator = RealCredentialsValidator; + validator.validate(&buster_creds.url, &buster_creds.api_key).await?; // Save credentials unless --no-save is specified if !args.no_save { set_buster_credentials(buster_creds).await .context("Failed to save credentials")?; println!("Credentials saved successfully!"); - } - - println!("Authentication successful!"); - if args.no_save { - println!("Note: Credentials were not saved due to --no-save flag"); + } else { + // Only print success if we actually went through validation. + // If validation failed, error would have been returned above. + println!("Authentication successful!"); + println!("Note: Credentials were not saved due to --no-save flag"); } Ok(()) } + +// --- Tests --- +#[cfg(test)] +mod tests { + use super::*; + use crate::utils::file::buster_credentials::BusterCredentials; + use mockall::predicate::*; + use std::env; + + // Helper to run async tests with env var setup/teardown + async fn run_test_with_env(env_vars: Vec<(&str, &str)>, test_fn: F) + where + F: FnOnce() -> Fut, + Fut: std::future::Future, + { + // Use a mutex to ensure env vars don't interfere between parallel tests + static ENV_MUTEX: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(()); + let _guard = ENV_MUTEX.lock().await; + + let original_vars: Vec<(&str, Option)> = env_vars + .iter() + .map(|(k, _)| (*k, env::var(*k).ok())) + .collect(); + + for (k, v) in &env_vars { + env::set_var(k, v); + } + + (test_fn)().await; + + // Teardown: Restore original environment variables + for (k, v) in original_vars { + match v { + Some(val) => env::set_var(k, val), + None => env::remove_var(k), + } + } + } + + #[tokio::test] + async fn check_inner_success_env_only() { + let test_host = "http://env.host"; + let test_key = "env_key"; + run_test_with_env(vec![("BUSTER_HOST", test_host), ("BUSTER_API_KEY", test_key)], || async { + let mut mock_validator = MockCredentialsValidator::new(); + mock_validator + .expect_validate() + .with(eq(test_host), eq(test_key)) + .times(1) + .returning(|_, _| Ok(())); + + // Pass Err to simulate no cache file found + let result = check_authentication_inner(Err(anyhow::anyhow!("No cache")), &mock_validator).await; + assert!(result.is_ok()); + }) + .await; + } + + #[tokio::test] + async fn check_inner_success_cache_only() { + // Define creds outside the test closure for lifetime reasons + let cached_creds = BusterCredentials { + url: "http://cache.host".to_string(), + api_key: "cache_key".to_string(), + }; + // Clone fields needed for the closure *before* the test closure + let expected_url = cached_creds.url.clone(); + let expected_key = cached_creds.api_key.clone(); + + run_test_with_env(vec![], || async { + let mut mock_validator = MockCredentialsValidator::new(); + mock_validator + .expect_validate() + // Use withf with the cloned values + .withf(move |url: &str, key: &str| { + url == expected_url && key == expected_key + }) + .times(1) + .returning(|_, _| Ok(())); + + // Clone creds when passing into the function + let result = check_authentication_inner(Ok(cached_creds.clone()), &mock_validator).await; + assert!(result.is_ok()); + }) + .await; + } + + #[tokio::test] + async fn check_inner_success_env_overrides_cache() { + // Define creds outside the test closure + let cached_creds = BusterCredentials { + url: "http://cache.host".to_string(), + api_key: "cache_key".to_string(), + }; + let test_host = "http://env.host"; + let test_key = "env_key"; + // Clone fields needed for the closure + let expected_host = test_host.to_string(); + let expected_key = test_key.to_string(); + + run_test_with_env(vec![("BUSTER_HOST", test_host), ("BUSTER_API_KEY", test_key)], || async { + let mut mock_validator = MockCredentialsValidator::new(); + mock_validator + .expect_validate() + // Use withf with the cloned values + .withf(move |url: &str, key: &str| { + url == expected_host && key == expected_key + }) + .times(1) + .returning(|_, _| Ok(())); + + // Clone creds when passing into the function + let result = check_authentication_inner(Ok(cached_creds.clone()), &mock_validator).await; + assert!(result.is_ok()); + }) + .await; + } + + #[tokio::test] + async fn check_inner_fail_missing_api_key_env() { + run_test_with_env(vec![("BUSTER_HOST", "http://some.host")], || async { // No API Key + let mut mock_validator = MockCredentialsValidator::new(); + mock_validator.expect_validate().times(0); // Validator should not be called + + let result = check_authentication_inner(Err(anyhow::anyhow!("No cache")), &mock_validator).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("BUSTER_API_KEY")); + }) + .await; + } + + #[tokio::test] + async fn check_inner_fail_missing_api_key_cache() { + // Define creds outside the test closure + let cached_creds = BusterCredentials { + url: "http://cache.host".to_string(), + api_key: "".to_string(), // Empty API Key + }; + run_test_with_env(vec![], || async { + let mut mock_validator = MockCredentialsValidator::new(); + mock_validator.expect_validate().times(0); + + // Clone creds when passing into the function + let result = check_authentication_inner(Ok(cached_creds.clone()), &mock_validator).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("BUSTER_API_KEY")); + }) + .await; + } + + #[tokio::test] + async fn check_inner_fail_validation() { + let test_host = "http://env.host"; + let test_key = "env_key_invalid"; + run_test_with_env(vec![("BUSTER_HOST", test_host), ("BUSTER_API_KEY", test_key)], || async { + let mut mock_validator = MockCredentialsValidator::new(); + mock_validator + .expect_validate() + .with(eq(test_host), eq(test_key)) + .times(1) + .returning(|_, _| Err(AuthError::InvalidApiKey)); // Return error + + let result = check_authentication_inner(Err(anyhow::anyhow!("No cache")), &mock_validator).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Authentication failed (Invalid API key)")); + }) + .await; + } + + #[tokio::test] + async fn check_inner_fail_no_creds() { + run_test_with_env(vec![], || async { // No env vars + let mut mock_validator = MockCredentialsValidator::new(); + mock_validator.expect_validate().times(0); + + let result = check_authentication_inner(Err(anyhow::anyhow!("No cache")), &mock_validator).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("BUSTER_API_KEY")); + }) + .await; + } +} diff --git a/cli/cli/src/main.rs b/cli/cli/src/main.rs index 462fb9317..6259e87ee 100644 --- a/cli/cli/src/main.rs +++ b/cli/cli/src/main.rs @@ -5,7 +5,7 @@ mod utils; use clap::{Parser, Subcommand}; use colored::*; -use commands::{auth::AuthArgs, deploy, init}; +use commands::{auth::AuthArgs, deploy, init, auth::check_authentication}; use utils::updater::check_for_updates; pub const APP_NAME: &str = "buster"; @@ -121,6 +121,7 @@ async fn main() { Ok(None) => println!("\n{}", "Unable to check for updates".yellow()), Err(e) => println!("\n{}: {}", "Error checking for updates".red(), e), } + // Explicitly return Ok(()) to match the other arms' types Ok(()) } Commands::Update { @@ -138,7 +139,8 @@ async fn main() { schema, database, flat_structure, - } => { + } => async move { + check_authentication().await?; commands::generate( source_path.as_deref(), destination_path.as_deref(), @@ -148,12 +150,15 @@ async fn main() { flat_structure, ) .await - } + }.await, Commands::Deploy { path, dry_run, recursive, - } => deploy(path.as_deref(), dry_run, recursive).await, + } => async move { + check_authentication().await?; + deploy(path.as_deref(), dry_run, recursive).await + }.await, }; if let Err(e) = result {