mirror of https://github.com/kortix-ai/suna.git
300 lines
12 KiB
Python
300 lines
12 KiB
Python
from decimal import Decimal
|
|
from datetime import datetime, timezone, timedelta
|
|
from typing import Optional, Dict, List, Any
|
|
from core.services.supabase import DBConnection
|
|
from core.utils.logger import logger
|
|
from core.utils.cache import Cache
|
|
from core.utils.config import config, EnvMode
|
|
from core.billing.config import FREE_TIER_INITIAL_CREDITS, TRIAL_ENABLED
|
|
|
|
class CreditService:
|
|
def __init__(self):
|
|
self.db = DBConnection()
|
|
self.cache = Cache
|
|
self._client = None
|
|
|
|
async def _get_client(self):
|
|
if self._client is None:
|
|
await self.db.initialize()
|
|
self._client = await self.db.client
|
|
return self._client
|
|
|
|
async def get_balance(self, user_id: str, use_cache: bool = True) -> Decimal:
|
|
cache_key = f"credit_balance:{user_id}"
|
|
|
|
if use_cache and self.cache:
|
|
cached = await self.cache.get(cache_key)
|
|
if cached:
|
|
return Decimal(cached)
|
|
|
|
try:
|
|
client = await self._get_client()
|
|
result = await client.from_('credit_accounts').select('balance').eq('account_id', user_id).execute()
|
|
except Exception as e:
|
|
logger.error(f"Error fetching balance for user {user_id}: {e}")
|
|
raise
|
|
|
|
if result.data and len(result.data) > 0:
|
|
balance = Decimal(str(result.data[0]['balance']))
|
|
else:
|
|
if config.ENV_MODE == EnvMode.LOCAL:
|
|
logger.info(f"LOCAL mode: Creating user {user_id} with tier='none' (no free tier in local mode)")
|
|
account_data = {
|
|
'account_id': user_id,
|
|
'balance': '0',
|
|
'tier': 'none',
|
|
'trial_status': 'none'
|
|
}
|
|
|
|
try:
|
|
await client.from_('credit_accounts').insert(account_data).execute()
|
|
logger.info(f"Successfully created tier='none' account for user {user_id} in LOCAL mode")
|
|
except Exception as e:
|
|
logger.error(f"Failed to create account for user {user_id}: {e}")
|
|
raise
|
|
|
|
balance = Decimal('0')
|
|
|
|
await client.from_('credit_ledger').insert({
|
|
'account_id': user_id,
|
|
'amount': '0',
|
|
'type': 'initial',
|
|
'description': 'Account created - no free tier in local mode',
|
|
'balance_after': '0'
|
|
}).execute()
|
|
else:
|
|
trial_mode = TRIAL_ENABLED
|
|
logger.info(f"Creating new user {user_id} with free tier (trial migration will handle conversion)")
|
|
|
|
if trial_mode == TRIAL_ENABLED:
|
|
account_data = {
|
|
'account_id': user_id,
|
|
'balance': str(FREE_TIER_INITIAL_CREDITS),
|
|
'tier': 'free'
|
|
}
|
|
|
|
try:
|
|
logger.info(f"Creating FREE TIER account for new user {user_id}")
|
|
|
|
try:
|
|
test_data = {**account_data, 'last_grant_date': datetime.now(timezone.utc).isoformat()}
|
|
await client.from_('credit_accounts').insert(test_data).execute()
|
|
logger.info(f"Successfully created FREE TIER account for user {user_id}")
|
|
except Exception as e1:
|
|
logger.warning(f"Creating account without last_grant_date: {e1}")
|
|
await client.from_('credit_accounts').insert(account_data).execute()
|
|
logger.info(f"Successfully created minimal FREE TIER account for user {user_id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to create FREE TIER account for user {user_id}: {e}")
|
|
raise
|
|
|
|
balance = FREE_TIER_INITIAL_CREDITS
|
|
|
|
await client.from_('credit_ledger').insert({
|
|
'account_id': user_id,
|
|
'amount': str(FREE_TIER_INITIAL_CREDITS),
|
|
'type': 'tier_grant',
|
|
'description': 'Welcome to Suna! Free tier initial credits',
|
|
'balance_after': str(FREE_TIER_INITIAL_CREDITS)
|
|
}).execute()
|
|
else:
|
|
account_data = {
|
|
'account_id': user_id,
|
|
'balance': '0',
|
|
'tier': 'free'
|
|
}
|
|
try:
|
|
logger.info(f"Creating TRIAL PENDING account for new user {user_id}")
|
|
await client.from_('credit_accounts').insert(account_data).execute()
|
|
logger.info(f"Successfully created TRIAL PENDING account for user {user_id}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to create TRIAL PENDING account for user {user_id}: {e}")
|
|
raise
|
|
|
|
balance = Decimal('0')
|
|
|
|
if self.cache:
|
|
await self.cache.set(cache_key, str(balance), ttl=300)
|
|
|
|
return balance
|
|
|
|
async def deduct_credits(self, user_id: str, amount: Decimal, description: str = None, reference_id: str = None, reference_type: str = None) -> Dict:
|
|
try:
|
|
client = await self._get_client()
|
|
result = await client.rpc('deduct_credits', {
|
|
'p_user_id': user_id,
|
|
'p_amount': str(amount),
|
|
'p_description': description or f'Credit usage: {amount}',
|
|
'p_reference_id': reference_id,
|
|
'p_reference_type': reference_type
|
|
}).execute()
|
|
|
|
if self.cache:
|
|
await self.cache.invalidate(f"credit_balance:{user_id}")
|
|
|
|
if result.data and len(result.data) > 0:
|
|
row = result.data[0]
|
|
success = row.get('success', False)
|
|
new_balance = Decimal(str(row.get('new_balance', 0)))
|
|
transaction_id = row.get('transaction_id')
|
|
|
|
if success:
|
|
return {
|
|
'success': True,
|
|
'new_balance': new_balance,
|
|
'transaction_id': transaction_id
|
|
}
|
|
else:
|
|
return {
|
|
'success': False,
|
|
'new_balance': new_balance,
|
|
'error': 'Insufficient credits'
|
|
}
|
|
else:
|
|
return {
|
|
'success': False,
|
|
'new_balance': await self.get_balance(user_id, use_cache=False),
|
|
'error': 'No result from deduct_credits'
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to deduct credits: {e}", user_id=user_id, amount=str(amount))
|
|
return {
|
|
'success': False,
|
|
'error': str(e)
|
|
}
|
|
|
|
async def add_credits(
|
|
self,
|
|
user_id: str,
|
|
amount: Decimal,
|
|
type: str = 'admin_grant',
|
|
description: str = None,
|
|
metadata: Dict = None
|
|
) -> Decimal:
|
|
try:
|
|
client = await self._get_client()
|
|
result = await client.rpc('add_credits', {
|
|
'p_user_id': user_id,
|
|
'p_amount': str(amount),
|
|
'p_description': description or f'Credit added: {amount}'
|
|
}).execute()
|
|
|
|
if result.data:
|
|
new_balance = Decimal(str(result.data))
|
|
|
|
if self.cache:
|
|
await self.cache.invalidate(f"credit_balance:{user_id}")
|
|
|
|
logger.info(f"Added {amount} credits to user {user_id}. New balance: {new_balance}")
|
|
return new_balance
|
|
else:
|
|
raise Exception("Failed to add credits")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to add credits: {e}", user_id=user_id, amount=str(amount))
|
|
raise
|
|
|
|
async def grant_tier_credits(self, user_id: str, price_id: str, tier_name: str) -> bool:
|
|
try:
|
|
from billing.config import get_tier_by_price_id
|
|
tier = get_tier_by_price_id(price_id)
|
|
|
|
if not tier:
|
|
logger.error(f"Unknown price_id: {price_id}")
|
|
return False
|
|
|
|
amount = Decimal(str(tier['credits']))
|
|
|
|
client = await self._get_client()
|
|
account_result = await client.from_('credit_accounts').select('last_grant_date').eq('account_id', user_id).execute()
|
|
|
|
if account_result.data and len(account_result.data) > 0:
|
|
last_grant = account_result.data[0].get('last_grant_date')
|
|
if last_grant:
|
|
last_grant_date = datetime.fromisoformat(last_grant.replace('Z', '+00:00'))
|
|
if (datetime.now(timezone.utc) - last_grant_date) < timedelta(days=25):
|
|
logger.info(f"Credits already granted this month for user {user_id}")
|
|
return False
|
|
|
|
result = await client.rpc('grant_tier_credits', {
|
|
'p_user_id': user_id,
|
|
'p_amount': str(amount),
|
|
'p_tier': tier_name
|
|
}).execute()
|
|
|
|
if self.cache:
|
|
await self.cache.invalidate(f"credit_balance:{user_id}")
|
|
|
|
logger.info(f"Granted {amount} {tier_name} credits to user {user_id}")
|
|
return bool(result.data)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to grant tier credits: {e}", user_id=user_id)
|
|
return False
|
|
|
|
async def get_ledger(
|
|
self,
|
|
user_id: str,
|
|
limit: int = 50,
|
|
offset: int = 0
|
|
) -> List[Dict[str, Any]]:
|
|
client = await self._get_client()
|
|
result = await client.from_('credit_ledger')\
|
|
.select('*')\
|
|
.eq('account_id', user_id)\
|
|
.order('created_at', desc=True)\
|
|
.limit(limit)\
|
|
.offset(offset)\
|
|
.execute()
|
|
|
|
return result.data or []
|
|
|
|
async def get_account_summary(self, user_id: str) -> Dict[str, Any]:
|
|
client = await self._get_client()
|
|
account_result = await client.from_('credit_accounts')\
|
|
.select('*')\
|
|
.eq('account_id', user_id)\
|
|
.execute()
|
|
|
|
if not account_result.data or len(account_result.data) == 0:
|
|
await self.get_balance(user_id)
|
|
return {
|
|
'balance': str(FREE_TIER_INITIAL_CREDITS),
|
|
'tier': 'free',
|
|
'lifetime_granted': float(FREE_TIER_INITIAL_CREDITS),
|
|
'lifetime_purchased': 0,
|
|
'lifetime_used': 0,
|
|
'last_grant_date': datetime.now(timezone.utc).isoformat()
|
|
}
|
|
|
|
ledger_result = await client.from_('credit_ledger')\
|
|
.select('type, amount, description')\
|
|
.eq('account_id', user_id)\
|
|
.execute()
|
|
|
|
lifetime_granted = Decimal('0')
|
|
lifetime_purchased = Decimal('0')
|
|
lifetime_used = Decimal('0')
|
|
|
|
for entry in (ledger_result.data or []):
|
|
amount = Decimal(str(entry['amount']))
|
|
if entry['type'] in ['tier_grant', 'admin_grant', 'tier_upgrade']:
|
|
lifetime_granted += amount
|
|
elif entry['type'] == 'purchase':
|
|
lifetime_purchased += amount
|
|
elif entry['type'] == 'usage':
|
|
lifetime_used += abs(amount)
|
|
|
|
account = account_result.data[0]
|
|
return {
|
|
'balance': str(account['balance']),
|
|
'tier': account['tier'],
|
|
'lifetime_granted': float(lifetime_granted),
|
|
'lifetime_purchased': float(lifetime_purchased),
|
|
'lifetime_used': float(lifetime_used),
|
|
'last_grant_date': account.get('last_grant_date')
|
|
}
|
|
|
|
credit_service = CreditService() |