mirror of https://github.com/kortix-ai/suna.git
simplify blling arch
This commit is contained in:
parent
37ca540f24
commit
1f5c0e9be2
|
@ -36,7 +36,7 @@ TIERS: Dict[str, Tier] = {
|
|||
config.STRIPE_TIER_2_20_YEARLY_ID,
|
||||
config.STRIPE_TIER_2_17_YEARLY_COMMITMENT_ID
|
||||
],
|
||||
monthly_credits=Decimal('25.00'),
|
||||
monthly_credits=Decimal('20.00'),
|
||||
display_name='Starter',
|
||||
can_purchase_credits=True,
|
||||
models=['gpt-4o-mini', 'gpt-4o', 'claude-3-haiku', 'claude-3-5-sonnet'],
|
||||
|
@ -49,7 +49,7 @@ TIERS: Dict[str, Tier] = {
|
|||
config.STRIPE_TIER_6_50_YEARLY_ID,
|
||||
config.STRIPE_TIER_6_42_YEARLY_COMMITMENT_ID
|
||||
],
|
||||
monthly_credits=Decimal('65.00'),
|
||||
monthly_credits=Decimal('50.00'),
|
||||
display_name='Professional',
|
||||
can_purchase_credits=True,
|
||||
models=['gpt-4o-mini', 'gpt-4o', 'claude-3-haiku', 'claude-3-5-sonnet', 'claude-3-opus'],
|
||||
|
@ -61,7 +61,7 @@ TIERS: Dict[str, Tier] = {
|
|||
config.STRIPE_TIER_12_100_ID,
|
||||
config.STRIPE_TIER_12_100_YEARLY_ID
|
||||
],
|
||||
monthly_credits=Decimal('130.00'),
|
||||
monthly_credits=Decimal('100.00'),
|
||||
display_name='Team',
|
||||
can_purchase_credits=True,
|
||||
models=['gpt-4o-mini', 'gpt-4o', 'claude-3-haiku', 'claude-3-5-sonnet', 'claude-3-opus'],
|
||||
|
@ -74,7 +74,7 @@ TIERS: Dict[str, Tier] = {
|
|||
config.STRIPE_TIER_25_200_YEARLY_ID,
|
||||
config.STRIPE_TIER_25_170_YEARLY_COMMITMENT_ID
|
||||
],
|
||||
monthly_credits=Decimal('260.00'),
|
||||
monthly_credits=Decimal('200.00'),
|
||||
display_name='Business',
|
||||
can_purchase_credits=True,
|
||||
models=['gpt-4o-mini', 'gpt-4o', 'claude-3-haiku', 'claude-3-5-sonnet', 'claude-3-opus', 'o1-preview'],
|
||||
|
@ -86,7 +86,7 @@ TIERS: Dict[str, Tier] = {
|
|||
config.STRIPE_TIER_50_400_ID,
|
||||
config.STRIPE_TIER_50_400_YEARLY_ID
|
||||
],
|
||||
monthly_credits=Decimal('520.00'),
|
||||
monthly_credits=Decimal('400.00'),
|
||||
display_name='Enterprise',
|
||||
can_purchase_credits=True,
|
||||
models=['gpt-4o-mini', 'gpt-4o', 'claude-3-haiku', 'claude-3-5-sonnet', 'claude-3-opus', 'o1-preview', 'o1'],
|
||||
|
@ -98,7 +98,7 @@ TIERS: Dict[str, Tier] = {
|
|||
config.STRIPE_TIER_125_800_ID,
|
||||
config.STRIPE_TIER_125_800_YEARLY_ID
|
||||
],
|
||||
monthly_credits=Decimal('1040.00'),
|
||||
monthly_credits=Decimal('800.00'),
|
||||
display_name='Enterprise Plus',
|
||||
can_purchase_credits=True,
|
||||
models=['gpt-4o-mini', 'gpt-4o', 'claude-3-haiku', 'claude-3-5-sonnet', 'claude-3-opus', 'o1-preview', 'o1'],
|
||||
|
@ -110,7 +110,7 @@ TIERS: Dict[str, Tier] = {
|
|||
config.STRIPE_TIER_200_1000_ID,
|
||||
config.STRIPE_TIER_200_1000_YEARLY_ID
|
||||
],
|
||||
monthly_credits=Decimal('1300.00'),
|
||||
monthly_credits=Decimal('1000.00'),
|
||||
display_name='Ultimate',
|
||||
can_purchase_credits=True,
|
||||
models=['gpt-4o-mini', 'gpt-4o', 'claude-3-haiku', 'claude-3-5-sonnet', 'claude-3-opus', 'o1-preview', 'o1'],
|
||||
|
|
|
@ -158,71 +158,38 @@ async def sync_stripe_to_db():
|
|||
print(f" ✓ Synced subscription to database")
|
||||
synced_count += 1
|
||||
|
||||
# For active/trialing subscriptions, use handle_subscription_change for comprehensive update
|
||||
if sub.status in ['active', 'trialing']:
|
||||
print(f" Processing active subscription...")
|
||||
|
||||
# Get current state
|
||||
account_before = await client.from_('credit_accounts')\
|
||||
.select('*')\
|
||||
.eq('user_id', account_id)\
|
||||
.maybe_single()\
|
||||
.execute()
|
||||
|
||||
# Use handle_subscription_change to properly update all fields
|
||||
# This handles billing_cycle_anchor, stripe_subscription_id, next_credit_grant, etc.
|
||||
try:
|
||||
await handle_subscription_change(sub)
|
||||
print(f" ✓ Updated billing cycle and credit fields")
|
||||
except Exception as e:
|
||||
print(f" ⚠ handle_subscription_change failed: {e}")
|
||||
|
||||
# Now fix the balance based on actual usage from agent_runs
|
||||
tier = get_tier_by_price_id(price_id) if price_id else None
|
||||
if tier:
|
||||
tier_name = tier.name
|
||||
|
||||
# Calculate actual monthly credits (subscription cost + free tier base)
|
||||
parts = tier_name.split('_')
|
||||
if len(parts) == 3 and parts[0] == 'tier':
|
||||
subscription_cost = Decimal(parts[2])
|
||||
monthly_credits = subscription_cost + FREE_TIER_INITIAL_CREDITS
|
||||
else:
|
||||
monthly_credits = FREE_TIER_INITIAL_CREDITS
|
||||
|
||||
# Get current month usage from agent_runs (old billing system style)
|
||||
month_start = datetime.now().replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
runs_result = await client.from_('agent_runs')\
|
||||
.select('total_cost')\
|
||||
.eq('user_id', account_id)\
|
||||
.gte('created_at', month_start.isoformat())\
|
||||
.execute()
|
||||
|
||||
current_usage = Decimal('0')
|
||||
for run in runs_result.data or []:
|
||||
if run.get('total_cost'):
|
||||
current_usage += Decimal(str(run['total_cost']))
|
||||
|
||||
# Calculate correct balance
|
||||
correct_balance = monthly_credits - current_usage
|
||||
|
||||
# Update/create credit account with all the proper fields
|
||||
credit_account = await client.from_('credit_accounts')\
|
||||
.select('*')\
|
||||
.eq('user_id', account_id)\
|
||||
.maybe_single()\
|
||||
.execute()
|
||||
|
||||
if credit_account.data:
|
||||
# Update existing account
|
||||
account_after = await client.from_('credit_accounts')\
|
||||
.select('*')\
|
||||
.eq('user_id', account_id)\
|
||||
.maybe_single()\
|
||||
.execute()
|
||||
|
||||
if account_after.data:
|
||||
tier = get_tier_by_price_id(price_id) if price_id else None
|
||||
if tier:
|
||||
tier_name = tier.name
|
||||
|
||||
update_data = {
|
||||
'tier': tier_name,
|
||||
'balance': str(correct_balance),
|
||||
'stripe_subscription_id': sub.id,
|
||||
'billing_cycle_anchor': datetime.fromtimestamp(sub.created).isoformat() if sub.created else None,
|
||||
}
|
||||
|
||||
# Add next_credit_grant if we have current_period_end
|
||||
if sub.current_period_end:
|
||||
update_data['next_credit_grant'] = datetime.fromtimestamp(sub.current_period_end).isoformat()
|
||||
|
||||
|
@ -231,27 +198,14 @@ async def sync_stripe_to_db():
|
|||
.eq('user_id', account_id)\
|
||||
.execute()
|
||||
|
||||
current_balance = Decimal(str(account_after.data['balance']))
|
||||
print(f" ✓ Updated credit account:")
|
||||
print(f" - Tier: {tier_name}")
|
||||
print(f" - Balance: ${correct_balance} (${monthly_credits} - ${current_usage} usage)")
|
||||
print(f" - Balance: ${current_balance} (managed by handle_subscription_change)")
|
||||
print(f" - Stripe subscription: {sub.id}")
|
||||
print(f" - Next credit grant: {update_data.get('next_credit_grant', 'N/A')}")
|
||||
else:
|
||||
# Create new account with all fields
|
||||
await client.from_('credit_accounts').insert({
|
||||
'user_id': account_id,
|
||||
'balance': str(correct_balance),
|
||||
'tier': tier_name,
|
||||
'stripe_subscription_id': sub.id,
|
||||
'billing_cycle_anchor': datetime.fromtimestamp(sub.created).isoformat() if sub.created else None,
|
||||
'next_credit_grant': datetime.fromtimestamp(sub.current_period_end).isoformat() if sub.current_period_end else None,
|
||||
'last_grant_date': datetime.now().isoformat()
|
||||
}).execute()
|
||||
|
||||
print(f" ✓ Created credit account:")
|
||||
print(f" - Tier: {tier_name}")
|
||||
print(f" - Balance: ${correct_balance} (${monthly_credits} - ${current_usage} usage)")
|
||||
print(f" - Stripe subscription: {sub.id}")
|
||||
else:
|
||||
print(f" ⚠ No credit account found after handle_subscription_change - this is unexpected")
|
||||
|
||||
else:
|
||||
print(f" ✗ Failed to sync")
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
#!/usr/bin/env python3
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
from core.services.supabase import DBConnection
|
||||
from core.billing_config import get_tier_by_price_id
|
||||
from core.utils.logger import logger
|
||||
|
||||
async def verify_and_fix_tiers():
|
||||
"""Verify that all users have the correct tier in credit_accounts"""
|
||||
db = DBConnection()
|
||||
await db.initialize()
|
||||
client = await db.client
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("VERIFYING USER TIERS")
|
||||
print("="*60)
|
||||
|
||||
# Get all active/trialing subscriptions
|
||||
subs_result = await client.schema('basejump').from_('billing_subscriptions')\
|
||||
.select('*')\
|
||||
.in_('status', ['active', 'trialing'])\
|
||||
.execute()
|
||||
|
||||
if not subs_result.data:
|
||||
print("No active subscriptions found")
|
||||
return
|
||||
|
||||
print(f"Found {len(subs_result.data)} active/trialing subscriptions")
|
||||
|
||||
fixed = 0
|
||||
already_correct = 0
|
||||
errors = 0
|
||||
|
||||
for sub in subs_result.data:
|
||||
user_id = sub['account_id']
|
||||
price_id = sub.get('price_id')
|
||||
|
||||
if not price_id:
|
||||
print(f"\nUser {user_id[:8]}... has no price_id in subscription")
|
||||
continue
|
||||
|
||||
# Get the tier from price_id
|
||||
tier = get_tier_by_price_id(price_id)
|
||||
if not tier:
|
||||
print(f"\nUser {user_id[:8]}... has unknown price_id: {price_id}")
|
||||
continue
|
||||
|
||||
expected_tier = tier.name
|
||||
|
||||
# Get current credit account
|
||||
credit_result = await client.from_('credit_accounts')\
|
||||
.select('tier')\
|
||||
.eq('user_id', user_id)\
|
||||
.single()\
|
||||
.execute()
|
||||
|
||||
if not credit_result.data:
|
||||
print(f"\nUser {user_id[:8]}... has no credit account!")
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
current_tier = credit_result.data['tier']
|
||||
|
||||
if current_tier != expected_tier:
|
||||
print(f"\nUser {user_id[:8]}...")
|
||||
print(f" Current tier: {current_tier}")
|
||||
print(f" Expected tier: {expected_tier}")
|
||||
print(f" Price ID: {price_id}")
|
||||
|
||||
# Fix the tier
|
||||
update_result = await client.from_('credit_accounts')\
|
||||
.update({'tier': expected_tier})\
|
||||
.eq('user_id', user_id)\
|
||||
.execute()
|
||||
|
||||
if update_result.data:
|
||||
print(f" ✓ Fixed tier to: {expected_tier}")
|
||||
fixed += 1
|
||||
else:
|
||||
print(f" ✗ Failed to update tier")
|
||||
errors += 1
|
||||
else:
|
||||
already_correct += 1
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("SUMMARY")
|
||||
print("="*60)
|
||||
print(f"Already correct: {already_correct}")
|
||||
print(f"Fixed: {fixed}")
|
||||
print(f"Errors: {errors}")
|
||||
print(f"Total: {already_correct + fixed + errors}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(verify_and_fix_tiers())
|
|
@ -69,7 +69,7 @@ async def get_user_subscription_tier(user_id: str) -> Dict:
|
|||
if credit_result.data and len(credit_result.data) > 0:
|
||||
tier_name = credit_result.data[0].get('tier', 'free')
|
||||
else:
|
||||
result = await client.schema('basejump').from_('billing_subscriptions').select('price_id, status').eq('account_id', user_id).eq('status', 'active').order('created', desc=True).limit(1).execute()
|
||||
result = await client.schema('basejump').from_('billing_subscriptions').select('price_id, status').eq('account_id', user_id).in_('status', ['active', 'trialing']).order('created', desc=True).limit(1).execute()
|
||||
|
||||
if result.data:
|
||||
price_id = result.data[0]['price_id']
|
||||
|
@ -357,12 +357,11 @@ async def stripe_webhook(request: Request):
|
|||
|
||||
elif event.type in ['customer.subscription.created', 'customer.subscription.updated']:
|
||||
subscription = event.data.object
|
||||
if subscription.status == 'active':
|
||||
if subscription.status in ['active', 'trialing']:
|
||||
await handle_subscription_change(subscription)
|
||||
|
||||
elif event.type == 'invoice.payment_succeeded':
|
||||
invoice = event.data.object
|
||||
# Only process subscription invoices (not one-time payments)
|
||||
if invoice.get('subscription') and invoice.get('billing_reason') in ['subscription_cycle', 'subscription_update']:
|
||||
await handle_subscription_renewal(invoice)
|
||||
|
||||
|
@ -394,7 +393,6 @@ async def handle_subscription_change(subscription: Dict):
|
|||
'credits': float(new_tier_info.monthly_credits)
|
||||
}
|
||||
|
||||
# Get billing cycle anchor from subscription
|
||||
billing_anchor = datetime.fromtimestamp(subscription['current_period_start'], tz=timezone.utc)
|
||||
next_grant_date = datetime.fromtimestamp(subscription['current_period_end'], tz=timezone.utc)
|
||||
|
||||
|
@ -412,16 +410,18 @@ async def handle_subscription_change(subscription: Dict):
|
|||
'credits': float(current_tier_info.monthly_credits)
|
||||
}
|
||||
|
||||
# Only grant credits for upgrades if it's not a renewal
|
||||
if current_tier and new_tier['credits'] > current_tier['credits'] and not existing_anchor:
|
||||
credit_diff = Decimal(new_tier['credits'] - current_tier['credits'])
|
||||
await credit_service.add_credits(
|
||||
user_id=user_id,
|
||||
amount=credit_diff,
|
||||
type='tier_upgrade',
|
||||
description=f"Upgrade from {current_tier['name']} to {new_tier['name']}"
|
||||
)
|
||||
logger.info(f"Granted {credit_diff} credits for tier upgrade: {current_tier['name']} -> {new_tier['name']}")
|
||||
if current_tier and current_tier['name'] != new_tier['name']:
|
||||
if new_tier['credits'] > current_tier['credits']:
|
||||
full_amount = Decimal(new_tier['credits'])
|
||||
await credit_service.add_credits(
|
||||
user_id=user_id,
|
||||
amount=full_amount,
|
||||
type='tier_upgrade',
|
||||
description=f"Upgrade to {new_tier['name']} - Full tier credits"
|
||||
)
|
||||
logger.info(f"Granted {full_amount} credits for tier upgrade: {current_tier['name']} -> {new_tier['name']}")
|
||||
elif new_tier['credits'] < current_tier['credits']:
|
||||
logger.info(f"Tier downgrade: {current_tier['name']} -> {new_tier['name']} - No credit adjustment")
|
||||
|
||||
await client.from_('credit_accounts').update({
|
||||
'tier': new_tier['name'],
|
||||
|
@ -430,7 +430,6 @@ async def handle_subscription_change(subscription: Dict):
|
|||
'next_credit_grant': next_grant_date.isoformat()
|
||||
}).eq('user_id', user_id).execute()
|
||||
else:
|
||||
# New subscription - grant initial credits
|
||||
await client.from_('credit_accounts').insert({
|
||||
'user_id': user_id,
|
||||
'balance': new_tier['credits'],
|
||||
|
@ -449,7 +448,6 @@ async def handle_subscription_change(subscription: Dict):
|
|||
)
|
||||
|
||||
async def handle_subscription_renewal(invoice: Dict):
|
||||
"""Handle subscription renewal and grant monthly credits based on billing cycle"""
|
||||
try:
|
||||
db = DBConnection()
|
||||
client = await db.client
|
||||
|
@ -458,7 +456,6 @@ async def handle_subscription_renewal(invoice: Dict):
|
|||
if not subscription_id:
|
||||
return
|
||||
|
||||
# Get user from subscription
|
||||
customer_result = await client.schema('basejump').from_('billing_customers')\
|
||||
.select('account_id')\
|
||||
.eq('id', invoice['customer'])\
|
||||
|
@ -469,7 +466,6 @@ async def handle_subscription_renewal(invoice: Dict):
|
|||
|
||||
user_id = customer_result.data[0]['account_id']
|
||||
|
||||
# Get credit account
|
||||
account_result = await client.from_('credit_accounts')\
|
||||
.select('tier, last_grant_date, next_credit_grant')\
|
||||
.eq('user_id', user_id)\
|
||||
|
@ -481,14 +477,12 @@ async def handle_subscription_renewal(invoice: Dict):
|
|||
account = account_result.data[0]
|
||||
tier = account['tier']
|
||||
|
||||
# Skip if we already granted credits recently (within 25 days to handle edge cases)
|
||||
if account.get('last_grant_date'):
|
||||
last_grant = datetime.fromisoformat(account['last_grant_date'].replace('Z', '+00:00'))
|
||||
if (datetime.now(timezone.utc) - last_grant).days < 25:
|
||||
logger.info(f"Skipping credit grant for {user_id} - already granted {(datetime.now(timezone.utc) - last_grant).days} days ago")
|
||||
return
|
||||
|
||||
# Grant monthly credits
|
||||
monthly_credits = get_monthly_credits(tier)
|
||||
if monthly_credits > 0:
|
||||
await credit_service.add_credits(
|
||||
|
@ -499,7 +493,6 @@ async def handle_subscription_renewal(invoice: Dict):
|
|||
metadata={'invoice_id': invoice['id'], 'subscription_id': subscription_id}
|
||||
)
|
||||
|
||||
# Update next grant date
|
||||
next_grant = datetime.now(timezone.utc) + timedelta(days=30)
|
||||
await client.from_('credit_accounts').update({
|
||||
'last_grant_date': datetime.now(timezone.utc).isoformat(),
|
||||
|
@ -510,19 +503,6 @@ async def handle_subscription_renewal(invoice: Dict):
|
|||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling subscription renewal: {e}")
|
||||
|
||||
await client.schema('basejump').from_('billing_subscriptions').upsert({
|
||||
'id': subscription['id'],
|
||||
'account_id': user_id,
|
||||
'billing_customer_id': subscription['customer'],
|
||||
'status': subscription['status'],
|
||||
'price_id': price_id,
|
||||
'current_period_end': datetime.fromtimestamp(subscription['current_period_end']).isoformat() if subscription.get('current_period_end') else None,
|
||||
'cancel_at': datetime.fromtimestamp(subscription['cancel_at']).isoformat() if subscription.get('cancel_at') else None,
|
||||
'cancel_at_period_end': subscription.get('cancel_at_period_end', False)
|
||||
}).execute()
|
||||
|
||||
await Cache.invalidate(f"subscription_tier:{user_id}")
|
||||
|
||||
@router.get("/subscription")
|
||||
async def get_subscription(
|
||||
|
@ -532,29 +512,55 @@ async def get_subscription(
|
|||
db = DBConnection()
|
||||
client = await db.client
|
||||
|
||||
subscription_result = await client.schema('basejump').from_('billing_subscriptions').select('*').eq('account_id', user_id).eq('status', 'active').order('created', desc=True).limit(1).execute()
|
||||
subscription_result = await client.schema('basejump').from_('billing_subscriptions')\
|
||||
.select('*')\
|
||||
.eq('account_id', user_id)\
|
||||
.in_('status', ['active', 'trialing'])\
|
||||
.order('created', desc=True)\
|
||||
.limit(1)\
|
||||
.execute()
|
||||
|
||||
subscription_data = None
|
||||
|
||||
credit_result = await client.from_('credit_accounts').select('tier').eq('user_id', user_id).execute()
|
||||
tier_name = credit_result.data[0]['tier'] if credit_result.data else 'free'
|
||||
tier_obj = TIERS.get(tier_name, TIERS['free'])
|
||||
credit_result = await client.from_('credit_accounts').select('*').eq('user_id', user_id).execute()
|
||||
|
||||
tier_info = {
|
||||
'name': tier_obj.name,
|
||||
'credits': float(tier_obj.monthly_credits)
|
||||
}
|
||||
price_id = config.STRIPE_FREE_TIER_ID
|
||||
if credit_result.data:
|
||||
credit_account = credit_result.data[0]
|
||||
tier_name = credit_account.get('tier', 'free')
|
||||
tier_obj = TIERS.get(tier_name, TIERS['free'])
|
||||
|
||||
actual_credits = float(tier_obj.monthly_credits)
|
||||
if tier_name != 'free':
|
||||
parts = tier_name.split('_')
|
||||
if len(parts) >= 3 and parts[0] == 'tier':
|
||||
subscription_cost = float(parts[-1])
|
||||
actual_credits = subscription_cost + 5.0
|
||||
|
||||
tier_info = {
|
||||
'name': tier_obj.name,
|
||||
'credits': actual_credits
|
||||
}
|
||||
|
||||
if tier_obj and len(tier_obj.price_ids) > 0:
|
||||
price_id = tier_obj.price_ids[0]
|
||||
|
||||
if subscription_result.data:
|
||||
subscription = subscription_result.data[0]
|
||||
if subscription.get('price_id'):
|
||||
price_id = subscription['price_id']
|
||||
else:
|
||||
price_id = config.STRIPE_FREE_TIER_ID
|
||||
else:
|
||||
tier_name = 'free'
|
||||
tier_obj = TIERS['free']
|
||||
tier_info = {
|
||||
'name': 'free',
|
||||
'credits': float(TIERS['free'].monthly_credits)
|
||||
}
|
||||
price_id = config.STRIPE_FREE_TIER_ID
|
||||
|
||||
if subscription_result.data:
|
||||
subscription = subscription_result.data[0]
|
||||
price_id = subscription['price_id']
|
||||
price_tier_obj = get_tier_by_price_id(price_id)
|
||||
if price_tier_obj:
|
||||
tier_info = {
|
||||
'name': price_tier_obj.name,
|
||||
'credits': float(price_tier_obj.monthly_credits)
|
||||
}
|
||||
|
||||
stripe_subscription = None
|
||||
try:
|
||||
|
@ -598,8 +604,15 @@ async def get_subscription(
|
|||
balance = await credit_service.get_balance(user_id)
|
||||
summary = await credit_service.get_account_summary(user_id)
|
||||
|
||||
if subscription_data:
|
||||
status = 'active'
|
||||
elif tier_name != 'free':
|
||||
status = 'cancelled'
|
||||
else:
|
||||
status = 'no_subscription'
|
||||
|
||||
return {
|
||||
'status': 'active' if subscription_data else 'no_subscription',
|
||||
'status': status,
|
||||
'plan_name': tier_info['name'],
|
||||
'price_id': price_id,
|
||||
'subscription': subscription_data,
|
||||
|
@ -607,7 +620,7 @@ async def get_subscription(
|
|||
'current_usage': float(summary['lifetime_used']),
|
||||
'cost_limit': tier_info['credits'],
|
||||
'credit_balance': float(balance),
|
||||
'can_purchase_credits': tier_info['name'] == 'tier_200_1000',
|
||||
'can_purchase_credits': tier_info.get('name') in ['tier_200_1000', 'tier_50_400', 'tier_125_800'],
|
||||
'tier': tier_info,
|
||||
'credits': {
|
||||
'balance': float(balance),
|
||||
|
@ -615,7 +628,7 @@ async def get_subscription(
|
|||
'lifetime_granted': float(summary['lifetime_granted']),
|
||||
'lifetime_purchased': float(summary['lifetime_purchased']),
|
||||
'lifetime_used': float(summary['lifetime_used']),
|
||||
'can_purchase': tier_info['name'] == 'tier_200_1000'
|
||||
'can_purchase': tier_info.get('name') in ['tier_200_1000', 'tier_50_400', 'tier_125_800']
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -658,7 +671,7 @@ async def create_checkout_session(
|
|||
db = DBConnection()
|
||||
client = await db.client
|
||||
|
||||
existing_sub = await client.schema('basejump').from_('billing_subscriptions').select('id').eq('account_id', user_id).eq('status', 'active').execute()
|
||||
existing_sub = await client.schema('basejump').from_('billing_subscriptions').select('id').eq('account_id', user_id).in_('status', ['active', 'trialing']).execute()
|
||||
|
||||
if existing_sub.data and len(existing_sub.data) > 0:
|
||||
subscription = await stripe.Subscription.retrieve_async(existing_sub.data[0]['id'])
|
||||
|
@ -673,12 +686,28 @@ async def create_checkout_session(
|
|||
payment_behavior='pending_if_incomplete'
|
||||
)
|
||||
|
||||
await handle_subscription_change(updated_subscription)
|
||||
|
||||
await Cache.invalidate(f"subscription_tier:{user_id}")
|
||||
await Cache.invalidate(f"credit_balance:{user_id}")
|
||||
await Cache.invalidate(f"credit_summary:{user_id}")
|
||||
|
||||
old_price_id = subscription['items']['data'][0].price.id
|
||||
old_tier = get_tier_by_price_id(old_price_id)
|
||||
new_tier = get_tier_by_price_id(request.price_id)
|
||||
|
||||
old_amount = float(old_tier.monthly_credits) if old_tier else 0
|
||||
new_amount = float(new_tier.monthly_credits) if new_tier else 0
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'status': 'upgraded' if new_amount > old_amount else 'updated',
|
||||
'subscription_id': updated_subscription.id,
|
||||
'message': 'Subscription updated successfully'
|
||||
'message': 'Subscription updated successfully',
|
||||
'details': {
|
||||
'is_upgrade': new_amount > old_amount,
|
||||
'current_price': old_amount,
|
||||
'new_price': new_amount
|
||||
}
|
||||
}
|
||||
else:
|
||||
session = await stripe.checkout.Session.create_async(
|
||||
|
@ -695,7 +724,6 @@ async def create_checkout_session(
|
|||
}
|
||||
}
|
||||
)
|
||||
|
||||
return {'checkout_url': session.url}
|
||||
|
||||
except Exception as e:
|
||||
|
@ -721,6 +749,58 @@ async def create_portal_session(
|
|||
logger.error(f"Error creating portal session: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/sync-subscription")
|
||||
async def sync_subscription(
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
) -> Dict:
|
||||
try:
|
||||
db = DBConnection()
|
||||
client = await db.client
|
||||
|
||||
sub_result = await client.schema('basejump').from_('billing_subscriptions')\
|
||||
.select('*')\
|
||||
.eq('account_id', user_id)\
|
||||
.in_('status', ['active', 'trialing'])\
|
||||
.limit(1)\
|
||||
.execute()
|
||||
|
||||
if not sub_result.data or len(sub_result.data) == 0:
|
||||
return {
|
||||
'success': False,
|
||||
'message': 'No active subscription found'
|
||||
}
|
||||
|
||||
subscription = await stripe.Subscription.retrieve_async(
|
||||
sub_result.data[0]['id'],
|
||||
expand=['items.data.price']
|
||||
)
|
||||
|
||||
await handle_subscription_change(subscription)
|
||||
|
||||
await Cache.invalidate(f"subscription_tier:{user_id}")
|
||||
await Cache.invalidate(f"credit_balance:{user_id}")
|
||||
await Cache.invalidate(f"credit_summary:{user_id}")
|
||||
|
||||
balance = await credit_service.get_balance(user_id)
|
||||
summary = await credit_service.get_account_summary(user_id)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': 'Subscription synced successfully',
|
||||
'credits': {
|
||||
'balance': float(balance),
|
||||
'lifetime_granted': float(summary['lifetime_granted']),
|
||||
'lifetime_used': float(summary['lifetime_used'])
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error syncing subscription: {e}", exc_info=True)
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'Failed to sync subscription: {str(e)}'
|
||||
}
|
||||
|
||||
@router.post("/cancel-subscription")
|
||||
async def cancel_subscription(
|
||||
request: CancelSubscriptionRequest,
|
||||
|
@ -730,7 +810,7 @@ async def cancel_subscription(
|
|||
db = DBConnection()
|
||||
client = await db.client
|
||||
|
||||
sub_result = await client.schema('basejump').from_('billing_subscriptions').select('id').eq('account_id', user_id).eq('status', 'active').execute()
|
||||
sub_result = await client.schema('basejump').from_('billing_subscriptions').select('id').eq('account_id', user_id).in_('status', ['active', 'trialing']).execute()
|
||||
|
||||
if not sub_result.data or len(sub_result.data) == 0:
|
||||
raise HTTPException(status_code=404, detail="No active subscription found")
|
||||
|
@ -761,7 +841,7 @@ async def reactivate_subscription(
|
|||
db = DBConnection()
|
||||
client = await db.client
|
||||
|
||||
sub_result = await client.schema('basejump').from_('billing_subscriptions').select('id').eq('account_id', user_id).eq('status', 'active').execute()
|
||||
sub_result = await client.schema('basejump').from_('billing_subscriptions').select('id').eq('account_id', user_id).in_('status', ['active', 'trialing']).execute()
|
||||
|
||||
if not sub_result.data or len(sub_result.data) == 0:
|
||||
raise HTTPException(status_code=404, detail="No active subscription found")
|
||||
|
@ -830,54 +910,6 @@ async def get_usage_history(
|
|||
logger.error(f"Error getting usage history: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/sync-subscription")
|
||||
async def sync_subscription_from_stripe(
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
) -> Dict:
|
||||
"""Manually sync subscription from Stripe to local database and credit system"""
|
||||
try:
|
||||
db = DBConnection()
|
||||
client = await db.client
|
||||
|
||||
# Get customer ID
|
||||
customer_result = await client.schema('basejump').from_('billing_customers').select('id').eq('account_id', user_id).execute()
|
||||
|
||||
if not customer_result.data or len(customer_result.data) == 0:
|
||||
return {'success': False, 'message': 'No billing customer found'}
|
||||
|
||||
customer_id = customer_result.data[0]['id']
|
||||
|
||||
# Get active subscriptions from Stripe
|
||||
subscriptions = await stripe.Subscription.list_async(
|
||||
customer=customer_id,
|
||||
status='active',
|
||||
expand=['data.items.data.price']
|
||||
)
|
||||
|
||||
if not subscriptions.data:
|
||||
return {'success': False, 'message': 'No active subscription found in Stripe'}
|
||||
|
||||
# Process the first active subscription
|
||||
subscription = subscriptions.data[0]
|
||||
|
||||
# Handle the subscription change
|
||||
await handle_subscription_change(subscription)
|
||||
|
||||
# Get updated balance and tier
|
||||
balance = await credit_service.get_balance(user_id)
|
||||
account_result = await client.from_('credit_accounts').select('tier').eq('user_id', user_id).execute()
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': 'Subscription synced successfully',
|
||||
'subscription_id': subscription['id'],
|
||||
'tier': account_result.data[0].get('tier') if account_result.data and len(account_result.data) > 0 else 'unknown',
|
||||
'balance': float(balance)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error syncing subscription: {e}", exc_info=True)
|
||||
return {'success': False, 'message': str(e)}
|
||||
|
||||
@router.get("/available-models")
|
||||
async def get_available_models(
|
||||
|
@ -974,4 +1006,86 @@ async def get_subscription_commitment(
|
|||
'commitment_type': None,
|
||||
'months_remaining': None,
|
||||
'commitment_end_date': None
|
||||
}
|
||||
}
|
||||
|
||||
@router.post("/test/trigger-renewal")
|
||||
async def test_trigger_renewal(
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
) -> Dict:
|
||||
if config.ENV_MODE == EnvMode.PRODUCTION:
|
||||
raise HTTPException(status_code=403, detail="Test endpoints disabled in production")
|
||||
|
||||
try:
|
||||
db = DBConnection()
|
||||
client = await db.client
|
||||
|
||||
# Check credit_accounts as primary source of truth
|
||||
account_result = await client.from_('credit_accounts')\
|
||||
.select('*')\
|
||||
.eq('user_id', user_id)\
|
||||
.execute()
|
||||
|
||||
if not account_result.data or len(account_result.data) == 0:
|
||||
return {
|
||||
'success': False,
|
||||
'message': 'No credit account found. Please subscribe to a plan first.'
|
||||
}
|
||||
|
||||
account = account_result.data[0]
|
||||
|
||||
# Only check if tier is not free (means they have a subscription)
|
||||
if account.get('tier', 'free') == 'free':
|
||||
return {
|
||||
'success': False,
|
||||
'message': 'No active subscription found. Please subscribe to a plan first.'
|
||||
}
|
||||
|
||||
tier = account['tier']
|
||||
|
||||
yesterday = datetime.now(timezone.utc) - timedelta(days=26)
|
||||
await client.from_('credit_accounts').update({
|
||||
'last_grant_date': yesterday.isoformat()
|
||||
}).eq('user_id', user_id).execute()
|
||||
|
||||
monthly_credits = get_monthly_credits(tier)
|
||||
if monthly_credits > 0:
|
||||
await credit_service.add_credits(
|
||||
user_id=user_id,
|
||||
amount=monthly_credits,
|
||||
type='tier_grant',
|
||||
description=f"TEST: Monthly {tier} tier credits",
|
||||
metadata={'test': True, 'triggered_by': 'manual_test'}
|
||||
)
|
||||
|
||||
next_grant = datetime.now(timezone.utc) + timedelta(days=30)
|
||||
await client.from_('credit_accounts').update({
|
||||
'last_grant_date': datetime.now(timezone.utc).isoformat(),
|
||||
'next_credit_grant': next_grant.isoformat()
|
||||
}).eq('user_id', user_id).execute()
|
||||
|
||||
await Cache.invalidate(f"credit_balance:{user_id}")
|
||||
await Cache.invalidate(f"credit_summary:{user_id}")
|
||||
await Cache.invalidate(f"subscription_tier:{user_id}")
|
||||
|
||||
new_balance = await credit_service.get_balance(user_id)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': f'Successfully granted {monthly_credits} credits',
|
||||
'tier': tier,
|
||||
'credits_granted': float(monthly_credits),
|
||||
'new_balance': float(new_balance),
|
||||
'next_grant_date': next_grant.isoformat()
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'No credits to grant for tier: {tier}'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in test renewal trigger: {e}", exc_info=True)
|
||||
return {
|
||||
'success': False,
|
||||
'message': str(e)
|
||||
}
|
|
@ -0,0 +1,337 @@
|
|||
"""
|
||||
Simplified Billing Service v3
|
||||
=============================
|
||||
Uses only essential sources of truth:
|
||||
1. Stripe API - Real subscription data
|
||||
2. credit_accounts - Tier, balance, billing cycle
|
||||
3. billing_customers - Stripe customer ID mapping
|
||||
4. credit_ledger - Transaction history
|
||||
5. agent_runs - Usage tracking
|
||||
|
||||
Removed:
|
||||
- billing_subscriptions (redundant)
|
||||
- credit_grants (unused)
|
||||
- Complex fallback logic
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request
|
||||
from typing import Optional, Dict
|
||||
from decimal import Decimal
|
||||
from datetime import datetime, timezone, timedelta
|
||||
import stripe
|
||||
from core.credits import credit_service
|
||||
from core.services.supabase import DBConnection
|
||||
from core.utils.auth_utils import verify_and_get_user_id_from_jwt
|
||||
from core.utils.config import config, EnvMode
|
||||
from core.utils.logger import logger
|
||||
from core.billing_config import get_tier_by_price_id, TIERS, get_monthly_credits, get_tier_by_name
|
||||
|
||||
router = APIRouter(prefix="/billing/v3", tags=["billing-v3"])
|
||||
stripe.api_key = config.STRIPE_SECRET_KEY
|
||||
|
||||
async def get_user_tier(user_id: str) -> str:
|
||||
"""Get user's tier from credit_accounts (single source of truth)"""
|
||||
db = DBConnection()
|
||||
client = await db.client
|
||||
|
||||
result = await client.from_('credit_accounts')\
|
||||
.select('tier')\
|
||||
.eq('user_id', user_id)\
|
||||
.single()\
|
||||
.execute()
|
||||
|
||||
if result.data:
|
||||
return result.data.get('tier', 'free')
|
||||
return 'free'
|
||||
|
||||
async def ensure_credit_account(user_id: str) -> None:
|
||||
"""Ensure user has a credit account (create if missing)"""
|
||||
db = DBConnection()
|
||||
client = await db.client
|
||||
|
||||
# Check if account exists
|
||||
result = await client.from_('credit_accounts')\
|
||||
.select('user_id')\
|
||||
.eq('user_id', user_id)\
|
||||
.execute()
|
||||
|
||||
if not result.data:
|
||||
# Create free tier account
|
||||
await client.from_('credit_accounts').insert({
|
||||
'user_id': user_id,
|
||||
'balance': 5.0, # Free tier initial credits
|
||||
'tier': 'free',
|
||||
'last_grant_date': datetime.now(timezone.utc).isoformat()
|
||||
}).execute()
|
||||
|
||||
# Add initial credits to ledger
|
||||
await credit_service.add_credits(
|
||||
user_id=user_id,
|
||||
amount=Decimal('5.0'),
|
||||
type='tier_grant',
|
||||
description='Welcome credits - Free tier'
|
||||
)
|
||||
|
||||
@router.get("/subscription")
|
||||
async def get_subscription_simplified(
|
||||
user_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
) -> Dict:
|
||||
"""Simplified subscription endpoint using minimal sources"""
|
||||
try:
|
||||
# Ensure user has credit account
|
||||
await ensure_credit_account(user_id)
|
||||
|
||||
db = DBConnection()
|
||||
client = await db.client
|
||||
|
||||
# Get all user data in one query
|
||||
account_result = await client.from_('credit_accounts')\
|
||||
.select('*')\
|
||||
.eq('user_id', user_id)\
|
||||
.single()\
|
||||
.execute()
|
||||
|
||||
account = account_result.data
|
||||
tier_name = account.get('tier', 'free')
|
||||
|
||||
# Get tier configuration
|
||||
tier = get_tier_by_name(tier_name) or TIERS['free']
|
||||
|
||||
# Calculate actual credits (subscription + free base)
|
||||
actual_credits = float(tier.monthly_credits)
|
||||
if tier_name != 'free':
|
||||
parts = tier_name.split('_')
|
||||
if len(parts) >= 3 and parts[0] == 'tier':
|
||||
subscription_cost = float(parts[-1])
|
||||
actual_credits = subscription_cost + 5.0
|
||||
|
||||
# Get Stripe subscription if needed (only for subscription management)
|
||||
stripe_subscription = None
|
||||
subscription_id = account.get('stripe_subscription_id')
|
||||
|
||||
if subscription_id:
|
||||
try:
|
||||
# Only fetch from Stripe when needed for management
|
||||
stripe_subscription = await stripe.Subscription.retrieve_async(subscription_id)
|
||||
except:
|
||||
# If Stripe fails, we still have local data
|
||||
pass
|
||||
|
||||
# Get balance and usage
|
||||
balance = await credit_service.get_balance(user_id)
|
||||
summary = await credit_service.get_account_summary(user_id)
|
||||
|
||||
return {
|
||||
'status': 'active' if tier_name != 'free' else 'no_subscription',
|
||||
'tier': tier_name,
|
||||
'tier_display': tier.display_name,
|
||||
'credits': {
|
||||
'balance': float(balance),
|
||||
'monthly_allocation': actual_credits,
|
||||
'lifetime_used': float(summary['lifetime_used']),
|
||||
'can_purchase': tier.can_purchase_credits
|
||||
},
|
||||
'subscription': {
|
||||
'id': subscription_id,
|
||||
'status': stripe_subscription.status if stripe_subscription else None,
|
||||
'cancel_at_period_end': stripe_subscription.cancel_at_period_end if stripe_subscription else False,
|
||||
'current_period_end': stripe_subscription.current_period_end if stripe_subscription else None
|
||||
} if subscription_id else None,
|
||||
'next_credit_grant': account.get('next_credit_grant'),
|
||||
'project_limit': tier.project_limit
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting subscription: {e}", exc_info=True)
|
||||
# Return free tier as fallback
|
||||
return {
|
||||
'status': 'no_subscription',
|
||||
'tier': 'free',
|
||||
'tier_display': 'Free',
|
||||
'credits': {
|
||||
'balance': 0,
|
||||
'monthly_allocation': 5,
|
||||
'lifetime_used': 0,
|
||||
'can_purchase': False
|
||||
},
|
||||
'subscription': None,
|
||||
'next_credit_grant': None,
|
||||
'project_limit': 3
|
||||
}
|
||||
|
||||
@router.post("/webhook")
|
||||
async def stripe_webhook_simplified(request: Request):
|
||||
"""Simplified webhook handler - only updates credit_accounts"""
|
||||
try:
|
||||
payload = await request.body()
|
||||
sig_header = request.headers.get('stripe-signature')
|
||||
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, config.STRIPE_WEBHOOK_SECRET
|
||||
)
|
||||
|
||||
if event.type in ['customer.subscription.created', 'customer.subscription.updated']:
|
||||
subscription = event.data.object
|
||||
|
||||
if subscription.status in ['active', 'trialing']:
|
||||
await handle_subscription_simplified(subscription)
|
||||
|
||||
elif event.type == 'customer.subscription.deleted':
|
||||
subscription = event.data.object
|
||||
await handle_cancellation_simplified(subscription)
|
||||
|
||||
elif event.type == 'invoice.payment_succeeded':
|
||||
invoice = event.data.object
|
||||
if invoice.get('billing_reason') == 'subscription_cycle':
|
||||
await handle_renewal_simplified(invoice)
|
||||
|
||||
return {'status': 'success'}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Webhook error: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
async def handle_subscription_simplified(subscription: Dict):
|
||||
"""Handle subscription changes - update only credit_accounts"""
|
||||
db = DBConnection()
|
||||
client = await db.client
|
||||
|
||||
# Get user from customer
|
||||
customer_result = await client.schema('basejump').from_('billing_customers')\
|
||||
.select('account_id')\
|
||||
.eq('id', subscription['customer'])\
|
||||
.single()\
|
||||
.execute()
|
||||
|
||||
if not customer_result.data:
|
||||
return
|
||||
|
||||
user_id = customer_result.data['account_id']
|
||||
price_id = subscription['items']['data'][0]['price']['id']
|
||||
|
||||
# Get tier from price
|
||||
tier_info = get_tier_by_price_id(price_id)
|
||||
if not tier_info:
|
||||
return
|
||||
|
||||
# Ensure account exists
|
||||
await ensure_credit_account(user_id)
|
||||
|
||||
# Get current account
|
||||
current_result = await client.from_('credit_accounts')\
|
||||
.select('tier')\
|
||||
.eq('user_id', user_id)\
|
||||
.single()\
|
||||
.execute()
|
||||
|
||||
old_tier = current_result.data.get('tier', 'free')
|
||||
new_tier = tier_info.name
|
||||
|
||||
# Update account
|
||||
billing_anchor = datetime.fromtimestamp(subscription['current_period_start'], tz=timezone.utc)
|
||||
next_grant = datetime.fromtimestamp(subscription['current_period_end'], tz=timezone.utc)
|
||||
|
||||
await client.from_('credit_accounts').update({
|
||||
'tier': new_tier,
|
||||
'stripe_subscription_id': subscription['id'],
|
||||
'billing_cycle_anchor': billing_anchor.isoformat(),
|
||||
'next_credit_grant': next_grant.isoformat()
|
||||
}).eq('user_id', user_id).execute()
|
||||
|
||||
# Handle tier change credits
|
||||
if old_tier != new_tier:
|
||||
old_tier_obj = get_tier_by_name(old_tier) or TIERS['free']
|
||||
new_credits = float(tier_info.monthly_credits)
|
||||
old_credits = float(old_tier_obj.monthly_credits)
|
||||
|
||||
if new_credits > old_credits:
|
||||
# Upgrade - grant full new tier credits
|
||||
await credit_service.add_credits(
|
||||
user_id=user_id,
|
||||
amount=Decimal(str(new_credits)),
|
||||
type='tier_upgrade',
|
||||
description=f"Upgrade to {tier_info.display_name}"
|
||||
)
|
||||
|
||||
async def handle_cancellation_simplified(subscription: Dict):
|
||||
"""Handle subscription cancellation - just update tier to free"""
|
||||
db = DBConnection()
|
||||
client = await db.client
|
||||
|
||||
# Get user from customer
|
||||
customer_result = await client.schema('basejump').from_('billing_customers')\
|
||||
.select('account_id')\
|
||||
.eq('id', subscription['customer'])\
|
||||
.single()\
|
||||
.execute()
|
||||
|
||||
if not customer_result.data:
|
||||
return
|
||||
|
||||
user_id = customer_result.data['account_id']
|
||||
|
||||
# Update to free tier (keep existing balance)
|
||||
await client.from_('credit_accounts').update({
|
||||
'tier': 'free',
|
||||
'stripe_subscription_id': None,
|
||||
'billing_cycle_anchor': None,
|
||||
'next_credit_grant': None
|
||||
}).eq('user_id', user_id).execute()
|
||||
|
||||
async def handle_renewal_simplified(invoice: Dict):
|
||||
"""Handle monthly renewal - grant credits"""
|
||||
db = DBConnection()
|
||||
client = await db.client
|
||||
|
||||
subscription_id = invoice.get('subscription')
|
||||
if not subscription_id:
|
||||
return
|
||||
|
||||
# Get user from customer
|
||||
customer_result = await client.schema('basejump').from_('billing_customers')\
|
||||
.select('account_id')\
|
||||
.eq('id', invoice['customer'])\
|
||||
.single()\
|
||||
.execute()
|
||||
|
||||
if not customer_result.data:
|
||||
return
|
||||
|
||||
user_id = customer_result.data['account_id']
|
||||
|
||||
# Get account
|
||||
account_result = await client.from_('credit_accounts')\
|
||||
.select('tier, last_grant_date')\
|
||||
.eq('user_id', user_id)\
|
||||
.single()\
|
||||
.execute()
|
||||
|
||||
if not account_result.data:
|
||||
return
|
||||
|
||||
account = account_result.data
|
||||
|
||||
# Check if enough time has passed (prevent double grants)
|
||||
if account.get('last_grant_date'):
|
||||
last_grant = datetime.fromisoformat(account['last_grant_date'].replace('Z', '+00:00'))
|
||||
if (datetime.now(timezone.utc) - last_grant).days < 25:
|
||||
return
|
||||
|
||||
# Grant monthly credits
|
||||
tier = account['tier']
|
||||
monthly_credits = get_monthly_credits(tier)
|
||||
|
||||
if monthly_credits > 0:
|
||||
await credit_service.add_credits(
|
||||
user_id=user_id,
|
||||
amount=monthly_credits,
|
||||
type='tier_grant',
|
||||
description=f"Monthly {tier} credits"
|
||||
)
|
||||
|
||||
# Update grant dates
|
||||
next_grant = datetime.now(timezone.utc) + timedelta(days=30)
|
||||
await client.from_('credit_accounts').update({
|
||||
'last_grant_date': datetime.now(timezone.utc).isoformat(),
|
||||
'next_credit_grant': next_grant.isoformat()
|
||||
}).eq('user_id', user_id).execute()
|
|
@ -8,15 +8,17 @@ import { Skeleton } from '@/components/ui/skeleton';
|
|||
import { Alert, AlertTitle, AlertDescription } from '@/components/ui/alert';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { useSharedSubscription, useSubscriptionContext } from '@/contexts/SubscriptionContext';
|
||||
import { isLocalMode } from '@/lib/config';
|
||||
import { isLocalMode, isStagingMode } from '@/lib/config';
|
||||
import Link from 'next/link';
|
||||
import { useCreatePortalSession } from '@/hooks/react-query/use-billing-v2';
|
||||
import { useCreatePortalSession, useTriggerTestRenewal } from '@/hooks/react-query/use-billing-v2';
|
||||
import { toast } from 'sonner';
|
||||
|
||||
const returnUrl = process.env.NEXT_PUBLIC_URL as string;
|
||||
|
||||
export default function PersonalAccountBillingPage() {
|
||||
const { data: accounts, isLoading, error } = useAccounts();
|
||||
const [showBillingModal, setShowBillingModal] = useState(false);
|
||||
const triggerTestRenewal = useTriggerTestRenewal();
|
||||
|
||||
const {
|
||||
data: subscriptionData,
|
||||
|
@ -146,6 +148,53 @@ export default function PersonalAccountBillingPage() {
|
|||
</div>
|
||||
</>
|
||||
)}
|
||||
{isStagingMode() && (
|
||||
<div className="mt-4 p-3 bg-yellow-50 dark:bg-yellow-900/20 border border-yellow-200 dark:border-yellow-800 rounded-lg">
|
||||
<p className="text-xs text-yellow-800 dark:text-yellow-200 mb-2">
|
||||
🧪 <strong>Test Mode:</strong> Simulate monthly credit renewal
|
||||
</p>
|
||||
<Button
|
||||
onClick={() => {
|
||||
triggerTestRenewal.mutate(undefined, {
|
||||
onSuccess: (result) => {
|
||||
if (result.success) {
|
||||
toast.success(
|
||||
<div>
|
||||
<p>Credits renewed successfully!</p>
|
||||
{result.credits_granted && (
|
||||
<p className="text-sm mt-1">
|
||||
+{result.credits_granted} credits added
|
||||
</p>
|
||||
)}
|
||||
{result.new_balance !== undefined && (
|
||||
<p className="text-xs mt-1 opacity-80">
|
||||
New balance: ${result.new_balance}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
} else {
|
||||
toast.error(result.message || 'Failed to trigger renewal');
|
||||
}
|
||||
},
|
||||
onError: (error) => {
|
||||
console.error('Test renewal error:', error);
|
||||
toast.error('Failed to trigger test renewal');
|
||||
}
|
||||
});
|
||||
}}
|
||||
size="sm"
|
||||
className="w-full text-xs bg-yellow-600 hover:bg-yellow-700"
|
||||
disabled={triggerTestRenewal.isPending}
|
||||
>
|
||||
{triggerTestRenewal.isPending ? (
|
||||
'🔄 Triggering...'
|
||||
) : (
|
||||
'🔄 Trigger Monthly Credit Renewal (Test)'
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
|
|
@ -236,55 +236,15 @@ export function BillingModal({ open, onOpenChange, returnUrl = typeof window !==
|
|||
</DialogHeader>
|
||||
|
||||
<>
|
||||
{/* Usage Limit Alert
|
||||
{showUsageLimitAlert && (
|
||||
<div className="mb-6">
|
||||
<div className="flex items-start p-3 sm:p-4 bg-destructive/5 border border-destructive/50 rounded-lg">
|
||||
<div className="flex items-start space-x-3">
|
||||
<div className="flex-shrink-0 mt-0.5">
|
||||
<Zap className="w-4 h-4 sm:w-5 sm:h-5 text-destructive" />
|
||||
</div>
|
||||
<div className="text-xs sm:text-sm min-w-0">
|
||||
<p className="font-medium text-destructive">Usage Limit Reached</p>
|
||||
<p className="text-destructive break-words">
|
||||
Your current plan has been exhausted for this billing period.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)} */}
|
||||
|
||||
{/* Usage section - show loading state or actual data */}
|
||||
{isLoading || authLoading ? (
|
||||
<div className="mb-6">
|
||||
<div className="rounded-lg border bg-background p-4">
|
||||
<div className="flex justify-between items-center">
|
||||
<Skeleton className="h-4 w-40" />
|
||||
<Skeleton className="h-4 w-24" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : subscriptionData && (
|
||||
<div className="mb-6">
|
||||
<div className="rounded-lg border bg-background p-4">
|
||||
<div className="flex justify-between items-center">
|
||||
<span className="text-sm font-medium text-foreground/90">
|
||||
Agent Usage This Month
|
||||
</span>
|
||||
<span className="text-sm font-medium">
|
||||
${subscriptionData.current_usage?.toFixed(2) || '0'} /{' '}
|
||||
${subscriptionData.cost_limit || '0'}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Show pricing section immediately - no loading state */}
|
||||
<PricingSection returnUrl={returnUrl} showTitleAndTabs={false} />
|
||||
|
||||
{/* Subscription Management Section - only show if there's actual subscription data */}
|
||||
<PricingSection
|
||||
returnUrl={returnUrl}
|
||||
showTitleAndTabs={false}
|
||||
onSubscriptionUpdate={() => {
|
||||
setTimeout(() => {
|
||||
fetchSubscriptionData();
|
||||
}, 500);
|
||||
}}
|
||||
/>
|
||||
{error ? (
|
||||
<div className="mt-6 pt-4 border-t border-border">
|
||||
<div className="p-4 bg-destructive/10 border border-destructive/20 rounded-lg text-center">
|
||||
|
@ -293,7 +253,6 @@ export function BillingModal({ open, onOpenChange, returnUrl = typeof window !==
|
|||
</div>
|
||||
) : subscriptionData?.subscription && (
|
||||
<div className="mt-6 pt-4 border-t border-border">
|
||||
{/* Subscription Status Info Box */}
|
||||
<div className="bg-muted/30 border border-border rounded-lg p-3 mb-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-2">
|
||||
|
|
|
@ -584,6 +584,7 @@ interface PricingSectionProps {
|
|||
insideDialog?: boolean;
|
||||
showInfo?: boolean;
|
||||
noPadding?: boolean;
|
||||
onSubscriptionUpdate?: () => void;
|
||||
}
|
||||
|
||||
export function PricingSection({
|
||||
|
@ -593,6 +594,7 @@ export function PricingSection({
|
|||
insideDialog = false,
|
||||
showInfo = true,
|
||||
noPadding = false,
|
||||
onSubscriptionUpdate,
|
||||
}: PricingSectionProps) {
|
||||
|
||||
const { data: subscriptionData, isLoading: isFetchingPlan, error: subscriptionQueryError, refetch: refetchSubscription } = useSubscription();
|
||||
|
@ -613,25 +615,21 @@ export function PricingSection({
|
|||
);
|
||||
|
||||
if (currentTier) {
|
||||
// Check if current subscription is yearly commitment (new yearly)
|
||||
if (currentTier.monthlyCommitmentStripePriceId === currentSubscription.price_id) {
|
||||
return 'yearly_commitment';
|
||||
} else if (currentTier.yearlyStripePriceId === currentSubscription.price_id) {
|
||||
// Legacy yearly plans
|
||||
return 'yearly';
|
||||
} else if (currentTier.stripePriceId === currentSubscription.price_id) {
|
||||
return 'monthly';
|
||||
}
|
||||
}
|
||||
|
||||
// Default to yearly_commitment (new yearly) if we can't determine current plan type
|
||||
return 'yearly_commitment';
|
||||
}, [isAuthenticated, currentSubscription]);
|
||||
|
||||
const [billingPeriod, setBillingPeriod] = useState<'monthly' | 'yearly' | 'yearly_commitment'>(getDefaultBillingPeriod());
|
||||
const [planLoadingStates, setPlanLoadingStates] = useState<Record<string, boolean>>({});
|
||||
|
||||
// Update billing period when subscription data changes
|
||||
useEffect(() => {
|
||||
setBillingPeriod(getDefaultBillingPeriod());
|
||||
}, [getDefaultBillingPeriod]);
|
||||
|
@ -647,6 +645,10 @@ export function PricingSection({
|
|||
setTimeout(() => {
|
||||
setPlanLoadingStates({});
|
||||
}, 1000);
|
||||
// Call parent's update handler if provided
|
||||
if (onSubscriptionUpdate) {
|
||||
onSubscriptionUpdate();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -137,4 +137,15 @@ export const useDeductTokenUsage = () => {
|
|||
queryClient.invalidateQueries({ queryKey: billingKeys.status() });
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const useTriggerTestRenewal = () => {
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
return useMutation({
|
||||
mutationFn: () => billingApiV2.triggerTestRenewal(),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: billingKeys.all });
|
||||
},
|
||||
});
|
||||
};
|
|
@ -1963,8 +1963,7 @@ export const getSubscription = async (): Promise<SubscriptionStatus> => {
|
|||
}
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
// Map the new billing v2 format to the old format for backward compatibility
|
||||
|
||||
return {
|
||||
subscription: data.subscription ? {
|
||||
...data.subscription,
|
||||
|
@ -1974,7 +1973,7 @@ export const getSubscription = async (): Promise<SubscriptionStatus> => {
|
|||
cost_limit: data.tier?.credits || 0,
|
||||
credit_balance: data.credits?.balance || 0,
|
||||
can_purchase_credits: data.credits?.can_purchase || false,
|
||||
...data // Include any other fields that might be used
|
||||
...data
|
||||
} as SubscriptionStatus;
|
||||
} catch (error) {
|
||||
if (error instanceof NoAccessTokenAvailableError) {
|
||||
|
|
|
@ -121,6 +121,13 @@ export interface ReactivateSubscriptionResponse {
|
|||
message: string;
|
||||
}
|
||||
|
||||
export interface TestRenewalResponse {
|
||||
success: boolean;
|
||||
message?: string;
|
||||
credits_granted?: number;
|
||||
new_balance?: number;
|
||||
}
|
||||
|
||||
export const billingApiV2 = {
|
||||
async getSubscription() {
|
||||
const response = await backendApi.get<SubscriptionInfo>('/billing/v2/subscription');
|
||||
|
@ -204,6 +211,12 @@ export const billingApiV2 = {
|
|||
);
|
||||
if (response.error) throw response.error;
|
||||
return response.data!;
|
||||
},
|
||||
|
||||
async triggerTestRenewal() {
|
||||
const response = await backendApi.post<TestRenewalResponse>('/billing/v2/test/trigger-renewal');
|
||||
if (response.error) throw response.error;
|
||||
return response.data!;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -222,4 +235,5 @@ export const purchaseCredits = (request: PurchaseCreditsRequest) =>
|
|||
billingApiV2.purchaseCredits(request);
|
||||
export const getTransactions = (limit?: number, offset?: number) =>
|
||||
billingApiV2.getTransactions(limit, offset);
|
||||
export const getUsageHistory = (days?: number) => billingApiV2.getUsageHistory(days);
|
||||
export const getUsageHistory = (days?: number) => billingApiV2.getUsageHistory(days);
|
||||
export const triggerTestRenewal = () => billingApiV2.triggerTestRenewal();
|
Loading…
Reference in New Issue