suna/backend/core/utils/limits_checker.py

246 lines
9.8 KiB
Python

"""
Subscription limit checking utilities.
Handles checking various limits based on user tier:
- Agent run limits
- Agent count limits
- Project count limits
"""
from typing import Dict, Any
from datetime import datetime, timezone, timedelta
from core.utils.logger import logger
from core.utils.config import config
from core.utils.cache import Cache
async def check_agent_run_limit(client, account_id: str) -> Dict[str, Any]:
"""
Check if the account has reached the limit of parallel agent runs within the past 24 hours.
Args:
client: Database client
account_id: Account ID to check
Returns:
Dict with 'can_start' (bool), 'running_count' (int), 'running_thread_ids' (list)
Note: This function does not use caching to ensure real-time limit checks.
"""
try:
# Calculate 24 hours ago
twenty_four_hours_ago = datetime.now(timezone.utc) - timedelta(hours=24)
twenty_four_hours_ago_iso = twenty_four_hours_ago.isoformat()
logger.debug(f"Checking agent run limit for account {account_id} since {twenty_four_hours_ago_iso}")
# Get all threads for this account
threads_result = await client.table('threads').select('thread_id').eq('account_id', account_id).execute()
if not threads_result.data:
logger.debug(f"No threads found for account {account_id}")
return {
'can_start': True,
'running_count': 0,
'running_thread_ids': []
}
thread_ids = [thread['thread_id'] for thread in threads_result.data]
logger.debug(f"Found {len(thread_ids)} threads for account {account_id}")
# Query for running agent runs within the past 24 hours for these threads
from core.utils.query_utils import batch_query_in
running_runs = await batch_query_in(
client=client,
table_name='agent_runs',
select_fields='id, thread_id, started_at',
in_field='thread_id',
in_values=thread_ids,
additional_filters={
'status': 'running',
'started_at_gte': twenty_four_hours_ago_iso
}
)
running_count = len(running_runs)
running_thread_ids = [run['thread_id'] for run in running_runs]
logger.debug(f"Account {account_id} has {running_count} running agent runs in the past 24 hours")
result = {
'can_start': running_count < config.MAX_PARALLEL_AGENT_RUNS,
'running_count': running_count,
'running_thread_ids': running_thread_ids
}
return result
except Exception as e:
logger.error(f"Error checking agent run limit for account {account_id}: {str(e)}")
# In case of error, allow the run to proceed but log the error
return {
'can_start': True,
'running_count': 0,
'running_thread_ids': []
}
async def check_agent_count_limit(client, account_id: str) -> Dict[str, Any]:
"""
Check if a user can create more agents based on their subscription tier.
Args:
client: Database client
account_id: Account ID to check
Returns:
Dict containing:
- can_create: bool - whether user can create another agent
- current_count: int - current number of custom agents (excluding Suna defaults)
- limit: int - maximum agents allowed for this tier
- tier_name: str - subscription tier name
Note: This function does not use caching to ensure real-time agent counts.
"""
try:
# In local mode, allow practically unlimited custom agents
if config.ENV_MODE.value == "local":
return {
'can_create': True,
'current_count': 0, # Return 0 to avoid showing any limit warnings
'limit': 999999, # Practically unlimited
'tier_name': 'local'
}
# Always query fresh data from database to avoid stale cache issues
agents_result = await client.table('agents').select('agent_id, metadata').eq('account_id', account_id).execute()
non_suna_agents = []
for agent in agents_result.data or []:
metadata = agent.get('metadata', {}) or {}
is_suna_default = metadata.get('is_suna_default', False)
if not is_suna_default:
non_suna_agents.append(agent)
current_count = len(non_suna_agents)
logger.debug(f"Account {account_id} has {current_count} custom agents (excluding Suna defaults)")
try:
from core.billing import subscription_service
tier_info = await subscription_service.get_user_subscription_tier(account_id)
tier_name = tier_info['name']
logger.debug(f"Account {account_id} subscription tier: {tier_name}")
except Exception as billing_error:
logger.warning(f"Could not get subscription tier for {account_id}: {str(billing_error)}, defaulting to free")
tier_name = 'free'
agent_limit = config.AGENT_LIMITS.get(tier_name, config.AGENT_LIMITS['free'])
can_create = current_count < agent_limit
result = {
'can_create': can_create,
'current_count': current_count,
'limit': agent_limit,
'tier_name': tier_name
}
logger.debug(f"Account {account_id} has {current_count}/{agent_limit} agents (tier: {tier_name}) - can_create: {can_create}")
return result
except Exception as e:
logger.error(f"Error checking agent count limit for account {account_id}: {str(e)}", exc_info=True)
return {
'can_create': True,
'current_count': 0,
'limit': config.AGENT_LIMITS['free'],
'tier_name': 'free'
}
async def check_project_count_limit(client, account_id: str) -> Dict[str, Any]:
"""
Check if a user can create more projects based on their subscription tier.
Args:
client: Database client
account_id: Account ID to check
Returns:
Dict containing:
- can_create: bool - whether user can create another project
- current_count: int - current number of projects
- limit: int - maximum projects allowed for this tier
- tier_name: str - subscription tier name
Note: This function does not use caching to ensure real-time project counts,
preventing issues where deleted projects aren't immediately reflected in limits.
"""
try:
# In local mode, allow practically unlimited projects
if config.ENV_MODE.value == "local":
return {
'can_create': True,
'current_count': 0, # Return 0 to avoid showing any limit warnings
'limit': 999999, # Practically unlimited
'tier_name': 'local'
}
try:
result = await Cache.get(f"project_count_limit:{account_id}")
if result:
logger.debug(f"Cache hit for project count limit: {account_id}")
return result
except Exception as cache_error:
logger.warning(f"Cache read failed for project count limit {account_id}: {str(cache_error)}")
projects_result = await client.table('projects').select('project_id').eq('account_id', account_id).execute()
current_count = len(projects_result.data or [])
logger.debug(f"Account {account_id} has {current_count} projects (real-time count)")
try:
credit_result = await client.table('credit_accounts').select('tier').eq('account_id', account_id).single().execute()
tier_name = credit_result.data.get('tier', 'free') if credit_result.data else 'free'
logger.debug(f"Account {account_id} credit tier: {tier_name}")
except Exception as credit_error:
try:
logger.debug(f"Trying user_id fallback for account {account_id}")
credit_result = await client.table('credit_accounts').select('tier').eq('user_id', account_id).single().execute()
tier_name = credit_result.data.get('tier', 'free') if credit_result.data else 'free'
logger.debug(f"Account {account_id} credit tier (via fallback): {tier_name}")
except:
logger.debug(f"No credit account for {account_id}, defaulting to free tier")
tier_name = 'free'
from core.billing.config import get_project_limit
project_limit = get_project_limit(tier_name)
can_create = current_count < project_limit
result = {
'can_create': can_create,
'current_count': current_count,
'limit': project_limit,
'tier_name': tier_name
}
logger.debug(f"Account {account_id} has {current_count}/{project_limit} projects (tier: {tier_name}) - can_create: {can_create}")
# Cache for 1 minute - balance between staleness and DB load
try:
await Cache.set(f"project_count_limit:{account_id}", result, ttl=60)
except Exception as cache_error:
logger.warning(f"Cache write failed for project count limit {account_id}: {str(cache_error)}")
return result
except Exception as e:
logger.error(f"Error checking project count limit for account {account_id}: {str(e)}", exc_info=True)
from core.billing.config import get_project_limit
return {
'can_create': True,
'current_count': 0,
'limit': get_project_limit('free'),
'tier_name': 'free'
}