From fbdc97c9fac8caa0d98a4f06aa28de453532fa49 Mon Sep 17 00:00:00 2001 From: marko-kraemer Date: Fri, 22 Aug 2025 21:31:00 -0700 Subject: [PATCH] query util for batch in() calls --- backend/agent/api.py | 37 +++++--- backend/agent/utils.py | 129 +++++++++++++++++++------- backend/templates/template_service.py | 17 +++- backend/utils/query_utils.py | 81 ++++++++++++++++ 4 files changed, 215 insertions(+), 49 deletions(-) create mode 100644 backend/utils/query_utils.py diff --git a/backend/agent/api.py b/backend/agent/api.py index 1ab50641..bb204a8e 100644 --- a/backend/agent/api.py +++ b/backend/agent/api.py @@ -1386,11 +1386,17 @@ async def get_agents( version_ids = list({agent['current_version_id'] for agent in agents_data if agent.get('current_version_id')}) if version_ids: try: - versions_result = await client.table('agent_versions').select( - 'version_id, agent_id, version_number, version_name, is_active, created_at, updated_at, created_by, config' - ).in_('version_id', version_ids).execute() + from utils.query_utils import batch_query_in + + versions_data = await batch_query_in( + client=client, + table_name='agent_versions', + select_fields='version_id, agent_id, version_number, version_name, is_active, created_at, updated_at, created_by, config', + in_field='version_id', + in_values=version_ids + ) - for row in (versions_result.data or []): + for row in versions_data: config = row.get('config') or {} tools = config.get('tools') or {} version_dict = { @@ -2960,15 +2966,22 @@ async def get_user_threads( # Fetch projects if we have project IDs projects_by_id = {} if unique_project_ids: - projects_result = await client.table('projects').select('*').in_('project_id', unique_project_ids).execute() + from utils.query_utils import batch_query_in - if projects_result.data: - logger.debug(f"[API] Raw projects from DB: {len(projects_result.data)}") - # Create a lookup map of projects by ID - projects_by_id = { - project['project_id']: project - for project in projects_result.data - } + projects_data = await batch_query_in( + client=client, + table_name='projects', + select_fields='*', + in_field='project_id', + in_values=unique_project_ids + ) + + logger.debug(f"[API] Retrieved {len(projects_data)} projects") + # Create a lookup map of projects by ID + projects_by_id = { + project['project_id']: project + for project in projects_data + } # Map threads with their associated projects mapped_threads = [] diff --git a/backend/agent/utils.py b/backend/agent/utils.py index 40a4b8cf..846d3f13 100644 --- a/backend/agent/utils.py +++ b/backend/agent/utils.py @@ -22,19 +22,19 @@ async def check_for_active_project_agent_run(client, project_id: str): project_thread_ids = [t['thread_id'] for t in project_threads.data] if project_thread_ids: - # Handle large numbers of threads by batching to avoid URI length limits - if len(project_thread_ids) > 100: - # Process in batches to avoid URI too large errors - batch_size = 100 - for i in range(0, len(project_thread_ids), batch_size): - batch_thread_ids = project_thread_ids[i:i + batch_size] - active_runs = await client.table('agent_runs').select('id').in_('thread_id', batch_thread_ids).eq('status', 'running').execute() - if active_runs.data and len(active_runs.data) > 0: - return active_runs.data[0]['id'] - else: - active_runs = await client.table('agent_runs').select('id').in_('thread_id', project_thread_ids).eq('status', 'running').execute() - if active_runs.data and len(active_runs.data) > 0: - return active_runs.data[0]['id'] + from utils.query_utils import batch_query_in + + active_runs = await batch_query_in( + client=client, + table_name='agent_runs', + select_fields='id', + in_field='thread_id', + in_values=project_thread_ids, + additional_filters={'status': 'running'} + ) + + if active_runs: + return active_runs[0]['id'] return None @@ -124,27 +124,20 @@ async def check_agent_run_limit(client, account_id: str) -> Dict[str, Any]: logger.debug(f"Found {len(thread_ids)} threads for account {account_id}") # Query for running agent runs within the past 24 hours for these threads - # Handle large numbers of threads by batching to avoid URI length limits - if len(thread_ids) > 100: - # Process in batches to avoid URI too large errors - all_running_runs = [] - batch_size = 100 - for i in range(0, len(thread_ids), batch_size): - batch_thread_ids = thread_ids[i:i + batch_size] - batch_result = await client.table('agent_runs').select('id', 'thread_id', 'started_at').in_('thread_id', batch_thread_ids).eq('status', 'running').gte('started_at', twenty_four_hours_ago_iso).execute() - if batch_result.data: - all_running_runs.extend(batch_result.data) - - # Create a mock result object similar to what Supabase returns - class MockResult: - def __init__(self, data): - self.data = data - - running_runs_result = MockResult(all_running_runs) - else: - running_runs_result = await client.table('agent_runs').select('id', 'thread_id', 'started_at').in_('thread_id', thread_ids).eq('status', 'running').gte('started_at', twenty_four_hours_ago_iso).execute() + from utils.query_utils import batch_query_in + + running_runs = await batch_query_in( + client=client, + table_name='agent_runs', + select_fields='id, thread_id, started_at', + in_field='thread_id', + in_values=thread_ids, + additional_filters={ + 'status': 'running', + 'started_at_gte': twenty_four_hours_ago_iso + } + ) - running_runs = running_runs_result.data or [] running_count = len(running_runs) running_thread_ids = [run['thread_id'] for run in running_runs] @@ -348,12 +341,82 @@ if __name__ == "__main__": except Exception as e: print(f"āŒ Billing test failed: {str(e)}") + async def test_api_functions(): + """Test the API functions that were also fixed for URI limits.""" + print("\nšŸ”§ Testing API functions...") + + try: + # Import the API functions we fixed + import sys + sys.path.append('/app') # Add the app directory to path + + db = DBConnection() + client = await db.client + + test_user_id = "2558d81e-5008-46d6-b7d3-8cc62d44e4f6" + + print(f"šŸ“Š Testing API functions with user: {test_user_id}") + + # Test 1: get_user_threads (which has the project batching fix) + print("\n1ļøāƒ£ Testing get_user_threads simulation...") + try: + # Get threads for the user + threads_result = await client.table('threads').select('*').eq('account_id', test_user_id).order('created_at', desc=True).execute() + + if threads_result.data: + print(f" - Found {len(threads_result.data)} threads") + + # Extract unique project IDs (this is what could cause URI issues) + project_ids = [ + thread['project_id'] for thread in threads_result.data[:1000] # Limit to first 1000 + if thread.get('project_id') + ] + unique_project_ids = list(set(project_ids)) if project_ids else [] + + print(f" - Found {len(unique_project_ids)} unique project IDs") + + if unique_project_ids: + # Test the batching logic we implemented + if len(unique_project_ids) > 100: + print(f" - Would use batching for {len(unique_project_ids)} project IDs") + else: + print(f" - Would use direct query for {len(unique_project_ids)} project IDs") + + # Actually test a small batch to verify it works + test_batch = unique_project_ids[:min(10, len(unique_project_ids))] + projects_result = await client.table('projects').select('*').in_('project_id', test_batch).execute() + print(f"āœ… Project query test succeeded: found {len(projects_result.data or [])} projects") + else: + print(" - No project IDs to test") + else: + print(" - No threads found for user") + + except Exception as e: + print(f"āŒ get_user_threads test failed: {str(e)}") + + # Test 2: Template service simulation + print("\n2ļøāƒ£ Testing template service simulation...") + try: + from templates.template_service import TemplateService + + # This would test the creator ID batching, but we'll just verify the import works + print("āœ… Template service import succeeded") + + except ImportError as e: + print(f"āš ļø Could not import template service: {str(e)}") + except Exception as e: + print(f"āŒ Template service test failed: {str(e)}") + + except Exception as e: + print(f"āŒ API functions test failed: {str(e)}") + async def main(): """Main test function.""" print("šŸš€ Starting URI limit fix tests...\n") await test_large_thread_count() await test_billing_integration() + await test_api_functions() print("\n✨ Test suite completed!") diff --git a/backend/templates/template_service.py b/backend/templates/template_service.py index 112256c1..d1a8f8fd 100644 --- a/backend/templates/template_service.py +++ b/backend/templates/template_service.py @@ -312,12 +312,21 @@ class TemplateService: return [] creator_ids = list(set(template['creator_id'] for template in result.data)) - accounts_result = await client.schema('basejump').from_('accounts').select('id, name, slug').in_('id', creator_ids).execute() + + from utils.query_utils import batch_query_in + + accounts_data = await batch_query_in( + client=client, + table_name='accounts', + select_fields='id, name, slug', + in_field='id', + in_values=creator_ids, + schema='basejump' + ) creator_names = {} - if accounts_result.data: - for account in accounts_result.data: - creator_names[account['id']] = account.get('name') or account.get('slug') + for account in accounts_data: + creator_names[account['id']] = account.get('name') or account.get('slug') templates = [] for template_data in result.data: diff --git a/backend/utils/query_utils.py b/backend/utils/query_utils.py new file mode 100644 index 00000000..82387354 --- /dev/null +++ b/backend/utils/query_utils.py @@ -0,0 +1,81 @@ +""" +Query utilities for handling large datasets and avoiding URI length limits. +""" +from typing import List, Any, Dict, Optional +from utils.logger import logger + + +async def batch_query_in( + client, + table_name: str, + select_fields: str, + in_field: str, + in_values: List[Any], + batch_size: int = 100, + additional_filters: Optional[Dict[str, Any]] = None, + schema: Optional[str] = None +) -> List[Dict[str, Any]]: + """ + Execute a query with .in_() filtering, automatically batching large arrays to avoid URI limits. + + Args: + client: Supabase client + table_name: Name of the table to query + select_fields: Fields to select (e.g., '*' or 'id, name, created_at') + in_field: Field name for the .in_() filter + in_values: List of values to filter by + batch_size: Maximum number of values per batch (default: 100) + additional_filters: Optional dict of additional filters to apply + schema: Optional schema name (for basejump tables) + + Returns: + List of all matching records from all batches + """ + if not in_values: + return [] + + all_results = [] + + # If values list is small, do a single query + if len(in_values) <= batch_size: + query = client.schema(schema).from_(table_name) if schema else client.table(table_name) + query = query.select(select_fields).in_(in_field, in_values) + + # Apply additional filters + if additional_filters: + for field, value in additional_filters.items(): + if field.endswith('_gte'): + query = query.gte(field[:-4], value) + elif field.endswith('_eq'): + query = query.eq(field[:-3], value) + else: + query = query.eq(field, value) + + result = await query.execute() + return result.data or [] + + # Batch processing for large arrays + logger.debug(f"Batching {len(in_values)} {in_field} values into chunks of {batch_size}") + + for i in range(0, len(in_values), batch_size): + batch_values = in_values[i:i + batch_size] + + query = client.schema(schema).from_(table_name) if schema else client.table(table_name) + query = query.select(select_fields).in_(in_field, batch_values) + + # Apply additional filters + if additional_filters: + for field, value in additional_filters.items(): + if field.endswith('_gte'): + query = query.gte(field[:-4], value) + elif field.endswith('_eq'): + query = query.eq(field[:-3], value) + else: + query = query.eq(field, value) + + batch_result = await query.execute() + if batch_result.data: + all_results.extend(batch_result.data) + + logger.debug(f"Batched query returned {len(all_results)} total results") + return all_results