suna/backend/core/credits.py

300 lines
12 KiB
Python

from decimal import Decimal
from datetime import datetime, timezone, timedelta
from typing import Optional, Dict, List, Any
from core.services.supabase import DBConnection
from core.utils.logger import logger
from core.utils.cache import Cache
from core.utils.config import config, EnvMode
from core.billing.config import FREE_TIER_INITIAL_CREDITS, TRIAL_ENABLED
class CreditService:
def __init__(self):
self.db = DBConnection()
self.cache = Cache
self._client = None
async def _get_client(self):
if self._client is None:
await self.db.initialize()
self._client = await self.db.client
return self._client
async def get_balance(self, user_id: str, use_cache: bool = True) -> Decimal:
cache_key = f"credit_balance:{user_id}"
if use_cache and self.cache:
cached = await self.cache.get(cache_key)
if cached:
return Decimal(cached)
try:
client = await self._get_client()
result = await client.from_('credit_accounts').select('balance').eq('account_id', user_id).execute()
except Exception as e:
logger.error(f"Error fetching balance for user {user_id}: {e}")
raise
if result.data and len(result.data) > 0:
balance = Decimal(str(result.data[0]['balance']))
else:
if config.ENV_MODE == EnvMode.LOCAL:
logger.info(f"LOCAL mode: Creating user {user_id} with tier='none' (no free tier in local mode)")
account_data = {
'account_id': user_id,
'balance': '0',
'tier': 'none',
'trial_status': 'none'
}
try:
await client.from_('credit_accounts').insert(account_data).execute()
logger.info(f"Successfully created tier='none' account for user {user_id} in LOCAL mode")
except Exception as e:
logger.error(f"Failed to create account for user {user_id}: {e}")
raise
balance = Decimal('0')
await client.from_('credit_ledger').insert({
'account_id': user_id,
'amount': '0',
'type': 'initial',
'description': 'Account created - no free tier in local mode',
'balance_after': '0'
}).execute()
else:
trial_mode = TRIAL_ENABLED
logger.info(f"Creating new user {user_id} with free tier (trial migration will handle conversion)")
if trial_mode == TRIAL_ENABLED:
account_data = {
'account_id': user_id,
'balance': str(FREE_TIER_INITIAL_CREDITS),
'tier': 'free'
}
try:
logger.info(f"Creating FREE TIER account for new user {user_id}")
try:
test_data = {**account_data, 'last_grant_date': datetime.now(timezone.utc).isoformat()}
await client.from_('credit_accounts').insert(test_data).execute()
logger.info(f"Successfully created FREE TIER account for user {user_id}")
except Exception as e1:
logger.warning(f"Creating account without last_grant_date: {e1}")
await client.from_('credit_accounts').insert(account_data).execute()
logger.info(f"Successfully created minimal FREE TIER account for user {user_id}")
except Exception as e:
logger.error(f"Failed to create FREE TIER account for user {user_id}: {e}")
raise
balance = FREE_TIER_INITIAL_CREDITS
await client.from_('credit_ledger').insert({
'account_id': user_id,
'amount': str(FREE_TIER_INITIAL_CREDITS),
'type': 'tier_grant',
'description': 'Welcome to Suna! Free tier initial credits',
'balance_after': str(FREE_TIER_INITIAL_CREDITS)
}).execute()
else:
account_data = {
'account_id': user_id,
'balance': '0',
'tier': 'free'
}
try:
logger.info(f"Creating TRIAL PENDING account for new user {user_id}")
await client.from_('credit_accounts').insert(account_data).execute()
logger.info(f"Successfully created TRIAL PENDING account for user {user_id}")
except Exception as e:
logger.error(f"Failed to create TRIAL PENDING account for user {user_id}: {e}")
raise
balance = Decimal('0')
if self.cache:
await self.cache.set(cache_key, str(balance), ttl=300)
return balance
async def deduct_credits(self, user_id: str, amount: Decimal, description: str = None, reference_id: str = None, reference_type: str = None) -> Dict:
try:
client = await self._get_client()
result = await client.rpc('deduct_credits', {
'p_user_id': user_id,
'p_amount': str(amount),
'p_description': description or f'Credit usage: {amount}',
'p_reference_id': reference_id,
'p_reference_type': reference_type
}).execute()
if self.cache:
await self.cache.invalidate(f"credit_balance:{user_id}")
if result.data and len(result.data) > 0:
row = result.data[0]
success = row.get('success', False)
new_balance = Decimal(str(row.get('new_balance', 0)))
transaction_id = row.get('transaction_id')
if success:
return {
'success': True,
'new_balance': new_balance,
'transaction_id': transaction_id
}
else:
return {
'success': False,
'new_balance': new_balance,
'error': 'Insufficient credits'
}
else:
return {
'success': False,
'new_balance': await self.get_balance(user_id, use_cache=False),
'error': 'No result from deduct_credits'
}
except Exception as e:
logger.error(f"Failed to deduct credits: {e}", user_id=user_id, amount=str(amount))
return {
'success': False,
'error': str(e)
}
async def add_credits(
self,
user_id: str,
amount: Decimal,
type: str = 'admin_grant',
description: str = None,
metadata: Dict = None
) -> Decimal:
try:
client = await self._get_client()
result = await client.rpc('add_credits', {
'p_user_id': user_id,
'p_amount': str(amount),
'p_description': description or f'Credit added: {amount}'
}).execute()
if result.data:
new_balance = Decimal(str(result.data))
if self.cache:
await self.cache.invalidate(f"credit_balance:{user_id}")
logger.info(f"Added {amount} credits to user {user_id}. New balance: {new_balance}")
return new_balance
else:
raise Exception("Failed to add credits")
except Exception as e:
logger.error(f"Failed to add credits: {e}", user_id=user_id, amount=str(amount))
raise
async def grant_tier_credits(self, user_id: str, price_id: str, tier_name: str) -> bool:
try:
from billing.config import get_tier_by_price_id
tier = get_tier_by_price_id(price_id)
if not tier:
logger.error(f"Unknown price_id: {price_id}")
return False
amount = Decimal(str(tier['credits']))
client = await self._get_client()
account_result = await client.from_('credit_accounts').select('last_grant_date').eq('account_id', user_id).execute()
if account_result.data and len(account_result.data) > 0:
last_grant = account_result.data[0].get('last_grant_date')
if last_grant:
last_grant_date = datetime.fromisoformat(last_grant.replace('Z', '+00:00'))
if (datetime.now(timezone.utc) - last_grant_date) < timedelta(days=25):
logger.info(f"Credits already granted this month for user {user_id}")
return False
result = await client.rpc('grant_tier_credits', {
'p_user_id': user_id,
'p_amount': str(amount),
'p_tier': tier_name
}).execute()
if self.cache:
await self.cache.invalidate(f"credit_balance:{user_id}")
logger.info(f"Granted {amount} {tier_name} credits to user {user_id}")
return bool(result.data)
except Exception as e:
logger.error(f"Failed to grant tier credits: {e}", user_id=user_id)
return False
async def get_ledger(
self,
user_id: str,
limit: int = 50,
offset: int = 0
) -> List[Dict[str, Any]]:
client = await self._get_client()
result = await client.from_('credit_ledger')\
.select('*')\
.eq('account_id', user_id)\
.order('created_at', desc=True)\
.limit(limit)\
.offset(offset)\
.execute()
return result.data or []
async def get_account_summary(self, user_id: str) -> Dict[str, Any]:
client = await self._get_client()
account_result = await client.from_('credit_accounts')\
.select('*')\
.eq('account_id', user_id)\
.execute()
if not account_result.data or len(account_result.data) == 0:
await self.get_balance(user_id)
return {
'balance': str(FREE_TIER_INITIAL_CREDITS),
'tier': 'free',
'lifetime_granted': float(FREE_TIER_INITIAL_CREDITS),
'lifetime_purchased': 0,
'lifetime_used': 0,
'last_grant_date': datetime.now(timezone.utc).isoformat()
}
ledger_result = await client.from_('credit_ledger')\
.select('type, amount, description')\
.eq('account_id', user_id)\
.execute()
lifetime_granted = Decimal('0')
lifetime_purchased = Decimal('0')
lifetime_used = Decimal('0')
for entry in (ledger_result.data or []):
amount = Decimal(str(entry['amount']))
if entry['type'] in ['tier_grant', 'admin_grant', 'tier_upgrade']:
lifetime_granted += amount
elif entry['type'] == 'purchase':
lifetime_purchased += amount
elif entry['type'] == 'usage':
lifetime_used += abs(amount)
account = account_result.data[0]
return {
'balance': str(account['balance']),
'tier': account['tier'],
'lifetime_granted': float(lifetime_granted),
'lifetime_purchased': float(lifetime_purchased),
'lifetime_used': float(lifetime_used),
'last_grant_date': account.get('last_grant_date')
}
credit_service = CreditService()