suna/backend/agent/utils.py

362 lines
16 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
from typing import Optional, List, Dict, Any
from datetime import datetime, timezone, timedelta
from utils.cache import Cache
from utils.logger import logger
from utils.config import config
from services import redis
from run_agent_background import update_agent_run_status
async def _cleanup_redis_response_list(agent_run_id: str):
try:
response_list_key = f"agent_run:{agent_run_id}:responses"
await redis.delete(response_list_key)
logger.debug(f"Cleaned up Redis response list for agent run {agent_run_id}")
except Exception as e:
logger.warning(f"Failed to clean up Redis response list for {agent_run_id}: {str(e)}")
async def check_for_active_project_agent_run(client, project_id: str):
project_threads = await client.table('threads').select('thread_id').eq('project_id', project_id).execute()
project_thread_ids = [t['thread_id'] for t in project_threads.data]
if project_thread_ids:
# Handle large numbers of threads by batching to avoid URI length limits
if len(project_thread_ids) > 100:
# Process in batches to avoid URI too large errors
batch_size = 100
for i in range(0, len(project_thread_ids), batch_size):
batch_thread_ids = project_thread_ids[i:i + batch_size]
active_runs = await client.table('agent_runs').select('id').in_('thread_id', batch_thread_ids).eq('status', 'running').execute()
if active_runs.data and len(active_runs.data) > 0:
return active_runs.data[0]['id']
else:
active_runs = await client.table('agent_runs').select('id').in_('thread_id', project_thread_ids).eq('status', 'running').execute()
if active_runs.data and len(active_runs.data) > 0:
return active_runs.data[0]['id']
return None
async def stop_agent_run(db, agent_run_id: str, error_message: Optional[str] = None):
logger.debug(f"Stopping agent run: {agent_run_id}")
client = await db.client
final_status = "failed" if error_message else "stopped"
response_list_key = f"agent_run:{agent_run_id}:responses"
all_responses = []
try:
all_responses_json = await redis.lrange(response_list_key, 0, -1)
all_responses = [json.loads(r) for r in all_responses_json]
logger.debug(f"Fetched {len(all_responses)} responses from Redis for DB update on stop/fail: {agent_run_id}")
except Exception as e:
logger.error(f"Failed to fetch responses from Redis for {agent_run_id} during stop/fail: {e}")
update_success = await update_agent_run_status(
client, agent_run_id, final_status, error=error_message, responses=all_responses
)
if not update_success:
logger.error(f"Failed to update database status for stopped/failed run {agent_run_id}")
global_control_channel = f"agent_run:{agent_run_id}:control"
try:
await redis.publish(global_control_channel, "STOP")
logger.debug(f"Published STOP signal to global channel {global_control_channel}")
except Exception as e:
logger.error(f"Failed to publish STOP signal to global channel {global_control_channel}: {str(e)}")
try:
instance_keys = await redis.keys(f"active_run:*:{agent_run_id}")
logger.debug(f"Found {len(instance_keys)} active instance keys for agent run {agent_run_id}")
for key in instance_keys:
parts = key.split(":")
if len(parts) == 3:
instance_id_from_key = parts[1]
instance_control_channel = f"agent_run:{agent_run_id}:control:{instance_id_from_key}"
try:
await redis.publish(instance_control_channel, "STOP")
logger.debug(f"Published STOP signal to instance channel {instance_control_channel}")
except Exception as e:
logger.warning(f"Failed to publish STOP signal to instance channel {instance_control_channel}: {str(e)}")
else:
logger.warning(f"Unexpected key format found: {key}")
await _cleanup_redis_response_list(agent_run_id)
except Exception as e:
logger.error(f"Failed to find or signal active instances for {agent_run_id}: {str(e)}")
logger.debug(f"Successfully initiated stop process for agent run: {agent_run_id}")
async def check_agent_run_limit(client, account_id: str) -> Dict[str, Any]:
"""
Check if the account has reached the limit of 3 parallel agent runs within the past 24 hours.
Returns:
Dict with 'can_start' (bool), 'running_count' (int), 'running_thread_ids' (list)
"""
try:
result = await Cache.get(f"agent_run_limit:{account_id}")
if result:
return result
# Calculate 24 hours ago
twenty_four_hours_ago = datetime.now(timezone.utc) - timedelta(hours=24)
twenty_four_hours_ago_iso = twenty_four_hours_ago.isoformat()
logger.debug(f"Checking agent run limit for account {account_id} since {twenty_four_hours_ago_iso}")
# Get all threads for this account
threads_result = await client.table('threads').select('thread_id').eq('account_id', account_id).execute()
if not threads_result.data:
logger.debug(f"No threads found for account {account_id}")
return {
'can_start': True,
'running_count': 0,
'running_thread_ids': []
}
thread_ids = [thread['thread_id'] for thread in threads_result.data]
logger.debug(f"Found {len(thread_ids)} threads for account {account_id}")
# Query for running agent runs within the past 24 hours for these threads
# Handle large numbers of threads by batching to avoid URI length limits
if len(thread_ids) > 100:
# Process in batches to avoid URI too large errors
all_running_runs = []
batch_size = 100
for i in range(0, len(thread_ids), batch_size):
batch_thread_ids = thread_ids[i:i + batch_size]
batch_result = await client.table('agent_runs').select('id', 'thread_id', 'started_at').in_('thread_id', batch_thread_ids).eq('status', 'running').gte('started_at', twenty_four_hours_ago_iso).execute()
if batch_result.data:
all_running_runs.extend(batch_result.data)
# Create a mock result object similar to what Supabase returns
class MockResult:
def __init__(self, data):
self.data = data
running_runs_result = MockResult(all_running_runs)
else:
running_runs_result = await client.table('agent_runs').select('id', 'thread_id', 'started_at').in_('thread_id', thread_ids).eq('status', 'running').gte('started_at', twenty_four_hours_ago_iso).execute()
running_runs = running_runs_result.data or []
running_count = len(running_runs)
running_thread_ids = [run['thread_id'] for run in running_runs]
logger.debug(f"Account {account_id} has {running_count} running agent runs in the past 24 hours")
result = {
'can_start': running_count < config.MAX_PARALLEL_AGENT_RUNS,
'running_count': running_count,
'running_thread_ids': running_thread_ids
}
await Cache.set(f"agent_run_limit:{account_id}", result)
return result
except Exception as e:
logger.error(f"Error checking agent run limit for account {account_id}: {str(e)}")
# In case of error, allow the run to proceed but log the error
return {
'can_start': True,
'running_count': 0,
'running_thread_ids': []
}
async def check_agent_count_limit(client, account_id: str) -> Dict[str, Any]:
try:
# In local mode, allow practically unlimited custom agents
if config.ENV_MODE.value == "local":
return {
'can_create': True,
'current_count': 0, # Return 0 to avoid showing any limit warnings
'limit': 999999, # Practically unlimited
'tier_name': 'local'
}
try:
result = await Cache.get(f"agent_count_limit:{account_id}")
if result:
logger.debug(f"Cache hit for agent count limit: {account_id}")
return result
except Exception as cache_error:
logger.warning(f"Cache read failed for agent count limit {account_id}: {str(cache_error)}")
agents_result = await client.table('agents').select('agent_id, metadata').eq('account_id', account_id).execute()
non_suna_agents = []
for agent in agents_result.data or []:
metadata = agent.get('metadata', {}) or {}
is_suna_default = metadata.get('is_suna_default', False)
if not is_suna_default:
non_suna_agents.append(agent)
current_count = len(non_suna_agents)
logger.debug(f"Account {account_id} has {current_count} custom agents (excluding Suna defaults)")
try:
from services.billing import get_subscription_tier
tier_name = await get_subscription_tier(client, account_id)
logger.debug(f"Account {account_id} subscription tier: {tier_name}")
except Exception as billing_error:
logger.warning(f"Could not get subscription tier for {account_id}: {str(billing_error)}, defaulting to free")
tier_name = 'free'
agent_limit = config.AGENT_LIMITS.get(tier_name, config.AGENT_LIMITS['free'])
can_create = current_count < agent_limit
result = {
'can_create': can_create,
'current_count': current_count,
'limit': agent_limit,
'tier_name': tier_name
}
try:
await Cache.set(f"agent_count_limit:{account_id}", result, ttl=300)
except Exception as cache_error:
logger.warning(f"Cache write failed for agent count limit {account_id}: {str(cache_error)}")
logger.debug(f"Account {account_id} has {current_count}/{agent_limit} agents (tier: {tier_name}) - can_create: {can_create}")
return result
except Exception as e:
logger.error(f"Error checking agent count limit for account {account_id}: {str(e)}", exc_info=True)
return {
'can_create': True,
'current_count': 0,
'limit': config.AGENT_LIMITS['free'],
'tier_name': 'free'
}
if __name__ == "__main__":
import asyncio
import sys
import os
# Add the backend directory to the Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from services.supabase import DBConnection
from utils.logger import logger
async def test_large_thread_count():
"""Test the functions with a large number of threads to verify URI limit fixes."""
print("🧪 Testing URI limit fixes with large thread counts...")
try:
# Initialize database connection
db = DBConnection()
client = await db.client
# Test user ID (replace with actual user ID that has many threads)
test_user_id = "2558d81e-5008-46d6-b7d3-8cc62d44e4f6" # The user from the error logs
print(f"📊 Testing with user ID: {test_user_id}")
# Test 1: check_agent_run_limit with many threads
print("\n1⃣ Testing check_agent_run_limit...")
try:
result = await check_agent_run_limit(client, test_user_id)
print(f"✅ check_agent_run_limit succeeded:")
print(f" - Can start: {result['can_start']}")
print(f" - Running count: {result['running_count']}")
print(f" - Running thread IDs: {len(result['running_thread_ids'])} threads")
except Exception as e:
print(f"❌ check_agent_run_limit failed: {str(e)}")
# Test 2: Get a project ID to test check_for_active_project_agent_run
print("\n2⃣ Testing check_for_active_project_agent_run...")
try:
# Get a project for this user
projects_result = await client.table('projects').select('project_id').eq('account_id', test_user_id).limit(1).execute()
if projects_result.data and len(projects_result.data) > 0:
test_project_id = projects_result.data[0]['project_id']
print(f" Using project ID: {test_project_id}")
result = await check_for_active_project_agent_run(client, test_project_id)
print(f"✅ check_for_active_project_agent_run succeeded:")
print(f" - Active run ID: {result}")
else:
print(" ⚠️ No projects found for user, skipping this test")
except Exception as e:
print(f"❌ check_for_active_project_agent_run failed: {str(e)}")
# Test 3: check_agent_count_limit (doesn't have URI issues but good to test)
print("\n3⃣ Testing check_agent_count_limit...")
try:
result = await check_agent_count_limit(client, test_user_id)
print(f"✅ check_agent_count_limit succeeded:")
print(f" - Can create: {result['can_create']}")
print(f" - Current count: {result['current_count']}")
print(f" - Limit: {result['limit']}")
print(f" - Tier: {result['tier_name']}")
except Exception as e:
print(f"❌ check_agent_count_limit failed: {str(e)}")
print("\n🎉 All agent utils tests completed!")
except Exception as e:
print(f"❌ Test setup failed: {str(e)}")
import traceback
traceback.print_exc()
async def test_billing_integration():
"""Test the billing integration to make sure it works with the fixed functions."""
print("\n💰 Testing billing integration...")
try:
from services.billing import calculate_monthly_usage, get_usage_logs
db = DBConnection()
client = await db.client
test_user_id = "2558d81e-5008-46d6-b7d3-8cc62d44e4f6"
print(f"📊 Testing billing functions with user: {test_user_id}")
# Test calculate_monthly_usage (which uses get_usage_logs internally)
print("\n1⃣ Testing calculate_monthly_usage...")
try:
usage = await calculate_monthly_usage(client, test_user_id)
print(f"✅ calculate_monthly_usage succeeded: ${usage:.4f}")
except Exception as e:
print(f"❌ calculate_monthly_usage failed: {str(e)}")
# Test get_usage_logs directly with pagination
print("\n2⃣ Testing get_usage_logs with pagination...")
try:
logs = await get_usage_logs(client, test_user_id, page=0, items_per_page=10)
print(f"✅ get_usage_logs succeeded:")
print(f" - Found {len(logs.get('logs', []))} log entries")
print(f" - Has more: {logs.get('has_more', False)}")
print(f" - Subscription limit: ${logs.get('subscription_limit', 0)}")
except Exception as e:
print(f"❌ get_usage_logs failed: {str(e)}")
except ImportError as e:
print(f"⚠️ Could not import billing functions: {str(e)}")
except Exception as e:
print(f"❌ Billing test failed: {str(e)}")
async def main():
"""Main test function."""
print("🚀 Starting URI limit fix tests...\n")
await test_large_thread_count()
await test_billing_integration()
print("\n✨ Test suite completed!")
# Run the tests
asyncio.run(main())