Merge pull request #1700 from escapade-mckv/fix-yearly-commitment

Fix yearly commitment
This commit is contained in:
Bobbie 2025-09-23 13:43:09 +05:30 committed by GitHub
commit 7c7ccbaa46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 700 additions and 11 deletions

View File

@ -822,13 +822,22 @@ async def get_subscription_commitment(
subscription_id: str,
account_id: str = Depends(verify_and_get_user_id_from_jwt)
) -> Dict:
return {
'has_commitment': False,
'can_cancel': True,
'commitment_type': None,
'months_remaining': None,
'commitment_end_date': None
}
try:
commitment_status = await subscription_service.get_commitment_status(account_id)
if commitment_status['has_commitment']:
logger.info(f"[COMMITMENT] Account {account_id} has active commitment, {commitment_status['months_remaining']} months remaining")
return commitment_status
except Exception as e:
logger.error(f"Error checking commitment status for account {account_id}: {e}")
return {
'has_commitment': False,
'can_cancel': True,
'commitment_type': None,
'months_remaining': None,
'commitment_end_date': None
}
@router.get("/trial/status")
async def get_trial_status(

View File

@ -172,4 +172,36 @@ def is_model_allowed(tier_name: str, model: str) -> bool:
def get_project_limit(tier_name: str) -> int:
tier = TIERS.get(tier_name)
return tier.project_limit if tier else 3
return tier.project_limit if tier else 3
def is_commitment_price_id(price_id: str) -> bool:
commitment_price_ids = [
config.STRIPE_TIER_2_17_YEARLY_COMMITMENT_ID,
config.STRIPE_TIER_6_42_YEARLY_COMMITMENT_ID,
config.STRIPE_TIER_25_170_YEARLY_COMMITMENT_ID
]
return price_id in commitment_price_ids
def get_commitment_duration_months(price_id: str) -> int:
if is_commitment_price_id(price_id):
return 12
return 0
def get_price_type(price_id: str) -> str:
if is_commitment_price_id(price_id):
return 'yearly_commitment'
yearly_price_ids = [
config.STRIPE_TIER_2_20_YEARLY_ID,
config.STRIPE_TIER_6_50_YEARLY_ID,
config.STRIPE_TIER_12_100_YEARLY_ID,
config.STRIPE_TIER_25_200_YEARLY_ID,
config.STRIPE_TIER_50_400_YEARLY_ID,
config.STRIPE_TIER_125_800_YEARLY_ID,
config.STRIPE_TIER_200_1000_YEARLY_ID
]
if price_id in yearly_price_ids:
return 'yearly'
return 'monthly'

View File

@ -0,0 +1,432 @@
#!/usr/bin/env python3
"""
Migration script to track commitment plans by querying Stripe directly.
Usage:
# Dry run - see what would be changed without making changes
python -m core.billing.migrate_existing_commitments_stripe --dry-run
# Apply the migration
python -m core.billing.migrate_existing_commitments_stripe
# Only verify existing commitments
python -m core.billing.migrate_existing_commitments_stripe --verify-only
"""
import asyncio
import sys
import argparse
import stripe
from datetime import datetime, timezone, timedelta
from core.services.supabase import DBConnection
from core.utils.config import config
from core.utils.logger import logger
from .config import is_commitment_price_id, get_commitment_duration_months
if config.STRIPE_SECRET_KEY:
stripe.api_key = config.STRIPE_SECRET_KEY
else:
logger.warning("[COMMITMENT MIGRATION] No STRIPE_SECRET_KEY configured")
async def migrate_existing_commitments(dry_run=False):
db = DBConnection()
client = await db.client
mode = "DRY RUN" if dry_run else "LIVE"
logger.info(f"[COMMITMENT MIGRATION] Starting migration of existing commitment users - {mode} MODE")
if dry_run:
logger.info("[COMMITMENT MIGRATION] ⚠️ DRY RUN - No changes will be made to the database")
commitment_price_ids = [
config.STRIPE_TIER_2_17_YEARLY_COMMITMENT_ID,
config.STRIPE_TIER_6_42_YEARLY_COMMITMENT_ID,
config.STRIPE_TIER_25_170_YEARLY_COMMITMENT_ID
]
commitment_price_ids = [pid for pid in commitment_price_ids if pid]
if not commitment_price_ids:
logger.error("[COMMITMENT MIGRATION] No commitment price IDs configured!")
return
logger.info(f"[COMMITMENT MIGRATION] Looking for subscriptions with commitment price IDs: {commitment_price_ids}")
logger.info("[COMMITMENT MIGRATION] Querying Stripe for ALL subscriptions to find commitments...")
commitment_subscriptions = []
migrated_count = 0
error_count = 0
skipped_count = 0
accounts_to_migrate = []
all_price_ids_seen = set()
price_id_counts = {}
try:
has_more = True
starting_after = None
total_checked = 0
while has_more:
params = {
'status': 'active',
'limit': 100
}
if starting_after:
params['starting_after'] = starting_after
logger.info(f"[COMMITMENT MIGRATION] Fetching batch of subscriptions from Stripe (checked {total_checked} so far)...")
subscriptions = await stripe.Subscription.list_async(**params)
logger.info(f"[COMMITMENT MIGRATION] Retrieved {len(subscriptions.data)} subscriptions in this batch")
for subscription in subscriptions.data:
total_checked += 1
try:
price_id = None
if total_checked == 1:
logger.info(f"[COMMITMENT MIGRATION] First subscription structure: {type(subscription)}")
logger.info(f"[COMMITMENT MIGRATION] Subscription has items attr: {hasattr(subscription, 'items')}")
if hasattr(subscription, 'items'):
logger.info(f"[COMMITMENT MIGRATION] Items type: {type(subscription.items)}")
if hasattr(subscription, 'items'):
items = subscription.items
if hasattr(items, 'data') and len(items.data) > 0:
price_id = items.data[0].price.id
if not price_id:
continue
all_price_ids_seen.add(price_id)
price_id_counts[price_id] = price_id_counts.get(price_id, 0) + 1
if is_commitment_price_id(price_id):
commitment_subscriptions.append(subscription)
account_id = subscription.metadata.get('account_id')
if not account_id:
customer_result = await client.schema('basejump').from_('billing_customers')\
.select('account_id')\
.eq('id', subscription.customer)\
.execute()
if customer_result.data:
account_id = customer_result.data[0]['account_id']
if account_id:
logger.info(f"[COMMITMENT MIGRATION] Found commitment subscription: {subscription.id} for account: {account_id}, price: {price_id}")
else:
logger.warning(f"[COMMITMENT MIGRATION] Found commitment subscription {subscription.id} but couldn't determine account_id")
except Exception as e:
logger.warning(f"[COMMITMENT MIGRATION] Error processing subscription: {e}")
continue
has_more = subscriptions.has_more
if has_more and subscriptions.data:
starting_after = subscriptions.data[-1].id
logger.info(f"[COMMITMENT MIGRATION] Checked {total_checked} total subscriptions")
logger.info(f"[COMMITMENT MIGRATION] Found {len(commitment_subscriptions)} commitment subscriptions")
logger.info(f"[COMMITMENT MIGRATION] Found {len(all_price_ids_seen)} unique price IDs across all subscriptions")
if price_id_counts:
sorted_price_ids = sorted(price_id_counts.items(), key=lambda x: x[1], reverse=True)
logger.info("[COMMITMENT MIGRATION] Top 10 most common price IDs:")
for price_id, count in sorted_price_ids[:10]:
is_commitment = is_commitment_price_id(price_id)
marker = " ✓ COMMITMENT" if is_commitment else ""
logger.info(f" - {price_id}: {count} subscriptions{marker}")
for commitment_id in commitment_price_ids:
if commitment_id and price_id and commitment_id[:10] in price_id:
logger.warning(f" ⚠️ This looks similar to commitment ID: {commitment_id}")
logger.info("[COMMITMENT MIGRATION] Checking commitment price ID configuration:")
for cpid in commitment_price_ids:
if cpid and cpid.startswith('price_'):
is_recognized = is_commitment_price_id(cpid)
if is_recognized:
logger.info(f"{cpid} - recognized as commitment price ID by is_commitment_price_id()")
else:
logger.error(f"{cpid} - NOT recognized by is_commitment_price_id() function!")
else:
logger.warning(f"{cpid} - doesn't look like valid Stripe price ID")
except Exception as e:
logger.error(f"[COMMITMENT MIGRATION] Error querying Stripe: {str(e)}")
return
if not commitment_subscriptions:
logger.info("[COMMITMENT MIGRATION] No commitment subscriptions found in Stripe")
return
for subscription in commitment_subscriptions:
try:
account_id = subscription.metadata.get('account_id')
if not account_id:
customer_result = await client.schema('basejump').from_('billing_customers')\
.select('account_id')\
.eq('id', subscription.customer)\
.execute()
if customer_result.data:
account_id = customer_result.data[0]['account_id']
else:
logger.warning(f"[COMMITMENT MIGRATION] Cannot find account_id for subscription {subscription.id}")
error_count += 1
continue
start_date = datetime.fromtimestamp(subscription.start_date, tz=timezone.utc)
end_date = start_date + timedelta(days=365)
existing = await client.from_('credit_accounts').select(
'commitment_type'
).eq('account_id', account_id).execute()
if existing.data and existing.data[0].get('commitment_type'):
logger.info(f"[COMMITMENT MIGRATION] Account {account_id} already has commitment tracked, skipping")
skipped_count += 1
continue
months_remaining = (end_date.year - datetime.now(timezone.utc).year) * 12 + \
(end_date.month - datetime.now(timezone.utc).month)
price_id = None
if hasattr(subscription, 'items'):
items = subscription.items
if hasattr(items, 'data') and len(items.data) > 0:
price_id = items.data[0].price.id
if not price_id:
logger.warning(f"[COMMITMENT MIGRATION] Cannot get price_id for subscription {subscription.id}")
error_count += 1
continue
if dry_run:
logger.info(
f"[COMMITMENT MIGRATION] 🔍 Would migrate account {account_id}:\n"
f" - Subscription ID: {subscription.id}\n"
f" - Price ID: {price_id}\n"
f" - Commitment start: {start_date.date()}\n"
f" - Commitment end: {end_date.date()}\n"
f" - Months remaining: {months_remaining}"
)
accounts_to_migrate.append({
'account_id': account_id,
'subscription_id': subscription.id,
'price_id': price_id,
'end_date': end_date.date(),
'months_remaining': months_remaining
})
else:
await client.from_('credit_accounts').update({
'commitment_type': 'yearly_commitment',
'commitment_start_date': start_date.isoformat(),
'commitment_end_date': end_date.isoformat(),
'commitment_price_id': price_id,
'can_cancel_after': end_date.isoformat()
}).eq('account_id', account_id).execute()
await client.from_('commitment_history').insert({
'account_id': account_id,
'commitment_type': 'yearly_commitment',
'price_id': price_id,
'start_date': start_date.isoformat(),
'end_date': end_date.isoformat(),
'stripe_subscription_id': subscription.id
}).execute()
logger.info(
f"[COMMITMENT MIGRATION] ✅ Migrated account {account_id}: "
f"commitment ends {end_date.date()}, {months_remaining} months remaining"
)
migrated_count += 1
except Exception as e:
logger.error(f"[COMMITMENT MIGRATION] Error processing subscription {subscription.id}: {e}")
error_count += 1
continue
action = "Would migrate" if dry_run else "Migrated"
logger.info(
f"[COMMITMENT MIGRATION] Migration {'simulation' if dry_run else 'complete'}. "
f"{action}: {migrated_count}, Skipped: {skipped_count}, Errors: {error_count}"
)
if dry_run:
logger.info("[COMMITMENT MIGRATION] ⚠️ This was a DRY RUN - no changes were made")
logger.info("[COMMITMENT MIGRATION] To apply changes, run without --dry-run flag")
if accounts_to_migrate:
logger.info("\n[COMMITMENT MIGRATION] Summary of accounts that would be migrated:")
logger.info("=" * 70)
for acc in accounts_to_migrate:
logger.info(
f" Account: {acc['account_id']}\n"
f" - Subscription: {acc['subscription_id']}\n"
f" - Commitment ends: {acc['end_date']}\n"
f" - Months remaining: {acc['months_remaining']}\n"
f" - Price ID: {acc['price_id']}"
)
logger.info("=" * 70)
logger.info(f"Total accounts that would be migrated: {len(accounts_to_migrate)}")
else:
commitment_accounts = await client.from_('credit_accounts').select(
'account_id, commitment_type, commitment_end_date'
).not_.is_('commitment_type', 'null').execute()
if commitment_accounts.data:
logger.info(f"[COMMITMENT MIGRATION] Total accounts with commitments: {len(commitment_accounts.data)}")
async def verify_commitment_tracking():
db = DBConnection()
client = await db.client
logger.info("[COMMITMENT VERIFICATION] Starting verification...")
commitment_price_ids = [
config.STRIPE_TIER_2_17_YEARLY_COMMITMENT_ID,
config.STRIPE_TIER_6_42_YEARLY_COMMITMENT_ID,
config.STRIPE_TIER_25_170_YEARLY_COMMITMENT_ID
]
commitment_price_ids = [pid for pid in commitment_price_ids if pid]
tracked_accounts = await client.from_('credit_accounts').select(
'account_id, stripe_subscription_id, commitment_type'
).not_.is_('commitment_type', 'null').execute()
tracked_subscription_ids = set()
if tracked_accounts.data:
for acc in tracked_accounts.data:
if acc.get('stripe_subscription_id'):
tracked_subscription_ids.add(acc['stripe_subscription_id'])
logger.info(f"[COMMITMENT VERIFICATION] Found {len(tracked_subscription_ids)} tracked commitment subscriptions in database")
untracked = []
has_more = True
starting_after = None
while has_more:
params = {
'status': 'active',
'limit': 100
}
if starting_after:
params['starting_after'] = starting_after
subscriptions = await stripe.Subscription.list_async(**params)
for subscription in subscriptions.data:
price_id = None
try:
if hasattr(subscription, 'items'):
items = subscription.items
if hasattr(items, 'data') and len(items.data) > 0:
price_id = items.data[0].price.id
except Exception:
continue
if price_id and is_commitment_price_id(price_id) and subscription.id not in tracked_subscription_ids:
untracked.append({
'subscription_id': subscription.id,
'price_id': price_id,
'customer': subscription.customer
})
has_more = subscriptions.has_more
if has_more and subscriptions.data:
starting_after = subscriptions.data[-1].id
if untracked:
logger.warning(f"[COMMITMENT VERIFICATION] Found {len(untracked)} untracked commitment subscriptions:")
for item in untracked:
logger.warning(f" - Subscription {item['subscription_id']}: {item['price_id']}")
else:
logger.info("[COMMITMENT VERIFICATION] ✅ All commitment subscriptions are properly tracked")
async def main(dry_run=False):
try:
await migrate_existing_commitments(dry_run=dry_run)
if not dry_run:
await verify_commitment_tracking()
except Exception as e:
logger.error(f"[COMMITMENT MIGRATION] Fatal error: {e}")
raise
async def list_all_price_ids():
logger.info("[PRICE ID DISCOVERY] Fetching all active subscriptions from Stripe...")
price_id_counts = {}
total_checked = 0
has_more = True
starting_after = None
while has_more:
params = {
'status': 'active',
'limit': 100
}
if starting_after:
params['starting_after'] = starting_after
subscriptions = await stripe.Subscription.list_async(**params)
for subscription in subscriptions.data:
total_checked += 1
try:
if hasattr(subscription, 'items'):
items = subscription.items
if hasattr(items, 'data') and len(items.data) > 0:
price_id = items.data[0].price.id
price_id_counts[price_id] = price_id_counts.get(price_id, 0) + 1
except Exception:
continue
has_more = subscriptions.has_more
if has_more and subscriptions.data:
starting_after = subscriptions.data[-1].id
logger.info(f"[PRICE ID DISCOVERY] Checked {total_checked} subscriptions")
logger.info(f"[PRICE ID DISCOVERY] Found {len(price_id_counts)} unique price IDs")
sorted_price_ids = sorted(price_id_counts.items(), key=lambda x: x[1], reverse=True)
logger.info("[PRICE ID DISCOVERY] All price IDs (sorted by usage):")
for price_id, count in sorted_price_ids:
logger.info(f" {price_id}: {count} subscriptions")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Migrate existing commitment plan users by querying Stripe directly'
)
parser.add_argument(
'--dry-run',
action='store_true',
help='Run in dry-run mode (no database changes)'
)
parser.add_argument(
'--verify-only',
action='store_true',
help='Only verify existing commitments without migrating'
)
parser.add_argument(
'--list-price-ids',
action='store_true',
help='List all price IDs from Stripe (for debugging)'
)
args = parser.parse_args()
if args.list_price_ids:
asyncio.run(list_all_price_ids())
elif args.verify_only:
asyncio.run(verify_commitment_tracking())
else:
asyncio.run(main(dry_run=args.dry_run))

View File

@ -12,7 +12,10 @@ from .config import (
TIERS,
TRIAL_DURATION_DAYS,
TRIAL_CREDITS,
get_tier_by_name
get_tier_by_name,
is_commitment_price_id,
get_commitment_duration_months,
get_price_type
)
from .credit_manager import credit_manager
@ -363,13 +366,28 @@ class SubscriptionService:
client = await db.client
credit_result = await client.from_('credit_accounts').select(
'stripe_subscription_id'
'stripe_subscription_id, commitment_type, commitment_end_date'
).eq('account_id', account_id).execute()
if not credit_result.data or not credit_result.data[0].get('stripe_subscription_id'):
raise HTTPException(status_code=404, detail="No subscription found")
subscription_id = credit_result.data[0]['stripe_subscription_id']
commitment_type = credit_result.data[0].get('commitment_type')
commitment_end_date = credit_result.data[0].get('commitment_end_date')
# Check if user is in a commitment period
if commitment_type and commitment_end_date:
end_date = datetime.fromisoformat(commitment_end_date.replace('Z', '+00:00'))
if datetime.now(timezone.utc) < end_date:
months_remaining = (end_date.year - datetime.now(timezone.utc).year) * 12 + \
(end_date.month - datetime.now(timezone.utc).month)
raise HTTPException(
status_code=403,
detail=f"Cannot cancel during commitment period. Your commitment ends on {end_date.date()}. "
f"You have {months_remaining} months remaining in your commitment."
)
try:
subscription = await stripe.Subscription.modify_async(
@ -378,6 +396,18 @@ class SubscriptionService:
metadata={'cancellation_feedback': feedback} if feedback else {}
)
# Log cancellation in commitment history if applicable
if commitment_type:
await client.from_('commitment_history').insert({
'account_id': account_id,
'commitment_type': commitment_type,
'start_date': credit_result.data[0].get('commitment_start_date'),
'end_date': commitment_end_date,
'stripe_subscription_id': subscription_id,
'cancelled_at': datetime.now(timezone.utc).isoformat(),
'cancellation_reason': feedback or 'User cancelled'
}).execute()
return {
'success': True,
'message': 'Subscription will be cancelled at the end of the current period',
@ -433,6 +463,9 @@ class SubscriptionService:
account_id = customer_result.data[0]['account_id']
price_id = subscription['items']['data'][0]['price']['id'] if subscription.get('items') else None
# Handle commitment tracking
await self._track_commitment_if_needed(account_id, price_id, subscription, client)
new_tier_info = get_tier_by_price_id(price_id)
if not new_tier_info:
logger.warning(f"Unknown price ID in subscription: {price_id}")
@ -685,6 +718,89 @@ class SubscriptionService:
logger.error(f"[ALLOWED_MODELS] Error getting allowed models for user {user_id}: {e}")
return []
async def _track_commitment_if_needed(self, account_id: str, price_id: str, subscription: Dict, client):
"""Track commitment if the subscription uses a commitment price ID"""
if not is_commitment_price_id(price_id):
return
commitment_duration = get_commitment_duration_months(price_id)
if commitment_duration == 0:
return
start_date = datetime.fromtimestamp(subscription['current_period_start'], tz=timezone.utc)
end_date = start_date + timedelta(days=commitment_duration * 30) # Approximate months
# Update credit_accounts with commitment info
await client.from_('credit_accounts').update({
'commitment_type': 'yearly_commitment',
'commitment_start_date': start_date.isoformat(),
'commitment_end_date': end_date.isoformat(),
'commitment_price_id': price_id,
'can_cancel_after': end_date.isoformat()
}).eq('account_id', account_id).execute()
# Log in commitment history
await client.from_('commitment_history').insert({
'account_id': account_id,
'commitment_type': 'yearly_commitment',
'price_id': price_id,
'start_date': start_date.isoformat(),
'end_date': end_date.isoformat(),
'stripe_subscription_id': subscription['id']
}).execute()
logger.info(f"[COMMITMENT] Tracked yearly commitment for account {account_id}, ends {end_date.date()}")
async def get_commitment_status(self, account_id: str) -> Dict:
"""Get the commitment status for an account"""
db = DBConnection()
client = await db.client
result = await client.from_('credit_accounts').select(
'commitment_type, commitment_start_date, commitment_end_date, commitment_price_id'
).eq('account_id', account_id).execute()
if not result.data or not result.data[0].get('commitment_type'):
return {
'has_commitment': False,
'can_cancel': True,
'commitment_type': None,
'months_remaining': None,
'commitment_end_date': None
}
data = result.data[0]
end_date = datetime.fromisoformat(data['commitment_end_date'].replace('Z', '+00:00'))
now = datetime.now(timezone.utc)
if now >= end_date:
# Commitment has expired, clear it
await client.from_('credit_accounts').update({
'commitment_type': None,
'commitment_start_date': None,
'commitment_end_date': None,
'commitment_price_id': None,
'can_cancel_after': None
}).eq('account_id', account_id).execute()
return {
'has_commitment': False,
'can_cancel': True,
'commitment_type': None,
'months_remaining': None,
'commitment_end_date': None
}
months_remaining = (end_date.year - now.year) * 12 + (end_date.month - now.month)
return {
'has_commitment': True,
'can_cancel': False,
'commitment_type': data['commitment_type'],
'months_remaining': max(1, months_remaining),
'commitment_end_date': data['commitment_end_date']
}
subscription_service = SubscriptionService()

View File

@ -14,6 +14,8 @@ from .config import (
get_monthly_credits,
TRIAL_DURATION_DAYS,
TRIAL_CREDITS,
is_commitment_price_id,
get_commitment_duration_months
)
from .credit_manager import credit_manager
@ -265,6 +267,15 @@ class WebhookService:
logger.info(f"[WEBHOOK] Previous attributes: {previous_attributes}")
logger.info(f"[WEBHOOK] Account ID from metadata: {subscription.metadata.get('account_id')}")
# Track commitment if price changed to a commitment plan
price_id = subscription['items']['data'][0]['price']['id'] if subscription.get('items') else None
prev_price_id = previous_attributes.get('items', {}).get('data', [{}])[0].get('price', {}).get('id') if previous_attributes.get('items') else None
if price_id and price_id != prev_price_id and is_commitment_price_id(price_id):
account_id = subscription.metadata.get('account_id')
if account_id:
await self._track_commitment(account_id, price_id, subscription, client)
if subscription.status == 'trialing' and subscription.get('default_payment_method') and not prev_default_payment:
account_id = subscription.metadata.get('account_id')
if account_id:
@ -740,4 +751,32 @@ class WebhookService:
logger.error(f"Error handling subscription renewal: {e}")
async def _track_commitment(self, account_id: str, price_id: str, subscription: Dict, client):
commitment_duration = get_commitment_duration_months(price_id)
if commitment_duration == 0:
return
start_date = datetime.fromtimestamp(subscription['current_period_start'], tz=timezone.utc)
end_date = start_date + timedelta(days=365)
await client.from_('credit_accounts').update({
'commitment_type': 'yearly_commitment',
'commitment_start_date': start_date.isoformat(),
'commitment_end_date': end_date.isoformat(),
'commitment_price_id': price_id,
'can_cancel_after': end_date.isoformat()
}).eq('account_id', account_id).execute()
await client.from_('commitment_history').upsert({
'account_id': account_id,
'commitment_type': 'yearly_commitment',
'price_id': price_id,
'start_date': start_date.isoformat(),
'end_date': end_date.isoformat(),
'stripe_subscription_id': subscription['id']
}, on_conflict='account_id,stripe_subscription_id').execute()
logger.info(f"[WEBHOOK COMMITMENT] Tracked yearly commitment for account {account_id}, ends {end_date.date()}")
webhook_service = WebhookService()

View File

@ -1,5 +1,5 @@
create extension if not exists "pg_net" with schema "public" version '0.14.0';
alter table "public"."projects" add column "icon_name" text;
alter table "public"."projects" add column if not exists "icon_name" text;

View File

@ -0,0 +1,61 @@
BEGIN;
ALTER TABLE credit_accounts
ADD COLUMN IF NOT EXISTS commitment_type VARCHAR(50),
ADD COLUMN IF NOT EXISTS commitment_start_date TIMESTAMPTZ,
ADD COLUMN IF NOT EXISTS commitment_end_date TIMESTAMPTZ,
ADD COLUMN IF NOT EXISTS commitment_price_id VARCHAR(255),
ADD COLUMN IF NOT EXISTS can_cancel_after TIMESTAMPTZ;
CREATE INDEX IF NOT EXISTS idx_credit_accounts_commitment
ON credit_accounts(commitment_end_date)
WHERE commitment_type IS NOT NULL;
CREATE OR REPLACE FUNCTION can_cancel_subscription(p_account_id UUID)
RETURNS BOOLEAN AS $$
DECLARE
v_commitment_end_date TIMESTAMPTZ;
v_commitment_type VARCHAR(50);
BEGIN
SELECT commitment_end_date, commitment_type
INTO v_commitment_end_date, v_commitment_type
FROM credit_accounts
WHERE account_id = p_account_id;
IF v_commitment_type IS NULL OR v_commitment_end_date IS NULL THEN
RETURN TRUE;
END IF;
IF NOW() >= v_commitment_end_date THEN
RETURN TRUE;
END IF;
RETURN FALSE;
END;
$$ LANGUAGE plpgsql SECURITY DEFINER;
CREATE TABLE IF NOT EXISTS commitment_history (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
account_id UUID NOT NULL REFERENCES auth.users(id) ON DELETE CASCADE,
commitment_type VARCHAR(50),
price_id VARCHAR(255),
start_date TIMESTAMPTZ NOT NULL,
end_date TIMESTAMPTZ NOT NULL,
stripe_subscription_id VARCHAR(255),
created_at TIMESTAMPTZ DEFAULT NOW(),
cancelled_at TIMESTAMPTZ,
cancellation_reason TEXT
);
CREATE INDEX IF NOT EXISTS idx_commitment_history_account ON commitment_history(account_id);
CREATE INDEX IF NOT EXISTS idx_commitment_history_active ON commitment_history(end_date) WHERE cancelled_at IS NULL;
ALTER TABLE commitment_history ENABLE ROW LEVEL SECURITY;
CREATE POLICY "Users can view own commitment history" ON commitment_history
FOR SELECT USING (auth.uid() = account_id);
CREATE POLICY "Service role can manage commitment history" ON commitment_history
FOR ALL USING (auth.role() = 'service_role');
COMMIT;