mirror of https://github.com/kortix-ai/suna.git
Merge pull request #1700 from escapade-mckv/fix-yearly-commitment
Fix yearly commitment
This commit is contained in:
commit
7c7ccbaa46
|
@ -822,6 +822,15 @@ async def get_subscription_commitment(
|
|||
subscription_id: str,
|
||||
account_id: str = Depends(verify_and_get_user_id_from_jwt)
|
||||
) -> Dict:
|
||||
try:
|
||||
commitment_status = await subscription_service.get_commitment_status(account_id)
|
||||
if commitment_status['has_commitment']:
|
||||
logger.info(f"[COMMITMENT] Account {account_id} has active commitment, {commitment_status['months_remaining']} months remaining")
|
||||
|
||||
return commitment_status
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking commitment status for account {account_id}: {e}")
|
||||
return {
|
||||
'has_commitment': False,
|
||||
'can_cancel': True,
|
||||
|
|
|
@ -173,3 +173,35 @@ def is_model_allowed(tier_name: str, model: str) -> bool:
|
|||
def get_project_limit(tier_name: str) -> int:
|
||||
tier = TIERS.get(tier_name)
|
||||
return tier.project_limit if tier else 3
|
||||
|
||||
def is_commitment_price_id(price_id: str) -> bool:
|
||||
commitment_price_ids = [
|
||||
config.STRIPE_TIER_2_17_YEARLY_COMMITMENT_ID,
|
||||
config.STRIPE_TIER_6_42_YEARLY_COMMITMENT_ID,
|
||||
config.STRIPE_TIER_25_170_YEARLY_COMMITMENT_ID
|
||||
]
|
||||
return price_id in commitment_price_ids
|
||||
|
||||
def get_commitment_duration_months(price_id: str) -> int:
|
||||
if is_commitment_price_id(price_id):
|
||||
return 12
|
||||
return 0
|
||||
|
||||
def get_price_type(price_id: str) -> str:
|
||||
if is_commitment_price_id(price_id):
|
||||
return 'yearly_commitment'
|
||||
|
||||
yearly_price_ids = [
|
||||
config.STRIPE_TIER_2_20_YEARLY_ID,
|
||||
config.STRIPE_TIER_6_50_YEARLY_ID,
|
||||
config.STRIPE_TIER_12_100_YEARLY_ID,
|
||||
config.STRIPE_TIER_25_200_YEARLY_ID,
|
||||
config.STRIPE_TIER_50_400_YEARLY_ID,
|
||||
config.STRIPE_TIER_125_800_YEARLY_ID,
|
||||
config.STRIPE_TIER_200_1000_YEARLY_ID
|
||||
]
|
||||
|
||||
if price_id in yearly_price_ids:
|
||||
return 'yearly'
|
||||
|
||||
return 'monthly'
|
|
@ -0,0 +1,432 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Migration script to track commitment plans by querying Stripe directly.
|
||||
|
||||
Usage:
|
||||
# Dry run - see what would be changed without making changes
|
||||
python -m core.billing.migrate_existing_commitments_stripe --dry-run
|
||||
|
||||
# Apply the migration
|
||||
python -m core.billing.migrate_existing_commitments_stripe
|
||||
|
||||
# Only verify existing commitments
|
||||
python -m core.billing.migrate_existing_commitments_stripe --verify-only
|
||||
"""
|
||||
import asyncio
|
||||
import sys
|
||||
import argparse
|
||||
import stripe
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from core.services.supabase import DBConnection
|
||||
from core.utils.config import config
|
||||
from core.utils.logger import logger
|
||||
from .config import is_commitment_price_id, get_commitment_duration_months
|
||||
|
||||
if config.STRIPE_SECRET_KEY:
|
||||
stripe.api_key = config.STRIPE_SECRET_KEY
|
||||
else:
|
||||
logger.warning("[COMMITMENT MIGRATION] No STRIPE_SECRET_KEY configured")
|
||||
|
||||
async def migrate_existing_commitments(dry_run=False):
|
||||
db = DBConnection()
|
||||
client = await db.client
|
||||
|
||||
mode = "DRY RUN" if dry_run else "LIVE"
|
||||
logger.info(f"[COMMITMENT MIGRATION] Starting migration of existing commitment users - {mode} MODE")
|
||||
|
||||
if dry_run:
|
||||
logger.info("[COMMITMENT MIGRATION] ⚠️ DRY RUN - No changes will be made to the database")
|
||||
|
||||
commitment_price_ids = [
|
||||
config.STRIPE_TIER_2_17_YEARLY_COMMITMENT_ID,
|
||||
config.STRIPE_TIER_6_42_YEARLY_COMMITMENT_ID,
|
||||
config.STRIPE_TIER_25_170_YEARLY_COMMITMENT_ID
|
||||
]
|
||||
|
||||
commitment_price_ids = [pid for pid in commitment_price_ids if pid]
|
||||
|
||||
if not commitment_price_ids:
|
||||
logger.error("[COMMITMENT MIGRATION] No commitment price IDs configured!")
|
||||
return
|
||||
|
||||
logger.info(f"[COMMITMENT MIGRATION] Looking for subscriptions with commitment price IDs: {commitment_price_ids}")
|
||||
|
||||
logger.info("[COMMITMENT MIGRATION] Querying Stripe for ALL subscriptions to find commitments...")
|
||||
|
||||
commitment_subscriptions = []
|
||||
migrated_count = 0
|
||||
error_count = 0
|
||||
skipped_count = 0
|
||||
accounts_to_migrate = []
|
||||
|
||||
all_price_ids_seen = set()
|
||||
price_id_counts = {}
|
||||
|
||||
try:
|
||||
has_more = True
|
||||
starting_after = None
|
||||
total_checked = 0
|
||||
|
||||
while has_more:
|
||||
params = {
|
||||
'status': 'active',
|
||||
'limit': 100
|
||||
}
|
||||
if starting_after:
|
||||
params['starting_after'] = starting_after
|
||||
|
||||
logger.info(f"[COMMITMENT MIGRATION] Fetching batch of subscriptions from Stripe (checked {total_checked} so far)...")
|
||||
subscriptions = await stripe.Subscription.list_async(**params)
|
||||
|
||||
logger.info(f"[COMMITMENT MIGRATION] Retrieved {len(subscriptions.data)} subscriptions in this batch")
|
||||
|
||||
for subscription in subscriptions.data:
|
||||
total_checked += 1
|
||||
|
||||
try:
|
||||
price_id = None
|
||||
|
||||
if total_checked == 1:
|
||||
logger.info(f"[COMMITMENT MIGRATION] First subscription structure: {type(subscription)}")
|
||||
logger.info(f"[COMMITMENT MIGRATION] Subscription has items attr: {hasattr(subscription, 'items')}")
|
||||
if hasattr(subscription, 'items'):
|
||||
logger.info(f"[COMMITMENT MIGRATION] Items type: {type(subscription.items)}")
|
||||
|
||||
if hasattr(subscription, 'items'):
|
||||
items = subscription.items
|
||||
if hasattr(items, 'data') and len(items.data) > 0:
|
||||
price_id = items.data[0].price.id
|
||||
|
||||
if not price_id:
|
||||
continue
|
||||
|
||||
all_price_ids_seen.add(price_id)
|
||||
price_id_counts[price_id] = price_id_counts.get(price_id, 0) + 1
|
||||
|
||||
if is_commitment_price_id(price_id):
|
||||
commitment_subscriptions.append(subscription)
|
||||
|
||||
account_id = subscription.metadata.get('account_id')
|
||||
|
||||
if not account_id:
|
||||
customer_result = await client.schema('basejump').from_('billing_customers')\
|
||||
.select('account_id')\
|
||||
.eq('id', subscription.customer)\
|
||||
.execute()
|
||||
|
||||
if customer_result.data:
|
||||
account_id = customer_result.data[0]['account_id']
|
||||
|
||||
if account_id:
|
||||
logger.info(f"[COMMITMENT MIGRATION] Found commitment subscription: {subscription.id} for account: {account_id}, price: {price_id}")
|
||||
else:
|
||||
logger.warning(f"[COMMITMENT MIGRATION] Found commitment subscription {subscription.id} but couldn't determine account_id")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[COMMITMENT MIGRATION] Error processing subscription: {e}")
|
||||
continue
|
||||
|
||||
has_more = subscriptions.has_more
|
||||
if has_more and subscriptions.data:
|
||||
starting_after = subscriptions.data[-1].id
|
||||
|
||||
logger.info(f"[COMMITMENT MIGRATION] Checked {total_checked} total subscriptions")
|
||||
logger.info(f"[COMMITMENT MIGRATION] Found {len(commitment_subscriptions)} commitment subscriptions")
|
||||
|
||||
logger.info(f"[COMMITMENT MIGRATION] Found {len(all_price_ids_seen)} unique price IDs across all subscriptions")
|
||||
|
||||
if price_id_counts:
|
||||
sorted_price_ids = sorted(price_id_counts.items(), key=lambda x: x[1], reverse=True)
|
||||
logger.info("[COMMITMENT MIGRATION] Top 10 most common price IDs:")
|
||||
for price_id, count in sorted_price_ids[:10]:
|
||||
is_commitment = is_commitment_price_id(price_id)
|
||||
marker = " ✓ COMMITMENT" if is_commitment else ""
|
||||
logger.info(f" - {price_id}: {count} subscriptions{marker}")
|
||||
for commitment_id in commitment_price_ids:
|
||||
if commitment_id and price_id and commitment_id[:10] in price_id:
|
||||
logger.warning(f" ⚠️ This looks similar to commitment ID: {commitment_id}")
|
||||
|
||||
logger.info("[COMMITMENT MIGRATION] Checking commitment price ID configuration:")
|
||||
for cpid in commitment_price_ids:
|
||||
if cpid and cpid.startswith('price_'):
|
||||
is_recognized = is_commitment_price_id(cpid)
|
||||
if is_recognized:
|
||||
logger.info(f" ✓ {cpid} - recognized as commitment price ID by is_commitment_price_id()")
|
||||
else:
|
||||
logger.error(f" ✗ {cpid} - NOT recognized by is_commitment_price_id() function!")
|
||||
else:
|
||||
logger.warning(f" ✗ {cpid} - doesn't look like valid Stripe price ID")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[COMMITMENT MIGRATION] Error querying Stripe: {str(e)}")
|
||||
return
|
||||
|
||||
if not commitment_subscriptions:
|
||||
logger.info("[COMMITMENT MIGRATION] No commitment subscriptions found in Stripe")
|
||||
return
|
||||
|
||||
for subscription in commitment_subscriptions:
|
||||
try:
|
||||
account_id = subscription.metadata.get('account_id')
|
||||
|
||||
if not account_id:
|
||||
customer_result = await client.schema('basejump').from_('billing_customers')\
|
||||
.select('account_id')\
|
||||
.eq('id', subscription.customer)\
|
||||
.execute()
|
||||
|
||||
if customer_result.data:
|
||||
account_id = customer_result.data[0]['account_id']
|
||||
else:
|
||||
logger.warning(f"[COMMITMENT MIGRATION] Cannot find account_id for subscription {subscription.id}")
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
start_date = datetime.fromtimestamp(subscription.start_date, tz=timezone.utc)
|
||||
end_date = start_date + timedelta(days=365)
|
||||
|
||||
existing = await client.from_('credit_accounts').select(
|
||||
'commitment_type'
|
||||
).eq('account_id', account_id).execute()
|
||||
|
||||
if existing.data and existing.data[0].get('commitment_type'):
|
||||
logger.info(f"[COMMITMENT MIGRATION] Account {account_id} already has commitment tracked, skipping")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
months_remaining = (end_date.year - datetime.now(timezone.utc).year) * 12 + \
|
||||
(end_date.month - datetime.now(timezone.utc).month)
|
||||
|
||||
price_id = None
|
||||
if hasattr(subscription, 'items'):
|
||||
items = subscription.items
|
||||
if hasattr(items, 'data') and len(items.data) > 0:
|
||||
price_id = items.data[0].price.id
|
||||
|
||||
if not price_id:
|
||||
logger.warning(f"[COMMITMENT MIGRATION] Cannot get price_id for subscription {subscription.id}")
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
if dry_run:
|
||||
logger.info(
|
||||
f"[COMMITMENT MIGRATION] 🔍 Would migrate account {account_id}:\n"
|
||||
f" - Subscription ID: {subscription.id}\n"
|
||||
f" - Price ID: {price_id}\n"
|
||||
f" - Commitment start: {start_date.date()}\n"
|
||||
f" - Commitment end: {end_date.date()}\n"
|
||||
f" - Months remaining: {months_remaining}"
|
||||
)
|
||||
accounts_to_migrate.append({
|
||||
'account_id': account_id,
|
||||
'subscription_id': subscription.id,
|
||||
'price_id': price_id,
|
||||
'end_date': end_date.date(),
|
||||
'months_remaining': months_remaining
|
||||
})
|
||||
else:
|
||||
await client.from_('credit_accounts').update({
|
||||
'commitment_type': 'yearly_commitment',
|
||||
'commitment_start_date': start_date.isoformat(),
|
||||
'commitment_end_date': end_date.isoformat(),
|
||||
'commitment_price_id': price_id,
|
||||
'can_cancel_after': end_date.isoformat()
|
||||
}).eq('account_id', account_id).execute()
|
||||
|
||||
await client.from_('commitment_history').insert({
|
||||
'account_id': account_id,
|
||||
'commitment_type': 'yearly_commitment',
|
||||
'price_id': price_id,
|
||||
'start_date': start_date.isoformat(),
|
||||
'end_date': end_date.isoformat(),
|
||||
'stripe_subscription_id': subscription.id
|
||||
}).execute()
|
||||
|
||||
logger.info(
|
||||
f"[COMMITMENT MIGRATION] ✅ Migrated account {account_id}: "
|
||||
f"commitment ends {end_date.date()}, {months_remaining} months remaining"
|
||||
)
|
||||
|
||||
migrated_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[COMMITMENT MIGRATION] Error processing subscription {subscription.id}: {e}")
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
action = "Would migrate" if dry_run else "Migrated"
|
||||
logger.info(
|
||||
f"[COMMITMENT MIGRATION] Migration {'simulation' if dry_run else 'complete'}. "
|
||||
f"{action}: {migrated_count}, Skipped: {skipped_count}, Errors: {error_count}"
|
||||
)
|
||||
|
||||
if dry_run:
|
||||
logger.info("[COMMITMENT MIGRATION] ⚠️ This was a DRY RUN - no changes were made")
|
||||
logger.info("[COMMITMENT MIGRATION] To apply changes, run without --dry-run flag")
|
||||
|
||||
if accounts_to_migrate:
|
||||
logger.info("\n[COMMITMENT MIGRATION] Summary of accounts that would be migrated:")
|
||||
logger.info("=" * 70)
|
||||
for acc in accounts_to_migrate:
|
||||
logger.info(
|
||||
f" Account: {acc['account_id']}\n"
|
||||
f" - Subscription: {acc['subscription_id']}\n"
|
||||
f" - Commitment ends: {acc['end_date']}\n"
|
||||
f" - Months remaining: {acc['months_remaining']}\n"
|
||||
f" - Price ID: {acc['price_id']}"
|
||||
)
|
||||
logger.info("=" * 70)
|
||||
logger.info(f"Total accounts that would be migrated: {len(accounts_to_migrate)}")
|
||||
else:
|
||||
commitment_accounts = await client.from_('credit_accounts').select(
|
||||
'account_id, commitment_type, commitment_end_date'
|
||||
).not_.is_('commitment_type', 'null').execute()
|
||||
|
||||
if commitment_accounts.data:
|
||||
logger.info(f"[COMMITMENT MIGRATION] Total accounts with commitments: {len(commitment_accounts.data)}")
|
||||
|
||||
async def verify_commitment_tracking():
|
||||
db = DBConnection()
|
||||
client = await db.client
|
||||
|
||||
logger.info("[COMMITMENT VERIFICATION] Starting verification...")
|
||||
|
||||
commitment_price_ids = [
|
||||
config.STRIPE_TIER_2_17_YEARLY_COMMITMENT_ID,
|
||||
config.STRIPE_TIER_6_42_YEARLY_COMMITMENT_ID,
|
||||
config.STRIPE_TIER_25_170_YEARLY_COMMITMENT_ID
|
||||
]
|
||||
commitment_price_ids = [pid for pid in commitment_price_ids if pid]
|
||||
|
||||
tracked_accounts = await client.from_('credit_accounts').select(
|
||||
'account_id, stripe_subscription_id, commitment_type'
|
||||
).not_.is_('commitment_type', 'null').execute()
|
||||
|
||||
tracked_subscription_ids = set()
|
||||
if tracked_accounts.data:
|
||||
for acc in tracked_accounts.data:
|
||||
if acc.get('stripe_subscription_id'):
|
||||
tracked_subscription_ids.add(acc['stripe_subscription_id'])
|
||||
|
||||
logger.info(f"[COMMITMENT VERIFICATION] Found {len(tracked_subscription_ids)} tracked commitment subscriptions in database")
|
||||
|
||||
untracked = []
|
||||
has_more = True
|
||||
starting_after = None
|
||||
|
||||
while has_more:
|
||||
params = {
|
||||
'status': 'active',
|
||||
'limit': 100
|
||||
}
|
||||
if starting_after:
|
||||
params['starting_after'] = starting_after
|
||||
|
||||
subscriptions = await stripe.Subscription.list_async(**params)
|
||||
|
||||
for subscription in subscriptions.data:
|
||||
price_id = None
|
||||
try:
|
||||
if hasattr(subscription, 'items'):
|
||||
items = subscription.items
|
||||
if hasattr(items, 'data') and len(items.data) > 0:
|
||||
price_id = items.data[0].price.id
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if price_id and is_commitment_price_id(price_id) and subscription.id not in tracked_subscription_ids:
|
||||
untracked.append({
|
||||
'subscription_id': subscription.id,
|
||||
'price_id': price_id,
|
||||
'customer': subscription.customer
|
||||
})
|
||||
|
||||
has_more = subscriptions.has_more
|
||||
if has_more and subscriptions.data:
|
||||
starting_after = subscriptions.data[-1].id
|
||||
|
||||
if untracked:
|
||||
logger.warning(f"[COMMITMENT VERIFICATION] Found {len(untracked)} untracked commitment subscriptions:")
|
||||
for item in untracked:
|
||||
logger.warning(f" - Subscription {item['subscription_id']}: {item['price_id']}")
|
||||
else:
|
||||
logger.info("[COMMITMENT VERIFICATION] ✅ All commitment subscriptions are properly tracked")
|
||||
|
||||
async def main(dry_run=False):
|
||||
try:
|
||||
await migrate_existing_commitments(dry_run=dry_run)
|
||||
if not dry_run:
|
||||
await verify_commitment_tracking()
|
||||
except Exception as e:
|
||||
logger.error(f"[COMMITMENT MIGRATION] Fatal error: {e}")
|
||||
raise
|
||||
|
||||
async def list_all_price_ids():
|
||||
logger.info("[PRICE ID DISCOVERY] Fetching all active subscriptions from Stripe...")
|
||||
|
||||
price_id_counts = {}
|
||||
total_checked = 0
|
||||
has_more = True
|
||||
starting_after = None
|
||||
|
||||
while has_more:
|
||||
params = {
|
||||
'status': 'active',
|
||||
'limit': 100
|
||||
}
|
||||
if starting_after:
|
||||
params['starting_after'] = starting_after
|
||||
|
||||
subscriptions = await stripe.Subscription.list_async(**params)
|
||||
|
||||
for subscription in subscriptions.data:
|
||||
total_checked += 1
|
||||
try:
|
||||
if hasattr(subscription, 'items'):
|
||||
items = subscription.items
|
||||
if hasattr(items, 'data') and len(items.data) > 0:
|
||||
price_id = items.data[0].price.id
|
||||
price_id_counts[price_id] = price_id_counts.get(price_id, 0) + 1
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
has_more = subscriptions.has_more
|
||||
if has_more and subscriptions.data:
|
||||
starting_after = subscriptions.data[-1].id
|
||||
|
||||
logger.info(f"[PRICE ID DISCOVERY] Checked {total_checked} subscriptions")
|
||||
logger.info(f"[PRICE ID DISCOVERY] Found {len(price_id_counts)} unique price IDs")
|
||||
|
||||
sorted_price_ids = sorted(price_id_counts.items(), key=lambda x: x[1], reverse=True)
|
||||
logger.info("[PRICE ID DISCOVERY] All price IDs (sorted by usage):")
|
||||
for price_id, count in sorted_price_ids:
|
||||
logger.info(f" {price_id}: {count} subscriptions")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Migrate existing commitment plan users by querying Stripe directly'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dry-run',
|
||||
action='store_true',
|
||||
help='Run in dry-run mode (no database changes)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--verify-only',
|
||||
action='store_true',
|
||||
help='Only verify existing commitments without migrating'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--list-price-ids',
|
||||
action='store_true',
|
||||
help='List all price IDs from Stripe (for debugging)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.list_price_ids:
|
||||
asyncio.run(list_all_price_ids())
|
||||
elif args.verify_only:
|
||||
asyncio.run(verify_commitment_tracking())
|
||||
else:
|
||||
asyncio.run(main(dry_run=args.dry_run))
|
|
@ -12,7 +12,10 @@ from .config import (
|
|||
TIERS,
|
||||
TRIAL_DURATION_DAYS,
|
||||
TRIAL_CREDITS,
|
||||
get_tier_by_name
|
||||
get_tier_by_name,
|
||||
is_commitment_price_id,
|
||||
get_commitment_duration_months,
|
||||
get_price_type
|
||||
)
|
||||
from .credit_manager import credit_manager
|
||||
|
||||
|
@ -363,13 +366,28 @@ class SubscriptionService:
|
|||
client = await db.client
|
||||
|
||||
credit_result = await client.from_('credit_accounts').select(
|
||||
'stripe_subscription_id'
|
||||
'stripe_subscription_id, commitment_type, commitment_end_date'
|
||||
).eq('account_id', account_id).execute()
|
||||
|
||||
if not credit_result.data or not credit_result.data[0].get('stripe_subscription_id'):
|
||||
raise HTTPException(status_code=404, detail="No subscription found")
|
||||
|
||||
subscription_id = credit_result.data[0]['stripe_subscription_id']
|
||||
commitment_type = credit_result.data[0].get('commitment_type')
|
||||
commitment_end_date = credit_result.data[0].get('commitment_end_date')
|
||||
|
||||
# Check if user is in a commitment period
|
||||
if commitment_type and commitment_end_date:
|
||||
end_date = datetime.fromisoformat(commitment_end_date.replace('Z', '+00:00'))
|
||||
if datetime.now(timezone.utc) < end_date:
|
||||
months_remaining = (end_date.year - datetime.now(timezone.utc).year) * 12 + \
|
||||
(end_date.month - datetime.now(timezone.utc).month)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Cannot cancel during commitment period. Your commitment ends on {end_date.date()}. "
|
||||
f"You have {months_remaining} months remaining in your commitment."
|
||||
)
|
||||
|
||||
try:
|
||||
subscription = await stripe.Subscription.modify_async(
|
||||
|
@ -378,6 +396,18 @@ class SubscriptionService:
|
|||
metadata={'cancellation_feedback': feedback} if feedback else {}
|
||||
)
|
||||
|
||||
# Log cancellation in commitment history if applicable
|
||||
if commitment_type:
|
||||
await client.from_('commitment_history').insert({
|
||||
'account_id': account_id,
|
||||
'commitment_type': commitment_type,
|
||||
'start_date': credit_result.data[0].get('commitment_start_date'),
|
||||
'end_date': commitment_end_date,
|
||||
'stripe_subscription_id': subscription_id,
|
||||
'cancelled_at': datetime.now(timezone.utc).isoformat(),
|
||||
'cancellation_reason': feedback or 'User cancelled'
|
||||
}).execute()
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': 'Subscription will be cancelled at the end of the current period',
|
||||
|
@ -433,6 +463,9 @@ class SubscriptionService:
|
|||
account_id = customer_result.data[0]['account_id']
|
||||
price_id = subscription['items']['data'][0]['price']['id'] if subscription.get('items') else None
|
||||
|
||||
# Handle commitment tracking
|
||||
await self._track_commitment_if_needed(account_id, price_id, subscription, client)
|
||||
|
||||
new_tier_info = get_tier_by_price_id(price_id)
|
||||
if not new_tier_info:
|
||||
logger.warning(f"Unknown price ID in subscription: {price_id}")
|
||||
|
@ -685,6 +718,89 @@ class SubscriptionService:
|
|||
logger.error(f"[ALLOWED_MODELS] Error getting allowed models for user {user_id}: {e}")
|
||||
return []
|
||||
|
||||
async def _track_commitment_if_needed(self, account_id: str, price_id: str, subscription: Dict, client):
|
||||
"""Track commitment if the subscription uses a commitment price ID"""
|
||||
if not is_commitment_price_id(price_id):
|
||||
return
|
||||
|
||||
commitment_duration = get_commitment_duration_months(price_id)
|
||||
if commitment_duration == 0:
|
||||
return
|
||||
|
||||
start_date = datetime.fromtimestamp(subscription['current_period_start'], tz=timezone.utc)
|
||||
end_date = start_date + timedelta(days=commitment_duration * 30) # Approximate months
|
||||
|
||||
# Update credit_accounts with commitment info
|
||||
await client.from_('credit_accounts').update({
|
||||
'commitment_type': 'yearly_commitment',
|
||||
'commitment_start_date': start_date.isoformat(),
|
||||
'commitment_end_date': end_date.isoformat(),
|
||||
'commitment_price_id': price_id,
|
||||
'can_cancel_after': end_date.isoformat()
|
||||
}).eq('account_id', account_id).execute()
|
||||
|
||||
# Log in commitment history
|
||||
await client.from_('commitment_history').insert({
|
||||
'account_id': account_id,
|
||||
'commitment_type': 'yearly_commitment',
|
||||
'price_id': price_id,
|
||||
'start_date': start_date.isoformat(),
|
||||
'end_date': end_date.isoformat(),
|
||||
'stripe_subscription_id': subscription['id']
|
||||
}).execute()
|
||||
|
||||
logger.info(f"[COMMITMENT] Tracked yearly commitment for account {account_id}, ends {end_date.date()}")
|
||||
|
||||
async def get_commitment_status(self, account_id: str) -> Dict:
|
||||
"""Get the commitment status for an account"""
|
||||
db = DBConnection()
|
||||
client = await db.client
|
||||
|
||||
result = await client.from_('credit_accounts').select(
|
||||
'commitment_type, commitment_start_date, commitment_end_date, commitment_price_id'
|
||||
).eq('account_id', account_id).execute()
|
||||
|
||||
if not result.data or not result.data[0].get('commitment_type'):
|
||||
return {
|
||||
'has_commitment': False,
|
||||
'can_cancel': True,
|
||||
'commitment_type': None,
|
||||
'months_remaining': None,
|
||||
'commitment_end_date': None
|
||||
}
|
||||
|
||||
data = result.data[0]
|
||||
end_date = datetime.fromisoformat(data['commitment_end_date'].replace('Z', '+00:00'))
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
if now >= end_date:
|
||||
# Commitment has expired, clear it
|
||||
await client.from_('credit_accounts').update({
|
||||
'commitment_type': None,
|
||||
'commitment_start_date': None,
|
||||
'commitment_end_date': None,
|
||||
'commitment_price_id': None,
|
||||
'can_cancel_after': None
|
||||
}).eq('account_id', account_id).execute()
|
||||
|
||||
return {
|
||||
'has_commitment': False,
|
||||
'can_cancel': True,
|
||||
'commitment_type': None,
|
||||
'months_remaining': None,
|
||||
'commitment_end_date': None
|
||||
}
|
||||
|
||||
months_remaining = (end_date.year - now.year) * 12 + (end_date.month - now.month)
|
||||
|
||||
return {
|
||||
'has_commitment': True,
|
||||
'can_cancel': False,
|
||||
'commitment_type': data['commitment_type'],
|
||||
'months_remaining': max(1, months_remaining),
|
||||
'commitment_end_date': data['commitment_end_date']
|
||||
}
|
||||
|
||||
|
||||
subscription_service = SubscriptionService()
|
||||
|
||||
|
|
|
@ -14,6 +14,8 @@ from .config import (
|
|||
get_monthly_credits,
|
||||
TRIAL_DURATION_DAYS,
|
||||
TRIAL_CREDITS,
|
||||
is_commitment_price_id,
|
||||
get_commitment_duration_months
|
||||
)
|
||||
from .credit_manager import credit_manager
|
||||
|
||||
|
@ -265,6 +267,15 @@ class WebhookService:
|
|||
logger.info(f"[WEBHOOK] Previous attributes: {previous_attributes}")
|
||||
logger.info(f"[WEBHOOK] Account ID from metadata: {subscription.metadata.get('account_id')}")
|
||||
|
||||
# Track commitment if price changed to a commitment plan
|
||||
price_id = subscription['items']['data'][0]['price']['id'] if subscription.get('items') else None
|
||||
prev_price_id = previous_attributes.get('items', {}).get('data', [{}])[0].get('price', {}).get('id') if previous_attributes.get('items') else None
|
||||
|
||||
if price_id and price_id != prev_price_id and is_commitment_price_id(price_id):
|
||||
account_id = subscription.metadata.get('account_id')
|
||||
if account_id:
|
||||
await self._track_commitment(account_id, price_id, subscription, client)
|
||||
|
||||
if subscription.status == 'trialing' and subscription.get('default_payment_method') and not prev_default_payment:
|
||||
account_id = subscription.metadata.get('account_id')
|
||||
if account_id:
|
||||
|
@ -740,4 +751,32 @@ class WebhookService:
|
|||
logger.error(f"Error handling subscription renewal: {e}")
|
||||
|
||||
|
||||
async def _track_commitment(self, account_id: str, price_id: str, subscription: Dict, client):
|
||||
commitment_duration = get_commitment_duration_months(price_id)
|
||||
if commitment_duration == 0:
|
||||
return
|
||||
|
||||
start_date = datetime.fromtimestamp(subscription['current_period_start'], tz=timezone.utc)
|
||||
end_date = start_date + timedelta(days=365)
|
||||
|
||||
await client.from_('credit_accounts').update({
|
||||
'commitment_type': 'yearly_commitment',
|
||||
'commitment_start_date': start_date.isoformat(),
|
||||
'commitment_end_date': end_date.isoformat(),
|
||||
'commitment_price_id': price_id,
|
||||
'can_cancel_after': end_date.isoformat()
|
||||
}).eq('account_id', account_id).execute()
|
||||
|
||||
await client.from_('commitment_history').upsert({
|
||||
'account_id': account_id,
|
||||
'commitment_type': 'yearly_commitment',
|
||||
'price_id': price_id,
|
||||
'start_date': start_date.isoformat(),
|
||||
'end_date': end_date.isoformat(),
|
||||
'stripe_subscription_id': subscription['id']
|
||||
}, on_conflict='account_id,stripe_subscription_id').execute()
|
||||
|
||||
logger.info(f"[WEBHOOK COMMITMENT] Tracked yearly commitment for account {account_id}, ends {end_date.date()}")
|
||||
|
||||
|
||||
webhook_service = WebhookService()
|
|
@ -1,5 +1,5 @@
|
|||
create extension if not exists "pg_net" with schema "public" version '0.14.0';
|
||||
|
||||
alter table "public"."projects" add column "icon_name" text;
|
||||
alter table "public"."projects" add column if not exists "icon_name" text;
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
BEGIN;
|
||||
|
||||
ALTER TABLE credit_accounts
|
||||
ADD COLUMN IF NOT EXISTS commitment_type VARCHAR(50),
|
||||
ADD COLUMN IF NOT EXISTS commitment_start_date TIMESTAMPTZ,
|
||||
ADD COLUMN IF NOT EXISTS commitment_end_date TIMESTAMPTZ,
|
||||
ADD COLUMN IF NOT EXISTS commitment_price_id VARCHAR(255),
|
||||
ADD COLUMN IF NOT EXISTS can_cancel_after TIMESTAMPTZ;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_credit_accounts_commitment
|
||||
ON credit_accounts(commitment_end_date)
|
||||
WHERE commitment_type IS NOT NULL;
|
||||
|
||||
CREATE OR REPLACE FUNCTION can_cancel_subscription(p_account_id UUID)
|
||||
RETURNS BOOLEAN AS $$
|
||||
DECLARE
|
||||
v_commitment_end_date TIMESTAMPTZ;
|
||||
v_commitment_type VARCHAR(50);
|
||||
BEGIN
|
||||
SELECT commitment_end_date, commitment_type
|
||||
INTO v_commitment_end_date, v_commitment_type
|
||||
FROM credit_accounts
|
||||
WHERE account_id = p_account_id;
|
||||
|
||||
IF v_commitment_type IS NULL OR v_commitment_end_date IS NULL THEN
|
||||
RETURN TRUE;
|
||||
END IF;
|
||||
|
||||
IF NOW() >= v_commitment_end_date THEN
|
||||
RETURN TRUE;
|
||||
END IF;
|
||||
|
||||
RETURN FALSE;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql SECURITY DEFINER;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS commitment_history (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
account_id UUID NOT NULL REFERENCES auth.users(id) ON DELETE CASCADE,
|
||||
commitment_type VARCHAR(50),
|
||||
price_id VARCHAR(255),
|
||||
start_date TIMESTAMPTZ NOT NULL,
|
||||
end_date TIMESTAMPTZ NOT NULL,
|
||||
stripe_subscription_id VARCHAR(255),
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
cancelled_at TIMESTAMPTZ,
|
||||
cancellation_reason TEXT
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_commitment_history_account ON commitment_history(account_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_commitment_history_active ON commitment_history(end_date) WHERE cancelled_at IS NULL;
|
||||
|
||||
ALTER TABLE commitment_history ENABLE ROW LEVEL SECURITY;
|
||||
|
||||
CREATE POLICY "Users can view own commitment history" ON commitment_history
|
||||
FOR SELECT USING (auth.uid() = account_id);
|
||||
|
||||
CREATE POLICY "Service role can manage commitment history" ON commitment_history
|
||||
FOR ALL USING (auth.role() = 'service_role');
|
||||
|
||||
COMMIT;
|
Loading…
Reference in New Issue