From 1f5c0e9be2e6e70f471fc5f53f155e828323e2b0 Mon Sep 17 00:00:00 2001 From: Saumya Date: Sat, 6 Sep 2025 15:26:09 +0530 Subject: [PATCH] simplify blling arch --- backend/core/billing_config.py | 14 +- backend/scripts/sync_stripe_subscriptions.py | 76 +--- backend/scripts/verify_user_tiers.py | 97 +++++ backend/services/billing_v2.py | 330 +++++++++++------ backend/services/billing_v3.py | 337 ++++++++++++++++++ .../settings/billing/page.tsx | 53 ++- .../src/components/billing/billing-modal.tsx | 59 +-- .../home/sections/pricing-section.tsx | 10 +- .../src/hooks/react-query/use-billing-v2.ts | 11 + frontend/src/lib/api.ts | 5 +- frontend/src/lib/api/billing-v2.ts | 16 +- 11 files changed, 772 insertions(+), 236 deletions(-) create mode 100644 backend/scripts/verify_user_tiers.py create mode 100644 backend/services/billing_v3.py diff --git a/backend/core/billing_config.py b/backend/core/billing_config.py index bd250687..7d2e7926 100644 --- a/backend/core/billing_config.py +++ b/backend/core/billing_config.py @@ -36,7 +36,7 @@ TIERS: Dict[str, Tier] = { config.STRIPE_TIER_2_20_YEARLY_ID, config.STRIPE_TIER_2_17_YEARLY_COMMITMENT_ID ], - monthly_credits=Decimal('25.00'), + monthly_credits=Decimal('20.00'), display_name='Starter', can_purchase_credits=True, models=['gpt-4o-mini', 'gpt-4o', 'claude-3-haiku', 'claude-3-5-sonnet'], @@ -49,7 +49,7 @@ TIERS: Dict[str, Tier] = { config.STRIPE_TIER_6_50_YEARLY_ID, config.STRIPE_TIER_6_42_YEARLY_COMMITMENT_ID ], - monthly_credits=Decimal('65.00'), + monthly_credits=Decimal('50.00'), display_name='Professional', can_purchase_credits=True, models=['gpt-4o-mini', 'gpt-4o', 'claude-3-haiku', 'claude-3-5-sonnet', 'claude-3-opus'], @@ -61,7 +61,7 @@ TIERS: Dict[str, Tier] = { config.STRIPE_TIER_12_100_ID, config.STRIPE_TIER_12_100_YEARLY_ID ], - monthly_credits=Decimal('130.00'), + monthly_credits=Decimal('100.00'), display_name='Team', can_purchase_credits=True, models=['gpt-4o-mini', 'gpt-4o', 'claude-3-haiku', 'claude-3-5-sonnet', 'claude-3-opus'], @@ -74,7 +74,7 @@ TIERS: Dict[str, Tier] = { config.STRIPE_TIER_25_200_YEARLY_ID, config.STRIPE_TIER_25_170_YEARLY_COMMITMENT_ID ], - monthly_credits=Decimal('260.00'), + monthly_credits=Decimal('200.00'), display_name='Business', can_purchase_credits=True, models=['gpt-4o-mini', 'gpt-4o', 'claude-3-haiku', 'claude-3-5-sonnet', 'claude-3-opus', 'o1-preview'], @@ -86,7 +86,7 @@ TIERS: Dict[str, Tier] = { config.STRIPE_TIER_50_400_ID, config.STRIPE_TIER_50_400_YEARLY_ID ], - monthly_credits=Decimal('520.00'), + monthly_credits=Decimal('400.00'), display_name='Enterprise', can_purchase_credits=True, models=['gpt-4o-mini', 'gpt-4o', 'claude-3-haiku', 'claude-3-5-sonnet', 'claude-3-opus', 'o1-preview', 'o1'], @@ -98,7 +98,7 @@ TIERS: Dict[str, Tier] = { config.STRIPE_TIER_125_800_ID, config.STRIPE_TIER_125_800_YEARLY_ID ], - monthly_credits=Decimal('1040.00'), + monthly_credits=Decimal('800.00'), display_name='Enterprise Plus', can_purchase_credits=True, models=['gpt-4o-mini', 'gpt-4o', 'claude-3-haiku', 'claude-3-5-sonnet', 'claude-3-opus', 'o1-preview', 'o1'], @@ -110,7 +110,7 @@ TIERS: Dict[str, Tier] = { config.STRIPE_TIER_200_1000_ID, config.STRIPE_TIER_200_1000_YEARLY_ID ], - monthly_credits=Decimal('1300.00'), + monthly_credits=Decimal('1000.00'), display_name='Ultimate', can_purchase_credits=True, models=['gpt-4o-mini', 'gpt-4o', 'claude-3-haiku', 'claude-3-5-sonnet', 'claude-3-opus', 'o1-preview', 'o1'], diff --git a/backend/scripts/sync_stripe_subscriptions.py b/backend/scripts/sync_stripe_subscriptions.py index cd7299c3..e0d68737 100644 --- a/backend/scripts/sync_stripe_subscriptions.py +++ b/backend/scripts/sync_stripe_subscriptions.py @@ -158,71 +158,38 @@ async def sync_stripe_to_db(): print(f" ✓ Synced subscription to database") synced_count += 1 - # For active/trialing subscriptions, use handle_subscription_change for comprehensive update if sub.status in ['active', 'trialing']: print(f" Processing active subscription...") - # Get current state account_before = await client.from_('credit_accounts')\ .select('*')\ .eq('user_id', account_id)\ .maybe_single()\ .execute() - # Use handle_subscription_change to properly update all fields - # This handles billing_cycle_anchor, stripe_subscription_id, next_credit_grant, etc. try: await handle_subscription_change(sub) print(f" ✓ Updated billing cycle and credit fields") except Exception as e: print(f" ⚠ handle_subscription_change failed: {e}") - # Now fix the balance based on actual usage from agent_runs - tier = get_tier_by_price_id(price_id) if price_id else None - if tier: - tier_name = tier.name - - # Calculate actual monthly credits (subscription cost + free tier base) - parts = tier_name.split('_') - if len(parts) == 3 and parts[0] == 'tier': - subscription_cost = Decimal(parts[2]) - monthly_credits = subscription_cost + FREE_TIER_INITIAL_CREDITS - else: - monthly_credits = FREE_TIER_INITIAL_CREDITS - - # Get current month usage from agent_runs (old billing system style) - month_start = datetime.now().replace(day=1, hour=0, minute=0, second=0, microsecond=0) - runs_result = await client.from_('agent_runs')\ - .select('total_cost')\ - .eq('user_id', account_id)\ - .gte('created_at', month_start.isoformat())\ - .execute() - - current_usage = Decimal('0') - for run in runs_result.data or []: - if run.get('total_cost'): - current_usage += Decimal(str(run['total_cost'])) - - # Calculate correct balance - correct_balance = monthly_credits - current_usage - - # Update/create credit account with all the proper fields - credit_account = await client.from_('credit_accounts')\ - .select('*')\ - .eq('user_id', account_id)\ - .maybe_single()\ - .execute() - - if credit_account.data: - # Update existing account + account_after = await client.from_('credit_accounts')\ + .select('*')\ + .eq('user_id', account_id)\ + .maybe_single()\ + .execute() + + if account_after.data: + tier = get_tier_by_price_id(price_id) if price_id else None + if tier: + tier_name = tier.name + update_data = { 'tier': tier_name, - 'balance': str(correct_balance), 'stripe_subscription_id': sub.id, 'billing_cycle_anchor': datetime.fromtimestamp(sub.created).isoformat() if sub.created else None, } - # Add next_credit_grant if we have current_period_end if sub.current_period_end: update_data['next_credit_grant'] = datetime.fromtimestamp(sub.current_period_end).isoformat() @@ -231,27 +198,14 @@ async def sync_stripe_to_db(): .eq('user_id', account_id)\ .execute() + current_balance = Decimal(str(account_after.data['balance'])) print(f" ✓ Updated credit account:") print(f" - Tier: {tier_name}") - print(f" - Balance: ${correct_balance} (${monthly_credits} - ${current_usage} usage)") + print(f" - Balance: ${current_balance} (managed by handle_subscription_change)") print(f" - Stripe subscription: {sub.id}") print(f" - Next credit grant: {update_data.get('next_credit_grant', 'N/A')}") - else: - # Create new account with all fields - await client.from_('credit_accounts').insert({ - 'user_id': account_id, - 'balance': str(correct_balance), - 'tier': tier_name, - 'stripe_subscription_id': sub.id, - 'billing_cycle_anchor': datetime.fromtimestamp(sub.created).isoformat() if sub.created else None, - 'next_credit_grant': datetime.fromtimestamp(sub.current_period_end).isoformat() if sub.current_period_end else None, - 'last_grant_date': datetime.now().isoformat() - }).execute() - - print(f" ✓ Created credit account:") - print(f" - Tier: {tier_name}") - print(f" - Balance: ${correct_balance} (${monthly_credits} - ${current_usage} usage)") - print(f" - Stripe subscription: {sub.id}") + else: + print(f" ⚠ No credit account found after handle_subscription_change - this is unexpected") else: print(f" ✗ Failed to sync") diff --git a/backend/scripts/verify_user_tiers.py b/backend/scripts/verify_user_tiers.py new file mode 100644 index 00000000..0d93f38f --- /dev/null +++ b/backend/scripts/verify_user_tiers.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +import asyncio +import sys +from pathlib import Path +sys.path.append(str(Path(__file__).parent.parent)) + +from core.services.supabase import DBConnection +from core.billing_config import get_tier_by_price_id +from core.utils.logger import logger + +async def verify_and_fix_tiers(): + """Verify that all users have the correct tier in credit_accounts""" + db = DBConnection() + await db.initialize() + client = await db.client + + print("\n" + "="*60) + print("VERIFYING USER TIERS") + print("="*60) + + # Get all active/trialing subscriptions + subs_result = await client.schema('basejump').from_('billing_subscriptions')\ + .select('*')\ + .in_('status', ['active', 'trialing'])\ + .execute() + + if not subs_result.data: + print("No active subscriptions found") + return + + print(f"Found {len(subs_result.data)} active/trialing subscriptions") + + fixed = 0 + already_correct = 0 + errors = 0 + + for sub in subs_result.data: + user_id = sub['account_id'] + price_id = sub.get('price_id') + + if not price_id: + print(f"\nUser {user_id[:8]}... has no price_id in subscription") + continue + + # Get the tier from price_id + tier = get_tier_by_price_id(price_id) + if not tier: + print(f"\nUser {user_id[:8]}... has unknown price_id: {price_id}") + continue + + expected_tier = tier.name + + # Get current credit account + credit_result = await client.from_('credit_accounts')\ + .select('tier')\ + .eq('user_id', user_id)\ + .single()\ + .execute() + + if not credit_result.data: + print(f"\nUser {user_id[:8]}... has no credit account!") + errors += 1 + continue + + current_tier = credit_result.data['tier'] + + if current_tier != expected_tier: + print(f"\nUser {user_id[:8]}...") + print(f" Current tier: {current_tier}") + print(f" Expected tier: {expected_tier}") + print(f" Price ID: {price_id}") + + # Fix the tier + update_result = await client.from_('credit_accounts')\ + .update({'tier': expected_tier})\ + .eq('user_id', user_id)\ + .execute() + + if update_result.data: + print(f" ✓ Fixed tier to: {expected_tier}") + fixed += 1 + else: + print(f" ✗ Failed to update tier") + errors += 1 + else: + already_correct += 1 + + print("\n" + "="*60) + print("SUMMARY") + print("="*60) + print(f"Already correct: {already_correct}") + print(f"Fixed: {fixed}") + print(f"Errors: {errors}") + print(f"Total: {already_correct + fixed + errors}") + +if __name__ == "__main__": + asyncio.run(verify_and_fix_tiers()) \ No newline at end of file diff --git a/backend/services/billing_v2.py b/backend/services/billing_v2.py index d0a0df2f..4d5a387b 100644 --- a/backend/services/billing_v2.py +++ b/backend/services/billing_v2.py @@ -69,7 +69,7 @@ async def get_user_subscription_tier(user_id: str) -> Dict: if credit_result.data and len(credit_result.data) > 0: tier_name = credit_result.data[0].get('tier', 'free') else: - result = await client.schema('basejump').from_('billing_subscriptions').select('price_id, status').eq('account_id', user_id).eq('status', 'active').order('created', desc=True).limit(1).execute() + result = await client.schema('basejump').from_('billing_subscriptions').select('price_id, status').eq('account_id', user_id).in_('status', ['active', 'trialing']).order('created', desc=True).limit(1).execute() if result.data: price_id = result.data[0]['price_id'] @@ -357,12 +357,11 @@ async def stripe_webhook(request: Request): elif event.type in ['customer.subscription.created', 'customer.subscription.updated']: subscription = event.data.object - if subscription.status == 'active': + if subscription.status in ['active', 'trialing']: await handle_subscription_change(subscription) elif event.type == 'invoice.payment_succeeded': invoice = event.data.object - # Only process subscription invoices (not one-time payments) if invoice.get('subscription') and invoice.get('billing_reason') in ['subscription_cycle', 'subscription_update']: await handle_subscription_renewal(invoice) @@ -394,7 +393,6 @@ async def handle_subscription_change(subscription: Dict): 'credits': float(new_tier_info.monthly_credits) } - # Get billing cycle anchor from subscription billing_anchor = datetime.fromtimestamp(subscription['current_period_start'], tz=timezone.utc) next_grant_date = datetime.fromtimestamp(subscription['current_period_end'], tz=timezone.utc) @@ -412,16 +410,18 @@ async def handle_subscription_change(subscription: Dict): 'credits': float(current_tier_info.monthly_credits) } - # Only grant credits for upgrades if it's not a renewal - if current_tier and new_tier['credits'] > current_tier['credits'] and not existing_anchor: - credit_diff = Decimal(new_tier['credits'] - current_tier['credits']) - await credit_service.add_credits( - user_id=user_id, - amount=credit_diff, - type='tier_upgrade', - description=f"Upgrade from {current_tier['name']} to {new_tier['name']}" - ) - logger.info(f"Granted {credit_diff} credits for tier upgrade: {current_tier['name']} -> {new_tier['name']}") + if current_tier and current_tier['name'] != new_tier['name']: + if new_tier['credits'] > current_tier['credits']: + full_amount = Decimal(new_tier['credits']) + await credit_service.add_credits( + user_id=user_id, + amount=full_amount, + type='tier_upgrade', + description=f"Upgrade to {new_tier['name']} - Full tier credits" + ) + logger.info(f"Granted {full_amount} credits for tier upgrade: {current_tier['name']} -> {new_tier['name']}") + elif new_tier['credits'] < current_tier['credits']: + logger.info(f"Tier downgrade: {current_tier['name']} -> {new_tier['name']} - No credit adjustment") await client.from_('credit_accounts').update({ 'tier': new_tier['name'], @@ -430,7 +430,6 @@ async def handle_subscription_change(subscription: Dict): 'next_credit_grant': next_grant_date.isoformat() }).eq('user_id', user_id).execute() else: - # New subscription - grant initial credits await client.from_('credit_accounts').insert({ 'user_id': user_id, 'balance': new_tier['credits'], @@ -449,7 +448,6 @@ async def handle_subscription_change(subscription: Dict): ) async def handle_subscription_renewal(invoice: Dict): - """Handle subscription renewal and grant monthly credits based on billing cycle""" try: db = DBConnection() client = await db.client @@ -458,7 +456,6 @@ async def handle_subscription_renewal(invoice: Dict): if not subscription_id: return - # Get user from subscription customer_result = await client.schema('basejump').from_('billing_customers')\ .select('account_id')\ .eq('id', invoice['customer'])\ @@ -469,7 +466,6 @@ async def handle_subscription_renewal(invoice: Dict): user_id = customer_result.data[0]['account_id'] - # Get credit account account_result = await client.from_('credit_accounts')\ .select('tier, last_grant_date, next_credit_grant')\ .eq('user_id', user_id)\ @@ -481,14 +477,12 @@ async def handle_subscription_renewal(invoice: Dict): account = account_result.data[0] tier = account['tier'] - # Skip if we already granted credits recently (within 25 days to handle edge cases) if account.get('last_grant_date'): last_grant = datetime.fromisoformat(account['last_grant_date'].replace('Z', '+00:00')) if (datetime.now(timezone.utc) - last_grant).days < 25: logger.info(f"Skipping credit grant for {user_id} - already granted {(datetime.now(timezone.utc) - last_grant).days} days ago") return - # Grant monthly credits monthly_credits = get_monthly_credits(tier) if monthly_credits > 0: await credit_service.add_credits( @@ -499,7 +493,6 @@ async def handle_subscription_renewal(invoice: Dict): metadata={'invoice_id': invoice['id'], 'subscription_id': subscription_id} ) - # Update next grant date next_grant = datetime.now(timezone.utc) + timedelta(days=30) await client.from_('credit_accounts').update({ 'last_grant_date': datetime.now(timezone.utc).isoformat(), @@ -510,19 +503,6 @@ async def handle_subscription_renewal(invoice: Dict): except Exception as e: logger.error(f"Error handling subscription renewal: {e}") - - await client.schema('basejump').from_('billing_subscriptions').upsert({ - 'id': subscription['id'], - 'account_id': user_id, - 'billing_customer_id': subscription['customer'], - 'status': subscription['status'], - 'price_id': price_id, - 'current_period_end': datetime.fromtimestamp(subscription['current_period_end']).isoformat() if subscription.get('current_period_end') else None, - 'cancel_at': datetime.fromtimestamp(subscription['cancel_at']).isoformat() if subscription.get('cancel_at') else None, - 'cancel_at_period_end': subscription.get('cancel_at_period_end', False) - }).execute() - - await Cache.invalidate(f"subscription_tier:{user_id}") @router.get("/subscription") async def get_subscription( @@ -532,29 +512,55 @@ async def get_subscription( db = DBConnection() client = await db.client - subscription_result = await client.schema('basejump').from_('billing_subscriptions').select('*').eq('account_id', user_id).eq('status', 'active').order('created', desc=True).limit(1).execute() + subscription_result = await client.schema('basejump').from_('billing_subscriptions')\ + .select('*')\ + .eq('account_id', user_id)\ + .in_('status', ['active', 'trialing'])\ + .order('created', desc=True)\ + .limit(1)\ + .execute() subscription_data = None - credit_result = await client.from_('credit_accounts').select('tier').eq('user_id', user_id).execute() - tier_name = credit_result.data[0]['tier'] if credit_result.data else 'free' - tier_obj = TIERS.get(tier_name, TIERS['free']) + credit_result = await client.from_('credit_accounts').select('*').eq('user_id', user_id).execute() - tier_info = { - 'name': tier_obj.name, - 'credits': float(tier_obj.monthly_credits) - } - price_id = config.STRIPE_FREE_TIER_ID + if credit_result.data: + credit_account = credit_result.data[0] + tier_name = credit_account.get('tier', 'free') + tier_obj = TIERS.get(tier_name, TIERS['free']) + + actual_credits = float(tier_obj.monthly_credits) + if tier_name != 'free': + parts = tier_name.split('_') + if len(parts) >= 3 and parts[0] == 'tier': + subscription_cost = float(parts[-1]) + actual_credits = subscription_cost + 5.0 + + tier_info = { + 'name': tier_obj.name, + 'credits': actual_credits + } + + if tier_obj and len(tier_obj.price_ids) > 0: + price_id = tier_obj.price_ids[0] + + if subscription_result.data: + subscription = subscription_result.data[0] + if subscription.get('price_id'): + price_id = subscription['price_id'] + else: + price_id = config.STRIPE_FREE_TIER_ID + else: + tier_name = 'free' + tier_obj = TIERS['free'] + tier_info = { + 'name': 'free', + 'credits': float(TIERS['free'].monthly_credits) + } + price_id = config.STRIPE_FREE_TIER_ID if subscription_result.data: subscription = subscription_result.data[0] - price_id = subscription['price_id'] - price_tier_obj = get_tier_by_price_id(price_id) - if price_tier_obj: - tier_info = { - 'name': price_tier_obj.name, - 'credits': float(price_tier_obj.monthly_credits) - } stripe_subscription = None try: @@ -598,8 +604,15 @@ async def get_subscription( balance = await credit_service.get_balance(user_id) summary = await credit_service.get_account_summary(user_id) + if subscription_data: + status = 'active' + elif tier_name != 'free': + status = 'cancelled' + else: + status = 'no_subscription' + return { - 'status': 'active' if subscription_data else 'no_subscription', + 'status': status, 'plan_name': tier_info['name'], 'price_id': price_id, 'subscription': subscription_data, @@ -607,7 +620,7 @@ async def get_subscription( 'current_usage': float(summary['lifetime_used']), 'cost_limit': tier_info['credits'], 'credit_balance': float(balance), - 'can_purchase_credits': tier_info['name'] == 'tier_200_1000', + 'can_purchase_credits': tier_info.get('name') in ['tier_200_1000', 'tier_50_400', 'tier_125_800'], 'tier': tier_info, 'credits': { 'balance': float(balance), @@ -615,7 +628,7 @@ async def get_subscription( 'lifetime_granted': float(summary['lifetime_granted']), 'lifetime_purchased': float(summary['lifetime_purchased']), 'lifetime_used': float(summary['lifetime_used']), - 'can_purchase': tier_info['name'] == 'tier_200_1000' + 'can_purchase': tier_info.get('name') in ['tier_200_1000', 'tier_50_400', 'tier_125_800'] } } @@ -658,7 +671,7 @@ async def create_checkout_session( db = DBConnection() client = await db.client - existing_sub = await client.schema('basejump').from_('billing_subscriptions').select('id').eq('account_id', user_id).eq('status', 'active').execute() + existing_sub = await client.schema('basejump').from_('billing_subscriptions').select('id').eq('account_id', user_id).in_('status', ['active', 'trialing']).execute() if existing_sub.data and len(existing_sub.data) > 0: subscription = await stripe.Subscription.retrieve_async(existing_sub.data[0]['id']) @@ -673,12 +686,28 @@ async def create_checkout_session( payment_behavior='pending_if_incomplete' ) + await handle_subscription_change(updated_subscription) + await Cache.invalidate(f"subscription_tier:{user_id}") + await Cache.invalidate(f"credit_balance:{user_id}") + await Cache.invalidate(f"credit_summary:{user_id}") + + old_price_id = subscription['items']['data'][0].price.id + old_tier = get_tier_by_price_id(old_price_id) + new_tier = get_tier_by_price_id(request.price_id) + + old_amount = float(old_tier.monthly_credits) if old_tier else 0 + new_amount = float(new_tier.monthly_credits) if new_tier else 0 return { - 'success': True, + 'status': 'upgraded' if new_amount > old_amount else 'updated', 'subscription_id': updated_subscription.id, - 'message': 'Subscription updated successfully' + 'message': 'Subscription updated successfully', + 'details': { + 'is_upgrade': new_amount > old_amount, + 'current_price': old_amount, + 'new_price': new_amount + } } else: session = await stripe.checkout.Session.create_async( @@ -695,7 +724,6 @@ async def create_checkout_session( } } ) - return {'checkout_url': session.url} except Exception as e: @@ -721,6 +749,58 @@ async def create_portal_session( logger.error(f"Error creating portal session: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) +@router.post("/sync-subscription") +async def sync_subscription( + user_id: str = Depends(verify_and_get_user_id_from_jwt) +) -> Dict: + try: + db = DBConnection() + client = await db.client + + sub_result = await client.schema('basejump').from_('billing_subscriptions')\ + .select('*')\ + .eq('account_id', user_id)\ + .in_('status', ['active', 'trialing'])\ + .limit(1)\ + .execute() + + if not sub_result.data or len(sub_result.data) == 0: + return { + 'success': False, + 'message': 'No active subscription found' + } + + subscription = await stripe.Subscription.retrieve_async( + sub_result.data[0]['id'], + expand=['items.data.price'] + ) + + await handle_subscription_change(subscription) + + await Cache.invalidate(f"subscription_tier:{user_id}") + await Cache.invalidate(f"credit_balance:{user_id}") + await Cache.invalidate(f"credit_summary:{user_id}") + + balance = await credit_service.get_balance(user_id) + summary = await credit_service.get_account_summary(user_id) + + return { + 'success': True, + 'message': 'Subscription synced successfully', + 'credits': { + 'balance': float(balance), + 'lifetime_granted': float(summary['lifetime_granted']), + 'lifetime_used': float(summary['lifetime_used']) + } + } + + except Exception as e: + logger.error(f"Error syncing subscription: {e}", exc_info=True) + return { + 'success': False, + 'message': f'Failed to sync subscription: {str(e)}' + } + @router.post("/cancel-subscription") async def cancel_subscription( request: CancelSubscriptionRequest, @@ -730,7 +810,7 @@ async def cancel_subscription( db = DBConnection() client = await db.client - sub_result = await client.schema('basejump').from_('billing_subscriptions').select('id').eq('account_id', user_id).eq('status', 'active').execute() + sub_result = await client.schema('basejump').from_('billing_subscriptions').select('id').eq('account_id', user_id).in_('status', ['active', 'trialing']).execute() if not sub_result.data or len(sub_result.data) == 0: raise HTTPException(status_code=404, detail="No active subscription found") @@ -761,7 +841,7 @@ async def reactivate_subscription( db = DBConnection() client = await db.client - sub_result = await client.schema('basejump').from_('billing_subscriptions').select('id').eq('account_id', user_id).eq('status', 'active').execute() + sub_result = await client.schema('basejump').from_('billing_subscriptions').select('id').eq('account_id', user_id).in_('status', ['active', 'trialing']).execute() if not sub_result.data or len(sub_result.data) == 0: raise HTTPException(status_code=404, detail="No active subscription found") @@ -830,54 +910,6 @@ async def get_usage_history( logger.error(f"Error getting usage history: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) -@router.post("/sync-subscription") -async def sync_subscription_from_stripe( - user_id: str = Depends(verify_and_get_user_id_from_jwt) -) -> Dict: - """Manually sync subscription from Stripe to local database and credit system""" - try: - db = DBConnection() - client = await db.client - - # Get customer ID - customer_result = await client.schema('basejump').from_('billing_customers').select('id').eq('account_id', user_id).execute() - - if not customer_result.data or len(customer_result.data) == 0: - return {'success': False, 'message': 'No billing customer found'} - - customer_id = customer_result.data[0]['id'] - - # Get active subscriptions from Stripe - subscriptions = await stripe.Subscription.list_async( - customer=customer_id, - status='active', - expand=['data.items.data.price'] - ) - - if not subscriptions.data: - return {'success': False, 'message': 'No active subscription found in Stripe'} - - # Process the first active subscription - subscription = subscriptions.data[0] - - # Handle the subscription change - await handle_subscription_change(subscription) - - # Get updated balance and tier - balance = await credit_service.get_balance(user_id) - account_result = await client.from_('credit_accounts').select('tier').eq('user_id', user_id).execute() - - return { - 'success': True, - 'message': 'Subscription synced successfully', - 'subscription_id': subscription['id'], - 'tier': account_result.data[0].get('tier') if account_result.data and len(account_result.data) > 0 else 'unknown', - 'balance': float(balance) - } - - except Exception as e: - logger.error(f"Error syncing subscription: {e}", exc_info=True) - return {'success': False, 'message': str(e)} @router.get("/available-models") async def get_available_models( @@ -974,4 +1006,86 @@ async def get_subscription_commitment( 'commitment_type': None, 'months_remaining': None, 'commitment_end_date': None - } \ No newline at end of file + } + +@router.post("/test/trigger-renewal") +async def test_trigger_renewal( + user_id: str = Depends(verify_and_get_user_id_from_jwt) +) -> Dict: + if config.ENV_MODE == EnvMode.PRODUCTION: + raise HTTPException(status_code=403, detail="Test endpoints disabled in production") + + try: + db = DBConnection() + client = await db.client + + # Check credit_accounts as primary source of truth + account_result = await client.from_('credit_accounts')\ + .select('*')\ + .eq('user_id', user_id)\ + .execute() + + if not account_result.data or len(account_result.data) == 0: + return { + 'success': False, + 'message': 'No credit account found. Please subscribe to a plan first.' + } + + account = account_result.data[0] + + # Only check if tier is not free (means they have a subscription) + if account.get('tier', 'free') == 'free': + return { + 'success': False, + 'message': 'No active subscription found. Please subscribe to a plan first.' + } + + tier = account['tier'] + + yesterday = datetime.now(timezone.utc) - timedelta(days=26) + await client.from_('credit_accounts').update({ + 'last_grant_date': yesterday.isoformat() + }).eq('user_id', user_id).execute() + + monthly_credits = get_monthly_credits(tier) + if monthly_credits > 0: + await credit_service.add_credits( + user_id=user_id, + amount=monthly_credits, + type='tier_grant', + description=f"TEST: Monthly {tier} tier credits", + metadata={'test': True, 'triggered_by': 'manual_test'} + ) + + next_grant = datetime.now(timezone.utc) + timedelta(days=30) + await client.from_('credit_accounts').update({ + 'last_grant_date': datetime.now(timezone.utc).isoformat(), + 'next_credit_grant': next_grant.isoformat() + }).eq('user_id', user_id).execute() + + await Cache.invalidate(f"credit_balance:{user_id}") + await Cache.invalidate(f"credit_summary:{user_id}") + await Cache.invalidate(f"subscription_tier:{user_id}") + + new_balance = await credit_service.get_balance(user_id) + + return { + 'success': True, + 'message': f'Successfully granted {monthly_credits} credits', + 'tier': tier, + 'credits_granted': float(monthly_credits), + 'new_balance': float(new_balance), + 'next_grant_date': next_grant.isoformat() + } + else: + return { + 'success': False, + 'message': f'No credits to grant for tier: {tier}' + } + + except Exception as e: + logger.error(f"Error in test renewal trigger: {e}", exc_info=True) + return { + 'success': False, + 'message': str(e) + } \ No newline at end of file diff --git a/backend/services/billing_v3.py b/backend/services/billing_v3.py new file mode 100644 index 00000000..12335758 --- /dev/null +++ b/backend/services/billing_v3.py @@ -0,0 +1,337 @@ +""" +Simplified Billing Service v3 +============================= +Uses only essential sources of truth: +1. Stripe API - Real subscription data +2. credit_accounts - Tier, balance, billing cycle +3. billing_customers - Stripe customer ID mapping +4. credit_ledger - Transaction history +5. agent_runs - Usage tracking + +Removed: +- billing_subscriptions (redundant) +- credit_grants (unused) +- Complex fallback logic +""" + +from fastapi import APIRouter, HTTPException, Depends, Request +from typing import Optional, Dict +from decimal import Decimal +from datetime import datetime, timezone, timedelta +import stripe +from core.credits import credit_service +from core.services.supabase import DBConnection +from core.utils.auth_utils import verify_and_get_user_id_from_jwt +from core.utils.config import config, EnvMode +from core.utils.logger import logger +from core.billing_config import get_tier_by_price_id, TIERS, get_monthly_credits, get_tier_by_name + +router = APIRouter(prefix="/billing/v3", tags=["billing-v3"]) +stripe.api_key = config.STRIPE_SECRET_KEY + +async def get_user_tier(user_id: str) -> str: + """Get user's tier from credit_accounts (single source of truth)""" + db = DBConnection() + client = await db.client + + result = await client.from_('credit_accounts')\ + .select('tier')\ + .eq('user_id', user_id)\ + .single()\ + .execute() + + if result.data: + return result.data.get('tier', 'free') + return 'free' + +async def ensure_credit_account(user_id: str) -> None: + """Ensure user has a credit account (create if missing)""" + db = DBConnection() + client = await db.client + + # Check if account exists + result = await client.from_('credit_accounts')\ + .select('user_id')\ + .eq('user_id', user_id)\ + .execute() + + if not result.data: + # Create free tier account + await client.from_('credit_accounts').insert({ + 'user_id': user_id, + 'balance': 5.0, # Free tier initial credits + 'tier': 'free', + 'last_grant_date': datetime.now(timezone.utc).isoformat() + }).execute() + + # Add initial credits to ledger + await credit_service.add_credits( + user_id=user_id, + amount=Decimal('5.0'), + type='tier_grant', + description='Welcome credits - Free tier' + ) + +@router.get("/subscription") +async def get_subscription_simplified( + user_id: str = Depends(verify_and_get_user_id_from_jwt) +) -> Dict: + """Simplified subscription endpoint using minimal sources""" + try: + # Ensure user has credit account + await ensure_credit_account(user_id) + + db = DBConnection() + client = await db.client + + # Get all user data in one query + account_result = await client.from_('credit_accounts')\ + .select('*')\ + .eq('user_id', user_id)\ + .single()\ + .execute() + + account = account_result.data + tier_name = account.get('tier', 'free') + + # Get tier configuration + tier = get_tier_by_name(tier_name) or TIERS['free'] + + # Calculate actual credits (subscription + free base) + actual_credits = float(tier.monthly_credits) + if tier_name != 'free': + parts = tier_name.split('_') + if len(parts) >= 3 and parts[0] == 'tier': + subscription_cost = float(parts[-1]) + actual_credits = subscription_cost + 5.0 + + # Get Stripe subscription if needed (only for subscription management) + stripe_subscription = None + subscription_id = account.get('stripe_subscription_id') + + if subscription_id: + try: + # Only fetch from Stripe when needed for management + stripe_subscription = await stripe.Subscription.retrieve_async(subscription_id) + except: + # If Stripe fails, we still have local data + pass + + # Get balance and usage + balance = await credit_service.get_balance(user_id) + summary = await credit_service.get_account_summary(user_id) + + return { + 'status': 'active' if tier_name != 'free' else 'no_subscription', + 'tier': tier_name, + 'tier_display': tier.display_name, + 'credits': { + 'balance': float(balance), + 'monthly_allocation': actual_credits, + 'lifetime_used': float(summary['lifetime_used']), + 'can_purchase': tier.can_purchase_credits + }, + 'subscription': { + 'id': subscription_id, + 'status': stripe_subscription.status if stripe_subscription else None, + 'cancel_at_period_end': stripe_subscription.cancel_at_period_end if stripe_subscription else False, + 'current_period_end': stripe_subscription.current_period_end if stripe_subscription else None + } if subscription_id else None, + 'next_credit_grant': account.get('next_credit_grant'), + 'project_limit': tier.project_limit + } + + except Exception as e: + logger.error(f"Error getting subscription: {e}", exc_info=True) + # Return free tier as fallback + return { + 'status': 'no_subscription', + 'tier': 'free', + 'tier_display': 'Free', + 'credits': { + 'balance': 0, + 'monthly_allocation': 5, + 'lifetime_used': 0, + 'can_purchase': False + }, + 'subscription': None, + 'next_credit_grant': None, + 'project_limit': 3 + } + +@router.post("/webhook") +async def stripe_webhook_simplified(request: Request): + """Simplified webhook handler - only updates credit_accounts""" + try: + payload = await request.body() + sig_header = request.headers.get('stripe-signature') + + event = stripe.Webhook.construct_event( + payload, sig_header, config.STRIPE_WEBHOOK_SECRET + ) + + if event.type in ['customer.subscription.created', 'customer.subscription.updated']: + subscription = event.data.object + + if subscription.status in ['active', 'trialing']: + await handle_subscription_simplified(subscription) + + elif event.type == 'customer.subscription.deleted': + subscription = event.data.object + await handle_cancellation_simplified(subscription) + + elif event.type == 'invoice.payment_succeeded': + invoice = event.data.object + if invoice.get('billing_reason') == 'subscription_cycle': + await handle_renewal_simplified(invoice) + + return {'status': 'success'} + + except Exception as e: + logger.error(f"Webhook error: {e}") + raise HTTPException(status_code=400, detail=str(e)) + +async def handle_subscription_simplified(subscription: Dict): + """Handle subscription changes - update only credit_accounts""" + db = DBConnection() + client = await db.client + + # Get user from customer + customer_result = await client.schema('basejump').from_('billing_customers')\ + .select('account_id')\ + .eq('id', subscription['customer'])\ + .single()\ + .execute() + + if not customer_result.data: + return + + user_id = customer_result.data['account_id'] + price_id = subscription['items']['data'][0]['price']['id'] + + # Get tier from price + tier_info = get_tier_by_price_id(price_id) + if not tier_info: + return + + # Ensure account exists + await ensure_credit_account(user_id) + + # Get current account + current_result = await client.from_('credit_accounts')\ + .select('tier')\ + .eq('user_id', user_id)\ + .single()\ + .execute() + + old_tier = current_result.data.get('tier', 'free') + new_tier = tier_info.name + + # Update account + billing_anchor = datetime.fromtimestamp(subscription['current_period_start'], tz=timezone.utc) + next_grant = datetime.fromtimestamp(subscription['current_period_end'], tz=timezone.utc) + + await client.from_('credit_accounts').update({ + 'tier': new_tier, + 'stripe_subscription_id': subscription['id'], + 'billing_cycle_anchor': billing_anchor.isoformat(), + 'next_credit_grant': next_grant.isoformat() + }).eq('user_id', user_id).execute() + + # Handle tier change credits + if old_tier != new_tier: + old_tier_obj = get_tier_by_name(old_tier) or TIERS['free'] + new_credits = float(tier_info.monthly_credits) + old_credits = float(old_tier_obj.monthly_credits) + + if new_credits > old_credits: + # Upgrade - grant full new tier credits + await credit_service.add_credits( + user_id=user_id, + amount=Decimal(str(new_credits)), + type='tier_upgrade', + description=f"Upgrade to {tier_info.display_name}" + ) + +async def handle_cancellation_simplified(subscription: Dict): + """Handle subscription cancellation - just update tier to free""" + db = DBConnection() + client = await db.client + + # Get user from customer + customer_result = await client.schema('basejump').from_('billing_customers')\ + .select('account_id')\ + .eq('id', subscription['customer'])\ + .single()\ + .execute() + + if not customer_result.data: + return + + user_id = customer_result.data['account_id'] + + # Update to free tier (keep existing balance) + await client.from_('credit_accounts').update({ + 'tier': 'free', + 'stripe_subscription_id': None, + 'billing_cycle_anchor': None, + 'next_credit_grant': None + }).eq('user_id', user_id).execute() + +async def handle_renewal_simplified(invoice: Dict): + """Handle monthly renewal - grant credits""" + db = DBConnection() + client = await db.client + + subscription_id = invoice.get('subscription') + if not subscription_id: + return + + # Get user from customer + customer_result = await client.schema('basejump').from_('billing_customers')\ + .select('account_id')\ + .eq('id', invoice['customer'])\ + .single()\ + .execute() + + if not customer_result.data: + return + + user_id = customer_result.data['account_id'] + + # Get account + account_result = await client.from_('credit_accounts')\ + .select('tier, last_grant_date')\ + .eq('user_id', user_id)\ + .single()\ + .execute() + + if not account_result.data: + return + + account = account_result.data + + # Check if enough time has passed (prevent double grants) + if account.get('last_grant_date'): + last_grant = datetime.fromisoformat(account['last_grant_date'].replace('Z', '+00:00')) + if (datetime.now(timezone.utc) - last_grant).days < 25: + return + + # Grant monthly credits + tier = account['tier'] + monthly_credits = get_monthly_credits(tier) + + if monthly_credits > 0: + await credit_service.add_credits( + user_id=user_id, + amount=monthly_credits, + type='tier_grant', + description=f"Monthly {tier} credits" + ) + + # Update grant dates + next_grant = datetime.now(timezone.utc) + timedelta(days=30) + await client.from_('credit_accounts').update({ + 'last_grant_date': datetime.now(timezone.utc).isoformat(), + 'next_credit_grant': next_grant.isoformat() + }).eq('user_id', user_id).execute() \ No newline at end of file diff --git a/frontend/src/app/(dashboard)/(personalAccount)/settings/billing/page.tsx b/frontend/src/app/(dashboard)/(personalAccount)/settings/billing/page.tsx index 3c1624d2..5141666a 100644 --- a/frontend/src/app/(dashboard)/(personalAccount)/settings/billing/page.tsx +++ b/frontend/src/app/(dashboard)/(personalAccount)/settings/billing/page.tsx @@ -8,15 +8,17 @@ import { Skeleton } from '@/components/ui/skeleton'; import { Alert, AlertTitle, AlertDescription } from '@/components/ui/alert'; import { Button } from '@/components/ui/button'; import { useSharedSubscription, useSubscriptionContext } from '@/contexts/SubscriptionContext'; -import { isLocalMode } from '@/lib/config'; +import { isLocalMode, isStagingMode } from '@/lib/config'; import Link from 'next/link'; -import { useCreatePortalSession } from '@/hooks/react-query/use-billing-v2'; +import { useCreatePortalSession, useTriggerTestRenewal } from '@/hooks/react-query/use-billing-v2'; +import { toast } from 'sonner'; const returnUrl = process.env.NEXT_PUBLIC_URL as string; export default function PersonalAccountBillingPage() { const { data: accounts, isLoading, error } = useAccounts(); const [showBillingModal, setShowBillingModal] = useState(false); + const triggerTestRenewal = useTriggerTestRenewal(); const { data: subscriptionData, @@ -146,6 +148,53 @@ export default function PersonalAccountBillingPage() { )} + {isStagingMode() && ( +
+

+ 🧪 Test Mode: Simulate monthly credit renewal +

+ +
+ )} ); diff --git a/frontend/src/components/billing/billing-modal.tsx b/frontend/src/components/billing/billing-modal.tsx index 4dcf48bd..4b3f576d 100644 --- a/frontend/src/components/billing/billing-modal.tsx +++ b/frontend/src/components/billing/billing-modal.tsx @@ -236,55 +236,15 @@ export function BillingModal({ open, onOpenChange, returnUrl = typeof window !== <> - {/* Usage Limit Alert - {showUsageLimitAlert && ( -
-
-
-
- -
-
-

Usage Limit Reached

-

- Your current plan has been exhausted for this billing period. -

-
-
-
-
- )} */} - - {/* Usage section - show loading state or actual data */} - {isLoading || authLoading ? ( -
-
-
- - -
-
-
- ) : subscriptionData && ( -
-
-
- - Agent Usage This Month - - - ${subscriptionData.current_usage?.toFixed(2) || '0'} /{' '} - ${subscriptionData.cost_limit || '0'} - -
-
-
- )} - - {/* Show pricing section immediately - no loading state */} - - - {/* Subscription Management Section - only show if there's actual subscription data */} + { + setTimeout(() => { + fetchSubscriptionData(); + }, 500); + }} + /> {error ? (
@@ -293,7 +253,6 @@ export function BillingModal({ open, onOpenChange, returnUrl = typeof window !==
) : subscriptionData?.subscription && (
- {/* Subscription Status Info Box */}
diff --git a/frontend/src/components/home/sections/pricing-section.tsx b/frontend/src/components/home/sections/pricing-section.tsx index c3f31684..64726c58 100644 --- a/frontend/src/components/home/sections/pricing-section.tsx +++ b/frontend/src/components/home/sections/pricing-section.tsx @@ -584,6 +584,7 @@ interface PricingSectionProps { insideDialog?: boolean; showInfo?: boolean; noPadding?: boolean; + onSubscriptionUpdate?: () => void; } export function PricingSection({ @@ -593,6 +594,7 @@ export function PricingSection({ insideDialog = false, showInfo = true, noPadding = false, + onSubscriptionUpdate, }: PricingSectionProps) { const { data: subscriptionData, isLoading: isFetchingPlan, error: subscriptionQueryError, refetch: refetchSubscription } = useSubscription(); @@ -613,25 +615,21 @@ export function PricingSection({ ); if (currentTier) { - // Check if current subscription is yearly commitment (new yearly) if (currentTier.monthlyCommitmentStripePriceId === currentSubscription.price_id) { return 'yearly_commitment'; } else if (currentTier.yearlyStripePriceId === currentSubscription.price_id) { - // Legacy yearly plans return 'yearly'; } else if (currentTier.stripePriceId === currentSubscription.price_id) { return 'monthly'; } } - // Default to yearly_commitment (new yearly) if we can't determine current plan type return 'yearly_commitment'; }, [isAuthenticated, currentSubscription]); const [billingPeriod, setBillingPeriod] = useState<'monthly' | 'yearly' | 'yearly_commitment'>(getDefaultBillingPeriod()); const [planLoadingStates, setPlanLoadingStates] = useState>({}); - // Update billing period when subscription data changes useEffect(() => { setBillingPeriod(getDefaultBillingPeriod()); }, [getDefaultBillingPeriod]); @@ -647,6 +645,10 @@ export function PricingSection({ setTimeout(() => { setPlanLoadingStates({}); }, 1000); + // Call parent's update handler if provided + if (onSubscriptionUpdate) { + onSubscriptionUpdate(); + } }; diff --git a/frontend/src/hooks/react-query/use-billing-v2.ts b/frontend/src/hooks/react-query/use-billing-v2.ts index 250c886b..5d4c758d 100644 --- a/frontend/src/hooks/react-query/use-billing-v2.ts +++ b/frontend/src/hooks/react-query/use-billing-v2.ts @@ -137,4 +137,15 @@ export const useDeductTokenUsage = () => { queryClient.invalidateQueries({ queryKey: billingKeys.status() }); }, }); +}; + +export const useTriggerTestRenewal = () => { + const queryClient = useQueryClient(); + + return useMutation({ + mutationFn: () => billingApiV2.triggerTestRenewal(), + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: billingKeys.all }); + }, + }); }; \ No newline at end of file diff --git a/frontend/src/lib/api.ts b/frontend/src/lib/api.ts index a35c39ef..8b5f4b6e 100644 --- a/frontend/src/lib/api.ts +++ b/frontend/src/lib/api.ts @@ -1963,8 +1963,7 @@ export const getSubscription = async (): Promise => { } const data = await response.json(); - - // Map the new billing v2 format to the old format for backward compatibility + return { subscription: data.subscription ? { ...data.subscription, @@ -1974,7 +1973,7 @@ export const getSubscription = async (): Promise => { cost_limit: data.tier?.credits || 0, credit_balance: data.credits?.balance || 0, can_purchase_credits: data.credits?.can_purchase || false, - ...data // Include any other fields that might be used + ...data } as SubscriptionStatus; } catch (error) { if (error instanceof NoAccessTokenAvailableError) { diff --git a/frontend/src/lib/api/billing-v2.ts b/frontend/src/lib/api/billing-v2.ts index bead018c..1f36902f 100644 --- a/frontend/src/lib/api/billing-v2.ts +++ b/frontend/src/lib/api/billing-v2.ts @@ -121,6 +121,13 @@ export interface ReactivateSubscriptionResponse { message: string; } +export interface TestRenewalResponse { + success: boolean; + message?: string; + credits_granted?: number; + new_balance?: number; +} + export const billingApiV2 = { async getSubscription() { const response = await backendApi.get('/billing/v2/subscription'); @@ -204,6 +211,12 @@ export const billingApiV2 = { ); if (response.error) throw response.error; return response.data!; + }, + + async triggerTestRenewal() { + const response = await backendApi.post('/billing/v2/test/trigger-renewal'); + if (response.error) throw response.error; + return response.data!; } }; @@ -222,4 +235,5 @@ export const purchaseCredits = (request: PurchaseCreditsRequest) => billingApiV2.purchaseCredits(request); export const getTransactions = (limit?: number, offset?: number) => billingApiV2.getTransactions(limit, offset); -export const getUsageHistory = (days?: number) => billingApiV2.getUsageHistory(days); \ No newline at end of file +export const getUsageHistory = (days?: number) => billingApiV2.getUsageHistory(days); +export const triggerTestRenewal = () => billingApiV2.triggerTestRenewal(); \ No newline at end of file