mirror of https://github.com/kortix-ai/suna.git
112 lines
3.9 KiB
Python
112 lines
3.9 KiB
Python
from datetime import datetime, timezone
|
|
from typing import Dict, Optional, Tuple
|
|
from services.supabase import DBConnection
|
|
|
|
# Define subscription tiers and their monthly hour limits
|
|
SUBSCRIPTION_TIERS = {
|
|
'price_1RDQbOG6l1KZGqIrgrYzMbnL': {'name': 'free', 'hours': 100},
|
|
'price_1RC2PYG6l1KZGqIrpbzFB9Lp': {'name': 'base', 'hours': 100},
|
|
'price_1RDQWqG6l1KZGqIrChli4Ys4': {'name': 'extra', 'hours': 100}
|
|
}
|
|
|
|
async def get_account_subscription(client, account_id: str) -> Optional[Dict]:
|
|
"""Get the current subscription for an account."""
|
|
result = await client.schema('basejump').from_('billing_subscriptions') \
|
|
.select('*') \
|
|
.eq('account_id', account_id) \
|
|
.eq('status', 'active') \
|
|
.order('created', desc=True) \
|
|
.limit(1) \
|
|
.execute()
|
|
|
|
if result.data and len(result.data) > 0:
|
|
return result.data[0]
|
|
return None
|
|
|
|
async def calculate_monthly_usage(client, account_id: str) -> float:
|
|
"""Calculate total agent run hours for the current month for an account."""
|
|
# Get start of current month in UTC
|
|
now = datetime.now(timezone.utc)
|
|
start_of_month = datetime(now.year, now.month, 1, tzinfo=timezone.utc)
|
|
|
|
# First get all threads for this account
|
|
threads_result = await client.table('threads') \
|
|
.select('thread_id') \
|
|
.eq('account_id', account_id) \
|
|
.execute()
|
|
|
|
if not threads_result.data:
|
|
return 0.0
|
|
|
|
thread_ids = [t['thread_id'] for t in threads_result.data]
|
|
|
|
# Then get all agent runs for these threads in current month
|
|
runs_result = await client.table('agent_runs') \
|
|
.select('started_at, completed_at') \
|
|
.in_('thread_id', thread_ids) \
|
|
.gte('started_at', start_of_month.isoformat()) \
|
|
.execute()
|
|
|
|
if not runs_result.data:
|
|
return 0.0
|
|
|
|
# Calculate total hours
|
|
total_seconds = 0
|
|
now_ts = now.timestamp()
|
|
|
|
for run in runs_result.data:
|
|
start_time = datetime.fromisoformat(run['started_at'].replace('Z', '+00:00')).timestamp()
|
|
if run['completed_at']:
|
|
end_time = datetime.fromisoformat(run['completed_at'].replace('Z', '+00:00')).timestamp()
|
|
else:
|
|
# For running jobs, use current time
|
|
end_time = now_ts
|
|
|
|
total_seconds += (end_time - start_time)
|
|
|
|
return total_seconds / 3600 # Convert to hours
|
|
|
|
async def check_billing_status(client, account_id: str) -> Tuple[bool, str, Optional[Dict]]:
|
|
"""
|
|
Check if an account can run agents based on their subscription and usage.
|
|
|
|
Returns:
|
|
Tuple[bool, str, Optional[Dict]]: (can_run, message, subscription_info)
|
|
"""
|
|
# Get current subscription
|
|
subscription = await get_account_subscription(client, account_id)
|
|
|
|
# If no subscription, they can use free tier
|
|
if not subscription:
|
|
subscription = {
|
|
'price_id': 'price_1RDQbOG6l1KZGqIrgrYzMbnL', # Free tier
|
|
'plan_name': 'Free'
|
|
}
|
|
|
|
# Get tier info
|
|
tier_info = SUBSCRIPTION_TIERS.get(subscription['price_id'])
|
|
if not tier_info:
|
|
return False, "Invalid subscription tier", subscription
|
|
|
|
# Calculate current month's usage
|
|
current_usage = await calculate_monthly_usage(client, account_id)
|
|
|
|
# Check if within limits
|
|
if current_usage >= tier_info['hours']:
|
|
return False, f"Monthly limit of {tier_info['hours']} hours reached. Please upgrade your plan or wait until next month.", subscription
|
|
|
|
return True, "OK", subscription
|
|
|
|
# Helper function to get account ID from thread
|
|
async def get_account_id_from_thread(client, thread_id: str) -> Optional[str]:
|
|
"""Get the account ID associated with a thread."""
|
|
result = await client.table('threads') \
|
|
.select('account_id') \
|
|
.eq('thread_id', thread_id) \
|
|
.limit(1) \
|
|
.execute()
|
|
|
|
if result.data and len(result.data) > 0:
|
|
return result.data[0]['account_id']
|
|
return None
|