diff --git a/.gitignore b/.gitignore index c12ca579..24234970 100644 --- a/.gitignore +++ b/.gitignore @@ -189,3 +189,5 @@ supabase/.temp/storage-version **/.prompts/ **/__pycache__/ + +.env.scripts \ No newline at end of file diff --git a/backend/services/billing.py b/backend/services/billing.py index a4747f73..df3debec 100644 --- a/backend/services/billing.py +++ b/backend/services/billing.py @@ -320,6 +320,12 @@ async def create_checkout_session( billing_cycle_anchor='now' # Reset billing cycle ) + # Update active status in database to true (customer has active subscription) + await client.schema('basejump').from_('billing_customers').update( + {'active': True} + ).eq('id', customer_id).execute() + logger.info(f"Updated customer {customer_id} active status to TRUE after subscription upgrade") + latest_invoice = None if updated_subscription.get('latest_invoice'): latest_invoice = stripe.Invoice.retrieve(updated_subscription['latest_invoice']) @@ -505,6 +511,14 @@ async def create_checkout_session( 'product_id': product_id } ) + + # Update customer status to potentially active (will be confirmed by webhook) + # This ensures customer is marked as active once payment is completed + await client.schema('basejump').from_('billing_customers').update( + {'active': True} + ).eq('id', customer_id).execute() + logger.info(f"Updated customer {customer_id} active status to TRUE after creating checkout session") + return {"session_id": session['id'], "url": session['url'], "status": "new"} except Exception as e: @@ -745,8 +759,57 @@ async def stripe_webhook(request: Request): # Handle the event if event.type in ['customer.subscription.created', 'customer.subscription.updated', 'customer.subscription.deleted']: - # We don't need to do anything here as we'll query Stripe directly - pass + # Extract the subscription and customer information + subscription = event.data.object + customer_id = subscription.get('customer') + + if not customer_id: + logger.warning(f"No customer ID found in subscription event: {event.type}") + return {"status": "error", "message": "No customer ID found"} + + # Get database connection + db = DBConnection() + client = await db.client + + if event.type == 'customer.subscription.created' or event.type == 'customer.subscription.updated': + # Check if subscription is active + if subscription.get('status') in ['active', 'trialing']: + # Update customer's active status to true + await client.schema('basejump').from_('billing_customers').update( + {'active': True} + ).eq('id', customer_id).execute() + logger.info(f"Webhook: Updated customer {customer_id} active status to TRUE based on {event.type}") + else: + # Subscription is not active (e.g., past_due, canceled, etc.) + # Check if customer has any other active subscriptions before updating status + has_active = len(stripe.Subscription.list( + customer=customer_id, + status='active', + limit=1 + ).get('data', [])) > 0 + + if not has_active: + await client.schema('basejump').from_('billing_customers').update( + {'active': False} + ).eq('id', customer_id).execute() + logger.info(f"Webhook: Updated customer {customer_id} active status to FALSE based on {event.type}") + + elif event.type == 'customer.subscription.deleted': + # Check if customer has any other active subscriptions + has_active = len(stripe.Subscription.list( + customer=customer_id, + status='active', + limit=1 + ).get('data', [])) > 0 + + if not has_active: + # If no active subscriptions left, set active to false + await client.schema('basejump').from_('billing_customers').update( + {'active': False} + ).eq('id', customer_id).execute() + logger.info(f"Webhook: Updated customer {customer_id} active status to FALSE after subscription deletion") + + logger.info(f"Processed {event.type} event for customer {customer_id}") return {"status": "success"} diff --git a/backend/utils/scripts/archive_inactive_sandboxes.py b/backend/utils/scripts/archive_inactive_sandboxes.py new file mode 100644 index 00000000..e01e8593 --- /dev/null +++ b/backend/utils/scripts/archive_inactive_sandboxes.py @@ -0,0 +1,350 @@ +#!/usr/bin/env python +""" +Script to archive sandboxes for projects whose account_id is not associated with an active billing customer. + +Usage: + python archive_inactive_sandboxes.py + +This script: +1. Gets all active account_ids from basejump.billing_customers (active=TRUE) +2. Gets all projects from the projects table +3. Archives sandboxes for any project whose account_id is not in the active billing customers list + +Make sure your environment variables are properly set: +- SUPABASE_URL +- SUPABASE_SERVICE_ROLE_KEY +- DAYTONA_SERVER_URL +""" + +import asyncio +import sys +import os +import argparse +from typing import List, Dict, Any, Set +from dotenv import load_dotenv + +# Load script-specific environment variables +load_dotenv(".env") + +from services.supabase import DBConnection +from sandbox.sandbox import daytona +from utils.logger import logger + +# Global DB connection to reuse +db_connection = None + + +async def get_active_billing_customer_account_ids() -> Set[str]: + """ + Query all account_ids from the basejump.billing_customers table where active=TRUE. + + Returns: + Set of account_ids that have an active billing customer record + """ + global db_connection + if db_connection is None: + db_connection = DBConnection() + + client = await db_connection.client + + # Print the Supabase URL being used + print(f"Using Supabase URL: {os.getenv('SUPABASE_URL')}") + + # Query all account_ids from billing_customers where active=true + result = await client.schema('basejump').from_('billing_customers').select('account_id, active').execute() + + # Print the query result + print(f"Found {len(result.data)} billing customers in database") + print(result.data) + + if not result.data: + logger.info("No billing customers found in database") + return set() + + # Extract account_ids for active customers and return as a set for fast lookups + active_account_ids = {customer.get('account_id') for customer in result.data + if customer.get('account_id') and customer.get('active') is True} + + print(f"Found {len(active_account_ids)} active billing customers") + return active_account_ids + + +async def get_all_projects() -> List[Dict[str, Any]]: + """ + Query all projects with sandbox information. + + Returns: + List of projects with their sandbox information + """ + global db_connection + if db_connection is None: + db_connection = DBConnection() + + client = await db_connection.client + + # Initialize variables for pagination + all_projects = [] + page_size = 1000 + current_page = 0 + has_more = True + + logger.info("Starting to fetch all projects (paginated)") + + # Paginate through all projects + while has_more: + # Query projects with pagination + start_range = current_page * page_size + end_range = start_range + page_size - 1 + + logger.info(f"Fetching projects page {current_page+1} (range: {start_range}-{end_range})") + + result = await client.table('projects').select( + 'project_id', + 'name', + 'account_id', + 'sandbox' + ).range(start_range, end_range).execute() + + if not result.data: + has_more = False + else: + all_projects.extend(result.data) + current_page += 1 + + # Progress update + logger.info(f"Loaded {len(all_projects)} projects so far") + print(f"Loaded {len(all_projects)} projects so far...") + + # Check if we've reached the end + if len(result.data) < page_size: + has_more = False + + # Print the query result + total_projects = len(all_projects) + print(f"Found {total_projects} projects in database") + logger.info(f"Total projects found in database: {total_projects}") + + if not all_projects: + logger.info("No projects found in database") + return [] + + # Filter projects that have sandbox information + projects_with_sandboxes = [ + project for project in all_projects + if project.get('sandbox') and project['sandbox'].get('id') + ] + + logger.info(f"Found {len(projects_with_sandboxes)} projects with sandboxes") + return projects_with_sandboxes + + +async def archive_sandbox(project: Dict[str, Any], dry_run: bool) -> bool: + """ + Archive a single sandbox. + + Args: + project: Project information containing sandbox to archive + dry_run: If True, only simulate archiving + + Returns: + True if successful, False otherwise + """ + sandbox_id = project['sandbox'].get('id') + project_name = project.get('name', 'Unknown') + project_id = project.get('project_id', 'Unknown') + + try: + logger.info(f"Checking sandbox {sandbox_id} for project '{project_name}' (ID: {project_id})") + + if dry_run: + logger.info(f"DRY RUN: Would archive sandbox {sandbox_id}") + print(f"Would archive sandbox {sandbox_id} for project '{project_name}'") + return True + + # Get the sandbox + sandbox = daytona.get_current_sandbox(sandbox_id) + + # Check sandbox state - it must be stopped before archiving + sandbox_info = sandbox.info() + + # Log the current state + logger.info(f"Sandbox {sandbox_id} is in '{sandbox_info.state}' state") + + # Only archive if the sandbox is in the stopped state + if sandbox_info.state == "stopped": + logger.info(f"Archiving sandbox {sandbox_id} as it is in stopped state") + sandbox.archive() + logger.info(f"Successfully archived sandbox {sandbox_id}") + return True + else: + logger.info(f"Skipping sandbox {sandbox_id} as it is not in stopped state (current: {sandbox_info.state})") + return True + + except Exception as e: + import traceback + error_type = type(e).__name__ + stack_trace = traceback.format_exc() + + # Log detailed error information + logger.error(f"Error processing sandbox {sandbox_id}: {str(e)}") + logger.error(f"Error type: {error_type}") + logger.error(f"Stack trace:\n{stack_trace}") + + # If the exception has a response attribute (like in HTTP errors), log it + if hasattr(e, 'response'): + try: + response_data = e.response.json() if hasattr(e.response, 'json') else str(e.response) + logger.error(f"Response data: {response_data}") + except Exception: + logger.error(f"Could not parse response data from error") + + print(f"Failed to process sandbox {sandbox_id}: {error_type} - {str(e)}") + return False + + +async def process_sandboxes(inactive_projects: List[Dict[str, Any]], dry_run: bool) -> tuple[int, int]: + """ + Process all sandboxes sequentially. + + Args: + inactive_projects: List of projects without active billing + dry_run: Whether to actually archive sandboxes or just simulate + + Returns: + Tuple of (processed_count, failed_count) + """ + processed_count = 0 + failed_count = 0 + + if dry_run: + logger.info(f"DRY RUN: Would archive {len(inactive_projects)} sandboxes") + else: + logger.info(f"Archiving {len(inactive_projects)} sandboxes") + + print(f"Processing {len(inactive_projects)} sandboxes...") + + # Process each sandbox sequentially + for i, project in enumerate(inactive_projects): + success = await archive_sandbox(project, dry_run) + + if success: + processed_count += 1 + else: + failed_count += 1 + + # Print progress periodically + if (i + 1) % 20 == 0 or (i + 1) == len(inactive_projects): + progress = (i + 1) / len(inactive_projects) * 100 + print(f"Progress: {i + 1}/{len(inactive_projects)} sandboxes processed ({progress:.1f}%)") + print(f" - Processed: {processed_count}, Failed: {failed_count}") + + return processed_count, failed_count + + +async def main(): + """Main function to run the script.""" + # Parse command line arguments + parser = argparse.ArgumentParser(description='Archive sandboxes for projects without active billing') + parser.add_argument('--dry-run', action='store_true', help='Show what would be archived without actually archiving') + args = parser.parse_args() + + logger.info("Starting sandbox cleanup for projects without active billing") + if args.dry_run: + logger.info("DRY RUN MODE - No sandboxes will be archived") + + # Print environment info + print(f"Environment Mode: {os.getenv('ENV_MODE', 'Not set')}") + print(f"Daytona Server: {os.getenv('DAYTONA_SERVER_URL', 'Not set')}") + + try: + # Initialize global DB connection + global db_connection + db_connection = DBConnection() + + # Get all account_ids that have an active billing customer + active_billing_customer_account_ids = await get_active_billing_customer_account_ids() + + # Get all projects with sandboxes + all_projects = await get_all_projects() + + if not all_projects: + logger.info("No projects with sandboxes to process") + return + + # Filter projects whose account_id is not in the active billing customers list + inactive_projects = [ + project for project in all_projects + if project.get('account_id') not in active_billing_customer_account_ids + ] + + # Print summary of what will be processed + active_projects_count = len(all_projects) - len(inactive_projects) + print("\n===== SANDBOX CLEANUP SUMMARY =====") + print(f"Total projects found: {len(all_projects)}") + print(f"Projects with active billing accounts: {active_projects_count}") + print(f"Projects without active billing accounts: {len(inactive_projects)}") + print(f"Sandboxes that will be archived: {len(inactive_projects)}") + print("===================================") + + logger.info(f"Found {len(inactive_projects)} projects without an active billing customer account") + + if not inactive_projects: + logger.info("No projects to archive sandboxes for") + return + + # Ask for confirmation before proceeding + if not args.dry_run: + print("\n⚠️ WARNING: You are about to archive sandboxes for inactive accounts ⚠️") + print("This action cannot be undone!") + confirmation = input("\nAre you sure you want to proceed with archiving? (TRUE/FALSE): ").strip().upper() + + if confirmation != "TRUE": + print("Archiving cancelled. Exiting script.") + logger.info("Archiving cancelled by user") + return + + print("\nProceeding with sandbox archiving...\n") + logger.info("User confirmed sandbox archiving") + + # List all projects to be processed + for i, project in enumerate(inactive_projects[:5]): # Just show first 5 for brevity + account_id = project.get('account_id', 'Unknown') + project_name = project.get('name', 'Unknown') + project_id = project.get('project_id', 'Unknown') + sandbox_id = project['sandbox'].get('id') + + print(f"{i+1}. Project: {project_name}") + print(f" Project ID: {project_id}") + print(f" Account ID: {account_id}") + print(f" Sandbox ID: {sandbox_id}") + + if len(inactive_projects) > 5: + print(f" ... and {len(inactive_projects) - 5} more projects") + + # Process all sandboxes + processed_count, failed_count = await process_sandboxes(inactive_projects, args.dry_run) + + # Print final summary + print("\nSandbox Cleanup Summary:") + print(f"Total projects without active billing: {len(inactive_projects)}") + print(f"Total sandboxes processed: {len(inactive_projects)}") + + if args.dry_run: + print(f"DRY RUN: No sandboxes were actually archived") + else: + print(f"Successfully processed: {processed_count}") + print(f"Failed to process: {failed_count}") + + logger.info("Sandbox cleanup completed") + + except Exception as e: + logger.error(f"Error during sandbox cleanup: {str(e)}") + sys.exit(1) + finally: + # Clean up database connection + if db_connection: + await DBConnection.disconnect() + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/backend/utils/scripts/set_all_customers_active.py b/backend/utils/scripts/set_all_customers_active.py new file mode 100644 index 00000000..a64cf75c --- /dev/null +++ b/backend/utils/scripts/set_all_customers_active.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python +""" +Script to set all Stripe customers in the database to active status. + +Usage: + python update_customer_status.py + +This script: +1. Queries all customer IDs from basejump.billing_customers +2. Sets all customers' active field to True in the database + +Make sure your environment variables are properly set: +- SUPABASE_URL +- SUPABASE_SERVICE_ROLE_KEY +""" + +import asyncio +import sys +import os +from typing import List, Dict, Any +from dotenv import load_dotenv + +# Load script-specific environment variables +load_dotenv(".env") + +from services.supabase import DBConnection +from utils.logger import logger + +# Semaphore to limit concurrent database connections +DB_CONNECTION_LIMIT = 20 +db_semaphore = asyncio.Semaphore(DB_CONNECTION_LIMIT) + +# Global DB connection to reuse +db_connection = None + + +async def get_all_customers() -> List[Dict[str, Any]]: + """ + Query all customers from the database. + + Returns: + List of customers with their ID and account_id + """ + global db_connection + if db_connection is None: + db_connection = DBConnection() + + client = await db_connection.client + + # Print the Supabase URL being used + print(f"Using Supabase URL: {os.getenv('SUPABASE_URL')}") + + # Query all customers from billing_customers + result = await client.schema('basejump').from_('billing_customers').select( + 'id', + 'account_id', + 'active' + ).execute() + + # Print the query result + print(f"Found {len(result.data)} customers in database") + print(result.data) + + if not result.data: + logger.info("No customers found in database") + return [] + + return result.data + + +async def update_all_customers_to_active() -> Dict[str, int]: + """ + Update all customers to active status in the database. + + Returns: + Dict with count of updated customers + """ + try: + global db_connection + if db_connection is None: + db_connection = DBConnection() + + client = await db_connection.client + + # Update all customers to active + result = await client.schema('basejump').from_('billing_customers').update( + {'active': True} + ).filter('id', 'neq', None).execute() + + updated_count = len(result.data) if hasattr(result, 'data') else 0 + logger.info(f"Updated {updated_count} customers to active status") + print(f"Updated {updated_count} customers to active status") + print("Result:", result) + + return {'updated': updated_count} + except Exception as e: + logger.error(f"Error updating customers in database: {str(e)}") + return {'updated': 0, 'error': str(e)} + + +async def main(): + """Main function to run the script.""" + logger.info("Starting customer status update process") + + try: + # Initialize global DB connection + global db_connection + db_connection = DBConnection() + + # Get all customers from the database + customers = await get_all_customers() + + if not customers: + logger.info("No customers to process") + return + + # Ask for confirmation before proceeding + confirm = input(f"\nSet all {len(customers)} customers to active? (y/n): ") + if confirm.lower() != 'y': + logger.info("Operation cancelled by user") + return + + # Update all customers to active + results = await update_all_customers_to_active() + + # Print summary + print("\nCustomer Status Update Summary:") + print(f"Total customers set to active: {results.get('updated', 0)}") + + logger.info("Customer status update completed") + + except Exception as e: + logger.error(f"Error during customer status update: {str(e)}") + sys.exit(1) + finally: + # Clean up database connection + if db_connection: + await DBConnection.disconnect() + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/backend/utils/scripts/update_customer_active_status.py b/backend/utils/scripts/update_customer_active_status.py new file mode 100644 index 00000000..ced3ad2e --- /dev/null +++ b/backend/utils/scripts/update_customer_active_status.py @@ -0,0 +1,326 @@ +#!/usr/bin/env python +""" +Script to check Stripe subscriptions for all customers and update their active status. + +Usage: + python update_customer_active_status.py + +This script: +1. Queries all customers from basejump.billing_customers +2. Checks subscription status directly on Stripe using customer_id +3. Updates customer active status in database + +Make sure your environment variables are properly set: +- SUPABASE_URL +- SUPABASE_SERVICE_ROLE_KEY +- STRIPE_SECRET_KEY +""" + +import asyncio +import sys +import os +import time +from typing import List, Dict, Any, Tuple +from dotenv import load_dotenv +import stripe + +# Load script-specific environment variables +load_dotenv(".env") + +# Import relative modules +from services.supabase import DBConnection +from utils.logger import logger +from utils.config import config + +# Initialize Stripe with the API key +stripe.api_key = config.STRIPE_SECRET_KEY + +# Batch size settings +BATCH_SIZE = 100 # Process customers in batches +MAX_CONCURRENCY = 20 # Maximum concurrent Stripe API calls + +# Global DB connection to reuse +db_connection = None + +async def get_all_customers() -> List[Dict[str, Any]]: + """ + Query all customers from the database. + + Returns: + List of customers with their ID (customer_id is used for Stripe) + """ + global db_connection + if db_connection is None: + db_connection = DBConnection() + + client = await db_connection.client + + # Print the Supabase URL being used + print(f"Using Supabase URL: {os.getenv('SUPABASE_URL')}") + + # Query all customers from billing_customers + result = await client.schema('basejump').from_('billing_customers').select( + 'id', + 'active' + ).execute() + + # Print the query result + print(f"Found {len(result.data)} customers in database") + + if not result.data: + logger.info("No customers found in database") + return [] + + return result.data + +async def check_stripe_subscription(customer_id: str) -> bool: + """ + Check if a customer has an active subscription directly on Stripe. + + Args: + customer_id: Customer ID (billing_customers.id) which is the Stripe customer ID + + Returns: + True if customer has at least one active subscription, False otherwise + """ + if not customer_id: + print(f"⚠️ Empty customer_id") + return False + + try: + # Print what we're checking for debugging + print(f"Checking Stripe subscriptions for customer: {customer_id}") + + # List all subscriptions for this customer directly on Stripe + subscriptions = stripe.Subscription.list( + customer=customer_id, + status='active', # Only get active subscriptions + limit=1 # We only need to know if there's at least one + ) + + # Print the raw data for debugging + print(f"Stripe returned data: {subscriptions.data}") + + # If there's at least one active subscription, the customer is active + has_active_subscription = len(subscriptions.data) > 0 + + if has_active_subscription: + print(f"✅ Customer {customer_id} has ACTIVE subscription") + else: + print(f"❌ Customer {customer_id} has NO active subscription") + + return has_active_subscription + + except Exception as e: + logger.error(f"Error checking Stripe subscription for customer {customer_id}: {str(e)}") + print(f"⚠️ Error checking subscription for {customer_id}: {str(e)}") + return False + +async def process_customer_batch(batch: List[Dict[str, Any]], batch_number: int, total_batches: int) -> Dict[str, bool]: + """ + Process a batch of customers by checking their Stripe subscriptions concurrently. + + Args: + batch: List of customer records in this batch + batch_number: Current batch number (for logging) + total_batches: Total number of batches (for logging) + + Returns: + Dictionary mapping customer IDs to subscription status (True/False) + """ + start_time = time.time() + batch_size = len(batch) + print(f"Processing batch {batch_number}/{total_batches} ({batch_size} customers)...") + + # Create a semaphore to limit concurrency within the batch to avoid rate limiting + semaphore = asyncio.Semaphore(MAX_CONCURRENCY) + + async def check_single_customer(customer: Dict[str, Any]) -> Tuple[str, bool]: + async with semaphore: # Limit concurrent API calls + customer_id = customer['id'] + + # Check directly on Stripe - customer_id IS the Stripe customer ID + is_active = await check_stripe_subscription(customer_id) + return customer_id, is_active + + # Create tasks for all customers in this batch + tasks = [check_single_customer(customer) for customer in batch] + + # Run all tasks in this batch concurrently + results = await asyncio.gather(*tasks) + + # Convert results to dictionary + subscription_status = {customer_id: status for customer_id, status in results} + + end_time = time.time() + + # Count active/inactive in this batch + active_count = sum(1 for status in subscription_status.values() if status) + inactive_count = batch_size - active_count + + print(f"Batch {batch_number} completed in {end_time - start_time:.2f} seconds") + print(f"Results (batch {batch_number}): {active_count} active, {inactive_count} inactive subscriptions") + + return subscription_status + +async def update_customer_batch(subscription_status: Dict[str, bool]) -> Dict[str, int]: + """ + Update a batch of customers in the database. + + Args: + subscription_status: Dictionary mapping customer IDs to active status + + Returns: + Dictionary with statistics about the update + """ + start_time = time.time() + + global db_connection + if db_connection is None: + db_connection = DBConnection() + + client = await db_connection.client + + # Separate customers into active and inactive groups + active_customers = [cid for cid, status in subscription_status.items() if status] + inactive_customers = [cid for cid, status in subscription_status.items() if not status] + + total_count = len(active_customers) + len(inactive_customers) + + # Update statistics + stats = { + 'total': total_count, + 'active_updated': 0, + 'inactive_updated': 0, + 'errors': 0 + } + + # Update active customers in a single operation + if active_customers: + try: + print(f"Updating {len(active_customers)} customers to ACTIVE status") + await client.schema('basejump').from_('billing_customers').update( + {'active': True} + ).in_('id', active_customers).execute() + + stats['active_updated'] = len(active_customers) + logger.info(f"Updated {len(active_customers)} customers to ACTIVE status") + except Exception as e: + logger.error(f"Error updating active customers: {str(e)}") + stats['errors'] += 1 + + # Update inactive customers in a single operation + if inactive_customers: + try: + print(f"Updating {len(inactive_customers)} customers to INACTIVE status") + await client.schema('basejump').from_('billing_customers').update( + {'active': False} + ).in_('id', inactive_customers).execute() + + stats['inactive_updated'] = len(inactive_customers) + logger.info(f"Updated {len(inactive_customers)} customers to INACTIVE status") + except Exception as e: + logger.error(f"Error updating inactive customers: {str(e)}") + stats['errors'] += 1 + + end_time = time.time() + print(f"Database updates completed in {end_time - start_time:.2f} seconds") + + return stats + +async def main(): + """Main function to run the script.""" + total_start_time = time.time() + logger.info("Starting customer active status update process") + + try: + # Check Stripe API key + print(f"Stripe API key configured: {'Yes' if config.STRIPE_SECRET_KEY else 'No'}") + if not config.STRIPE_SECRET_KEY: + print("ERROR: Stripe API key not configured. Please set STRIPE_SECRET_KEY in your environment.") + return + + # Initialize global DB connection + global db_connection + db_connection = DBConnection() + + # Get all customers from the database + all_customers = await get_all_customers() + + if not all_customers: + logger.info("No customers to process") + return + + # Print a small sample of the customer data + print("\nCustomer data sample (customer_id = Stripe customer ID):") + for i, customer in enumerate(all_customers[:5]): # Show first 5 only + print(f" {i+1}. ID: {customer['id']}, Active: {customer.get('active')}") + if len(all_customers) > 5: + print(f" ... and {len(all_customers) - 5} more") + + # Split customers into batches + batches = [all_customers[i:i + BATCH_SIZE] for i in range(0, len(all_customers), BATCH_SIZE)] + total_batches = len(batches) + + # Ask for confirmation before proceeding + confirm = input(f"\nProcess {len(all_customers)} customers in {total_batches} batches of {BATCH_SIZE}? (y/n): ") + if confirm.lower() != 'y': + logger.info("Operation cancelled by user") + return + + # Overall statistics + all_stats = { + 'total': 0, + 'active_updated': 0, + 'inactive_updated': 0, + 'errors': 0 + } + + # Process each batch + for i, batch in enumerate(batches): + batch_number = i + 1 + + # STEP 1: Process this batch of customers + subscription_status = await process_customer_batch(batch, batch_number, total_batches) + + # STEP 2: Update this batch in the database + batch_stats = await update_customer_batch(subscription_status) + + # Accumulate statistics + all_stats['total'] += batch_stats['total'] + all_stats['active_updated'] += batch_stats['active_updated'] + all_stats['inactive_updated'] += batch_stats['inactive_updated'] + all_stats['errors'] += batch_stats['errors'] + + # Show batch completion + print(f"Completed batch {batch_number}/{total_batches}") + + # Brief pause between batches to avoid Stripe rate limiting + if batch_number < total_batches: + await asyncio.sleep(1) # 1 second pause between batches + + # Print summary + total_end_time = time.time() + total_time = total_end_time - total_start_time + + print("\nCustomer Status Update Summary:") + print(f"Total customers processed: {all_stats['total']}") + print(f"Customers set to active: {all_stats['active_updated']}") + print(f"Customers set to inactive: {all_stats['inactive_updated']}") + if all_stats['errors'] > 0: + print(f"Update errors: {all_stats['errors']}") + print(f"Total processing time: {total_time:.2f} seconds") + + logger.info(f"Customer active status update completed in {total_time:.2f} seconds") + + except Exception as e: + logger.error(f"Error during customer status update: {str(e)}") + sys.exit(1) + finally: + # Clean up database connection + if db_connection: + await DBConnection.disconnect() + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/frontend/src/components/home/sections/pricing-section.tsx b/frontend/src/components/home/sections/pricing-section.tsx index 952cb02a..0dae913e 100644 --- a/frontend/src/components/home/sections/pricing-section.tsx +++ b/frontend/src/components/home/sections/pricing-section.tsx @@ -236,13 +236,11 @@ function PricingTier({ ? new Date(response.effective_date).toLocaleDateString() : 'the end of your billing period'; - const downgradeMessage = response.details?.is_upgrade === false - ? `Subscription downgrade scheduled from $${response.details.current_price} to $${response.details.new_price}` - : 'Subscription change scheduled'; + const statusChangeMessage = 'Subscription change scheduled'; toast.success(
{downgradeMessage}
+{statusChangeMessage}
Your plan will change on {effectiveDate}.
@@ -390,11 +388,20 @@ function PricingTier({ buttonVariant = "secondary"; buttonClassName = "bg-primary/5 hover:bg-primary/10 text-primary"; } else { - buttonText = targetAmount > currentAmount ? "Upgrade" : "Downgrade"; - buttonVariant = tier.buttonColor as ButtonVariant; - buttonClassName = targetAmount > currentAmount - ? "bg-primary hover:bg-primary/90 text-primary-foreground" - : "bg-primary/5 hover:bg-primary/10 text-primary"; + if (targetAmount > currentAmount) { + buttonText = "Upgrade"; + buttonVariant = tier.buttonColor as ButtonVariant; + buttonClassName = "bg-primary hover:bg-primary/90 text-primary-foreground"; + } else if (targetAmount < currentAmount) { + buttonText = "-"; + buttonDisabled = true; + buttonVariant = "secondary"; + buttonClassName = "opacity-50 cursor-not-allowed bg-muted text-muted-foreground"; + } else { + buttonText = "Select Plan"; + buttonVariant = tier.buttonColor as ButtonVariant; + buttonClassName = "bg-primary hover:bg-primary/90 text-primary-foreground"; + } } } @@ -410,8 +417,6 @@ function PricingTier({ : "bg-secondary hover:bg-secondary/90 text-white"; } - - return (