mirror of https://github.com/kortix-ai/suna.git
query util for batch in() calls
This commit is contained in:
parent
66d96d0ad3
commit
fbdc97c9fa
|
@ -1386,11 +1386,17 @@ async def get_agents(
|
|||
version_ids = list({agent['current_version_id'] for agent in agents_data if agent.get('current_version_id')})
|
||||
if version_ids:
|
||||
try:
|
||||
versions_result = await client.table('agent_versions').select(
|
||||
'version_id, agent_id, version_number, version_name, is_active, created_at, updated_at, created_by, config'
|
||||
).in_('version_id', version_ids).execute()
|
||||
from utils.query_utils import batch_query_in
|
||||
|
||||
for row in (versions_result.data or []):
|
||||
versions_data = await batch_query_in(
|
||||
client=client,
|
||||
table_name='agent_versions',
|
||||
select_fields='version_id, agent_id, version_number, version_name, is_active, created_at, updated_at, created_by, config',
|
||||
in_field='version_id',
|
||||
in_values=version_ids
|
||||
)
|
||||
|
||||
for row in versions_data:
|
||||
config = row.get('config') or {}
|
||||
tools = config.get('tools') or {}
|
||||
version_dict = {
|
||||
|
@ -2960,14 +2966,21 @@ async def get_user_threads(
|
|||
# Fetch projects if we have project IDs
|
||||
projects_by_id = {}
|
||||
if unique_project_ids:
|
||||
projects_result = await client.table('projects').select('*').in_('project_id', unique_project_ids).execute()
|
||||
from utils.query_utils import batch_query_in
|
||||
|
||||
if projects_result.data:
|
||||
logger.debug(f"[API] Raw projects from DB: {len(projects_result.data)}")
|
||||
projects_data = await batch_query_in(
|
||||
client=client,
|
||||
table_name='projects',
|
||||
select_fields='*',
|
||||
in_field='project_id',
|
||||
in_values=unique_project_ids
|
||||
)
|
||||
|
||||
logger.debug(f"[API] Retrieved {len(projects_data)} projects")
|
||||
# Create a lookup map of projects by ID
|
||||
projects_by_id = {
|
||||
project['project_id']: project
|
||||
for project in projects_result.data
|
||||
for project in projects_data
|
||||
}
|
||||
|
||||
# Map threads with their associated projects
|
||||
|
|
|
@ -22,19 +22,19 @@ async def check_for_active_project_agent_run(client, project_id: str):
|
|||
project_thread_ids = [t['thread_id'] for t in project_threads.data]
|
||||
|
||||
if project_thread_ids:
|
||||
# Handle large numbers of threads by batching to avoid URI length limits
|
||||
if len(project_thread_ids) > 100:
|
||||
# Process in batches to avoid URI too large errors
|
||||
batch_size = 100
|
||||
for i in range(0, len(project_thread_ids), batch_size):
|
||||
batch_thread_ids = project_thread_ids[i:i + batch_size]
|
||||
active_runs = await client.table('agent_runs').select('id').in_('thread_id', batch_thread_ids).eq('status', 'running').execute()
|
||||
if active_runs.data and len(active_runs.data) > 0:
|
||||
return active_runs.data[0]['id']
|
||||
else:
|
||||
active_runs = await client.table('agent_runs').select('id').in_('thread_id', project_thread_ids).eq('status', 'running').execute()
|
||||
if active_runs.data and len(active_runs.data) > 0:
|
||||
return active_runs.data[0]['id']
|
||||
from utils.query_utils import batch_query_in
|
||||
|
||||
active_runs = await batch_query_in(
|
||||
client=client,
|
||||
table_name='agent_runs',
|
||||
select_fields='id',
|
||||
in_field='thread_id',
|
||||
in_values=project_thread_ids,
|
||||
additional_filters={'status': 'running'}
|
||||
)
|
||||
|
||||
if active_runs:
|
||||
return active_runs[0]['id']
|
||||
return None
|
||||
|
||||
|
||||
|
@ -124,27 +124,20 @@ async def check_agent_run_limit(client, account_id: str) -> Dict[str, Any]:
|
|||
logger.debug(f"Found {len(thread_ids)} threads for account {account_id}")
|
||||
|
||||
# Query for running agent runs within the past 24 hours for these threads
|
||||
# Handle large numbers of threads by batching to avoid URI length limits
|
||||
if len(thread_ids) > 100:
|
||||
# Process in batches to avoid URI too large errors
|
||||
all_running_runs = []
|
||||
batch_size = 100
|
||||
for i in range(0, len(thread_ids), batch_size):
|
||||
batch_thread_ids = thread_ids[i:i + batch_size]
|
||||
batch_result = await client.table('agent_runs').select('id', 'thread_id', 'started_at').in_('thread_id', batch_thread_ids).eq('status', 'running').gte('started_at', twenty_four_hours_ago_iso).execute()
|
||||
if batch_result.data:
|
||||
all_running_runs.extend(batch_result.data)
|
||||
from utils.query_utils import batch_query_in
|
||||
|
||||
# Create a mock result object similar to what Supabase returns
|
||||
class MockResult:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
running_runs = await batch_query_in(
|
||||
client=client,
|
||||
table_name='agent_runs',
|
||||
select_fields='id, thread_id, started_at',
|
||||
in_field='thread_id',
|
||||
in_values=thread_ids,
|
||||
additional_filters={
|
||||
'status': 'running',
|
||||
'started_at_gte': twenty_four_hours_ago_iso
|
||||
}
|
||||
)
|
||||
|
||||
running_runs_result = MockResult(all_running_runs)
|
||||
else:
|
||||
running_runs_result = await client.table('agent_runs').select('id', 'thread_id', 'started_at').in_('thread_id', thread_ids).eq('status', 'running').gte('started_at', twenty_four_hours_ago_iso).execute()
|
||||
|
||||
running_runs = running_runs_result.data or []
|
||||
running_count = len(running_runs)
|
||||
running_thread_ids = [run['thread_id'] for run in running_runs]
|
||||
|
||||
|
@ -348,12 +341,82 @@ if __name__ == "__main__":
|
|||
except Exception as e:
|
||||
print(f"❌ Billing test failed: {str(e)}")
|
||||
|
||||
async def test_api_functions():
|
||||
"""Test the API functions that were also fixed for URI limits."""
|
||||
print("\n🔧 Testing API functions...")
|
||||
|
||||
try:
|
||||
# Import the API functions we fixed
|
||||
import sys
|
||||
sys.path.append('/app') # Add the app directory to path
|
||||
|
||||
db = DBConnection()
|
||||
client = await db.client
|
||||
|
||||
test_user_id = "2558d81e-5008-46d6-b7d3-8cc62d44e4f6"
|
||||
|
||||
print(f"📊 Testing API functions with user: {test_user_id}")
|
||||
|
||||
# Test 1: get_user_threads (which has the project batching fix)
|
||||
print("\n1️⃣ Testing get_user_threads simulation...")
|
||||
try:
|
||||
# Get threads for the user
|
||||
threads_result = await client.table('threads').select('*').eq('account_id', test_user_id).order('created_at', desc=True).execute()
|
||||
|
||||
if threads_result.data:
|
||||
print(f" - Found {len(threads_result.data)} threads")
|
||||
|
||||
# Extract unique project IDs (this is what could cause URI issues)
|
||||
project_ids = [
|
||||
thread['project_id'] for thread in threads_result.data[:1000] # Limit to first 1000
|
||||
if thread.get('project_id')
|
||||
]
|
||||
unique_project_ids = list(set(project_ids)) if project_ids else []
|
||||
|
||||
print(f" - Found {len(unique_project_ids)} unique project IDs")
|
||||
|
||||
if unique_project_ids:
|
||||
# Test the batching logic we implemented
|
||||
if len(unique_project_ids) > 100:
|
||||
print(f" - Would use batching for {len(unique_project_ids)} project IDs")
|
||||
else:
|
||||
print(f" - Would use direct query for {len(unique_project_ids)} project IDs")
|
||||
|
||||
# Actually test a small batch to verify it works
|
||||
test_batch = unique_project_ids[:min(10, len(unique_project_ids))]
|
||||
projects_result = await client.table('projects').select('*').in_('project_id', test_batch).execute()
|
||||
print(f"✅ Project query test succeeded: found {len(projects_result.data or [])} projects")
|
||||
else:
|
||||
print(" - No project IDs to test")
|
||||
else:
|
||||
print(" - No threads found for user")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ get_user_threads test failed: {str(e)}")
|
||||
|
||||
# Test 2: Template service simulation
|
||||
print("\n2️⃣ Testing template service simulation...")
|
||||
try:
|
||||
from templates.template_service import TemplateService
|
||||
|
||||
# This would test the creator ID batching, but we'll just verify the import works
|
||||
print("✅ Template service import succeeded")
|
||||
|
||||
except ImportError as e:
|
||||
print(f"⚠️ Could not import template service: {str(e)}")
|
||||
except Exception as e:
|
||||
print(f"❌ Template service test failed: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ API functions test failed: {str(e)}")
|
||||
|
||||
async def main():
|
||||
"""Main test function."""
|
||||
print("🚀 Starting URI limit fix tests...\n")
|
||||
|
||||
await test_large_thread_count()
|
||||
await test_billing_integration()
|
||||
await test_api_functions()
|
||||
|
||||
print("\n✨ Test suite completed!")
|
||||
|
||||
|
|
|
@ -312,11 +312,20 @@ class TemplateService:
|
|||
return []
|
||||
|
||||
creator_ids = list(set(template['creator_id'] for template in result.data))
|
||||
accounts_result = await client.schema('basejump').from_('accounts').select('id, name, slug').in_('id', creator_ids).execute()
|
||||
|
||||
from utils.query_utils import batch_query_in
|
||||
|
||||
accounts_data = await batch_query_in(
|
||||
client=client,
|
||||
table_name='accounts',
|
||||
select_fields='id, name, slug',
|
||||
in_field='id',
|
||||
in_values=creator_ids,
|
||||
schema='basejump'
|
||||
)
|
||||
|
||||
creator_names = {}
|
||||
if accounts_result.data:
|
||||
for account in accounts_result.data:
|
||||
for account in accounts_data:
|
||||
creator_names[account['id']] = account.get('name') or account.get('slug')
|
||||
|
||||
templates = []
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
"""
|
||||
Query utilities for handling large datasets and avoiding URI length limits.
|
||||
"""
|
||||
from typing import List, Any, Dict, Optional
|
||||
from utils.logger import logger
|
||||
|
||||
|
||||
async def batch_query_in(
|
||||
client,
|
||||
table_name: str,
|
||||
select_fields: str,
|
||||
in_field: str,
|
||||
in_values: List[Any],
|
||||
batch_size: int = 100,
|
||||
additional_filters: Optional[Dict[str, Any]] = None,
|
||||
schema: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Execute a query with .in_() filtering, automatically batching large arrays to avoid URI limits.
|
||||
|
||||
Args:
|
||||
client: Supabase client
|
||||
table_name: Name of the table to query
|
||||
select_fields: Fields to select (e.g., '*' or 'id, name, created_at')
|
||||
in_field: Field name for the .in_() filter
|
||||
in_values: List of values to filter by
|
||||
batch_size: Maximum number of values per batch (default: 100)
|
||||
additional_filters: Optional dict of additional filters to apply
|
||||
schema: Optional schema name (for basejump tables)
|
||||
|
||||
Returns:
|
||||
List of all matching records from all batches
|
||||
"""
|
||||
if not in_values:
|
||||
return []
|
||||
|
||||
all_results = []
|
||||
|
||||
# If values list is small, do a single query
|
||||
if len(in_values) <= batch_size:
|
||||
query = client.schema(schema).from_(table_name) if schema else client.table(table_name)
|
||||
query = query.select(select_fields).in_(in_field, in_values)
|
||||
|
||||
# Apply additional filters
|
||||
if additional_filters:
|
||||
for field, value in additional_filters.items():
|
||||
if field.endswith('_gte'):
|
||||
query = query.gte(field[:-4], value)
|
||||
elif field.endswith('_eq'):
|
||||
query = query.eq(field[:-3], value)
|
||||
else:
|
||||
query = query.eq(field, value)
|
||||
|
||||
result = await query.execute()
|
||||
return result.data or []
|
||||
|
||||
# Batch processing for large arrays
|
||||
logger.debug(f"Batching {len(in_values)} {in_field} values into chunks of {batch_size}")
|
||||
|
||||
for i in range(0, len(in_values), batch_size):
|
||||
batch_values = in_values[i:i + batch_size]
|
||||
|
||||
query = client.schema(schema).from_(table_name) if schema else client.table(table_name)
|
||||
query = query.select(select_fields).in_(in_field, batch_values)
|
||||
|
||||
# Apply additional filters
|
||||
if additional_filters:
|
||||
for field, value in additional_filters.items():
|
||||
if field.endswith('_gte'):
|
||||
query = query.gte(field[:-4], value)
|
||||
elif field.endswith('_eq'):
|
||||
query = query.eq(field[:-3], value)
|
||||
else:
|
||||
query = query.eq(field, value)
|
||||
|
||||
batch_result = await query.execute()
|
||||
if batch_result.data:
|
||||
all_results.extend(batch_result.data)
|
||||
|
||||
logger.debug(f"Batched query returned {len(all_results)} total results")
|
||||
return all_results
|
Loading…
Reference in New Issue