query util for batch in() calls

This commit is contained in:
marko-kraemer 2025-08-22 21:31:00 -07:00
parent 66d96d0ad3
commit fbdc97c9fa
4 changed files with 215 additions and 49 deletions

View File

@ -1386,11 +1386,17 @@ async def get_agents(
version_ids = list({agent['current_version_id'] for agent in agents_data if agent.get('current_version_id')}) version_ids = list({agent['current_version_id'] for agent in agents_data if agent.get('current_version_id')})
if version_ids: if version_ids:
try: try:
versions_result = await client.table('agent_versions').select( from utils.query_utils import batch_query_in
'version_id, agent_id, version_number, version_name, is_active, created_at, updated_at, created_by, config'
).in_('version_id', version_ids).execute() versions_data = await batch_query_in(
client=client,
table_name='agent_versions',
select_fields='version_id, agent_id, version_number, version_name, is_active, created_at, updated_at, created_by, config',
in_field='version_id',
in_values=version_ids
)
for row in (versions_result.data or []): for row in versions_data:
config = row.get('config') or {} config = row.get('config') or {}
tools = config.get('tools') or {} tools = config.get('tools') or {}
version_dict = { version_dict = {
@ -2960,15 +2966,22 @@ async def get_user_threads(
# Fetch projects if we have project IDs # Fetch projects if we have project IDs
projects_by_id = {} projects_by_id = {}
if unique_project_ids: if unique_project_ids:
projects_result = await client.table('projects').select('*').in_('project_id', unique_project_ids).execute() from utils.query_utils import batch_query_in
if projects_result.data: projects_data = await batch_query_in(
logger.debug(f"[API] Raw projects from DB: {len(projects_result.data)}") client=client,
# Create a lookup map of projects by ID table_name='projects',
projects_by_id = { select_fields='*',
project['project_id']: project in_field='project_id',
for project in projects_result.data in_values=unique_project_ids
} )
logger.debug(f"[API] Retrieved {len(projects_data)} projects")
# Create a lookup map of projects by ID
projects_by_id = {
project['project_id']: project
for project in projects_data
}
# Map threads with their associated projects # Map threads with their associated projects
mapped_threads = [] mapped_threads = []

View File

@ -22,19 +22,19 @@ async def check_for_active_project_agent_run(client, project_id: str):
project_thread_ids = [t['thread_id'] for t in project_threads.data] project_thread_ids = [t['thread_id'] for t in project_threads.data]
if project_thread_ids: if project_thread_ids:
# Handle large numbers of threads by batching to avoid URI length limits from utils.query_utils import batch_query_in
if len(project_thread_ids) > 100:
# Process in batches to avoid URI too large errors active_runs = await batch_query_in(
batch_size = 100 client=client,
for i in range(0, len(project_thread_ids), batch_size): table_name='agent_runs',
batch_thread_ids = project_thread_ids[i:i + batch_size] select_fields='id',
active_runs = await client.table('agent_runs').select('id').in_('thread_id', batch_thread_ids).eq('status', 'running').execute() in_field='thread_id',
if active_runs.data and len(active_runs.data) > 0: in_values=project_thread_ids,
return active_runs.data[0]['id'] additional_filters={'status': 'running'}
else: )
active_runs = await client.table('agent_runs').select('id').in_('thread_id', project_thread_ids).eq('status', 'running').execute()
if active_runs.data and len(active_runs.data) > 0: if active_runs:
return active_runs.data[0]['id'] return active_runs[0]['id']
return None return None
@ -124,27 +124,20 @@ async def check_agent_run_limit(client, account_id: str) -> Dict[str, Any]:
logger.debug(f"Found {len(thread_ids)} threads for account {account_id}") logger.debug(f"Found {len(thread_ids)} threads for account {account_id}")
# Query for running agent runs within the past 24 hours for these threads # Query for running agent runs within the past 24 hours for these threads
# Handle large numbers of threads by batching to avoid URI length limits from utils.query_utils import batch_query_in
if len(thread_ids) > 100:
# Process in batches to avoid URI too large errors running_runs = await batch_query_in(
all_running_runs = [] client=client,
batch_size = 100 table_name='agent_runs',
for i in range(0, len(thread_ids), batch_size): select_fields='id, thread_id, started_at',
batch_thread_ids = thread_ids[i:i + batch_size] in_field='thread_id',
batch_result = await client.table('agent_runs').select('id', 'thread_id', 'started_at').in_('thread_id', batch_thread_ids).eq('status', 'running').gte('started_at', twenty_four_hours_ago_iso).execute() in_values=thread_ids,
if batch_result.data: additional_filters={
all_running_runs.extend(batch_result.data) 'status': 'running',
'started_at_gte': twenty_four_hours_ago_iso
# Create a mock result object similar to what Supabase returns }
class MockResult: )
def __init__(self, data):
self.data = data
running_runs_result = MockResult(all_running_runs)
else:
running_runs_result = await client.table('agent_runs').select('id', 'thread_id', 'started_at').in_('thread_id', thread_ids).eq('status', 'running').gte('started_at', twenty_four_hours_ago_iso).execute()
running_runs = running_runs_result.data or []
running_count = len(running_runs) running_count = len(running_runs)
running_thread_ids = [run['thread_id'] for run in running_runs] running_thread_ids = [run['thread_id'] for run in running_runs]
@ -348,12 +341,82 @@ if __name__ == "__main__":
except Exception as e: except Exception as e:
print(f"❌ Billing test failed: {str(e)}") print(f"❌ Billing test failed: {str(e)}")
async def test_api_functions():
"""Test the API functions that were also fixed for URI limits."""
print("\n🔧 Testing API functions...")
try:
# Import the API functions we fixed
import sys
sys.path.append('/app') # Add the app directory to path
db = DBConnection()
client = await db.client
test_user_id = "2558d81e-5008-46d6-b7d3-8cc62d44e4f6"
print(f"📊 Testing API functions with user: {test_user_id}")
# Test 1: get_user_threads (which has the project batching fix)
print("\n1⃣ Testing get_user_threads simulation...")
try:
# Get threads for the user
threads_result = await client.table('threads').select('*').eq('account_id', test_user_id).order('created_at', desc=True).execute()
if threads_result.data:
print(f" - Found {len(threads_result.data)} threads")
# Extract unique project IDs (this is what could cause URI issues)
project_ids = [
thread['project_id'] for thread in threads_result.data[:1000] # Limit to first 1000
if thread.get('project_id')
]
unique_project_ids = list(set(project_ids)) if project_ids else []
print(f" - Found {len(unique_project_ids)} unique project IDs")
if unique_project_ids:
# Test the batching logic we implemented
if len(unique_project_ids) > 100:
print(f" - Would use batching for {len(unique_project_ids)} project IDs")
else:
print(f" - Would use direct query for {len(unique_project_ids)} project IDs")
# Actually test a small batch to verify it works
test_batch = unique_project_ids[:min(10, len(unique_project_ids))]
projects_result = await client.table('projects').select('*').in_('project_id', test_batch).execute()
print(f"✅ Project query test succeeded: found {len(projects_result.data or [])} projects")
else:
print(" - No project IDs to test")
else:
print(" - No threads found for user")
except Exception as e:
print(f"❌ get_user_threads test failed: {str(e)}")
# Test 2: Template service simulation
print("\n2⃣ Testing template service simulation...")
try:
from templates.template_service import TemplateService
# This would test the creator ID batching, but we'll just verify the import works
print("✅ Template service import succeeded")
except ImportError as e:
print(f"⚠️ Could not import template service: {str(e)}")
except Exception as e:
print(f"❌ Template service test failed: {str(e)}")
except Exception as e:
print(f"❌ API functions test failed: {str(e)}")
async def main(): async def main():
"""Main test function.""" """Main test function."""
print("🚀 Starting URI limit fix tests...\n") print("🚀 Starting URI limit fix tests...\n")
await test_large_thread_count() await test_large_thread_count()
await test_billing_integration() await test_billing_integration()
await test_api_functions()
print("\n✨ Test suite completed!") print("\n✨ Test suite completed!")

View File

@ -312,12 +312,21 @@ class TemplateService:
return [] return []
creator_ids = list(set(template['creator_id'] for template in result.data)) creator_ids = list(set(template['creator_id'] for template in result.data))
accounts_result = await client.schema('basejump').from_('accounts').select('id, name, slug').in_('id', creator_ids).execute()
from utils.query_utils import batch_query_in
accounts_data = await batch_query_in(
client=client,
table_name='accounts',
select_fields='id, name, slug',
in_field='id',
in_values=creator_ids,
schema='basejump'
)
creator_names = {} creator_names = {}
if accounts_result.data: for account in accounts_data:
for account in accounts_result.data: creator_names[account['id']] = account.get('name') or account.get('slug')
creator_names[account['id']] = account.get('name') or account.get('slug')
templates = [] templates = []
for template_data in result.data: for template_data in result.data:

View File

@ -0,0 +1,81 @@
"""
Query utilities for handling large datasets and avoiding URI length limits.
"""
from typing import List, Any, Dict, Optional
from utils.logger import logger
async def batch_query_in(
client,
table_name: str,
select_fields: str,
in_field: str,
in_values: List[Any],
batch_size: int = 100,
additional_filters: Optional[Dict[str, Any]] = None,
schema: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
Execute a query with .in_() filtering, automatically batching large arrays to avoid URI limits.
Args:
client: Supabase client
table_name: Name of the table to query
select_fields: Fields to select (e.g., '*' or 'id, name, created_at')
in_field: Field name for the .in_() filter
in_values: List of values to filter by
batch_size: Maximum number of values per batch (default: 100)
additional_filters: Optional dict of additional filters to apply
schema: Optional schema name (for basejump tables)
Returns:
List of all matching records from all batches
"""
if not in_values:
return []
all_results = []
# If values list is small, do a single query
if len(in_values) <= batch_size:
query = client.schema(schema).from_(table_name) if schema else client.table(table_name)
query = query.select(select_fields).in_(in_field, in_values)
# Apply additional filters
if additional_filters:
for field, value in additional_filters.items():
if field.endswith('_gte'):
query = query.gte(field[:-4], value)
elif field.endswith('_eq'):
query = query.eq(field[:-3], value)
else:
query = query.eq(field, value)
result = await query.execute()
return result.data or []
# Batch processing for large arrays
logger.debug(f"Batching {len(in_values)} {in_field} values into chunks of {batch_size}")
for i in range(0, len(in_values), batch_size):
batch_values = in_values[i:i + batch_size]
query = client.schema(schema).from_(table_name) if schema else client.table(table_name)
query = query.select(select_fields).in_(in_field, batch_values)
# Apply additional filters
if additional_filters:
for field, value in additional_filters.items():
if field.endswith('_gte'):
query = query.gte(field[:-4], value)
elif field.endswith('_eq'):
query = query.eq(field[:-3], value)
else:
query = query.eq(field, value)
batch_result = await query.execute()
if batch_result.data:
all_results.extend(batch_result.data)
logger.debug(f"Batched query returned {len(all_results)} total results")
return all_results