suna/backend/services/billing.py

2702 lines
131 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Stripe Billing API implementation for Suna on top of Basejump. ONLY HAS SUPPOT FOR USER ACCOUNTS no team accounts. As we are using the user_id as account_id as is the case with personal accounts. In personal accounts, the account_id equals the user_id. In team accounts, the account_id is unique.
stripe listen --forward-to localhost:8000/api/billing/webhook
"""
from fastapi import APIRouter, HTTPException, Depends, Request
from typing import Optional, Dict, Tuple
import stripe
from datetime import datetime, timezone, timedelta
from dateutil import parser as dateutil_parser
from supabase import Client as SupabaseClient
from utils.cache import Cache
from utils.logger import logger
from utils.config import config, EnvMode
from services.supabase import DBConnection
from utils.auth_utils import get_current_user_id_from_jwt
from pydantic import BaseModel
from models import model_manager
from litellm.cost_calculator import cost_per_token
import time
import json
# Initialize Stripe
stripe.api_key = config.STRIPE_SECRET_KEY
# Token price multiplier
TOKEN_PRICE_MULTIPLIER = 1.5
# Minimum credits required to allow a new request when over subscription limit
CREDIT_MIN_START_DOLLARS = 0.20
# Credit packages with Stripe price IDs
CREDIT_PACKAGES = {
'credits_10': {'amount': 10, 'price': 10, 'stripe_price_id': config.STRIPE_CREDITS_10_PRICE_ID},
'credits_25': {'amount': 25, 'price': 25, 'stripe_price_id': config.STRIPE_CREDITS_25_PRICE_ID},
# Uncomment these when you create the additional price IDs in Stripe:
'credits_50': {'amount': 50, 'price': 50, 'stripe_price_id': config.STRIPE_CREDITS_50_PRICE_ID},
'credits_100': {'amount': 100, 'price': 100, 'stripe_price_id': config.STRIPE_CREDITS_100_PRICE_ID},
'credits_250': {'amount': 250, 'price': 250, 'stripe_price_id': config.STRIPE_CREDITS_250_PRICE_ID},
'credits_500': {'amount': 500, 'price': 500, 'stripe_price_id': config.STRIPE_CREDITS_500_PRICE_ID}
}
router = APIRouter(prefix="/billing", tags=["billing"])
def get_plan_info(price_id: str) -> dict:
PLAN_TIERS = {
config.STRIPE_TIER_2_20_ID: {'tier': 1, 'type': 'monthly', 'name': '2h/$20'},
config.STRIPE_TIER_6_50_ID: {'tier': 2, 'type': 'monthly', 'name': '6h/$50'},
config.STRIPE_TIER_12_100_ID: {'tier': 3, 'type': 'monthly', 'name': '12h/$100'},
config.STRIPE_TIER_25_200_ID: {'tier': 4, 'type': 'monthly', 'name': '25h/$200'},
config.STRIPE_TIER_50_400_ID: {'tier': 5, 'type': 'monthly', 'name': '50h/$400'},
config.STRIPE_TIER_125_800_ID: {'tier': 6, 'type': 'monthly', 'name': '125h/$800'},
config.STRIPE_TIER_200_1000_ID: {'tier': 7, 'type': 'monthly', 'name': '200h/$1000'},
# Yearly plans
config.STRIPE_TIER_2_20_YEARLY_ID: {'tier': 1, 'type': 'yearly', 'name': '2h/$204/year'},
config.STRIPE_TIER_6_50_YEARLY_ID: {'tier': 2, 'type': 'yearly', 'name': '6h/$510/year'},
config.STRIPE_TIER_12_100_YEARLY_ID: {'tier': 3, 'type': 'yearly', 'name': '12h/$1020/year'},
config.STRIPE_TIER_25_200_YEARLY_ID: {'tier': 4, 'type': 'yearly', 'name': '25h/$2040/year'},
config.STRIPE_TIER_50_400_YEARLY_ID: {'tier': 5, 'type': 'yearly', 'name': '50h/$4080/year'},
config.STRIPE_TIER_125_800_YEARLY_ID: {'tier': 6, 'type': 'yearly', 'name': '125h/$8160/year'},
config.STRIPE_TIER_200_1000_YEARLY_ID: {'tier': 7, 'type': 'yearly', 'name': '200h/$10200/year'},
# Yearly commitment plans
config.STRIPE_TIER_2_17_YEARLY_COMMITMENT_ID: {'tier': 1, 'type': 'yearly_commitment', 'name': '2h/$17/month'},
config.STRIPE_TIER_6_42_YEARLY_COMMITMENT_ID: {'tier': 2, 'type': 'yearly_commitment', 'name': '6h/$42.50/month'},
config.STRIPE_TIER_25_170_YEARLY_COMMITMENT_ID: {'tier': 4, 'type': 'yearly_commitment', 'name': '25h/$170/month'},
}
return PLAN_TIERS.get(price_id, {'tier': 0, 'type': 'unknown', 'name': 'Unknown'})
def is_plan_change_allowed(current_price_id: str, new_price_id: str) -> tuple[bool, str]:
"""
Validate if a plan change is allowed based on business rules.
Returns:
Tuple of (is_allowed, reason_if_not_allowed)
"""
current_plan = get_plan_info(current_price_id)
new_plan = get_plan_info(new_price_id)
# Allow if same plan
if current_price_id == new_price_id:
return True, ""
# Restriction 1: Don't allow downgrade from monthly to lower monthly
if current_plan['type'] == 'monthly' and new_plan['type'] == 'monthly' and new_plan['tier'] < current_plan['tier']:
return False, "Downgrading to a lower monthly plan is not allowed. You can only upgrade to a higher tier or switch to yearly billing."
# Restriction 2: Don't allow downgrade from yearly commitment to monthly
if current_plan['type'] == 'yearly_commitment' and new_plan['type'] == 'monthly':
return False, "Downgrading from yearly commitment to monthly is not allowed. You can only upgrade within yearly commitment plans."
# Restriction 2b: Don't allow downgrade within yearly commitment plans
if current_plan['type'] == 'yearly_commitment' and new_plan['type'] == 'yearly_commitment' and new_plan['tier'] < current_plan['tier']:
return False, "Downgrading to a lower yearly commitment plan is not allowed. You can only upgrade to higher commitment tiers."
# Restriction 3: Only allow upgrade from monthly to yearly commitment on same level or above
if current_plan['type'] == 'monthly' and new_plan['type'] == 'yearly_commitment' and new_plan['tier'] < current_plan['tier']:
return False, "You can only upgrade to yearly commitment plans at the same tier level or higher."
# Allow all other changes (upgrades, yearly to yearly, yearly commitment upgrades, etc.)
return True, ""
# Simplified yearly commitment logic - no subscription schedules needed
def get_model_pricing(model: str) -> tuple[float, float] | None:
"""
Get pricing for a model. Returns (input_cost_per_million, output_cost_per_million) or None.
Args:
model: The model name to get pricing for (can be display name or model ID)
Returns:
Tuple of (input_cost_per_million_tokens, output_cost_per_million_tokens) or None if not found
"""
# First try to resolve the model ID to handle aliases
resolved_model = model_manager.resolve_model_id(model)
logger.debug(f"Resolving model '{model}' -> '{resolved_model}'")
# Try the resolved model first, then fallback to original
for model_to_try in [resolved_model, model]:
model_obj = model_manager.get_model(model_to_try)
if model_obj and model_obj.pricing:
logger.debug(f"Found pricing for model {model_to_try}: input=${model_obj.pricing.input_cost_per_million_tokens}/M, output=${model_obj.pricing.output_cost_per_million_tokens}/M")
return model_obj.pricing.input_cost_per_million_tokens, model_obj.pricing.output_cost_per_million_tokens
else:
logger.debug(f"No pricing for model_to_try='{model_to_try}' (model_obj: {model_obj is not None}, has_pricing: {model_obj.pricing is not None if model_obj else False})")
logger.warning(f"No pricing found for model '{model}' (resolved: '{resolved_model}')")
return None
SUBSCRIPTION_TIERS = {
config.STRIPE_FREE_TIER_ID: {'name': 'free', 'minutes': 60, 'cost': 5},
config.STRIPE_TIER_2_20_ID: {'name': 'tier_2_20', 'minutes': 120, 'cost': 20 + 5}, # 2 hours
config.STRIPE_TIER_6_50_ID: {'name': 'tier_6_50', 'minutes': 360, 'cost': 50 + 5}, # 6 hours
config.STRIPE_TIER_12_100_ID: {'name': 'tier_12_100', 'minutes': 720, 'cost': 100 + 5}, # 12 hours
config.STRIPE_TIER_25_200_ID: {'name': 'tier_25_200', 'minutes': 1500, 'cost': 200 + 5}, # 25 hours
config.STRIPE_TIER_50_400_ID: {'name': 'tier_50_400', 'minutes': 3000, 'cost': 400 + 5}, # 50 hours
config.STRIPE_TIER_125_800_ID: {'name': 'tier_125_800', 'minutes': 7500, 'cost': 800 + 5}, # 125 hours
config.STRIPE_TIER_200_1000_ID: {'name': 'tier_200_1000', 'minutes': 12000, 'cost': 1000 + 5}, # 200 hours
# Yearly tiers (same usage limits, different billing period)
config.STRIPE_TIER_2_20_YEARLY_ID: {'name': 'tier_2_20', 'minutes': 120, 'cost': 20 + 5}, # 2 hours/month, $204/year
config.STRIPE_TIER_6_50_YEARLY_ID: {'name': 'tier_6_50', 'minutes': 360, 'cost': 50 + 5}, # 6 hours/month, $510/year
config.STRIPE_TIER_12_100_YEARLY_ID: {'name': 'tier_12_100', 'minutes': 720, 'cost': 100 + 5}, # 12 hours/month, $1020/year
config.STRIPE_TIER_25_200_YEARLY_ID: {'name': 'tier_25_200', 'minutes': 1500, 'cost': 200 + 5}, # 25 hours/month, $2040/year
config.STRIPE_TIER_50_400_YEARLY_ID: {'name': 'tier_50_400', 'minutes': 3000, 'cost': 400 + 5}, # 50 hours/month, $4080/year
config.STRIPE_TIER_125_800_YEARLY_ID: {'name': 'tier_125_800', 'minutes': 7500, 'cost': 800 + 5}, # 125 hours/month, $8160/year
config.STRIPE_TIER_200_1000_YEARLY_ID: {'name': 'tier_200_1000', 'minutes': 12000, 'cost': 1000 + 5}, # 200 hours/month, $10200/year
# Yearly commitment tiers (15% discount, monthly payments with 12-month commitment via schedules)
config.STRIPE_TIER_2_17_YEARLY_COMMITMENT_ID: {'name': 'tier_2_17_yearly_commitment', 'minutes': 120, 'cost': 20 + 5}, # 2 hours/month, $17/month (12-month commitment)
config.STRIPE_TIER_6_42_YEARLY_COMMITMENT_ID: {'name': 'tier_6_42_yearly_commitment', 'minutes': 360, 'cost': 50 + 5}, # 6 hours/month, $42.50/month (12-month commitment)
config.STRIPE_TIER_25_170_YEARLY_COMMITMENT_ID: {'name': 'tier_25_170_yearly_commitment', 'minutes': 1500, 'cost': 200 + 5}, # 25 hours/month, $170/month (12-month commitment)
}
# Pydantic models for request/response validation
class CreateCheckoutSessionRequest(BaseModel):
price_id: str
success_url: str
cancel_url: str
tolt_referral: Optional[str] = None
commitment_type: Optional[str] = "monthly" # "monthly", "yearly", or "yearly_commitment"
class CreatePortalSessionRequest(BaseModel):
return_url: str
class SubscriptionStatus(BaseModel):
status: str # e.g., 'active', 'trialing', 'past_due', 'scheduled_downgrade', 'no_subscription'
plan_name: Optional[str] = None
price_id: Optional[str] = None # Added price ID
current_period_end: Optional[datetime] = None
cancel_at_period_end: bool = False
trial_end: Optional[datetime] = None
minutes_limit: Optional[int] = None
cost_limit: Optional[float] = None
current_usage: Optional[float] = None
# Fields for scheduled changes
has_schedule: bool = False
scheduled_plan_name: Optional[str] = None
scheduled_price_id: Optional[str] = None # Added scheduled price ID
scheduled_change_date: Optional[datetime] = None
# Subscription data for frontend components
subscription_id: Optional[str] = None
subscription: Optional[Dict] = None
# Credit information
credit_balance: Optional[float] = None
can_purchase_credits: bool = False
class PurchaseCreditsRequest(BaseModel):
amount_dollars: float # Amount of credits to purchase in dollars
success_url: str
cancel_url: str
class CreditBalance(BaseModel):
balance_dollars: float
total_purchased: float
total_used: float
last_updated: Optional[datetime] = None
can_purchase_credits: bool = False # True only for highest tier users
class CreditPurchase(BaseModel):
id: str
amount_dollars: float
status: str
created_at: datetime
completed_at: Optional[datetime] = None
stripe_payment_intent_id: Optional[str] = None
class CreditUsage(BaseModel):
id: str
amount_dollars: float
description: Optional[str] = None
created_at: datetime
thread_id: Optional[str] = None
message_id: Optional[str] = None
# Helper functions
async def get_stripe_customer_id(client: SupabaseClient, user_id: str) -> Optional[str]:
"""Get the Stripe customer ID for a user."""
result = await Cache.get(f"stripe_customer_id:{user_id}")
if result:
return result
result = await client.schema('basejump').from_('billing_customers') \
.select('id') \
.eq('account_id', user_id) \
.execute()
if result.data and len(result.data) > 0:
customer_id = result.data[0]['id']
await Cache.set(f"stripe_customer_id:{user_id}", customer_id, ttl=24 * 60)
return customer_id
customer_result = await stripe.Customer.search_async(
query=f"metadata['user_id']:'{user_id}' OR metadata['basejump_account_id']:'{user_id}'"
)
if customer_result.data and len(customer_result.data) > 0:
customer = customer_result.data[0]
# If the customer does not have 'user_id' in metadata, add it now
if not customer.get('metadata', {}).get('user_id'):
try:
await stripe.Customer.modify_async(
customer['id'],
metadata={**customer.get('metadata', {}), 'user_id': user_id}
)
logger.debug(f"Added missing user_id metadata to Stripe customer {customer['id']}")
except Exception as e:
logger.error(f"Failed to add user_id metadata to Stripe customer {customer['id']}: {str(e)}")
has_active = len((await stripe.Subscription.list_async(
customer=customer['id'],
status='active',
limit=1
)).get('data', [])) > 0
# Create or update record in billing_customers table
await client.schema('basejump').from_('billing_customers').upsert({
'id': customer['id'],
'account_id': user_id,
'email': customer.get('email'),
'provider': 'stripe',
'active': has_active
}).execute()
logger.debug(f"Updated billing_customers record for customer {customer['id']} and user {user_id}")
return customer['id']
return None
async def create_stripe_customer(client, user_id: str, email: str) -> str:
"""Create a new Stripe customer for a user."""
# Create customer in Stripe
customer = await stripe.Customer.create_async(
email=email,
metadata={"user_id": user_id}
)
# Store customer ID in Supabase
await client.schema('basejump').from_('billing_customers').insert({
'id': customer.id,
'account_id': user_id,
'email': email,
'provider': 'stripe'
}).execute()
return customer.id
async def get_user_subscription(user_id: str) -> Optional[Dict]:
"""Get the current subscription for a user from Stripe."""
try:
result = await Cache.get(f"user_subscription:{user_id}")
if result:
return result
# Get customer ID
db = DBConnection()
client = await db.client
customer_id = await get_stripe_customer_id(client, user_id)
if not customer_id:
await Cache.set(f"user_subscription:{user_id}", None, ttl=1 * 60)
return None
# Get all active subscriptions for the customer
subscriptions = await stripe.Subscription.list_async(
customer=customer_id,
status='active'
)
# print("Found subscriptions:", subscriptions)
# Check if we have any subscriptions
if not subscriptions or not subscriptions.get('data'):
await Cache.set(f"user_subscription:{user_id}", None, ttl=1 * 60)
return None
# Filter subscriptions to only include our product's subscriptions
our_subscriptions = []
for sub in subscriptions['data']:
# Check if subscription items contain any of our price IDs
for item in sub.get('items', {}).get('data', []):
price_id = item.get('price', {}).get('id')
if price_id in [
config.STRIPE_FREE_TIER_ID,
config.STRIPE_TIER_2_20_ID, config.STRIPE_TIER_6_50_ID, config.STRIPE_TIER_12_100_ID,
config.STRIPE_TIER_25_200_ID, config.STRIPE_TIER_50_400_ID, config.STRIPE_TIER_125_800_ID,
config.STRIPE_TIER_200_1000_ID,
# Yearly tiers
config.STRIPE_TIER_2_20_YEARLY_ID, config.STRIPE_TIER_6_50_YEARLY_ID,
config.STRIPE_TIER_12_100_YEARLY_ID, config.STRIPE_TIER_25_200_YEARLY_ID,
config.STRIPE_TIER_50_400_YEARLY_ID, config.STRIPE_TIER_125_800_YEARLY_ID,
config.STRIPE_TIER_200_1000_YEARLY_ID,
# Yearly commitment tiers (monthly payments with 12-month commitment)
config.STRIPE_TIER_2_17_YEARLY_COMMITMENT_ID,
config.STRIPE_TIER_6_42_YEARLY_COMMITMENT_ID,
config.STRIPE_TIER_25_170_YEARLY_COMMITMENT_ID
]:
our_subscriptions.append(sub)
if not our_subscriptions:
await Cache.set(f"user_subscription:{user_id}", None, ttl=1 * 60)
return None
# If there are multiple active subscriptions, we need to handle this
if len(our_subscriptions) > 1:
logger.warning(f"User {user_id} has multiple active subscriptions: {[sub['id'] for sub in our_subscriptions]}")
# Get the most recent subscription
most_recent = max(our_subscriptions, key=lambda x: x['created'])
# Cancel all other subscriptions
for sub in our_subscriptions:
if sub['id'] != most_recent['id']:
try:
await stripe.Subscription.modify_async(
sub['id'],
cancel_at_period_end=True
)
logger.debug(f"Cancelled subscription {sub['id']} for user {user_id}")
except Exception as e:
logger.error(f"Error cancelling subscription {sub['id']}: {str(e)}")
return most_recent
result = our_subscriptions[0]
await Cache.set(f"user_subscription:{user_id}", result, ttl=1 * 60)
return result
except Exception as e:
logger.error(f"Error getting subscription from Stripe: {str(e)}")
return None
async def calculate_monthly_usage(client, user_id: str) -> float:
"""Calculate total agent run minutes for the current month for a user."""
result = await Cache.get(f"monthly_usage:{user_id}")
if result:
return result
start_time = time.time()
# Use get_usage_logs to fetch all usage data (it already handles the date filtering and batching)
total_cost = 0.0
page = 0
items_per_page = 1000
while True:
# Get usage logs for this page
usage_result = await get_usage_logs(client, user_id, page, items_per_page)
if not usage_result['logs']:
break
# Sum up the estimated costs from this page
for log_entry in usage_result['logs']:
total_cost += log_entry['estimated_cost']
# If there are no more pages, break
if not usage_result['has_more']:
break
page += 1
end_time = time.time()
execution_time = end_time - start_time
logger.debug(f"Calculate monthly usage took {execution_time:.3f} seconds, total cost: {total_cost}")
await Cache.set(f"monthly_usage:{user_id}", total_cost, ttl=5)
return total_cost
async def get_usage_logs(client, user_id: str, page: int = 0, items_per_page: int = 1000) -> Dict:
"""Get detailed usage logs for a user with pagination, including credit usage info."""
logger.info(f"[USAGE_LOGS] Starting get_usage_logs for user_id={user_id}, page={page}, items_per_page={items_per_page}")
try:
# Get start of current month in UTC
now = datetime.now(timezone.utc)
start_of_month = datetime(now.year, now.month, 1, tzinfo=timezone.utc)
# Use fixed cutoff date: June 26, 2025 midnight UTC
# Ignore all token counts before this date
cutoff_date = datetime(2025, 6, 30, 9, 0, 0, tzinfo=timezone.utc)
start_of_month = max(start_of_month, cutoff_date)
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Using start_of_month: {start_of_month.isoformat()}")
# First get all threads for this user in batches
batch_size = 1000
offset = 0
all_threads = []
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Fetching threads in batches")
while True:
try:
threads_batch = await client.table('threads') \
.select('thread_id, agent_runs(thread_id)') \
.eq('account_id', user_id) \
.gte('agent_runs.created_at', start_of_month.isoformat()) \
.range(offset, offset + batch_size - 1) \
.execute()
if not threads_batch.data:
break
all_threads.extend(threads_batch.data)
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Fetched {len(threads_batch.data)} threads in batch (offset={offset})")
# If we got less than batch_size, we've reached the end
if len(threads_batch.data) < batch_size:
break
offset += batch_size
except Exception as thread_error:
logger.error(f"[USAGE_LOGS] user_id={user_id} - Error fetching threads batch at offset {offset}: {str(thread_error)}")
raise
logger.info(f"[USAGE_LOGS] user_id={user_id} - Found {len(all_threads)} total threads")
if not all_threads:
logger.info(f"[USAGE_LOGS] user_id={user_id} - No threads found, returning empty result")
return {"logs": [], "has_more": False}
thread_ids = [t['thread_id'] for t in all_threads]
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Thread IDs: {thread_ids[:5]}..." if len(thread_ids) > 5 else f"[USAGE_LOGS] user_id={user_id} - Thread IDs: {thread_ids}")
# Fetch usage messages with pagination, including thread project info
# Use a more efficient approach to avoid URI length limits with many threads
start_time = time.time()
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Starting messages query")
try:
# Instead of using .in_() with all thread IDs (which can cause URI too large errors),
# we'll use a join-based approach by querying messages directly for the user's account
# and filtering by date and type, then joining with threads for project info
messages_result = await client.table('messages') \
.select(
'message_id, thread_id, created_at, content, threads!inner(project_id, account_id)'
) \
.eq('threads.account_id', user_id) \
.eq('type', 'assistant_response_end') \
.gte('created_at', start_of_month.isoformat()) \
.order('created_at', desc=True) \
.range(page * items_per_page, (page + 1) * items_per_page - 1) \
.execute()
except Exception as query_error:
logger.error(f"[USAGE_LOGS] user_id={user_id} - Database query failed: {str(query_error)}")
logger.error(f"[USAGE_LOGS] user_id={user_id} - Query details: page={page}, items_per_page={items_per_page}, thread_count={len(thread_ids)}")
# Fallback: If the join approach fails, try batching the thread IDs
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Attempting fallback with batched thread ID queries")
try:
all_messages = []
batch_size = 100 # Process threads in smaller batches to avoid URI limits
for i in range(0, len(thread_ids), batch_size):
batch_thread_ids = thread_ids[i:i + batch_size]
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Processing thread batch {i//batch_size + 1}/{(len(thread_ids) + batch_size - 1)//batch_size}")
batch_result = await client.table('messages') \
.select(
'message_id, thread_id, created_at, content, threads!inner(project_id)'
) \
.in_('thread_id', batch_thread_ids) \
.eq('type', 'assistant_response_end') \
.gte('created_at', start_of_month.isoformat()) \
.order('created_at', desc=True) \
.execute()
if batch_result.data:
all_messages.extend(batch_result.data)
# Sort all messages by created_at descending and apply pagination
all_messages.sort(key=lambda x: x['created_at'], reverse=True)
# Apply pagination to the combined results
start_idx = page * items_per_page
end_idx = start_idx + items_per_page
paginated_messages = all_messages[start_idx:end_idx]
# Create a mock result object similar to what Supabase returns
class MockResult:
def __init__(self, data):
self.data = data
messages_result = MockResult(paginated_messages)
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Fallback successful, found {len(all_messages)} total messages, returning {len(paginated_messages)} for page {page}")
except Exception as fallback_error:
logger.error(f"[USAGE_LOGS] user_id={user_id} - Fallback query also failed: {str(fallback_error)}")
raise query_error # Raise the original error
end_time = time.time()
execution_time = end_time - start_time
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Database query for usage logs took {execution_time:.3f} seconds")
if not messages_result.data:
logger.info(f"[USAGE_LOGS] user_id={user_id} - No messages found, returning empty result")
return {"logs": [], "has_more": False}
logger.info(f"[USAGE_LOGS] user_id={user_id} - Found {len(messages_result.data)} messages to process")
# Get the user's subscription tier info for credit checking
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Getting subscription info")
try:
subscription = await get_user_subscription(user_id)
price_id = config.STRIPE_FREE_TIER_ID # Default to free
if subscription and subscription.get('items'):
items = subscription['items'].get('data', [])
if items:
price_id = items[0]['price']['id']
tier_info = SUBSCRIPTION_TIERS.get(price_id, SUBSCRIPTION_TIERS[config.STRIPE_FREE_TIER_ID])
subscription_limit = tier_info['cost']
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Subscription limit: {subscription_limit}, price_id: {price_id}")
except Exception as sub_error:
logger.error(f"[USAGE_LOGS] user_id={user_id} - Error getting subscription info: {str(sub_error)}")
# Use free tier as fallback
tier_info = SUBSCRIPTION_TIERS[config.STRIPE_FREE_TIER_ID]
subscription_limit = tier_info['cost']
# Get credit usage records for this month to match with messages
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Fetching credit usage records")
try:
credit_usage_result = await client.table('credit_usage') \
.select('message_id, amount_dollars, created_at') \
.eq('user_id', user_id) \
.gte('created_at', start_of_month.isoformat()) \
.execute()
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Found {len(credit_usage_result.data) if credit_usage_result.data else 0} credit usage records")
except Exception as credit_error:
logger.error(f"[USAGE_LOGS] user_id={user_id} - Error fetching credit usage: {str(credit_error)}")
credit_usage_result = None
# Create a map of message_id to credit usage
credit_usage_map = {}
if credit_usage_result and credit_usage_result.data:
for usage in credit_usage_result.data:
if usage.get('message_id'):
try:
credit_usage_map[usage['message_id']] = {
'amount': float(usage['amount_dollars']),
'created_at': usage['created_at']
}
except Exception as parse_error:
logger.warning(f"[USAGE_LOGS] user_id={user_id} - Error parsing credit usage record: {str(parse_error)}")
continue
# Track cumulative usage to determine when credits started being used
cumulative_cost = 0.0
# Process messages into usage log entries
processed_logs = []
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Starting to process {len(messages_result.data)} messages")
for i, message in enumerate(messages_result.data):
try:
message_id = message.get('message_id', f'unknown_{i}')
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Processing message {i+1}/{len(messages_result.data)}: {message_id}")
# Safely extract usage data with defaults
content = message.get('content', {})
usage = content.get('usage', {})
# Ensure usage has required fields with safe defaults
prompt_tokens = usage.get('prompt_tokens', 0)
completion_tokens = usage.get('completion_tokens', 0)
model = content.get('model', 'unknown')
# Validate token values
if not isinstance(prompt_tokens, (int, float)) or prompt_tokens is None:
logger.warning(f"[USAGE_LOGS] user_id={user_id} - Invalid prompt_tokens for message {message_id}: {prompt_tokens}")
prompt_tokens = 0
if not isinstance(completion_tokens, (int, float)) or completion_tokens is None:
logger.warning(f"[USAGE_LOGS] user_id={user_id} - Invalid completion_tokens for message {message_id}: {completion_tokens}")
completion_tokens = 0
# Safely calculate total tokens
total_tokens = int(prompt_tokens or 0) + int(completion_tokens or 0)
# Calculate estimated cost using the same logic as calculate_monthly_usage
try:
estimated_cost = calculate_token_cost(
prompt_tokens,
completion_tokens,
model
)
except Exception as cost_error:
logger.warning(f"[USAGE_LOGS] user_id={user_id} - Error calculating cost for message {message_id}: {str(cost_error)}")
estimated_cost = 0.0
cumulative_cost += estimated_cost
# Safely extract project_id from threads relationship
project_id = 'unknown'
try:
if message.get('threads') and isinstance(message['threads'], list) and len(message['threads']) > 0:
project_id = message['threads'][0].get('project_id', 'unknown')
except Exception as project_error:
logger.warning(f"[USAGE_LOGS] user_id={user_id} - Error extracting project_id for message {message_id}: {str(project_error)}")
# Check if credits were used for this message
credit_used = credit_usage_map.get(message_id, {})
# Safely handle datetime serialization for created_at
created_at = message.get('created_at')
if created_at and isinstance(created_at, datetime):
created_at = created_at.isoformat()
elif created_at and not isinstance(created_at, str):
try:
created_at = str(created_at)
except Exception:
logger.warning(f"[USAGE_LOGS] user_id={user_id} - Could not convert created_at to string for message {message_id}")
created_at = None
log_entry = {
'message_id': str(message_id) if message_id else 'unknown',
'thread_id': str(message.get('thread_id', 'unknown')),
'created_at': created_at,
'content': {
'usage': {
'prompt_tokens': int(prompt_tokens),
'completion_tokens': int(completion_tokens)
},
'model': str(model)
},
'total_tokens': int(total_tokens),
'estimated_cost': float(estimated_cost),
'project_id': str(project_id),
# Add credit usage info
'credit_used': float(credit_used.get('amount', 0)) if credit_used else 0.0,
'payment_method': 'credits' if credit_used else 'subscription',
'was_over_limit': bool(cumulative_cost > subscription_limit if not credit_used else True)
}
# Test JSON serialization of this entry before adding it
try:
json.dumps(log_entry, default=str)
except Exception as json_error:
logger.error(f"[USAGE_LOGS] user_id={user_id} - JSON serialization failed for message {message_id}: {str(json_error)}")
logger.error(f"[USAGE_LOGS] user_id={user_id} - Problematic log_entry: {log_entry}")
continue
processed_logs.append(log_entry)
except Exception as e:
logger.error(f"[USAGE_LOGS] user_id={user_id} - Error processing usage log entry for message {message.get('message_id', 'unknown')}: {str(e)}")
continue
logger.info(f"[USAGE_LOGS] user_id={user_id} - Successfully processed {len(processed_logs)} messages")
# Check if there are more results
has_more = len(processed_logs) == items_per_page
result = {
"logs": processed_logs,
"has_more": bool(has_more),
"subscription_limit": float(subscription_limit),
"cumulative_cost": float(cumulative_cost)
}
# Validate final JSON serialization
try:
json.dumps(result, default=str)
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Final result JSON validation passed")
except Exception as final_json_error:
logger.error(f"[USAGE_LOGS] user_id={user_id} - Final result JSON serialization failed: {str(final_json_error)}")
logger.error(f"[USAGE_LOGS] user_id={user_id} - Problematic result keys: {list(result.keys())}")
# Return safe fallback
return {
"logs": [],
"has_more": False,
"subscription_limit": float(subscription_limit),
"cumulative_cost": 0.0,
"error": "Failed to serialize usage data"
}
logger.info(f"[USAGE_LOGS] user_id={user_id} - Returning {len(processed_logs)} logs, has_more={has_more}")
return result
except Exception as outer_error:
logger.error(f"[USAGE_LOGS] user_id={user_id} - Outer exception in get_usage_logs: {str(outer_error)}")
raise
def calculate_token_cost(prompt_tokens: int, completion_tokens: int, model: str) -> float:
"""Calculate the cost for tokens using the same logic as the monthly usage calculation."""
try:
# Ensure tokens are valid integers
prompt_tokens = int(prompt_tokens) if prompt_tokens is not None else 0
completion_tokens = int(completion_tokens) if completion_tokens is not None else 0
logger.debug(f"Calculating token cost for model '{model}' with {prompt_tokens} input tokens and {completion_tokens} output tokens")
# Try to resolve the model name using new model manager first
from models import model_manager
resolved_model = model_manager.resolve_model_id(model)
logger.debug(f"Model '{model}' resolved to '{resolved_model}'")
# Check if we have hardcoded pricing for this model (try both original and resolved)
hardcoded_pricing = get_model_pricing(model) or get_model_pricing(resolved_model)
if hardcoded_pricing:
input_cost_per_million, output_cost_per_million = hardcoded_pricing
input_cost = (prompt_tokens / 1_000_000) * input_cost_per_million
output_cost = (completion_tokens / 1_000_000) * output_cost_per_million
message_cost = input_cost + output_cost
else:
# Use litellm pricing as fallback - try multiple variations
try:
models_to_try = [model]
# Add resolved model if different
if resolved_model != model:
models_to_try.append(resolved_model)
# Try without provider prefix if it has one
if '/' in model:
models_to_try.append(model.split('/', 1)[1])
if '/' in resolved_model and resolved_model != model:
models_to_try.append(resolved_model.split('/', 1)[1])
# Special handling for Google models accessed via OpenRouter
if model.startswith('openrouter/google/'):
google_model_name = model.replace('openrouter/', '')
models_to_try.append(google_model_name)
if resolved_model.startswith('openrouter/google/'):
google_model_name = resolved_model.replace('openrouter/', '')
models_to_try.append(google_model_name)
# Try each model name variation until we find one that works
message_cost = None
for model_name in models_to_try:
try:
prompt_token_cost, completion_token_cost = cost_per_token(model_name, prompt_tokens, completion_tokens)
if prompt_token_cost is not None and completion_token_cost is not None:
message_cost = prompt_token_cost + completion_token_cost
break
except Exception as e:
logger.debug(f"Failed to get pricing for model variation {model_name}: {str(e)}")
continue
if message_cost is None:
logger.warning(f"Could not get pricing for model {model} (resolved: {resolved_model}), returning 0 cost")
return 0.0
except Exception as e:
logger.warning(f"Could not get pricing for model {model} (resolved: {resolved_model}): {str(e)}, returning 0 cost")
return 0.0
# Apply the TOKEN_PRICE_MULTIPLIER
return message_cost * TOKEN_PRICE_MULTIPLIER
except Exception as e:
logger.error(f"Error calculating token cost for model {model}: {str(e)}")
return 0.0
async def get_allowed_models_for_user(client, user_id: str):
"""
Get the list of models allowed for a user based on their subscription tier.
Returns:
List of model names allowed for the user's subscription tier.
"""
result = await Cache.get(f"allowed_models_for_user:{user_id}")
if result:
return result
subscription = await get_user_subscription(user_id)
tier_name = 'free'
if subscription:
price_id = None
if subscription.get('items') and subscription['items'].get('data') and len(subscription['items']['data']) > 0:
price_id = subscription['items']['data'][0]['price']['id']
else:
price_id = subscription.get('price_id', config.STRIPE_FREE_TIER_ID)
# Get tier info for this price_id
tier_info = SUBSCRIPTION_TIERS.get(price_id)
if tier_info:
tier_name = tier_info['name']
# Return allowed models for this tier using model manager
if tier_name == 'free':
result = model_manager.get_models_for_tier('free')
result = [model.id for model in result] # Convert to list of IDs
else:
result = model_manager.get_models_for_tier('paid')
result = [model.id for model in result] # Convert to list of IDs
await Cache.set(f"allowed_models_for_user:{user_id}", result, ttl=1 * 60)
return result
async def can_use_model(client, user_id: str, model_name: str):
if config.ENV_MODE == EnvMode.LOCAL:
logger.debug("Running in local development mode - billing checks are disabled")
return True, "Local development mode - billing disabled", {
"price_id": "local_dev",
"plan_name": "Local Development",
"minutes_limit": "no limit"
}
allowed_models = await get_allowed_models_for_user(client, user_id)
from models import model_manager
resolved_model = model_manager.resolve_model_id(model_name)
if resolved_model in allowed_models:
return True, "Model access allowed", allowed_models
return False, f"Your current subscription plan does not include access to {model_name}. Please upgrade your subscription or choose from your available models: {', '.join(allowed_models)}", allowed_models
async def get_subscription_tier(client, user_id: str) -> str:
try:
subscription = await get_user_subscription(user_id)
if not subscription:
return 'free'
price_id = None
if subscription.get('items') and subscription['items'].get('data') and len(subscription['items']['data']) > 0:
price_id = subscription['items']['data'][0]['price']['id']
else:
price_id = subscription.get('price_id', config.STRIPE_FREE_TIER_ID)
tier_info = SUBSCRIPTION_TIERS.get(price_id)
if tier_info:
return tier_info['name']
logger.warning(f"Unknown price_id {price_id} for user {user_id}, defaulting to free tier")
return 'free'
except Exception as e:
logger.error(f"Error getting subscription tier for user {user_id}: {str(e)}")
return 'free'
async def check_billing_status(client, user_id: str) -> Tuple[bool, str, Optional[Dict]]:
"""
Check if a user can run agents based on their subscription and usage.
Now also checks credit balance if subscription limit is exceeded.
Returns:
Tuple[bool, str, Optional[Dict]]: (can_run, message, subscription_info)
"""
if config.ENV_MODE == EnvMode.LOCAL:
logger.debug("Running in local development mode - billing checks are disabled")
return True, "Local development mode - billing disabled", {
"price_id": "local_dev",
"plan_name": "Local Development",
"minutes_limit": "no limit"
}
# Get current subscription
subscription = await get_user_subscription(user_id)
# print("Current subscription:", subscription)
# If no subscription, they can use free tier
if not subscription:
subscription = {
'price_id': config.STRIPE_FREE_TIER_ID, # Free tier
'plan_name': 'free'
}
# Extract price ID from subscription items
price_id = None
if subscription.get('items') and subscription['items'].get('data') and len(subscription['items']['data']) > 0:
price_id = subscription['items']['data'][0]['price']['id']
else:
price_id = subscription.get('price_id', config.STRIPE_FREE_TIER_ID)
# Get tier info - default to free tier if not found
tier_info = SUBSCRIPTION_TIERS.get(price_id)
if not tier_info:
logger.warning(f"Unknown subscription tier: {price_id}, defaulting to free tier")
tier_info = SUBSCRIPTION_TIERS[config.STRIPE_FREE_TIER_ID]
# Calculate current month's usage
current_usage = await calculate_monthly_usage(client, user_id)
# Check if subscription limit is exceeded
if current_usage >= tier_info['cost']:
# Check if user has credits available
credit_balance = await get_user_credit_balance(client, user_id)
if credit_balance.balance_dollars >= CREDIT_MIN_START_DOLLARS:
# User has enough credits cushion; they can continue
return True, f"Subscription limit reached, using credits. Balance: ${credit_balance.balance_dollars:.2f}", subscription
else:
# Not enough credits to safely start a new request
if credit_balance.can_purchase_credits:
return False, (
f"Monthly limit of ${tier_info['cost']} reached. You need at least ${CREDIT_MIN_START_DOLLARS:.2f} in credits to continue. "
f"Current balance: ${credit_balance.balance_dollars:.2f}."
), subscription
else:
return False, (
f"Monthly limit of ${tier_info['cost']} reached and credits are unavailable. Please upgrade your plan or wait until next month."
), subscription
return True, "OK", subscription
async def check_subscription_commitment(subscription_id: str) -> dict:
"""
Check if a subscription has an active yearly commitment that prevents cancellation.
Simple logic: commitment lasts 1 year from subscription creation date.
"""
try:
subscription = await stripe.Subscription.retrieve_async(subscription_id)
# Get the price ID from subscription items
price_id = None
if subscription.get('items') and subscription['items'].get('data') and len(subscription['items']['data']) > 0:
price_id = subscription['items']['data'][0]['price']['id']
# Check if subscription has commitment metadata OR uses a yearly commitment price ID
commitment_type = subscription.metadata.get('commitment_type')
# Yearly commitment price IDs
yearly_commitment_price_ids = [
config.STRIPE_TIER_2_17_YEARLY_COMMITMENT_ID,
config.STRIPE_TIER_6_42_YEARLY_COMMITMENT_ID,
config.STRIPE_TIER_25_170_YEARLY_COMMITMENT_ID
]
is_yearly_commitment = (
commitment_type == 'yearly_commitment' or
price_id in yearly_commitment_price_ids
)
if is_yearly_commitment:
# Calculate commitment period: 1 year from subscription creation
subscription_start = subscription.created
current_time = int(time.time())
start_date = datetime.fromtimestamp(subscription_start, tz=timezone.utc)
commitment_end_date = start_date.replace(year=start_date.year + 1)
commitment_end_timestamp = int(commitment_end_date.timestamp())
if current_time < commitment_end_timestamp:
# Still in commitment period
current_date = datetime.fromtimestamp(current_time, tz=timezone.utc)
months_remaining = (commitment_end_date.year - current_date.year) * 12 + (commitment_end_date.month - current_date.month)
if current_date.day > commitment_end_date.day:
months_remaining -= 1
months_remaining = max(0, months_remaining)
logger.debug(f"Subscription {subscription_id} has active yearly commitment: {months_remaining} months remaining")
return {
'has_commitment': True,
'commitment_type': 'yearly_commitment',
'months_remaining': months_remaining,
'can_cancel': False,
'commitment_end_date': commitment_end_date.isoformat(),
'subscription_start_date': start_date.isoformat(),
'price_id': price_id
}
else:
# Commitment period has ended
logger.debug(f"Subscription {subscription_id} yearly commitment period has ended")
return {
'has_commitment': False,
'commitment_type': 'yearly_commitment',
'commitment_completed': True,
'can_cancel': True,
'subscription_start_date': start_date.isoformat(),
'price_id': price_id
}
# No commitment
return {
'has_commitment': False,
'can_cancel': True,
'price_id': price_id
}
except Exception as e:
logger.error(f"Error checking subscription commitment: {str(e)}", exc_info=True)
return {
'has_commitment': False,
'can_cancel': True
}
async def is_user_on_highest_tier(user_id: str) -> bool:
"""Check if user is on the highest subscription tier (200h/$1000)."""
try:
subscription = await get_user_subscription(user_id)
if not subscription:
logger.debug(f"User {user_id} has no subscription")
return False
# Extract price ID from subscription
price_id = None
if subscription.get('items') and subscription['items'].get('data') and len(subscription['items']['data']) > 0:
price_id = subscription['items']['data'][0]['price']['id']
logger.info(f"User {user_id} subscription price_id: {price_id}")
# Check if it's one of the highest tier price IDs (200h/$1000 only)
highest_tier_price_ids = [
config.STRIPE_TIER_200_1000_ID, # Monthly highest tier
config.STRIPE_TIER_200_1000_YEARLY_ID, # Yearly highest tier
config.STRIPE_TIER_25_200_ID_STAGING,
config.STRIPE_TIER_25_200_YEARLY_ID_STAGING,
config.STRIPE_TIER_2_20_ID_STAGING,
config.STRIPE_TIER_2_20_YEARLY_ID_STAGING,
]
is_highest = price_id in highest_tier_price_ids
logger.info(f"User {user_id} is_highest_tier: {is_highest}, price_id: {price_id}, checked against: {highest_tier_price_ids}")
return is_highest
except Exception as e:
logger.error(f"Error checking if user is on highest tier: {str(e)}")
return False
async def get_user_credit_balance(client: SupabaseClient, user_id: str) -> CreditBalance:
"""Get the credit balance for a user."""
try:
# Get balance from database - use execute() instead of single() to handle no records
result = await client.table('credit_balance') \
.select('*') \
.eq('user_id', user_id) \
.execute()
if result.data and len(result.data) > 0:
data = result.data[0]
is_highest_tier = await is_user_on_highest_tier(user_id)
# Safely handle last_updated datetime conversion
last_updated = None
if data.get('last_updated'):
try:
# If it's already a datetime object, use it
if isinstance(data['last_updated'], datetime):
last_updated = data['last_updated']
else:
# Try to parse it as a string
last_updated = dateutil_parser.parse(data['last_updated'])
except Exception as dt_error:
logger.warning(f"Error parsing last_updated datetime for user {user_id}: {dt_error}")
last_updated = None
return CreditBalance(
balance_dollars=float(data.get('balance_dollars', 0)),
total_purchased=float(data.get('total_purchased', 0)),
total_used=float(data.get('total_used', 0)),
last_updated=last_updated,
can_purchase_credits=is_highest_tier
)
else:
# No balance record exists yet - this is normal for users who haven't purchased credits
is_highest_tier = await is_user_on_highest_tier(user_id)
return CreditBalance(
balance_dollars=0.0,
total_purchased=0.0,
total_used=0.0,
can_purchase_credits=is_highest_tier
)
except Exception as e:
logger.error(f"Error getting credit balance for user {user_id}: {str(e)}")
return CreditBalance(
balance_dollars=0.0,
total_purchased=0.0,
total_used=0.0,
can_purchase_credits=False
)
async def add_credits_to_balance(client: SupabaseClient, user_id: str, amount: float, purchase_id: str = None) -> float:
"""Add credits to a user's balance."""
try:
# Use the database function to add credits
result = await client.rpc('add_credits', {
'p_user_id': user_id,
'p_amount': amount,
'p_purchase_id': purchase_id
}).execute()
if result.data is not None:
return float(result.data)
return 0.0
except Exception as e:
logger.error(f"Error adding credits for user {user_id}: {str(e)}")
raise
async def use_credits_from_balance(
client: SupabaseClient,
user_id: str,
amount: float,
description: str = None,
thread_id: str = None,
message_id: str = None
) -> bool:
"""Deduct credits from a user's balance."""
try:
# Use the database function to use credits
result = await client.rpc('use_credits', {
'p_user_id': user_id,
'p_amount': amount,
'p_description': description,
'p_thread_id': thread_id,
'p_message_id': message_id
}).execute()
if result.data is not None:
return bool(result.data)
return False
except Exception as e:
logger.error(f"Error using credits for user {user_id}: {str(e)}")
return False
async def handle_usage_with_credits(
client: SupabaseClient,
user_id: str,
token_cost: float,
thread_id: str = None,
message_id: str = None,
model: str = None
) -> Tuple[bool, str]:
"""
Handle token usage that may require credits if subscription limit is exceeded.
This should be called after each agent response to track and deduct from credits if needed.
Returns:
Tuple[bool, str]: (success, message)
"""
try:
# Get current subscription tier and limits
subscription = await get_user_subscription(user_id)
# Get tier info
price_id = config.STRIPE_FREE_TIER_ID # Default to free
if subscription and subscription.get('items'):
items = subscription['items'].get('data', [])
if items:
price_id = items[0]['price']['id']
tier_info = SUBSCRIPTION_TIERS.get(price_id, SUBSCRIPTION_TIERS[config.STRIPE_FREE_TIER_ID])
# Get current month's usage
current_usage = await calculate_monthly_usage(client, user_id)
# Check if this usage would exceed the subscription limit
new_total_usage = current_usage + token_cost
if new_total_usage > tier_info['cost']:
# Calculate overage amount
overage_amount = token_cost # The entire cost if already over limit
if current_usage < tier_info['cost']:
# If this is the transaction that pushes over the limit
overage_amount = new_total_usage - tier_info['cost']
# Try to use credits for the overage
credit_balance = await get_user_credit_balance(client, user_id)
if credit_balance.balance_dollars >= overage_amount:
# Deduct from credits
success = await use_credits_from_balance(
client,
user_id,
overage_amount,
description=f"Token overage for model {model or 'unknown'}",
thread_id=thread_id,
message_id=message_id
)
if success:
logger.debug(f"Used ${overage_amount:.4f} credits for user {user_id} overage")
return True, f"Used ${overage_amount:.4f} from credits (Balance: ${credit_balance.balance_dollars - overage_amount:.2f})"
else:
return False, "Failed to deduct credits"
else:
# Insufficient credits
if credit_balance.can_purchase_credits:
return False, f"Insufficient credits. Balance: ${credit_balance.balance_dollars:.2f}, Required: ${overage_amount:.4f}. Purchase more credits to continue."
else:
return False, f"Monthly limit exceeded and no credits available. Upgrade to the highest tier to purchase credits."
# Within subscription limits, no credits needed
return True, "Within subscription limits"
except Exception as e:
logger.error(f"Error handling usage with credits: {str(e)}")
return False, f"Error processing usage: {str(e)}"
# API endpoints
@router.post("/create-checkout-session")
async def create_checkout_session(
request: CreateCheckoutSessionRequest,
current_user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Create a Stripe Checkout session or modify an existing subscription."""
try:
# Get Supabase client
db = DBConnection()
client = await db.client
# Get user email from auth.users
user_result = await client.auth.admin.get_user_by_id(current_user_id)
if not user_result: raise HTTPException(status_code=404, detail="User not found")
email = user_result.user.email
# Get or create Stripe customer
customer_id = await get_stripe_customer_id(client, current_user_id)
if not customer_id: customer_id = await create_stripe_customer(client, current_user_id, email)
# Get the target price and product ID
try:
price = await stripe.Price.retrieve_async(request.price_id, expand=['product'])
product_id = price['product']['id']
except stripe.error.InvalidRequestError:
raise HTTPException(status_code=400, detail=f"Invalid price ID: {request.price_id}")
# Verify the price belongs to our product
if product_id != config.STRIPE_PRODUCT_ID:
raise HTTPException(status_code=400, detail="Price ID does not belong to the correct product.")
# Check for existing subscription for our product
existing_subscription = await get_user_subscription(current_user_id)
# print("Existing subscription for product:", existing_subscription)
if existing_subscription:
# --- Handle Subscription Change (Upgrade or Downgrade) ---
try:
subscription_id = existing_subscription['id']
subscription_item = existing_subscription['items']['data'][0]
current_price_id = subscription_item['price']['id']
# Skip if already on this plan
if current_price_id == request.price_id:
return {
"subscription_id": subscription_id,
"status": "no_change",
"message": "Already subscribed to this plan.",
"details": {
"is_upgrade": None,
"effective_date": None,
"current_price": round(price['unit_amount'] / 100, 2) if price.get('unit_amount') else 0,
"new_price": round(price['unit_amount'] / 100, 2) if price.get('unit_amount') else 0,
}
}
# Validate plan change restrictions
is_allowed, restriction_reason = is_plan_change_allowed(current_price_id, request.price_id)
if not is_allowed:
raise HTTPException(
status_code=400,
detail=f"Plan change not allowed: {restriction_reason}"
)
# Check current subscription's commitment status
commitment_info = await check_subscription_commitment(subscription_id)
# Get current and new price details
current_price = await stripe.Price.retrieve_async(current_price_id)
new_price = price # Already retrieved
# Determine if this is an upgrade
# Consider yearly plans as upgrades regardless of unit price (due to discounts)
current_interval = current_price.get('recurring', {}).get('interval', 'month')
new_interval = new_price.get('recurring', {}).get('interval', 'month')
is_upgrade = (
new_price['unit_amount'] > current_price['unit_amount'] or # Traditional price upgrade
(current_interval == 'month' and new_interval == 'year') # Monthly to yearly upgrade
)
logger.debug(f"Price comparison: current={current_price['unit_amount']}, new={new_price['unit_amount']}, "
f"intervals: {current_interval}->{new_interval}, is_upgrade={is_upgrade}")
# For commitment subscriptions, handle differently
if commitment_info.get('has_commitment'):
if is_upgrade:
# Allow upgrades for commitment subscriptions immediately
logger.debug(f"Upgrading commitment subscription {subscription_id}")
# Regular subscription modification for upgrades
updated_subscription = await stripe.Subscription.modify_async(
subscription_id,
items=[{
'id': subscription_item['id'],
'price': request.price_id,
}],
proration_behavior='always_invoice', # Prorate and charge immediately
billing_cycle_anchor='now', # Reset billing cycle
metadata={
**existing_subscription.get('metadata', {}),
'commitment_type': request.commitment_type or 'monthly'
}
)
# Update active status in database
await client.schema('basejump').from_('billing_customers').update(
{'active': True}
).eq('id', customer_id).execute()
logger.debug(f"Updated customer {customer_id} active status to TRUE after subscription upgrade")
# Force immediate payment for upgrades
latest_invoice = None
if updated_subscription.latest_invoice:
latest_invoice_id = updated_subscription.latest_invoice
latest_invoice = await stripe.Invoice.retrieve_async(latest_invoice_id)
try:
logger.debug(f"Latest invoice {latest_invoice_id} status: {latest_invoice.status}")
# If invoice is in draft status, finalize it to trigger immediate payment
if latest_invoice.status == 'draft':
finalized_invoice = stripe.Invoice.finalize_invoice(latest_invoice_id)
logger.debug(f"Finalized invoice {latest_invoice_id} for immediate payment")
latest_invoice = finalized_invoice
# Pay the invoice immediately if it's still open
if finalized_invoice.status == 'open':
paid_invoice = stripe.Invoice.pay(latest_invoice_id)
logger.debug(f"Paid invoice {latest_invoice_id} immediately, status: {paid_invoice.status}")
latest_invoice = paid_invoice
elif latest_invoice.status == 'open':
# Invoice is already finalized but not paid, pay it
paid_invoice = stripe.Invoice.pay(latest_invoice_id)
logger.debug(f"Paid existing open invoice {latest_invoice_id}, status: {paid_invoice.status}")
latest_invoice = paid_invoice
else:
logger.debug(f"Invoice {latest_invoice_id} is in status {latest_invoice.status}, no action needed")
except Exception as invoice_error:
logger.error(f"Error processing invoice for immediate payment: {str(invoice_error)}")
# Don't fail the entire operation if invoice processing fails
return {
"subscription_id": updated_subscription.id,
"status": "updated",
"message": f"Subscription upgraded successfully",
"details": {
"is_upgrade": True,
"effective_date": "immediate",
"current_price": round(current_price['unit_amount'] / 100, 2) if current_price.get('unit_amount') else 0,
"new_price": round(new_price['unit_amount'] / 100, 2) if new_price.get('unit_amount') else 0,
"invoice": {
"id": latest_invoice['id'] if latest_invoice else None,
"status": latest_invoice['status'] if latest_invoice else None,
"amount_due": round(latest_invoice['amount_due'] / 100, 2) if latest_invoice else 0,
"amount_paid": round(latest_invoice['amount_paid'] / 100, 2) if latest_invoice else 0
} if latest_invoice else None
}
}
else:
# Downgrade for commitment subscription - must wait until commitment ends
if not commitment_info.get('can_cancel'):
return {
"subscription_id": subscription_id,
"status": "commitment_blocks_downgrade",
"message": f"Cannot downgrade during commitment period. {commitment_info.get('months_remaining', 0)} months remaining.",
"details": {
"is_upgrade": False,
"effective_date": commitment_info.get('commitment_end_date'),
"current_price": round(current_price['unit_amount'] / 100, 2) if current_price.get('unit_amount') else 0,
"new_price": round(new_price['unit_amount'] / 100, 2) if new_price.get('unit_amount') else 0,
"commitment_end_date": commitment_info.get('commitment_end_date'),
"months_remaining": commitment_info.get('months_remaining', 0)
}
}
# If commitment allows cancellation, proceed with normal downgrade logic
else:
# Regular subscription without commitment - use existing logic
pass
if is_upgrade:
# --- Handle Upgrade --- Immediate modification
updated_subscription = await stripe.Subscription.modify_async(
subscription_id,
items=[{
'id': subscription_item['id'],
'price': request.price_id,
}],
proration_behavior='always_invoice', # Prorate and charge immediately
billing_cycle_anchor='now' # Reset billing cycle
)
# Update active status in database to true (customer has active subscription)
await client.schema('basejump').from_('billing_customers').update(
{'active': True}
).eq('id', customer_id).execute()
logger.debug(f"Updated customer {customer_id} active status to TRUE after subscription upgrade")
latest_invoice = None
if updated_subscription.latest_invoice:
latest_invoice_id = updated_subscription.latest_invoice
latest_invoice = await stripe.Invoice.retrieve_async(latest_invoice_id)
# Force immediate payment for upgrades
try:
logger.debug(f"Latest invoice {latest_invoice_id} status: {latest_invoice.status}")
# If invoice is in draft status, finalize it to trigger immediate payment
if latest_invoice.status == 'draft':
finalized_invoice = stripe.Invoice.finalize_invoice(latest_invoice_id)
logger.debug(f"Finalized invoice {latest_invoice_id} for immediate payment")
latest_invoice = finalized_invoice # Update reference
# Pay the invoice immediately if it's still open
if finalized_invoice.status == 'open':
paid_invoice = stripe.Invoice.pay(latest_invoice_id)
logger.debug(f"Paid invoice {latest_invoice_id} immediately, status: {paid_invoice.status}")
latest_invoice = paid_invoice # Update reference
elif latest_invoice.status == 'open':
# Invoice is already finalized but not paid, pay it
paid_invoice = stripe.Invoice.pay(latest_invoice_id)
logger.debug(f"Paid existing open invoice {latest_invoice_id}, status: {paid_invoice.status}")
latest_invoice = paid_invoice # Update reference
else:
logger.debug(f"Invoice {latest_invoice_id} is in status {latest_invoice.status}, no action needed")
except Exception as invoice_error:
logger.error(f"Error processing invoice for immediate payment: {str(invoice_error)}")
# Don't fail the entire operation if invoice processing fails
return {
"subscription_id": updated_subscription.id,
"status": "updated",
"message": "Subscription upgraded successfully",
"details": {
"is_upgrade": True,
"effective_date": "immediate",
"current_price": round(current_price['unit_amount'] / 100, 2) if current_price.get('unit_amount') else 0,
"new_price": round(new_price['unit_amount'] / 100, 2) if new_price.get('unit_amount') else 0,
"invoice": {
"id": latest_invoice['id'] if latest_invoice else None,
"status": latest_invoice['status'] if latest_invoice else None,
"amount_due": round(latest_invoice['amount_due'] / 100, 2) if latest_invoice else 0,
"amount_paid": round(latest_invoice['amount_paid'] / 100, 2) if latest_invoice else 0
} if latest_invoice else None
}
}
else:
# --- Handle Downgrade --- Simple downgrade at period end
updated_subscription = await stripe.Subscription.modify_async(
subscription_id,
items=[{
'id': subscription_item['id'],
'price': request.price_id,
}],
proration_behavior='none', # No proration for downgrades
billing_cycle_anchor='unchanged' # Keep current billing cycle
)
# Update active status in database
await client.schema('basejump').from_('billing_customers').update(
{'active': True}
).eq('id', customer_id).execute()
logger.debug(f"Updated customer {customer_id} active status to TRUE after subscription downgrade")
return {
"subscription_id": updated_subscription.id,
"status": "updated",
"message": "Subscription downgraded successfully",
"details": {
"is_upgrade": False,
"effective_date": "immediate",
"current_price": round(current_price['unit_amount'] / 100, 2) if current_price.get('unit_amount') else 0,
"new_price": round(new_price['unit_amount'] / 100, 2) if new_price.get('unit_amount') else 0,
}
}
except Exception as e:
logger.exception(f"Error updating subscription {existing_subscription.get('id') if existing_subscription else 'N/A'}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error updating subscription: {str(e)}")
else:
# Create regular subscription with commitment metadata if specified
session = await stripe.checkout.Session.create_async(
customer=customer_id,
payment_method_types=['card'],
line_items=[{'price': request.price_id, 'quantity': 1}],
mode='subscription',
subscription_data={
'metadata': {
'commitment_type': request.commitment_type or 'monthly',
'user_id': current_user_id
}
},
success_url=request.success_url,
cancel_url=request.cancel_url,
metadata={
'user_id': current_user_id,
'product_id': product_id,
'tolt_referral': request.tolt_referral,
'commitment_type': request.commitment_type or 'monthly'
},
allow_promotion_codes=True
)
# Update customer status to potentially active (will be confirmed by webhook)
await client.schema('basejump').from_('billing_customers').update(
{'active': True}
).eq('id', customer_id).execute()
logger.debug(f"Updated customer {customer_id} active status to TRUE after creating checkout session")
return {"session_id": session['id'], "url": session['url'], "status": "new"}
except Exception as e:
logger.exception(f"Error creating checkout session: {str(e)}")
# Check if it's a Stripe error with more details
if hasattr(e, 'json_body') and e.json_body and 'error' in e.json_body:
error_detail = e.json_body['error'].get('message', str(e))
else:
error_detail = str(e)
raise HTTPException(status_code=500, detail=f"Error creating checkout session: {error_detail}")
@router.post("/create-portal-session")
async def create_portal_session(
request: CreatePortalSessionRequest,
current_user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Create a Stripe Customer Portal session for subscription management."""
try:
# Get Supabase client
db = DBConnection()
client = await db.client
# Get customer ID
customer_id = await get_stripe_customer_id(client, current_user_id)
if not customer_id:
raise HTTPException(status_code=404, detail="No billing customer found")
# Ensure the portal configuration has subscription_update enabled
try:
# First, check if we have a configuration that already enables subscription update
configurations = await stripe.billing_portal.Configuration.list_async(limit=100)
active_config = None
# Look for a configuration with subscription_update enabled
for config in configurations.get('data', []):
features = config.get('features', {})
subscription_update = features.get('subscription_update', {})
if subscription_update.get('enabled', False):
active_config = config
logger.debug(f"Found existing portal configuration with subscription_update enabled: {config['id']}")
break
# If no config with subscription_update found, create one or update the active one
if not active_config:
# Find the active configuration or create a new one
if configurations.get('data', []):
default_config = configurations['data'][0]
logger.debug(f"Updating default portal configuration: {default_config['id']} to enable subscription_update")
active_config = await stripe.billing_portal.Configuration.update_async(
default_config['id'],
features={
'subscription_update': {
'enabled': True,
'proration_behavior': 'create_prorations',
'default_allowed_updates': ['price']
},
# Preserve other features that may already be enabled
'customer_update': default_config.get('features', {}).get('customer_update', {'enabled': True, 'allowed_updates': ['email', 'address']}),
'invoice_history': {'enabled': True},
'payment_method_update': {'enabled': True}
}
)
else:
# Create a new configuration with subscription_update enabled
logger.debug("Creating new portal configuration with subscription_update enabled")
active_config = await stripe.billing_portal.Configuration.create_async(
business_profile={
'headline': 'Subscription Management',
'privacy_policy_url': config.FRONTEND_URL + '/privacy',
'terms_of_service_url': config.FRONTEND_URL + '/terms'
},
features={
'subscription_update': {
'enabled': True,
'proration_behavior': 'create_prorations',
'default_allowed_updates': ['price']
},
'customer_update': {
'enabled': True,
'allowed_updates': ['email', 'address']
},
'invoice_history': {'enabled': True},
'payment_method_update': {'enabled': True}
}
)
# Log the active configuration for debugging
logger.debug(f"Using portal configuration: {active_config['id']} with subscription_update: {active_config.get('features', {}).get('subscription_update', {}).get('enabled', False)}")
except Exception as config_error:
logger.warning(f"Error configuring portal: {config_error}. Continuing with default configuration.")
# Create portal session using the proper configuration if available
portal_params = {
"customer": customer_id,
"return_url": request.return_url
}
# Add configuration_id if we found or created one with subscription_update enabled
if active_config:
portal_params["configuration"] = active_config['id']
# Create the session
session = await stripe.billing_portal.Session.create_async(**portal_params)
return {"url": session.url}
except Exception as e:
logger.error(f"Error creating portal session: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/subscription")
async def get_subscription(
current_user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Get the current subscription status for the current user, including scheduled changes and credit balance."""
try:
logger.debug(f"Getting subscription status for user {current_user_id}")
# Initialize default values with safe fallbacks
subscription = None
current_usage = 0.0
credit_balance_info = None
# Get subscription from Stripe with error handling
try:
subscription = await get_user_subscription(current_user_id)
logger.debug(f"Retrieved subscription data for user {current_user_id}: {subscription is not None}")
except Exception as sub_error:
logger.error(f"Error retrieving subscription for user {current_user_id}: {str(sub_error)}")
# Continue with None subscription - will default to free tier
# Calculate current usage with error handling
try:
db = DBConnection()
client = await db.client
current_usage = await calculate_monthly_usage(client, current_user_id)
logger.debug(f"Retrieved usage for user {current_user_id}: {current_usage}")
except Exception as usage_error:
logger.error(f"Error calculating usage for user {current_user_id}: {str(usage_error)}")
current_usage = 0.0 # Default to 0 if calculation fails
# Get credit balance with error handling
try:
if 'client' not in locals():
db = DBConnection()
client = await db.client
credit_balance_info = await get_user_credit_balance(client, current_user_id)
logger.debug(f"Retrieved credit balance for user {current_user_id}: {credit_balance_info.balance_dollars if credit_balance_info else 'None'}")
except Exception as balance_error:
logger.error(f"Error getting credit balance for user {current_user_id}: {str(balance_error)}")
# Create safe fallback credit balance
credit_balance_info = CreditBalance(
balance_dollars=0.0,
total_purchased=0.0,
total_used=0.0,
can_purchase_credits=False
)
# Return free tier if no subscription
if not subscription:
logger.debug(f"No subscription found for user {current_user_id}, returning free tier")
free_tier_id = config.STRIPE_FREE_TIER_ID
free_tier_info = SUBSCRIPTION_TIERS.get(free_tier_id)
return SubscriptionStatus(
status="no_subscription",
plan_name=free_tier_info.get('name', 'free') if free_tier_info else 'free',
price_id=free_tier_id,
minutes_limit=free_tier_info.get('minutes') if free_tier_info else 0,
cost_limit=free_tier_info.get('cost') if free_tier_info else 0,
current_usage=current_usage,
credit_balance=credit_balance_info.balance_dollars if credit_balance_info else 0.0,
can_purchase_credits=credit_balance_info.can_purchase_credits if credit_balance_info else False
)
# Safely extract current plan details with validation
try:
if not subscription.get('items') or not subscription['items'].get('data'):
raise ValueError("Subscription has no items data")
current_item = subscription['items']['data'][0]
if not current_item.get('price') or not current_item['price'].get('id'):
raise ValueError("Subscription item has no price data")
current_price_id = current_item['price']['id']
current_tier_info = SUBSCRIPTION_TIERS.get(current_price_id)
if not current_tier_info:
logger.warning(f"User {current_user_id} subscribed to unknown price {current_price_id}. Using defaults.")
current_tier_info = {'name': 'unknown', 'minutes': 0, 'cost': 0}
# Safely get timestamps with validation
current_period_end = None
trial_end = None
try:
if current_item.get('current_period_end'):
current_period_end = datetime.fromtimestamp(current_item['current_period_end'], tz=timezone.utc)
except (ValueError, TypeError, OSError) as ts_error:
logger.error(f"Error parsing current_period_end timestamp for user {current_user_id}: {ts_error}")
current_period_end = None
try:
if subscription.get('trial_end'):
trial_end = datetime.fromtimestamp(subscription['trial_end'], tz=timezone.utc)
except (ValueError, TypeError, OSError) as ts_error:
logger.error(f"Error parsing trial_end timestamp for user {current_user_id}: {ts_error}")
trial_end = None
# Safely construct subscription object for response
subscription_data = {
'id': subscription.get('id', ''),
'status': subscription.get('status', 'unknown'),
'cancel_at_period_end': bool(subscription.get('cancel_at_period_end', False)),
'cancel_at': subscription.get('cancel_at'),
'current_period_end': current_item.get('current_period_end')
}
# Get plan name safely
plan_name = 'unknown'
if subscription.get('plan') and subscription['plan'].get('nickname'):
plan_name = subscription['plan']['nickname']
elif current_tier_info.get('name'):
plan_name = current_tier_info['name']
status_response = SubscriptionStatus(
status=subscription.get('status', 'unknown'),
plan_name=plan_name,
price_id=current_price_id,
current_period_end=current_period_end,
cancel_at_period_end=bool(subscription.get('cancel_at_period_end', False)),
trial_end=trial_end,
minutes_limit=current_tier_info.get('minutes', 0),
cost_limit=current_tier_info.get('cost', 0),
current_usage=current_usage,
has_schedule=False,
subscription_id=subscription.get('id'),
subscription=subscription_data,
credit_balance=credit_balance_info.balance_dollars if credit_balance_info else 0.0,
can_purchase_credits=credit_balance_info.can_purchase_credits if credit_balance_info else False
)
# Check for an attached schedule (indicates pending downgrade) with error handling
schedule_id = subscription.get('schedule')
if schedule_id:
try:
logger.debug(f"Processing subscription schedule {schedule_id} for user {current_user_id}")
schedule = await stripe.SubscriptionSchedule.retrieve_async(schedule_id)
# Find the *next* phase after the current one
next_phase = None
current_phase_end = current_item.get('current_period_end')
if current_phase_end and schedule.get('phases'):
for phase in schedule.get('phases', []):
if phase.get('start_date') == current_phase_end:
next_phase = phase
break
if next_phase and next_phase.get('items'):
try:
scheduled_item = next_phase['items'][0]
scheduled_price_id = scheduled_item.get('price', '')
scheduled_tier_info = SUBSCRIPTION_TIERS.get(scheduled_price_id, {})
scheduled_change_date = None
if next_phase.get('start_date'):
try:
scheduled_change_date = datetime.fromtimestamp(next_phase['start_date'], tz=timezone.utc)
except (ValueError, TypeError, OSError) as ts_error:
logger.error(f"Error parsing scheduled change date for user {current_user_id}: {ts_error}")
status_response.has_schedule = True
status_response.status = 'scheduled_downgrade'
status_response.scheduled_plan_name = scheduled_tier_info.get('name', 'unknown')
status_response.scheduled_price_id = scheduled_price_id
status_response.scheduled_change_date = scheduled_change_date
except Exception as schedule_parse_error:
logger.error(f"Error parsing schedule details for user {current_user_id}: {schedule_parse_error}")
except Exception as schedule_error:
logger.error(f"Error retrieving schedule {schedule_id} for user {current_user_id}: {schedule_error}")
logger.debug(f"Successfully constructed subscription response for user {current_user_id}")
# Validate JSON serialization before returning
try:
# Test serialization using FastAPI's JSON encoder
test_dict = status_response.model_dump()
json.dumps(test_dict, default=str) # Use default=str to handle datetime objects
logger.debug(f"JSON serialization validation passed for user {current_user_id}")
except Exception as json_error:
logger.error(f"JSON serialization failed for user {current_user_id}: {json_error}")
logger.error(f"Response data: {status_response.model_dump()}")
# Fall back to safe response
free_tier_id = config.STRIPE_FREE_TIER_ID
free_tier_info = SUBSCRIPTION_TIERS.get(free_tier_id)
return SubscriptionStatus(
status="error",
plan_name=free_tier_info.get('name', 'free') if free_tier_info else 'free',
price_id=free_tier_id,
minutes_limit=free_tier_info.get('minutes') if free_tier_info else 0,
cost_limit=free_tier_info.get('cost') if free_tier_info else 0,
current_usage=current_usage,
credit_balance=credit_balance_info.balance_dollars if credit_balance_info else 0.0,
can_purchase_credits=credit_balance_info.can_purchase_credits if credit_balance_info else False
)
return status_response
except Exception as parsing_error:
logger.error(f"Error parsing subscription data for user {current_user_id}: {str(parsing_error)}")
# Fall back to free tier if subscription data is malformed
free_tier_id = config.STRIPE_FREE_TIER_ID
free_tier_info = SUBSCRIPTION_TIERS.get(free_tier_id)
return SubscriptionStatus(
status="no_subscription",
plan_name=free_tier_info.get('name', 'free') if free_tier_info else 'free',
price_id=free_tier_id,
minutes_limit=free_tier_info.get('minutes') if free_tier_info else 0,
cost_limit=free_tier_info.get('cost') if free_tier_info else 0,
current_usage=current_usage,
credit_balance=credit_balance_info.balance_dollars if credit_balance_info else 0.0,
can_purchase_credits=credit_balance_info.can_purchase_credits if credit_balance_info else False
)
except Exception as e:
logger.exception(f"Error getting subscription status for user {current_user_id}: {str(e)}")
# Return a safe fallback response instead of raising an error
try:
free_tier_id = config.STRIPE_FREE_TIER_ID
free_tier_info = SUBSCRIPTION_TIERS.get(free_tier_id)
return SubscriptionStatus(
status="error",
plan_name=free_tier_info.get('name', 'free') if free_tier_info else 'free',
price_id=free_tier_id,
minutes_limit=free_tier_info.get('minutes') if free_tier_info else 0,
cost_limit=free_tier_info.get('cost') if free_tier_info else 0,
current_usage=0.0,
credit_balance=0.0,
can_purchase_credits=False
)
except Exception as fallback_error:
logger.exception(f"Error creating fallback response for user {current_user_id}: {str(fallback_error)}")
raise HTTPException(status_code=500, detail="Error retrieving subscription status.")
@router.get("/check-status")
async def check_status(
current_user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Check if the user can run agents based on their subscription, usage, and credit balance."""
try:
# Get Supabase client
db = DBConnection()
client = await db.client
can_run, message, subscription = await check_billing_status(client, current_user_id)
# Get credit balance for additional info
credit_balance = await get_user_credit_balance(client, current_user_id)
return {
"can_run": can_run,
"message": message,
"subscription": subscription,
"credit_balance": credit_balance.balance_dollars,
"can_purchase_credits": credit_balance.can_purchase_credits
}
except Exception as e:
logger.error(f"Error checking billing status: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/webhook")
async def stripe_webhook(request: Request):
"""Handle Stripe webhook events."""
try:
# Get the webhook secret from config
webhook_secret = config.STRIPE_WEBHOOK_SECRET
# Get the webhook payload
payload = await request.body()
sig_header = request.headers.get('stripe-signature')
# Verify webhook signature
try:
event = stripe.Webhook.construct_event(
payload, sig_header, webhook_secret
)
logger.debug(f"Received Stripe webhook: {event.type} - Event ID: {event.id}")
except ValueError as e:
logger.error(f"Invalid webhook payload: {str(e)}")
raise HTTPException(status_code=400, detail="Invalid payload")
except stripe.error.SignatureVerificationError as e:
logger.error(f"Invalid webhook signature: {str(e)}")
raise HTTPException(status_code=400, detail="Invalid signature")
# Get database connection
db = DBConnection()
client = await db.client
# Handle credit purchase completion
if event.type == 'checkout.session.completed':
session = event.data.object
# Check if this is a credit purchase
if session.get('metadata', {}).get('type') == 'credit_purchase':
user_id = session['metadata']['user_id']
credit_amount = float(session['metadata']['credit_amount'])
payment_intent_id = session.get('payment_intent')
logger.debug(f"Processing credit purchase for user {user_id}: ${credit_amount}")
try:
# Update the purchase record status
purchase_update = await client.table('credit_purchases') \
.update({
'status': 'completed',
'completed_at': datetime.now(timezone.utc).isoformat(),
'stripe_payment_intent_id': payment_intent_id
}) \
.eq('stripe_payment_intent_id', payment_intent_id) \
.execute()
if not purchase_update.data:
# If no record found by payment_intent_id, try by session_id in metadata (PostgREST JSON operator requires filter)
purchase_update = await client.table('credit_purchases') \
.update({
'status': 'completed',
'completed_at': datetime.now(timezone.utc).isoformat(),
'stripe_payment_intent_id': payment_intent_id
}) \
.filter('metadata->>session_id', 'eq', session['id']) \
.execute()
# Add credits to user's balance
purchase_id = purchase_update.data[0]['id'] if purchase_update.data else None
new_balance = await add_credits_to_balance(client, user_id, credit_amount, purchase_id)
logger.info(f"Successfully added ${credit_amount} credits to user {user_id}. New balance: ${new_balance}")
# Clear cache for this user
await Cache.delete(f"monthly_usage:{user_id}")
await Cache.delete(f"user_subscription:{user_id}")
except Exception as e:
logger.error(f"Error processing credit purchase: {str(e)}")
# Don't fail the webhook, but log the error
return {"status": "success", "message": "Credit purchase processed"}
# Handle payment failed for credit purchases
if event.type == 'payment_intent.payment_failed':
payment_intent = event.data.object
# Check if this is related to a credit purchase
if payment_intent.get('metadata', {}).get('type') == 'credit_purchase':
user_id = payment_intent['metadata']['user_id']
# Update purchase record to failed
await client.table('credit_purchases') \
.update({'status': 'failed'}) \
.eq('stripe_payment_intent_id', payment_intent['id']) \
.execute()
logger.debug(f"Credit purchase failed for user {user_id}")
# Handle the existing subscription events
if event.type in ['customer.subscription.created', 'customer.subscription.updated', 'customer.subscription.deleted']:
# Extract the subscription and customer information
subscription = event.data.object
customer_id = subscription.get('customer')
if not customer_id:
logger.warning(f"No customer ID found in subscription event: {event.type}")
return {"status": "error", "message": "No customer ID found"}
if event.type == 'customer.subscription.created':
# Update customer active status for new subscriptions
if subscription.get('status') in ['active', 'trialing']:
await client.schema('basejump').from_('billing_customers').update(
{'active': True}
).eq('id', customer_id).execute()
logger.debug(f"Webhook: Updated customer {customer_id} active status to TRUE based on {event.type}")
elif event.type == 'customer.subscription.updated':
# Check if subscription is active
if subscription.get('status') in ['active', 'trialing']:
# Update customer's active status to true
await client.schema('basejump').from_('billing_customers').update(
{'active': True}
).eq('id', customer_id).execute()
logger.debug(f"Webhook: Updated customer {customer_id} active status to TRUE based on {event.type}")
else:
# Subscription is not active (e.g., past_due, canceled, etc.)
# Check if customer has any other active subscriptions before updating status
has_active = len(await stripe.Subscription.list_async(
customer=customer_id,
status='active',
limit=1
).get('data', [])) > 0
if not has_active:
await client.schema('basejump').from_('billing_customers').update(
{'active': False}
).eq('id', customer_id).execute()
logger.debug(f"Webhook: Updated customer {customer_id} active status to FALSE based on {event.type}")
elif event.type == 'customer.subscription.deleted':
# Check if customer has any other active subscriptions
has_active = len((await stripe.Subscription.list_async(
customer=customer_id,
status='active',
limit=1
)).get('data', [])) > 0
if not has_active:
# If no active subscriptions left, set active to false
await client.schema('basejump').from_('billing_customers').update(
{'active': False}
).eq('id', customer_id).execute()
logger.debug(f"Webhook: Updated customer {customer_id} active status to FALSE after subscription deletion")
logger.debug(f"Processed {event.type} event for customer {customer_id}")
return {"status": "success"}
except Exception as e:
logger.error(f"Error processing webhook: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/available-models")
async def get_available_models(
current_user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Get the list of models available to the user based on their subscription tier."""
try:
# Import the new model manager
from models import model_manager
# Get Supabase client
db = DBConnection()
client = await db.client
# Check if we're in local development mode
if config.ENV_MODE == EnvMode.LOCAL:
logger.debug("Running in local development mode - billing checks are disabled")
# In local mode, return all enabled models
all_models = model_manager.list_available_models(include_disabled=False)
model_info = []
for model_data in all_models:
# Create clean model info for frontend
model_info.append({
"id": model_data["id"],
"display_name": model_data["name"],
"short_name": model_data.get("aliases", [model_data["name"]])[0] if model_data.get("aliases") else model_data["name"],
"requires_subscription": False, # Always false in local dev mode
"input_cost_per_million_tokens": model_data["pricing"]["input_per_million"] if model_data["pricing"] else None,
"output_cost_per_million_tokens": model_data["pricing"]["output_per_million"] if model_data["pricing"] else None,
"context_window": model_data["context_window"],
"capabilities": model_data["capabilities"],
"recommended": model_data["recommended"],
"priority": model_data["priority"]
})
return {
"models": model_info,
"subscription_tier": "Local Development",
"total_models": len(model_info)
}
# For non-local mode, use new model manager system
# Get subscription info for context
subscription = await get_user_subscription(current_user_id)
# Determine tier name from subscription
tier_name = 'free'
if subscription:
price_id = None
if subscription.get('items') and subscription['items'].get('data') and len(subscription['items']['data']) > 0:
price_id = subscription['items']['data'][0]['price']['id']
else:
price_id = subscription.get('price_id', config.STRIPE_FREE_TIER_ID)
# Get tier info for this price_id
tier_info = SUBSCRIPTION_TIERS.get(price_id)
if tier_info:
tier_name = tier_info['name']
# Get ALL enabled models for preview UI (don't filter by tier here)
all_models = model_manager.list_available_models(tier=None, include_disabled=False)
logger.debug(f"Found {len(all_models)} total models available")
# Get allowed models for this specific user (for access checking)
allowed_models = await get_allowed_models_for_user(client, current_user_id)
logger.debug(f"User {current_user_id} allowed models: {allowed_models}")
logger.debug(f"User tier: {tier_name}")
# Create clean model info for frontend
model_info = []
for model_data in all_models:
model_id = model_data["id"]
# Check if model is available with current subscription
is_available = model_id in allowed_models
# Get pricing with multiplier applied
pricing_info = {}
if model_data["pricing"]:
pricing_info = {
"input_cost_per_million_tokens": model_data["pricing"]["input_per_million"] * TOKEN_PRICE_MULTIPLIER,
"output_cost_per_million_tokens": model_data["pricing"]["output_per_million"] * TOKEN_PRICE_MULTIPLIER,
"max_tokens": model_data["max_output_tokens"]
}
else:
pricing_info = {
"input_cost_per_million_tokens": None,
"output_cost_per_million_tokens": None,
"max_tokens": None
}
model_info.append({
"id": model_id,
"display_name": model_data["name"],
"short_name": model_data.get("aliases", [model_data["name"]])[0] if model_data.get("aliases") else model_data["name"],
"requires_subscription": not model_data.get("tier_availability", []) or "free" not in model_data["tier_availability"],
"is_available": is_available,
"context_window": model_data["context_window"],
"capabilities": model_data["capabilities"],
"recommended": model_data["recommended"],
"priority": model_data["priority"],
**pricing_info
})
logger.debug(f"Returning {len(model_info)} models to user {current_user_id} (tier: {tier_name})")
if model_info:
model_names = [m["display_name"] for m in model_info]
logger.debug(f"Model names: {model_names}")
return {
"models": model_info,
"subscription_tier": tier_name,
"total_models": len(model_info)
}
except Exception as e:
logger.error(f"Error getting available models: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error getting available models: {str(e)}")
@router.get("/usage-logs")
async def get_usage_logs_endpoint(
page: int = 0,
items_per_page: int = 1000,
current_user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Get detailed usage logs for a user with pagination."""
logger.info(f"[USAGE_LOGS_ENDPOINT] Starting get_usage_logs_endpoint for user_id={current_user_id}, page={page}, items_per_page={items_per_page}")
try:
# Get Supabase client
db = DBConnection()
client = await db.client
# Check if we're in local development mode
if config.ENV_MODE == EnvMode.LOCAL:
logger.debug(f"[USAGE_LOGS_ENDPOINT] user_id={current_user_id} - Running in local development mode - usage logs are not available")
return {
"logs": [],
"has_more": False,
"message": "Usage logs are not available in local development mode"
}
# Validate pagination parameters
if page < 0:
logger.error(f"[USAGE_LOGS_ENDPOINT] user_id={current_user_id} - Invalid page parameter: {page}")
raise HTTPException(status_code=400, detail="Page must be non-negative")
if items_per_page < 1 or items_per_page > 1000:
logger.error(f"[USAGE_LOGS_ENDPOINT] user_id={current_user_id} - Invalid items_per_page parameter: {items_per_page}")
raise HTTPException(status_code=400, detail="Items per page must be between 1 and 1000")
# Get usage logs
logger.debug(f"[USAGE_LOGS_ENDPOINT] user_id={current_user_id} - Calling get_usage_logs")
result = await get_usage_logs(client, current_user_id, page, items_per_page)
# Check if result contains an error
if isinstance(result, dict) and result.get('error'):
logger.error(f"[USAGE_LOGS_ENDPOINT] user_id={current_user_id} - Usage logs returned error: {result['error']}")
raise HTTPException(status_code=400, detail=f"Failed to retrieve usage logs: {result['error']}")
logger.info(f"[USAGE_LOGS_ENDPOINT] user_id={current_user_id} - Successfully returned {len(result.get('logs', []))} usage logs")
return result
except HTTPException:
raise
except Exception as e:
logger.exception(f"[USAGE_LOGS_ENDPOINT] user_id={current_user_id} - Error getting usage logs: {str(e)}")
# Check if this is a JSON serialization error
if "JSON could not be generated" in str(e) or "JSON" in str(e):
logger.error(f"[USAGE_LOGS_ENDPOINT] user_id={current_user_id} - Detected JSON serialization error")
raise HTTPException(status_code=400, detail=f"Data serialization error: {str(e)}")
else:
raise HTTPException(status_code=500, detail=f"Error getting usage logs: {str(e)}")
@router.get("/subscription-commitment/{subscription_id}")
async def get_subscription_commitment(
subscription_id: str,
current_user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Get commitment status for a subscription."""
try:
# Verify the subscription belongs to the current user
db = DBConnection()
client = await db.client
# Get user's subscription to verify ownership
user_subscription = await get_user_subscription(current_user_id)
if not user_subscription or user_subscription.get('id') != subscription_id:
raise HTTPException(status_code=404, detail="Subscription not found or access denied")
commitment_info = await check_subscription_commitment(subscription_id)
return commitment_info
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting subscription commitment: {str(e)}")
raise HTTPException(status_code=500, detail="Error retrieving commitment information")
@router.get("/subscription-details")
async def get_subscription_details(
current_user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Get detailed subscription information including commitment status."""
try:
subscription = await get_user_subscription(current_user_id)
if not subscription:
return {
"subscription": None,
"commitment": {"has_commitment": False, "can_cancel": True}
}
# Get commitment information
commitment_info = await check_subscription_commitment(subscription['id'])
# Enhanced subscription details
subscription_details = {
"id": subscription.get('id'),
"status": subscription.get('status'),
"current_period_end": subscription.get('current_period_end'),
"current_period_start": subscription.get('current_period_start'),
"cancel_at_period_end": subscription.get('cancel_at_period_end'),
"items": subscription.get('items', {}).get('data', []),
"metadata": subscription.get('metadata', {})
}
return {
"subscription": subscription_details,
"commitment": commitment_info
}
except Exception as e:
logger.error(f"Error getting subscription details: {str(e)}")
raise HTTPException(status_code=500, detail="Error retrieving subscription details")
@router.post("/cancel-subscription")
async def cancel_subscription(
current_user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Cancel subscription with yearly commitment handling."""
try:
# Get user's current subscription
subscription = await get_user_subscription(current_user_id)
if not subscription:
raise HTTPException(status_code=404, detail="No active subscription found")
subscription_id = subscription['id']
# Check commitment status
commitment_info = await check_subscription_commitment(subscription_id)
# If subscription has yearly commitment and still in commitment period
if commitment_info.get('has_commitment') and not commitment_info.get('can_cancel'):
# Schedule cancellation at the end of the commitment period (1 year anniversary)
commitment_end_date = datetime.fromisoformat(commitment_info.get('commitment_end_date').replace('Z', '+00:00'))
cancel_at_timestamp = int(commitment_end_date.timestamp())
# Update subscription to cancel at the commitment end date
updated_subscription = await stripe.Subscription.modify_async(
subscription_id,
cancel_at=cancel_at_timestamp,
metadata={
**subscription.get('metadata', {}),
'cancelled_by_user': 'true',
'cancellation_date': str(int(datetime.now(timezone.utc).timestamp())),
'scheduled_cancel_at_commitment_end': 'true'
}
)
logger.debug(f"Subscription {subscription_id} scheduled for cancellation at commitment end: {commitment_end_date}")
return {
"success": True,
"status": "scheduled_for_commitment_end",
"message": f"Subscription will be cancelled at the end of your yearly commitment period. {commitment_info.get('months_remaining', 0)} months remaining.",
"details": {
"subscription_id": subscription_id,
"cancellation_effective_date": commitment_end_date.isoformat(),
"months_remaining": commitment_info.get('months_remaining', 0),
"access_until": commitment_end_date.strftime("%B %d, %Y"),
"commitment_end_date": commitment_info.get('commitment_end_date')
}
}
# For non-commitment subscriptions or commitment period has ended, cancel at period end
updated_subscription = await stripe.Subscription.modify_async(
subscription_id,
cancel_at_period_end=True,
metadata={
**subscription.get('metadata', {}),
'cancelled_by_user': 'true',
'cancellation_date': str(int(datetime.now(timezone.utc).timestamp()))
}
)
logger.debug(f"Subscription {subscription_id} marked for cancellation at period end")
# Calculate when the subscription will actually end
current_period_end = updated_subscription.current_period_end or subscription.get('current_period_end')
# If still no period end, fetch fresh subscription data from Stripe
if not current_period_end:
logger.warning(f"No current_period_end found in cached data for subscription {subscription_id}, fetching fresh data from Stripe")
try:
fresh_subscription = await stripe.Subscription.retrieve_async(subscription_id)
current_period_end = fresh_subscription.current_period_end
except Exception as fetch_error:
logger.error(f"Failed to fetch fresh subscription data: {fetch_error}")
if not current_period_end:
logger.error(f"No current_period_end found in subscription {subscription_id} even after fresh fetch")
raise HTTPException(status_code=500, detail="Unable to determine subscription period end")
period_end_date = datetime.fromtimestamp(current_period_end, timezone.utc)
return {
"success": True,
"status": "cancelled_at_period_end",
"message": "Subscription will be cancelled at the end of your current billing period.",
"details": {
"subscription_id": subscription_id,
"cancellation_effective_date": period_end_date.isoformat(),
"current_period_end": current_period_end,
"access_until": period_end_date.strftime("%B %d, %Y")
}
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error cancelling subscription: {str(e)}")
raise HTTPException(status_code=500, detail="Error processing cancellation request")
@router.post("/reactivate-subscription")
async def reactivate_subscription(
current_user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Reactivate a subscription that was marked for cancellation."""
try:
# Get user's current subscription
subscription = await get_user_subscription(current_user_id)
if not subscription:
raise HTTPException(status_code=404, detail="No subscription found")
subscription_id = subscription['id']
# Check if subscription is marked for cancellation (either cancel_at_period_end or cancel_at)
is_cancelled = subscription.get('cancel_at_period_end') or subscription.get('cancel_at')
if not is_cancelled:
return {
"success": False,
"status": "not_cancelled",
"message": "Subscription is not marked for cancellation."
}
# Prepare the modification parameters
modify_params = {
'cancel_at_period_end': False,
'metadata': {
**subscription.get('metadata', {}),
'reactivated_by_user': 'true',
'reactivation_date': str(int(datetime.now(timezone.utc).timestamp()))
}
}
# If subscription has cancel_at set (yearly commitment), clear it
if subscription.get('cancel_at'):
modify_params['cancel_at'] = None
# Reactivate the subscription
updated_subscription = await stripe.Subscription.modify_async(
subscription_id,
**modify_params
)
logger.debug(f"Subscription {subscription_id} reactivated by user")
# Get the current period end safely
current_period_end = updated_subscription.current_period_end or subscription.get('current_period_end')
# If still no period end, fetch fresh subscription data from Stripe
if not current_period_end:
logger.warning(f"No current_period_end found in cached data for subscription {subscription_id}, fetching fresh data from Stripe")
try:
fresh_subscription = await stripe.Subscription.retrieve_async(subscription_id)
current_period_end = fresh_subscription.current_period_end
except Exception as fetch_error:
logger.error(f"Failed to fetch fresh subscription data: {fetch_error}")
if not current_period_end:
logger.error(f"No current_period_end found in subscription {subscription_id} even after fresh fetch")
raise HTTPException(status_code=500, detail="Unable to determine subscription period end")
return {
"success": True,
"status": "reactivated",
"message": "Subscription has been reactivated and will continue billing normally.",
"details": {
"subscription_id": subscription_id,
"next_billing_date": datetime.fromtimestamp(
current_period_end,
timezone.utc
).strftime("%B %d, %Y")
}
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error reactivating subscription: {str(e)}")
raise HTTPException(status_code=500, detail="Error processing reactivation request")
@router.post("/purchase-credits")
async def purchase_credits(
request: PurchaseCreditsRequest,
current_user_id: str = Depends(get_current_user_id_from_jwt)
):
"""
Create a Stripe checkout session for purchasing credits.
Only available for users on the highest subscription tier.
"""
try:
# Check if user is on the highest tier
is_highest_tier = await is_user_on_highest_tier(current_user_id)
if not is_highest_tier:
raise HTTPException(
status_code=403,
detail="Credit purchases are only available for users on the highest subscription tier ($1000/month)."
)
# Validate amount
if request.amount_dollars < 10:
raise HTTPException(status_code=400, detail="Minimum credit purchase is $10")
if request.amount_dollars > 5000:
raise HTTPException(status_code=400, detail="Maximum credit purchase is $5000")
# Get Supabase client
db = DBConnection()
client = await db.client
# Get user email
user_result = await client.auth.admin.get_user_by_id(current_user_id)
if not user_result:
raise HTTPException(status_code=404, detail="User not found")
email = user_result.user.email
# Get or create Stripe customer
customer_id = await get_stripe_customer_id(client, current_user_id)
if not customer_id:
customer_id = await create_stripe_customer(client, current_user_id, email)
# Check if we have a pre-configured price ID for this amount
matching_package = None
for package_key, package_info in CREDIT_PACKAGES.items():
if package_info['amount'] == request.amount_dollars and package_info.get('stripe_price_id'):
matching_package = package_info
break
# Create a checkout session
if matching_package and matching_package['stripe_price_id']:
# Use pre-configured price ID
session = await stripe.checkout.Session.create_async(
customer=customer_id,
payment_method_types=['card'],
line_items=[{
'price': matching_package['stripe_price_id'],
'quantity': 1,
}],
mode='payment',
success_url=request.success_url,
cancel_url=request.cancel_url,
metadata={
'user_id': current_user_id,
'credit_amount': str(request.amount_dollars),
'type': 'credit_purchase'
}
)
else:
session = await stripe.checkout.Session.create_async(
customer=customer_id,
payment_method_types=['card'],
line_items=[{
'price_data': {
'currency': 'usd',
'product_data': {
'name': f'Suna AI Credits',
'description': f'${request.amount_dollars:.2f} in usage credits for Suna AI',
},
'unit_amount': int(request.amount_dollars * 100),
},
'quantity': 1,
}],
mode='payment',
success_url=request.success_url,
cancel_url=request.cancel_url,
metadata={
'user_id': current_user_id,
'credit_amount': str(request.amount_dollars),
'type': 'credit_purchase'
}
)
# Record the pending purchase in database
purchase_record = await client.table('credit_purchases').insert({
'user_id': current_user_id,
'amount_dollars': request.amount_dollars,
'status': 'pending',
'stripe_payment_intent_id': session.payment_intent,
'description': f'Credit purchase via Stripe Checkout',
'metadata': {
'session_id': session.id,
'checkout_url': session.url,
'success_url': request.success_url,
'cancel_url': request.cancel_url
}
}).execute()
return {
"session_id": session.id,
"url": session.url,
"purchase_id": purchase_record.data[0]['id'] if purchase_record.data else None
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error creating credit purchase session: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error creating checkout session: {str(e)}")
@router.get("/credit-balance")
async def get_credit_balance_endpoint(
current_user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Get the current credit balance for the user."""
try:
db = DBConnection()
client = await db.client
balance = await get_user_credit_balance(client, current_user_id)
return balance
except Exception as e:
logger.error(f"Error getting credit balance: {str(e)}")
raise HTTPException(status_code=500, detail="Error retrieving credit balance")
@router.get("/credit-history")
async def get_credit_history(
page: int = 0,
items_per_page: int = 50,
current_user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Get credit purchase and usage history for the user."""
try:
db = DBConnection()
client = await db.client
# Get purchases
purchases_result = await client.table('credit_purchases') \
.select('*') \
.eq('user_id', current_user_id) \
.eq('status', 'completed') \
.order('created_at', desc=True) \
.range(page * items_per_page, (page + 1) * items_per_page - 1) \
.execute()
# Get usage
usage_result = await client.table('credit_usage') \
.select('*') \
.eq('user_id', current_user_id) \
.order('created_at', desc=True) \
.range(page * items_per_page, (page + 1) * items_per_page - 1) \
.execute()
# Format response
purchases = [
CreditPurchase(
id=p['id'],
amount_dollars=float(p['amount_dollars']),
status=p['status'],
created_at=p['created_at'],
completed_at=p.get('completed_at'),
stripe_payment_intent_id=p.get('stripe_payment_intent_id')
)
for p in (purchases_result.data or [])
]
usage = [
CreditUsage(
id=u['id'],
amount_dollars=float(u['amount_dollars']),
description=u.get('description'),
created_at=u['created_at'],
thread_id=u.get('thread_id'),
message_id=u.get('message_id')
)
for u in (usage_result.data or [])
]
return {
"purchases": purchases,
"usage": usage,
"page": page,
"has_more": len(purchases) == items_per_page or len(usage) == items_per_page
}
except Exception as e:
logger.error(f"Error getting credit history: {str(e)}")
raise HTTPException(status_code=500, detail="Error retrieving credit history")
@router.get("/can-purchase-credits")
async def can_purchase_credits(
current_user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Check if the current user can purchase credits (must be on highest tier)."""
try:
is_highest_tier = await is_user_on_highest_tier(current_user_id)
return {
"can_purchase": is_highest_tier,
"reason": "Credit purchases are available" if is_highest_tier else "Must be on the highest subscription tier ($1000/month) to purchase credits"
}
except Exception as e:
logger.error(f"Error checking credit purchase eligibility: {str(e)}")
raise HTTPException(status_code=500, detail="Error checking eligibility")