ok we got in the init

This commit is contained in:
dal 2025-02-25 20:17:00 -07:00
parent 9c7e217077
commit e7588c1d12
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
3 changed files with 138 additions and 70 deletions

View File

@ -1,16 +1,19 @@
use anyhow::Result; use anyhow::Result;
use colored::*; use colored::*;
use inquire::{Select, Text, Password, validator::Validation, Confirm}; use indicatif::{ProgressBar, ProgressStyle};
use inquire::{validator::Validation, Confirm, Password, Select, Text};
use regex::Regex; use regex::Regex;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_yaml;
use std::error::Error; use std::error::Error;
use indicatif::{ProgressBar, ProgressStyle}; use std::fs;
use std::path::{Path, PathBuf};
use std::time::Duration; use std::time::Duration;
use crate::utils::{ use crate::utils::{
buster_credentials::get_and_validate_buster_credentials, buster_credentials::get_and_validate_buster_credentials,
profiles::{Credential, PostgresCredentials}, profiles::{Credential, PostgresCredentials},
BusterClient, PostDataSourcesRequest, BusterClient, BusterConfig, PostDataSourcesRequest,
}; };
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -42,9 +45,39 @@ struct RedshiftCredentials {
pub schemas: Option<Vec<String>>, pub schemas: Option<Vec<String>>,
} }
pub async fn init() -> Result<()> { pub async fn init(destination_path: Option<&str>) -> Result<()> {
println!("{}", "Initializing Buster...".bold().green()); println!("{}", "Initializing Buster...".bold().green());
// Determine the destination path for buster.yml
let dest_path = match destination_path {
Some(path) => PathBuf::from(path),
None => std::env::current_dir()?,
};
// Ensure destination directory exists
if !dest_path.exists() {
fs::create_dir_all(&dest_path)?;
}
let config_path = dest_path.join("buster.yml");
if config_path.exists() {
let overwrite = Confirm::new(&format!(
"A buster.yml file already exists at {}. Do you want to overwrite it?",
config_path.display().to_string().cyan()
))
.with_default(false)
.prompt()?;
if !overwrite {
println!(
"{}",
"Keeping existing buster.yml file. Configuration will be skipped.".yellow()
);
return Ok(());
}
}
// Check for Buster credentials with progress indicator // Check for Buster credentials with progress indicator
let spinner = ProgressBar::new_spinner(); let spinner = ProgressBar::new_spinner();
spinner.set_style( spinner.set_style(
@ -60,7 +93,7 @@ pub async fn init() -> Result<()> {
Ok(creds) => { Ok(creds) => {
spinner.finish_with_message("✓ Buster credentials found".green().to_string()); spinner.finish_with_message("✓ Buster credentials found".green().to_string());
creds creds
}, }
Err(_) => { Err(_) => {
spinner.finish_with_message("✗ No valid Buster credentials found".red().to_string()); spinner.finish_with_message("✗ No valid Buster credentials found".red().to_string());
println!("Please run {} first.", "buster auth".cyan()); println!("Please run {} first.", "buster auth".cyan());
@ -76,25 +109,30 @@ pub async fn init() -> Result<()> {
DatabaseType::Snowflake, DatabaseType::Snowflake,
]; ];
let db_type = Select::new( let db_type = Select::new("Select your database type:", db_types).prompt()?;
"Select your database type:",
db_types,
)
.prompt()?;
println!("You selected: {}", db_type.to_string().cyan()); println!("You selected: {}", db_type.to_string().cyan());
match db_type { match db_type {
DatabaseType::Redshift => setup_redshift(buster_creds.url, buster_creds.api_key).await, DatabaseType::Redshift => {
setup_redshift(buster_creds.url, buster_creds.api_key, &config_path).await
}
_ => { _ => {
println!("{}", format!("{} support is coming soon!", db_type).yellow()); println!(
"{}",
format!("{} support is coming soon!", db_type).yellow()
);
println!("Currently, only Redshift is supported."); println!("Currently, only Redshift is supported.");
Err(anyhow::anyhow!("Database type not yet implemented")) Err(anyhow::anyhow!("Database type not yet implemented"))
} }
} }
} }
async fn setup_redshift(buster_url: String, buster_api_key: String) -> Result<()> { async fn setup_redshift(
buster_url: String,
buster_api_key: String,
config_path: &Path,
) -> Result<()> {
println!("{}", "Setting up Redshift connection...".bold().green()); println!("{}", "Setting up Redshift connection...".bold().green());
// Collect name (with validation) // Collect name (with validation)
@ -108,7 +146,10 @@ async fn setup_redshift(buster_url: String, buster_api_key: String) -> Result<()
if name_regex.is_match(input) { if name_regex.is_match(input) {
Ok(Validation::Valid) Ok(Validation::Valid)
} else { } else {
Ok(Validation::Invalid("Name must contain only alphanumeric characters, dash (-) or underscore (_)".into())) Ok(Validation::Invalid(
"Name must contain only alphanumeric characters, dash (-) or underscore (_)"
.into(),
))
} }
}) })
.prompt()?; .prompt()?;
@ -128,11 +169,11 @@ async fn setup_redshift(buster_url: String, buster_api_key: String) -> Result<()
let port_str = Text::new("Enter the Redshift port:") let port_str = Text::new("Enter the Redshift port:")
.with_default("5439") .with_default("5439")
.with_help_message("Default Redshift port is 5439") .with_help_message("Default Redshift port is 5439")
.with_validator(|input: &str| { .with_validator(|input: &str| match input.parse::<u16>() {
match input.parse::<u16>() { Ok(_) => Ok(Validation::Valid),
Ok(_) => Ok(Validation::Valid), Err(_) => Ok(Validation::Invalid(
Err(_) => Ok(Validation::Invalid("Port must be a valid number between 1 and 65535".into())), "Port must be a valid number between 1 and 65535".into(),
} )),
}) })
.prompt()?; .prompt()?;
let port = port_str.parse::<u16>()?; let port = port_str.parse::<u16>()?;
@ -214,8 +255,8 @@ async fn setup_redshift(buster_url: String, buster_api_key: String) -> Result<()
port, port,
username, username,
password, password,
database, database: database.clone(),
schemas: schema.map(|s| vec![s]), schemas: schema.as_ref().map(|s| vec![s.clone()]),
}; };
// Create API request // Create API request
@ -224,19 +265,21 @@ async fn setup_redshift(buster_url: String, buster_api_key: String) -> Result<()
let request = PostDataSourcesRequest { let request = PostDataSourcesRequest {
name: name.clone(), name: name.clone(),
env: "dev".to_string(), // Default to dev environment env: "dev".to_string(), // Default to dev environment
credential: Credential::Redshift( credential: Credential::Redshift(PostgresCredentials {
PostgresCredentials { host: redshift_creds.host,
host: redshift_creds.host, port: redshift_creds.port,
port: redshift_creds.port, username: redshift_creds.username,
username: redshift_creds.username, password: redshift_creds.password,
password: redshift_creds.password, database: redshift_creds.database.clone().unwrap_or_default(),
database: redshift_creds.database.clone().unwrap_or_default(), schema: redshift_creds
schema: redshift_creds.schemas.clone().and_then(|s| s.first().cloned()).unwrap_or_default(), .schemas
jump_host: None, .clone()
ssh_username: None, .and_then(|s| s.first().cloned())
ssh_private_key: None, .unwrap_or_default(),
} jump_host: None,
), ssh_username: None,
ssh_private_key: None,
}),
}; };
// Send to API with progress indicator // Send to API with progress indicator
@ -254,11 +297,32 @@ async fn setup_redshift(buster_url: String, buster_api_key: String) -> Result<()
match client.post_data_sources(vec![request]).await { match client.post_data_sources(vec![request]).await {
Ok(_) => { Ok(_) => {
spinner.finish_with_message("✓ Data source created successfully!".green().bold().to_string()); spinner.finish_with_message(
println!("\nData source '{}' is now available for use with Buster.", name.cyan()); "✓ Data source created successfully!"
.green()
.bold()
.to_string(),
);
println!(
"\nData source '{}' is now available for use with Buster.",
name.cyan()
);
// Create a copy of the values we need for the config file
let db_copy = database.clone();
let schema_copy = schema.clone();
// Create buster.yml file
create_buster_config_file(
config_path,
&name,
db_copy.as_deref(),
schema_copy.as_deref(),
)?;
println!("You can now use this data source with other Buster commands."); println!("You can now use this data source with other Buster commands.");
Ok(()) Ok(())
}, }
Err(e) => { Err(e) => {
spinner.finish_with_message("✗ Failed to create data source".red().bold().to_string()); spinner.finish_with_message("✗ Failed to create data source".red().bold().to_string());
println!("\nError: {}", e); println!("\nError: {}", e);
@ -267,3 +331,30 @@ async fn setup_redshift(buster_url: String, buster_api_key: String) -> Result<()
} }
} }
} }
// Helper function to create buster.yml file
fn create_buster_config_file(
path: &Path,
data_source_name: &str,
database: Option<&str>,
schema: Option<&str>,
) -> Result<()> {
let config = BusterConfig {
data_source_name: Some(data_source_name.to_string()),
schema: schema.map(String::from),
database: database.map(String::from),
exclude_files: None,
exclude_tags: None,
};
let yaml = serde_yaml::to_string(&config)?;
fs::write(path, yaml)?;
println!(
"{} {}",
"".green(),
format!("Created buster.yml at {}", path.display()).green()
);
Ok(())
}

View File

@ -15,7 +15,11 @@ pub const GIT_HASH: &str = env!("GIT_HASH");
#[derive(Subcommand)] #[derive(Subcommand)]
#[clap(rename_all = "kebab-case")] #[clap(rename_all = "kebab-case")]
pub enum Commands { pub enum Commands {
Init, Init {
/// Path to create the buster.yml file (defaults to current directory)
#[arg(long)]
destination_path: Option<String>,
},
/// Authenticate with Buster API /// Authenticate with Buster API
Auth { Auth {
/// The Buster API host URL /// The Buster API host URL
@ -82,7 +86,7 @@ async fn main() {
// TODO: All commands should check for an update. // TODO: All commands should check for an update.
let result = match args.cmd { let result = match args.cmd {
Commands::Init => init().await, Commands::Init { destination_path } => init(destination_path.as_deref()).await,
Commands::Auth { Commands::Auth {
host, host,
api_key, api_key,

View File

@ -1,27 +0,0 @@
use anyhow::Result;
use buster_cli::utils::file::profiles::{create_dbt_project_yml};
use tempfile::tempdir;
use std::fs::read_to_string;
#[tokio::test]
async fn test_create_dbt_project_yml() -> Result<()> {
// Create a temporary directory for the test
let dir = tempdir()?;
std::env::set_current_dir(dir.path())?;
// Create the project file
create_dbt_project_yml("test_project", "test_profile", "view").await?;
// Read the created file
let contents = read_to_string("dbt_project.yml")?;
let yaml: serde_yaml::Value = serde_yaml::from_str(&contents)?;
// Assert expected values
assert_eq!(yaml["name"], "test_project");
assert_eq!(yaml["version"], "1.0.0");
assert_eq!(yaml["profile"], "test_profile");
assert_eq!(yaml["model-paths"][0], "models");
assert_eq!(yaml["models"]["test_project"]["example"]["+materialized"], "view");
Ok(())
}