From 8dce0d3254f52b133be900b3e81749ff995c1225 Mon Sep 17 00:00:00 2001 From: marko-kraemer Date: Wed, 23 Apr 2025 07:12:50 +0100 Subject: [PATCH] in memory locking --- backend/agent/api.py | 86 ++++++++++++++++++-------------------------- 1 file changed, 35 insertions(+), 51 deletions(-) diff --git a/backend/agent/api.py b/backend/agent/api.py index 4acb8d68..b028b99b 100644 --- a/backend/agent/api.py +++ b/backend/agent/api.py @@ -10,6 +10,7 @@ import jwt from pydantic import BaseModel import tempfile import os +import threading from agentpress.thread_manager import ThreadManager from services.supabase import DBConnection @@ -30,6 +31,9 @@ db = None # In-memory storage for active agent runs and their responses active_agent_runs: Dict[str, List[Any]] = {} +# Add this near the top of the file with other global variables +sandbox_locks: Dict[str, threading.Lock] = {} + MODEL_NAME_ALIASES = { "sonnet-3.7": "anthropic/claude-3-7-sonnet-latest", "gpt-4.1": "openai/gpt-4.1-2025-04-14", @@ -259,7 +263,8 @@ async def _cleanup_agent_run(agent_run_id: str): async def get_or_create_project_sandbox(client, project_id: str, sandbox_cache={}): """ - Safely get or create a sandbox for a project using distributed locking to avoid race conditions. + Get or create a sandbox for a project using in-memory locking to avoid race conditions + within a single instance deployment. Args: client: The Supabase client @@ -296,59 +301,43 @@ async def get_or_create_project_sandbox(client, project_id: str, sandbox_cache={ logger.error(f"Failed to retrieve existing sandbox {sandbox_id} for project {project_id}: {str(e)}") # Fall through to create a new sandbox if retrieval fails - # Need to create a new sandbox - use Redis for distributed locking - lock_key = f"project_sandbox_lock:{project_id}" - lock_value = str(uuid.uuid4()) # Unique identifier for this lock acquisition - lock_timeout = 60 # seconds + # Need to create a new sandbox - use simple in-memory locking + # Make sure we have a lock for this project + if project_id not in sandbox_locks: + sandbox_locks[project_id] = threading.Lock() + + # Try to acquire the lock + lock_acquired = sandbox_locks[project_id].acquire(blocking=False) - # Try to acquire a lock try: - # Attempt to get a lock with a timeout - this is atomic in Redis - acquired = await redis.set( - lock_key, - lock_value, - nx=True, # Only set if key doesn't exist (NX = not exists) - ex=lock_timeout # Auto-expire the lock - ) - - if not acquired: + if not lock_acquired: # Someone else is creating a sandbox for this project - logger.info(f"Waiting for another process to create sandbox for project {project_id}") + logger.info(f"Waiting for another thread to create sandbox for project {project_id}") - # Wait and retry a few times - max_retries = 5 - retry_delay = 2 # seconds + # Wait and retry a few times with a blocking acquire + sandbox_locks[project_id].acquire(blocking=True, timeout=60) - for retry in range(max_retries): - await asyncio.sleep(retry_delay) + # Check if the other thread completed while we were waiting + fresh_project = await client.table('projects').select('*').eq('project_id', project_id).execute() + if fresh_project.data and fresh_project.data[0].get('sandbox', {}).get('id'): + sandbox_id = fresh_project.data[0]['sandbox']['id'] + sandbox_pass = fresh_project.data[0]['sandbox']['pass'] + logger.info(f"Another thread created sandbox {sandbox_id} for project {project_id}") - # Check if the other process completed - fresh_project = await client.table('projects').select('*').eq('project_id', project_id).execute() - if fresh_project.data and fresh_project.data[0].get('sandbox', {}).get('id'): - sandbox_id = fresh_project.data[0]['sandbox']['id'] - sandbox_pass = fresh_project.data[0]['sandbox']['pass'] - logger.info(f"Another process created sandbox {sandbox_id} for project {project_id}") - - sandbox = await get_or_start_sandbox(sandbox_id) - # Cache the result - sandbox_cache[project_id] = (sandbox, sandbox_id, sandbox_pass) - return (sandbox, sandbox_id, sandbox_pass) + sandbox = await get_or_start_sandbox(sandbox_id) + # Cache the result + sandbox_cache[project_id] = (sandbox, sandbox_id, sandbox_pass) + return (sandbox, sandbox_id, sandbox_pass) - # If we got here, the other process didn't complete in time - # Force-acquire the lock by deleting and recreating it - logger.warning(f"Timeout waiting for sandbox creation for project {project_id}, acquiring lock forcefully") - await redis.delete(lock_key) - await redis.set(lock_key, lock_value, ex=lock_timeout) + # If we got here and still don't have a sandbox, we'll create one now + # (the other thread must have failed or timed out) - # We have the lock now - check one more time to avoid race conditions + # Double-check the project data once more to avoid race conditions fresh_project = await client.table('projects').select('*').eq('project_id', project_id).execute() if fresh_project.data and fresh_project.data[0].get('sandbox', {}).get('id'): sandbox_id = fresh_project.data[0]['sandbox']['id'] sandbox_pass = fresh_project.data[0]['sandbox']['pass'] - logger.info(f"Sandbox {sandbox_id} was created by another process while acquiring lock for project {project_id}") - - # Release the lock - await redis.delete(lock_key) + logger.info(f"Sandbox {sandbox_id} was created by another thread while waiting for project {project_id}") sandbox = await get_or_start_sandbox(sandbox_id) # Cache the result @@ -407,15 +396,10 @@ async def get_or_create_project_sandbox(client, project_id: str, sandbox_cache={ raise e finally: - # Always try to release the lock if we have it - try: - # Only delete the lock if it's still ours - current_value = await redis.get(lock_key) - if current_value == lock_value: - await redis.delete(lock_key) - logger.debug(f"Released lock for project {project_id} sandbox creation") - except Exception as lock_error: - logger.warning(f"Error releasing sandbox creation lock for project {project_id}: {str(lock_error)}") + # Always release the lock if we acquired it + if lock_acquired: + sandbox_locks[project_id].release() + logger.debug(f"Released lock for project {project_id} sandbox creation") @router.post("/thread/{thread_id}/agent/start") async def start_agent(