simplify blling arch

This commit is contained in:
Saumya 2025-09-06 15:26:09 +05:30
parent 37ca540f24
commit 1f5c0e9be2
11 changed files with 772 additions and 236 deletions

View File

@ -36,7 +36,7 @@ TIERS: Dict[str, Tier] = {
config.STRIPE_TIER_2_20_YEARLY_ID,
config.STRIPE_TIER_2_17_YEARLY_COMMITMENT_ID
],
monthly_credits=Decimal('25.00'),
monthly_credits=Decimal('20.00'),
display_name='Starter',
can_purchase_credits=True,
models=['gpt-4o-mini', 'gpt-4o', 'claude-3-haiku', 'claude-3-5-sonnet'],
@ -49,7 +49,7 @@ TIERS: Dict[str, Tier] = {
config.STRIPE_TIER_6_50_YEARLY_ID,
config.STRIPE_TIER_6_42_YEARLY_COMMITMENT_ID
],
monthly_credits=Decimal('65.00'),
monthly_credits=Decimal('50.00'),
display_name='Professional',
can_purchase_credits=True,
models=['gpt-4o-mini', 'gpt-4o', 'claude-3-haiku', 'claude-3-5-sonnet', 'claude-3-opus'],
@ -61,7 +61,7 @@ TIERS: Dict[str, Tier] = {
config.STRIPE_TIER_12_100_ID,
config.STRIPE_TIER_12_100_YEARLY_ID
],
monthly_credits=Decimal('130.00'),
monthly_credits=Decimal('100.00'),
display_name='Team',
can_purchase_credits=True,
models=['gpt-4o-mini', 'gpt-4o', 'claude-3-haiku', 'claude-3-5-sonnet', 'claude-3-opus'],
@ -74,7 +74,7 @@ TIERS: Dict[str, Tier] = {
config.STRIPE_TIER_25_200_YEARLY_ID,
config.STRIPE_TIER_25_170_YEARLY_COMMITMENT_ID
],
monthly_credits=Decimal('260.00'),
monthly_credits=Decimal('200.00'),
display_name='Business',
can_purchase_credits=True,
models=['gpt-4o-mini', 'gpt-4o', 'claude-3-haiku', 'claude-3-5-sonnet', 'claude-3-opus', 'o1-preview'],
@ -86,7 +86,7 @@ TIERS: Dict[str, Tier] = {
config.STRIPE_TIER_50_400_ID,
config.STRIPE_TIER_50_400_YEARLY_ID
],
monthly_credits=Decimal('520.00'),
monthly_credits=Decimal('400.00'),
display_name='Enterprise',
can_purchase_credits=True,
models=['gpt-4o-mini', 'gpt-4o', 'claude-3-haiku', 'claude-3-5-sonnet', 'claude-3-opus', 'o1-preview', 'o1'],
@ -98,7 +98,7 @@ TIERS: Dict[str, Tier] = {
config.STRIPE_TIER_125_800_ID,
config.STRIPE_TIER_125_800_YEARLY_ID
],
monthly_credits=Decimal('1040.00'),
monthly_credits=Decimal('800.00'),
display_name='Enterprise Plus',
can_purchase_credits=True,
models=['gpt-4o-mini', 'gpt-4o', 'claude-3-haiku', 'claude-3-5-sonnet', 'claude-3-opus', 'o1-preview', 'o1'],
@ -110,7 +110,7 @@ TIERS: Dict[str, Tier] = {
config.STRIPE_TIER_200_1000_ID,
config.STRIPE_TIER_200_1000_YEARLY_ID
],
monthly_credits=Decimal('1300.00'),
monthly_credits=Decimal('1000.00'),
display_name='Ultimate',
can_purchase_credits=True,
models=['gpt-4o-mini', 'gpt-4o', 'claude-3-haiku', 'claude-3-5-sonnet', 'claude-3-opus', 'o1-preview', 'o1'],

View File

@ -158,71 +158,38 @@ async def sync_stripe_to_db():
print(f" ✓ Synced subscription to database")
synced_count += 1
# For active/trialing subscriptions, use handle_subscription_change for comprehensive update
if sub.status in ['active', 'trialing']:
print(f" Processing active subscription...")
# Get current state
account_before = await client.from_('credit_accounts')\
.select('*')\
.eq('user_id', account_id)\
.maybe_single()\
.execute()
# Use handle_subscription_change to properly update all fields
# This handles billing_cycle_anchor, stripe_subscription_id, next_credit_grant, etc.
try:
await handle_subscription_change(sub)
print(f" ✓ Updated billing cycle and credit fields")
except Exception as e:
print(f" ⚠ handle_subscription_change failed: {e}")
# Now fix the balance based on actual usage from agent_runs
tier = get_tier_by_price_id(price_id) if price_id else None
if tier:
tier_name = tier.name
# Calculate actual monthly credits (subscription cost + free tier base)
parts = tier_name.split('_')
if len(parts) == 3 and parts[0] == 'tier':
subscription_cost = Decimal(parts[2])
monthly_credits = subscription_cost + FREE_TIER_INITIAL_CREDITS
else:
monthly_credits = FREE_TIER_INITIAL_CREDITS
# Get current month usage from agent_runs (old billing system style)
month_start = datetime.now().replace(day=1, hour=0, minute=0, second=0, microsecond=0)
runs_result = await client.from_('agent_runs')\
.select('total_cost')\
.eq('user_id', account_id)\
.gte('created_at', month_start.isoformat())\
.execute()
current_usage = Decimal('0')
for run in runs_result.data or []:
if run.get('total_cost'):
current_usage += Decimal(str(run['total_cost']))
# Calculate correct balance
correct_balance = monthly_credits - current_usage
# Update/create credit account with all the proper fields
credit_account = await client.from_('credit_accounts')\
.select('*')\
.eq('user_id', account_id)\
.maybe_single()\
.execute()
if credit_account.data:
# Update existing account
account_after = await client.from_('credit_accounts')\
.select('*')\
.eq('user_id', account_id)\
.maybe_single()\
.execute()
if account_after.data:
tier = get_tier_by_price_id(price_id) if price_id else None
if tier:
tier_name = tier.name
update_data = {
'tier': tier_name,
'balance': str(correct_balance),
'stripe_subscription_id': sub.id,
'billing_cycle_anchor': datetime.fromtimestamp(sub.created).isoformat() if sub.created else None,
}
# Add next_credit_grant if we have current_period_end
if sub.current_period_end:
update_data['next_credit_grant'] = datetime.fromtimestamp(sub.current_period_end).isoformat()
@ -231,27 +198,14 @@ async def sync_stripe_to_db():
.eq('user_id', account_id)\
.execute()
current_balance = Decimal(str(account_after.data['balance']))
print(f" ✓ Updated credit account:")
print(f" - Tier: {tier_name}")
print(f" - Balance: ${correct_balance} (${monthly_credits} - ${current_usage} usage)")
print(f" - Balance: ${current_balance} (managed by handle_subscription_change)")
print(f" - Stripe subscription: {sub.id}")
print(f" - Next credit grant: {update_data.get('next_credit_grant', 'N/A')}")
else:
# Create new account with all fields
await client.from_('credit_accounts').insert({
'user_id': account_id,
'balance': str(correct_balance),
'tier': tier_name,
'stripe_subscription_id': sub.id,
'billing_cycle_anchor': datetime.fromtimestamp(sub.created).isoformat() if sub.created else None,
'next_credit_grant': datetime.fromtimestamp(sub.current_period_end).isoformat() if sub.current_period_end else None,
'last_grant_date': datetime.now().isoformat()
}).execute()
print(f" ✓ Created credit account:")
print(f" - Tier: {tier_name}")
print(f" - Balance: ${correct_balance} (${monthly_credits} - ${current_usage} usage)")
print(f" - Stripe subscription: {sub.id}")
else:
print(f" ⚠ No credit account found after handle_subscription_change - this is unexpected")
else:
print(f" ✗ Failed to sync")

View File

@ -0,0 +1,97 @@
#!/usr/bin/env python3
import asyncio
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
from core.services.supabase import DBConnection
from core.billing_config import get_tier_by_price_id
from core.utils.logger import logger
async def verify_and_fix_tiers():
"""Verify that all users have the correct tier in credit_accounts"""
db = DBConnection()
await db.initialize()
client = await db.client
print("\n" + "="*60)
print("VERIFYING USER TIERS")
print("="*60)
# Get all active/trialing subscriptions
subs_result = await client.schema('basejump').from_('billing_subscriptions')\
.select('*')\
.in_('status', ['active', 'trialing'])\
.execute()
if not subs_result.data:
print("No active subscriptions found")
return
print(f"Found {len(subs_result.data)} active/trialing subscriptions")
fixed = 0
already_correct = 0
errors = 0
for sub in subs_result.data:
user_id = sub['account_id']
price_id = sub.get('price_id')
if not price_id:
print(f"\nUser {user_id[:8]}... has no price_id in subscription")
continue
# Get the tier from price_id
tier = get_tier_by_price_id(price_id)
if not tier:
print(f"\nUser {user_id[:8]}... has unknown price_id: {price_id}")
continue
expected_tier = tier.name
# Get current credit account
credit_result = await client.from_('credit_accounts')\
.select('tier')\
.eq('user_id', user_id)\
.single()\
.execute()
if not credit_result.data:
print(f"\nUser {user_id[:8]}... has no credit account!")
errors += 1
continue
current_tier = credit_result.data['tier']
if current_tier != expected_tier:
print(f"\nUser {user_id[:8]}...")
print(f" Current tier: {current_tier}")
print(f" Expected tier: {expected_tier}")
print(f" Price ID: {price_id}")
# Fix the tier
update_result = await client.from_('credit_accounts')\
.update({'tier': expected_tier})\
.eq('user_id', user_id)\
.execute()
if update_result.data:
print(f" ✓ Fixed tier to: {expected_tier}")
fixed += 1
else:
print(f" ✗ Failed to update tier")
errors += 1
else:
already_correct += 1
print("\n" + "="*60)
print("SUMMARY")
print("="*60)
print(f"Already correct: {already_correct}")
print(f"Fixed: {fixed}")
print(f"Errors: {errors}")
print(f"Total: {already_correct + fixed + errors}")
if __name__ == "__main__":
asyncio.run(verify_and_fix_tiers())

View File

@ -69,7 +69,7 @@ async def get_user_subscription_tier(user_id: str) -> Dict:
if credit_result.data and len(credit_result.data) > 0:
tier_name = credit_result.data[0].get('tier', 'free')
else:
result = await client.schema('basejump').from_('billing_subscriptions').select('price_id, status').eq('account_id', user_id).eq('status', 'active').order('created', desc=True).limit(1).execute()
result = await client.schema('basejump').from_('billing_subscriptions').select('price_id, status').eq('account_id', user_id).in_('status', ['active', 'trialing']).order('created', desc=True).limit(1).execute()
if result.data:
price_id = result.data[0]['price_id']
@ -357,12 +357,11 @@ async def stripe_webhook(request: Request):
elif event.type in ['customer.subscription.created', 'customer.subscription.updated']:
subscription = event.data.object
if subscription.status == 'active':
if subscription.status in ['active', 'trialing']:
await handle_subscription_change(subscription)
elif event.type == 'invoice.payment_succeeded':
invoice = event.data.object
# Only process subscription invoices (not one-time payments)
if invoice.get('subscription') and invoice.get('billing_reason') in ['subscription_cycle', 'subscription_update']:
await handle_subscription_renewal(invoice)
@ -394,7 +393,6 @@ async def handle_subscription_change(subscription: Dict):
'credits': float(new_tier_info.monthly_credits)
}
# Get billing cycle anchor from subscription
billing_anchor = datetime.fromtimestamp(subscription['current_period_start'], tz=timezone.utc)
next_grant_date = datetime.fromtimestamp(subscription['current_period_end'], tz=timezone.utc)
@ -412,16 +410,18 @@ async def handle_subscription_change(subscription: Dict):
'credits': float(current_tier_info.monthly_credits)
}
# Only grant credits for upgrades if it's not a renewal
if current_tier and new_tier['credits'] > current_tier['credits'] and not existing_anchor:
credit_diff = Decimal(new_tier['credits'] - current_tier['credits'])
await credit_service.add_credits(
user_id=user_id,
amount=credit_diff,
type='tier_upgrade',
description=f"Upgrade from {current_tier['name']} to {new_tier['name']}"
)
logger.info(f"Granted {credit_diff} credits for tier upgrade: {current_tier['name']} -> {new_tier['name']}")
if current_tier and current_tier['name'] != new_tier['name']:
if new_tier['credits'] > current_tier['credits']:
full_amount = Decimal(new_tier['credits'])
await credit_service.add_credits(
user_id=user_id,
amount=full_amount,
type='tier_upgrade',
description=f"Upgrade to {new_tier['name']} - Full tier credits"
)
logger.info(f"Granted {full_amount} credits for tier upgrade: {current_tier['name']} -> {new_tier['name']}")
elif new_tier['credits'] < current_tier['credits']:
logger.info(f"Tier downgrade: {current_tier['name']} -> {new_tier['name']} - No credit adjustment")
await client.from_('credit_accounts').update({
'tier': new_tier['name'],
@ -430,7 +430,6 @@ async def handle_subscription_change(subscription: Dict):
'next_credit_grant': next_grant_date.isoformat()
}).eq('user_id', user_id).execute()
else:
# New subscription - grant initial credits
await client.from_('credit_accounts').insert({
'user_id': user_id,
'balance': new_tier['credits'],
@ -449,7 +448,6 @@ async def handle_subscription_change(subscription: Dict):
)
async def handle_subscription_renewal(invoice: Dict):
"""Handle subscription renewal and grant monthly credits based on billing cycle"""
try:
db = DBConnection()
client = await db.client
@ -458,7 +456,6 @@ async def handle_subscription_renewal(invoice: Dict):
if not subscription_id:
return
# Get user from subscription
customer_result = await client.schema('basejump').from_('billing_customers')\
.select('account_id')\
.eq('id', invoice['customer'])\
@ -469,7 +466,6 @@ async def handle_subscription_renewal(invoice: Dict):
user_id = customer_result.data[0]['account_id']
# Get credit account
account_result = await client.from_('credit_accounts')\
.select('tier, last_grant_date, next_credit_grant')\
.eq('user_id', user_id)\
@ -481,14 +477,12 @@ async def handle_subscription_renewal(invoice: Dict):
account = account_result.data[0]
tier = account['tier']
# Skip if we already granted credits recently (within 25 days to handle edge cases)
if account.get('last_grant_date'):
last_grant = datetime.fromisoformat(account['last_grant_date'].replace('Z', '+00:00'))
if (datetime.now(timezone.utc) - last_grant).days < 25:
logger.info(f"Skipping credit grant for {user_id} - already granted {(datetime.now(timezone.utc) - last_grant).days} days ago")
return
# Grant monthly credits
monthly_credits = get_monthly_credits(tier)
if monthly_credits > 0:
await credit_service.add_credits(
@ -499,7 +493,6 @@ async def handle_subscription_renewal(invoice: Dict):
metadata={'invoice_id': invoice['id'], 'subscription_id': subscription_id}
)
# Update next grant date
next_grant = datetime.now(timezone.utc) + timedelta(days=30)
await client.from_('credit_accounts').update({
'last_grant_date': datetime.now(timezone.utc).isoformat(),
@ -510,19 +503,6 @@ async def handle_subscription_renewal(invoice: Dict):
except Exception as e:
logger.error(f"Error handling subscription renewal: {e}")
await client.schema('basejump').from_('billing_subscriptions').upsert({
'id': subscription['id'],
'account_id': user_id,
'billing_customer_id': subscription['customer'],
'status': subscription['status'],
'price_id': price_id,
'current_period_end': datetime.fromtimestamp(subscription['current_period_end']).isoformat() if subscription.get('current_period_end') else None,
'cancel_at': datetime.fromtimestamp(subscription['cancel_at']).isoformat() if subscription.get('cancel_at') else None,
'cancel_at_period_end': subscription.get('cancel_at_period_end', False)
}).execute()
await Cache.invalidate(f"subscription_tier:{user_id}")
@router.get("/subscription")
async def get_subscription(
@ -532,29 +512,55 @@ async def get_subscription(
db = DBConnection()
client = await db.client
subscription_result = await client.schema('basejump').from_('billing_subscriptions').select('*').eq('account_id', user_id).eq('status', 'active').order('created', desc=True).limit(1).execute()
subscription_result = await client.schema('basejump').from_('billing_subscriptions')\
.select('*')\
.eq('account_id', user_id)\
.in_('status', ['active', 'trialing'])\
.order('created', desc=True)\
.limit(1)\
.execute()
subscription_data = None
credit_result = await client.from_('credit_accounts').select('tier').eq('user_id', user_id).execute()
tier_name = credit_result.data[0]['tier'] if credit_result.data else 'free'
tier_obj = TIERS.get(tier_name, TIERS['free'])
credit_result = await client.from_('credit_accounts').select('*').eq('user_id', user_id).execute()
tier_info = {
'name': tier_obj.name,
'credits': float(tier_obj.monthly_credits)
}
price_id = config.STRIPE_FREE_TIER_ID
if credit_result.data:
credit_account = credit_result.data[0]
tier_name = credit_account.get('tier', 'free')
tier_obj = TIERS.get(tier_name, TIERS['free'])
actual_credits = float(tier_obj.monthly_credits)
if tier_name != 'free':
parts = tier_name.split('_')
if len(parts) >= 3 and parts[0] == 'tier':
subscription_cost = float(parts[-1])
actual_credits = subscription_cost + 5.0
tier_info = {
'name': tier_obj.name,
'credits': actual_credits
}
if tier_obj and len(tier_obj.price_ids) > 0:
price_id = tier_obj.price_ids[0]
if subscription_result.data:
subscription = subscription_result.data[0]
if subscription.get('price_id'):
price_id = subscription['price_id']
else:
price_id = config.STRIPE_FREE_TIER_ID
else:
tier_name = 'free'
tier_obj = TIERS['free']
tier_info = {
'name': 'free',
'credits': float(TIERS['free'].monthly_credits)
}
price_id = config.STRIPE_FREE_TIER_ID
if subscription_result.data:
subscription = subscription_result.data[0]
price_id = subscription['price_id']
price_tier_obj = get_tier_by_price_id(price_id)
if price_tier_obj:
tier_info = {
'name': price_tier_obj.name,
'credits': float(price_tier_obj.monthly_credits)
}
stripe_subscription = None
try:
@ -598,8 +604,15 @@ async def get_subscription(
balance = await credit_service.get_balance(user_id)
summary = await credit_service.get_account_summary(user_id)
if subscription_data:
status = 'active'
elif tier_name != 'free':
status = 'cancelled'
else:
status = 'no_subscription'
return {
'status': 'active' if subscription_data else 'no_subscription',
'status': status,
'plan_name': tier_info['name'],
'price_id': price_id,
'subscription': subscription_data,
@ -607,7 +620,7 @@ async def get_subscription(
'current_usage': float(summary['lifetime_used']),
'cost_limit': tier_info['credits'],
'credit_balance': float(balance),
'can_purchase_credits': tier_info['name'] == 'tier_200_1000',
'can_purchase_credits': tier_info.get('name') in ['tier_200_1000', 'tier_50_400', 'tier_125_800'],
'tier': tier_info,
'credits': {
'balance': float(balance),
@ -615,7 +628,7 @@ async def get_subscription(
'lifetime_granted': float(summary['lifetime_granted']),
'lifetime_purchased': float(summary['lifetime_purchased']),
'lifetime_used': float(summary['lifetime_used']),
'can_purchase': tier_info['name'] == 'tier_200_1000'
'can_purchase': tier_info.get('name') in ['tier_200_1000', 'tier_50_400', 'tier_125_800']
}
}
@ -658,7 +671,7 @@ async def create_checkout_session(
db = DBConnection()
client = await db.client
existing_sub = await client.schema('basejump').from_('billing_subscriptions').select('id').eq('account_id', user_id).eq('status', 'active').execute()
existing_sub = await client.schema('basejump').from_('billing_subscriptions').select('id').eq('account_id', user_id).in_('status', ['active', 'trialing']).execute()
if existing_sub.data and len(existing_sub.data) > 0:
subscription = await stripe.Subscription.retrieve_async(existing_sub.data[0]['id'])
@ -673,12 +686,28 @@ async def create_checkout_session(
payment_behavior='pending_if_incomplete'
)
await handle_subscription_change(updated_subscription)
await Cache.invalidate(f"subscription_tier:{user_id}")
await Cache.invalidate(f"credit_balance:{user_id}")
await Cache.invalidate(f"credit_summary:{user_id}")
old_price_id = subscription['items']['data'][0].price.id
old_tier = get_tier_by_price_id(old_price_id)
new_tier = get_tier_by_price_id(request.price_id)
old_amount = float(old_tier.monthly_credits) if old_tier else 0
new_amount = float(new_tier.monthly_credits) if new_tier else 0
return {
'success': True,
'status': 'upgraded' if new_amount > old_amount else 'updated',
'subscription_id': updated_subscription.id,
'message': 'Subscription updated successfully'
'message': 'Subscription updated successfully',
'details': {
'is_upgrade': new_amount > old_amount,
'current_price': old_amount,
'new_price': new_amount
}
}
else:
session = await stripe.checkout.Session.create_async(
@ -695,7 +724,6 @@ async def create_checkout_session(
}
}
)
return {'checkout_url': session.url}
except Exception as e:
@ -721,6 +749,58 @@ async def create_portal_session(
logger.error(f"Error creating portal session: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/sync-subscription")
async def sync_subscription(
user_id: str = Depends(verify_and_get_user_id_from_jwt)
) -> Dict:
try:
db = DBConnection()
client = await db.client
sub_result = await client.schema('basejump').from_('billing_subscriptions')\
.select('*')\
.eq('account_id', user_id)\
.in_('status', ['active', 'trialing'])\
.limit(1)\
.execute()
if not sub_result.data or len(sub_result.data) == 0:
return {
'success': False,
'message': 'No active subscription found'
}
subscription = await stripe.Subscription.retrieve_async(
sub_result.data[0]['id'],
expand=['items.data.price']
)
await handle_subscription_change(subscription)
await Cache.invalidate(f"subscription_tier:{user_id}")
await Cache.invalidate(f"credit_balance:{user_id}")
await Cache.invalidate(f"credit_summary:{user_id}")
balance = await credit_service.get_balance(user_id)
summary = await credit_service.get_account_summary(user_id)
return {
'success': True,
'message': 'Subscription synced successfully',
'credits': {
'balance': float(balance),
'lifetime_granted': float(summary['lifetime_granted']),
'lifetime_used': float(summary['lifetime_used'])
}
}
except Exception as e:
logger.error(f"Error syncing subscription: {e}", exc_info=True)
return {
'success': False,
'message': f'Failed to sync subscription: {str(e)}'
}
@router.post("/cancel-subscription")
async def cancel_subscription(
request: CancelSubscriptionRequest,
@ -730,7 +810,7 @@ async def cancel_subscription(
db = DBConnection()
client = await db.client
sub_result = await client.schema('basejump').from_('billing_subscriptions').select('id').eq('account_id', user_id).eq('status', 'active').execute()
sub_result = await client.schema('basejump').from_('billing_subscriptions').select('id').eq('account_id', user_id).in_('status', ['active', 'trialing']).execute()
if not sub_result.data or len(sub_result.data) == 0:
raise HTTPException(status_code=404, detail="No active subscription found")
@ -761,7 +841,7 @@ async def reactivate_subscription(
db = DBConnection()
client = await db.client
sub_result = await client.schema('basejump').from_('billing_subscriptions').select('id').eq('account_id', user_id).eq('status', 'active').execute()
sub_result = await client.schema('basejump').from_('billing_subscriptions').select('id').eq('account_id', user_id).in_('status', ['active', 'trialing']).execute()
if not sub_result.data or len(sub_result.data) == 0:
raise HTTPException(status_code=404, detail="No active subscription found")
@ -830,54 +910,6 @@ async def get_usage_history(
logger.error(f"Error getting usage history: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/sync-subscription")
async def sync_subscription_from_stripe(
user_id: str = Depends(verify_and_get_user_id_from_jwt)
) -> Dict:
"""Manually sync subscription from Stripe to local database and credit system"""
try:
db = DBConnection()
client = await db.client
# Get customer ID
customer_result = await client.schema('basejump').from_('billing_customers').select('id').eq('account_id', user_id).execute()
if not customer_result.data or len(customer_result.data) == 0:
return {'success': False, 'message': 'No billing customer found'}
customer_id = customer_result.data[0]['id']
# Get active subscriptions from Stripe
subscriptions = await stripe.Subscription.list_async(
customer=customer_id,
status='active',
expand=['data.items.data.price']
)
if not subscriptions.data:
return {'success': False, 'message': 'No active subscription found in Stripe'}
# Process the first active subscription
subscription = subscriptions.data[0]
# Handle the subscription change
await handle_subscription_change(subscription)
# Get updated balance and tier
balance = await credit_service.get_balance(user_id)
account_result = await client.from_('credit_accounts').select('tier').eq('user_id', user_id).execute()
return {
'success': True,
'message': 'Subscription synced successfully',
'subscription_id': subscription['id'],
'tier': account_result.data[0].get('tier') if account_result.data and len(account_result.data) > 0 else 'unknown',
'balance': float(balance)
}
except Exception as e:
logger.error(f"Error syncing subscription: {e}", exc_info=True)
return {'success': False, 'message': str(e)}
@router.get("/available-models")
async def get_available_models(
@ -974,4 +1006,86 @@ async def get_subscription_commitment(
'commitment_type': None,
'months_remaining': None,
'commitment_end_date': None
}
}
@router.post("/test/trigger-renewal")
async def test_trigger_renewal(
user_id: str = Depends(verify_and_get_user_id_from_jwt)
) -> Dict:
if config.ENV_MODE == EnvMode.PRODUCTION:
raise HTTPException(status_code=403, detail="Test endpoints disabled in production")
try:
db = DBConnection()
client = await db.client
# Check credit_accounts as primary source of truth
account_result = await client.from_('credit_accounts')\
.select('*')\
.eq('user_id', user_id)\
.execute()
if not account_result.data or len(account_result.data) == 0:
return {
'success': False,
'message': 'No credit account found. Please subscribe to a plan first.'
}
account = account_result.data[0]
# Only check if tier is not free (means they have a subscription)
if account.get('tier', 'free') == 'free':
return {
'success': False,
'message': 'No active subscription found. Please subscribe to a plan first.'
}
tier = account['tier']
yesterday = datetime.now(timezone.utc) - timedelta(days=26)
await client.from_('credit_accounts').update({
'last_grant_date': yesterday.isoformat()
}).eq('user_id', user_id).execute()
monthly_credits = get_monthly_credits(tier)
if monthly_credits > 0:
await credit_service.add_credits(
user_id=user_id,
amount=monthly_credits,
type='tier_grant',
description=f"TEST: Monthly {tier} tier credits",
metadata={'test': True, 'triggered_by': 'manual_test'}
)
next_grant = datetime.now(timezone.utc) + timedelta(days=30)
await client.from_('credit_accounts').update({
'last_grant_date': datetime.now(timezone.utc).isoformat(),
'next_credit_grant': next_grant.isoformat()
}).eq('user_id', user_id).execute()
await Cache.invalidate(f"credit_balance:{user_id}")
await Cache.invalidate(f"credit_summary:{user_id}")
await Cache.invalidate(f"subscription_tier:{user_id}")
new_balance = await credit_service.get_balance(user_id)
return {
'success': True,
'message': f'Successfully granted {monthly_credits} credits',
'tier': tier,
'credits_granted': float(monthly_credits),
'new_balance': float(new_balance),
'next_grant_date': next_grant.isoformat()
}
else:
return {
'success': False,
'message': f'No credits to grant for tier: {tier}'
}
except Exception as e:
logger.error(f"Error in test renewal trigger: {e}", exc_info=True)
return {
'success': False,
'message': str(e)
}

View File

@ -0,0 +1,337 @@
"""
Simplified Billing Service v3
=============================
Uses only essential sources of truth:
1. Stripe API - Real subscription data
2. credit_accounts - Tier, balance, billing cycle
3. billing_customers - Stripe customer ID mapping
4. credit_ledger - Transaction history
5. agent_runs - Usage tracking
Removed:
- billing_subscriptions (redundant)
- credit_grants (unused)
- Complex fallback logic
"""
from fastapi import APIRouter, HTTPException, Depends, Request
from typing import Optional, Dict
from decimal import Decimal
from datetime import datetime, timezone, timedelta
import stripe
from core.credits import credit_service
from core.services.supabase import DBConnection
from core.utils.auth_utils import verify_and_get_user_id_from_jwt
from core.utils.config import config, EnvMode
from core.utils.logger import logger
from core.billing_config import get_tier_by_price_id, TIERS, get_monthly_credits, get_tier_by_name
router = APIRouter(prefix="/billing/v3", tags=["billing-v3"])
stripe.api_key = config.STRIPE_SECRET_KEY
async def get_user_tier(user_id: str) -> str:
"""Get user's tier from credit_accounts (single source of truth)"""
db = DBConnection()
client = await db.client
result = await client.from_('credit_accounts')\
.select('tier')\
.eq('user_id', user_id)\
.single()\
.execute()
if result.data:
return result.data.get('tier', 'free')
return 'free'
async def ensure_credit_account(user_id: str) -> None:
"""Ensure user has a credit account (create if missing)"""
db = DBConnection()
client = await db.client
# Check if account exists
result = await client.from_('credit_accounts')\
.select('user_id')\
.eq('user_id', user_id)\
.execute()
if not result.data:
# Create free tier account
await client.from_('credit_accounts').insert({
'user_id': user_id,
'balance': 5.0, # Free tier initial credits
'tier': 'free',
'last_grant_date': datetime.now(timezone.utc).isoformat()
}).execute()
# Add initial credits to ledger
await credit_service.add_credits(
user_id=user_id,
amount=Decimal('5.0'),
type='tier_grant',
description='Welcome credits - Free tier'
)
@router.get("/subscription")
async def get_subscription_simplified(
user_id: str = Depends(verify_and_get_user_id_from_jwt)
) -> Dict:
"""Simplified subscription endpoint using minimal sources"""
try:
# Ensure user has credit account
await ensure_credit_account(user_id)
db = DBConnection()
client = await db.client
# Get all user data in one query
account_result = await client.from_('credit_accounts')\
.select('*')\
.eq('user_id', user_id)\
.single()\
.execute()
account = account_result.data
tier_name = account.get('tier', 'free')
# Get tier configuration
tier = get_tier_by_name(tier_name) or TIERS['free']
# Calculate actual credits (subscription + free base)
actual_credits = float(tier.monthly_credits)
if tier_name != 'free':
parts = tier_name.split('_')
if len(parts) >= 3 and parts[0] == 'tier':
subscription_cost = float(parts[-1])
actual_credits = subscription_cost + 5.0
# Get Stripe subscription if needed (only for subscription management)
stripe_subscription = None
subscription_id = account.get('stripe_subscription_id')
if subscription_id:
try:
# Only fetch from Stripe when needed for management
stripe_subscription = await stripe.Subscription.retrieve_async(subscription_id)
except:
# If Stripe fails, we still have local data
pass
# Get balance and usage
balance = await credit_service.get_balance(user_id)
summary = await credit_service.get_account_summary(user_id)
return {
'status': 'active' if tier_name != 'free' else 'no_subscription',
'tier': tier_name,
'tier_display': tier.display_name,
'credits': {
'balance': float(balance),
'monthly_allocation': actual_credits,
'lifetime_used': float(summary['lifetime_used']),
'can_purchase': tier.can_purchase_credits
},
'subscription': {
'id': subscription_id,
'status': stripe_subscription.status if stripe_subscription else None,
'cancel_at_period_end': stripe_subscription.cancel_at_period_end if stripe_subscription else False,
'current_period_end': stripe_subscription.current_period_end if stripe_subscription else None
} if subscription_id else None,
'next_credit_grant': account.get('next_credit_grant'),
'project_limit': tier.project_limit
}
except Exception as e:
logger.error(f"Error getting subscription: {e}", exc_info=True)
# Return free tier as fallback
return {
'status': 'no_subscription',
'tier': 'free',
'tier_display': 'Free',
'credits': {
'balance': 0,
'monthly_allocation': 5,
'lifetime_used': 0,
'can_purchase': False
},
'subscription': None,
'next_credit_grant': None,
'project_limit': 3
}
@router.post("/webhook")
async def stripe_webhook_simplified(request: Request):
"""Simplified webhook handler - only updates credit_accounts"""
try:
payload = await request.body()
sig_header = request.headers.get('stripe-signature')
event = stripe.Webhook.construct_event(
payload, sig_header, config.STRIPE_WEBHOOK_SECRET
)
if event.type in ['customer.subscription.created', 'customer.subscription.updated']:
subscription = event.data.object
if subscription.status in ['active', 'trialing']:
await handle_subscription_simplified(subscription)
elif event.type == 'customer.subscription.deleted':
subscription = event.data.object
await handle_cancellation_simplified(subscription)
elif event.type == 'invoice.payment_succeeded':
invoice = event.data.object
if invoice.get('billing_reason') == 'subscription_cycle':
await handle_renewal_simplified(invoice)
return {'status': 'success'}
except Exception as e:
logger.error(f"Webhook error: {e}")
raise HTTPException(status_code=400, detail=str(e))
async def handle_subscription_simplified(subscription: Dict):
"""Handle subscription changes - update only credit_accounts"""
db = DBConnection()
client = await db.client
# Get user from customer
customer_result = await client.schema('basejump').from_('billing_customers')\
.select('account_id')\
.eq('id', subscription['customer'])\
.single()\
.execute()
if not customer_result.data:
return
user_id = customer_result.data['account_id']
price_id = subscription['items']['data'][0]['price']['id']
# Get tier from price
tier_info = get_tier_by_price_id(price_id)
if not tier_info:
return
# Ensure account exists
await ensure_credit_account(user_id)
# Get current account
current_result = await client.from_('credit_accounts')\
.select('tier')\
.eq('user_id', user_id)\
.single()\
.execute()
old_tier = current_result.data.get('tier', 'free')
new_tier = tier_info.name
# Update account
billing_anchor = datetime.fromtimestamp(subscription['current_period_start'], tz=timezone.utc)
next_grant = datetime.fromtimestamp(subscription['current_period_end'], tz=timezone.utc)
await client.from_('credit_accounts').update({
'tier': new_tier,
'stripe_subscription_id': subscription['id'],
'billing_cycle_anchor': billing_anchor.isoformat(),
'next_credit_grant': next_grant.isoformat()
}).eq('user_id', user_id).execute()
# Handle tier change credits
if old_tier != new_tier:
old_tier_obj = get_tier_by_name(old_tier) or TIERS['free']
new_credits = float(tier_info.monthly_credits)
old_credits = float(old_tier_obj.monthly_credits)
if new_credits > old_credits:
# Upgrade - grant full new tier credits
await credit_service.add_credits(
user_id=user_id,
amount=Decimal(str(new_credits)),
type='tier_upgrade',
description=f"Upgrade to {tier_info.display_name}"
)
async def handle_cancellation_simplified(subscription: Dict):
"""Handle subscription cancellation - just update tier to free"""
db = DBConnection()
client = await db.client
# Get user from customer
customer_result = await client.schema('basejump').from_('billing_customers')\
.select('account_id')\
.eq('id', subscription['customer'])\
.single()\
.execute()
if not customer_result.data:
return
user_id = customer_result.data['account_id']
# Update to free tier (keep existing balance)
await client.from_('credit_accounts').update({
'tier': 'free',
'stripe_subscription_id': None,
'billing_cycle_anchor': None,
'next_credit_grant': None
}).eq('user_id', user_id).execute()
async def handle_renewal_simplified(invoice: Dict):
"""Handle monthly renewal - grant credits"""
db = DBConnection()
client = await db.client
subscription_id = invoice.get('subscription')
if not subscription_id:
return
# Get user from customer
customer_result = await client.schema('basejump').from_('billing_customers')\
.select('account_id')\
.eq('id', invoice['customer'])\
.single()\
.execute()
if not customer_result.data:
return
user_id = customer_result.data['account_id']
# Get account
account_result = await client.from_('credit_accounts')\
.select('tier, last_grant_date')\
.eq('user_id', user_id)\
.single()\
.execute()
if not account_result.data:
return
account = account_result.data
# Check if enough time has passed (prevent double grants)
if account.get('last_grant_date'):
last_grant = datetime.fromisoformat(account['last_grant_date'].replace('Z', '+00:00'))
if (datetime.now(timezone.utc) - last_grant).days < 25:
return
# Grant monthly credits
tier = account['tier']
monthly_credits = get_monthly_credits(tier)
if monthly_credits > 0:
await credit_service.add_credits(
user_id=user_id,
amount=monthly_credits,
type='tier_grant',
description=f"Monthly {tier} credits"
)
# Update grant dates
next_grant = datetime.now(timezone.utc) + timedelta(days=30)
await client.from_('credit_accounts').update({
'last_grant_date': datetime.now(timezone.utc).isoformat(),
'next_credit_grant': next_grant.isoformat()
}).eq('user_id', user_id).execute()

View File

@ -8,15 +8,17 @@ import { Skeleton } from '@/components/ui/skeleton';
import { Alert, AlertTitle, AlertDescription } from '@/components/ui/alert';
import { Button } from '@/components/ui/button';
import { useSharedSubscription, useSubscriptionContext } from '@/contexts/SubscriptionContext';
import { isLocalMode } from '@/lib/config';
import { isLocalMode, isStagingMode } from '@/lib/config';
import Link from 'next/link';
import { useCreatePortalSession } from '@/hooks/react-query/use-billing-v2';
import { useCreatePortalSession, useTriggerTestRenewal } from '@/hooks/react-query/use-billing-v2';
import { toast } from 'sonner';
const returnUrl = process.env.NEXT_PUBLIC_URL as string;
export default function PersonalAccountBillingPage() {
const { data: accounts, isLoading, error } = useAccounts();
const [showBillingModal, setShowBillingModal] = useState(false);
const triggerTestRenewal = useTriggerTestRenewal();
const {
data: subscriptionData,
@ -146,6 +148,53 @@ export default function PersonalAccountBillingPage() {
</div>
</>
)}
{isStagingMode() && (
<div className="mt-4 p-3 bg-yellow-50 dark:bg-yellow-900/20 border border-yellow-200 dark:border-yellow-800 rounded-lg">
<p className="text-xs text-yellow-800 dark:text-yellow-200 mb-2">
🧪 <strong>Test Mode:</strong> Simulate monthly credit renewal
</p>
<Button
onClick={() => {
triggerTestRenewal.mutate(undefined, {
onSuccess: (result) => {
if (result.success) {
toast.success(
<div>
<p>Credits renewed successfully!</p>
{result.credits_granted && (
<p className="text-sm mt-1">
+{result.credits_granted} credits added
</p>
)}
{result.new_balance !== undefined && (
<p className="text-xs mt-1 opacity-80">
New balance: ${result.new_balance}
</p>
)}
</div>
);
} else {
toast.error(result.message || 'Failed to trigger renewal');
}
},
onError: (error) => {
console.error('Test renewal error:', error);
toast.error('Failed to trigger test renewal');
}
});
}}
size="sm"
className="w-full text-xs bg-yellow-600 hover:bg-yellow-700"
disabled={triggerTestRenewal.isPending}
>
{triggerTestRenewal.isPending ? (
'🔄 Triggering...'
) : (
'🔄 Trigger Monthly Credit Renewal (Test)'
)}
</Button>
</div>
)}
</div>
</div>
);

View File

@ -236,55 +236,15 @@ export function BillingModal({ open, onOpenChange, returnUrl = typeof window !==
</DialogHeader>
<>
{/* Usage Limit Alert
{showUsageLimitAlert && (
<div className="mb-6">
<div className="flex items-start p-3 sm:p-4 bg-destructive/5 border border-destructive/50 rounded-lg">
<div className="flex items-start space-x-3">
<div className="flex-shrink-0 mt-0.5">
<Zap className="w-4 h-4 sm:w-5 sm:h-5 text-destructive" />
</div>
<div className="text-xs sm:text-sm min-w-0">
<p className="font-medium text-destructive">Usage Limit Reached</p>
<p className="text-destructive break-words">
Your current plan has been exhausted for this billing period.
</p>
</div>
</div>
</div>
</div>
)} */}
{/* Usage section - show loading state or actual data */}
{isLoading || authLoading ? (
<div className="mb-6">
<div className="rounded-lg border bg-background p-4">
<div className="flex justify-between items-center">
<Skeleton className="h-4 w-40" />
<Skeleton className="h-4 w-24" />
</div>
</div>
</div>
) : subscriptionData && (
<div className="mb-6">
<div className="rounded-lg border bg-background p-4">
<div className="flex justify-between items-center">
<span className="text-sm font-medium text-foreground/90">
Agent Usage This Month
</span>
<span className="text-sm font-medium">
${subscriptionData.current_usage?.toFixed(2) || '0'} /{' '}
${subscriptionData.cost_limit || '0'}
</span>
</div>
</div>
</div>
)}
{/* Show pricing section immediately - no loading state */}
<PricingSection returnUrl={returnUrl} showTitleAndTabs={false} />
{/* Subscription Management Section - only show if there's actual subscription data */}
<PricingSection
returnUrl={returnUrl}
showTitleAndTabs={false}
onSubscriptionUpdate={() => {
setTimeout(() => {
fetchSubscriptionData();
}, 500);
}}
/>
{error ? (
<div className="mt-6 pt-4 border-t border-border">
<div className="p-4 bg-destructive/10 border border-destructive/20 rounded-lg text-center">
@ -293,7 +253,6 @@ export function BillingModal({ open, onOpenChange, returnUrl = typeof window !==
</div>
) : subscriptionData?.subscription && (
<div className="mt-6 pt-4 border-t border-border">
{/* Subscription Status Info Box */}
<div className="bg-muted/30 border border-border rounded-lg p-3 mb-3">
<div className="flex items-center justify-between">
<div className="flex items-center gap-2">

View File

@ -584,6 +584,7 @@ interface PricingSectionProps {
insideDialog?: boolean;
showInfo?: boolean;
noPadding?: boolean;
onSubscriptionUpdate?: () => void;
}
export function PricingSection({
@ -593,6 +594,7 @@ export function PricingSection({
insideDialog = false,
showInfo = true,
noPadding = false,
onSubscriptionUpdate,
}: PricingSectionProps) {
const { data: subscriptionData, isLoading: isFetchingPlan, error: subscriptionQueryError, refetch: refetchSubscription } = useSubscription();
@ -613,25 +615,21 @@ export function PricingSection({
);
if (currentTier) {
// Check if current subscription is yearly commitment (new yearly)
if (currentTier.monthlyCommitmentStripePriceId === currentSubscription.price_id) {
return 'yearly_commitment';
} else if (currentTier.yearlyStripePriceId === currentSubscription.price_id) {
// Legacy yearly plans
return 'yearly';
} else if (currentTier.stripePriceId === currentSubscription.price_id) {
return 'monthly';
}
}
// Default to yearly_commitment (new yearly) if we can't determine current plan type
return 'yearly_commitment';
}, [isAuthenticated, currentSubscription]);
const [billingPeriod, setBillingPeriod] = useState<'monthly' | 'yearly' | 'yearly_commitment'>(getDefaultBillingPeriod());
const [planLoadingStates, setPlanLoadingStates] = useState<Record<string, boolean>>({});
// Update billing period when subscription data changes
useEffect(() => {
setBillingPeriod(getDefaultBillingPeriod());
}, [getDefaultBillingPeriod]);
@ -647,6 +645,10 @@ export function PricingSection({
setTimeout(() => {
setPlanLoadingStates({});
}, 1000);
// Call parent's update handler if provided
if (onSubscriptionUpdate) {
onSubscriptionUpdate();
}
};

View File

@ -137,4 +137,15 @@ export const useDeductTokenUsage = () => {
queryClient.invalidateQueries({ queryKey: billingKeys.status() });
},
});
};
export const useTriggerTestRenewal = () => {
const queryClient = useQueryClient();
return useMutation({
mutationFn: () => billingApiV2.triggerTestRenewal(),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: billingKeys.all });
},
});
};

View File

@ -1963,8 +1963,7 @@ export const getSubscription = async (): Promise<SubscriptionStatus> => {
}
const data = await response.json();
// Map the new billing v2 format to the old format for backward compatibility
return {
subscription: data.subscription ? {
...data.subscription,
@ -1974,7 +1973,7 @@ export const getSubscription = async (): Promise<SubscriptionStatus> => {
cost_limit: data.tier?.credits || 0,
credit_balance: data.credits?.balance || 0,
can_purchase_credits: data.credits?.can_purchase || false,
...data // Include any other fields that might be used
...data
} as SubscriptionStatus;
} catch (error) {
if (error instanceof NoAccessTokenAvailableError) {

View File

@ -121,6 +121,13 @@ export interface ReactivateSubscriptionResponse {
message: string;
}
export interface TestRenewalResponse {
success: boolean;
message?: string;
credits_granted?: number;
new_balance?: number;
}
export const billingApiV2 = {
async getSubscription() {
const response = await backendApi.get<SubscriptionInfo>('/billing/v2/subscription');
@ -204,6 +211,12 @@ export const billingApiV2 = {
);
if (response.error) throw response.error;
return response.data!;
},
async triggerTestRenewal() {
const response = await backendApi.post<TestRenewalResponse>('/billing/v2/test/trigger-renewal');
if (response.error) throw response.error;
return response.data!;
}
};
@ -222,4 +235,5 @@ export const purchaseCredits = (request: PurchaseCreditsRequest) =>
billingApiV2.purchaseCredits(request);
export const getTransactions = (limit?: number, offset?: number) =>
billingApiV2.getTransactions(limit, offset);
export const getUsageHistory = (days?: number) => billingApiV2.getUsageHistory(days);
export const getUsageHistory = (days?: number) => billingApiV2.getUsageHistory(days);
export const triggerTestRenewal = () => billingApiV2.triggerTestRenewal();