This commit is contained in:
marko-kraemer 2025-04-23 19:46:22 +01:00
parent 7bffa72056
commit e985cbdc2b
6 changed files with 54 additions and 140 deletions

View File

@ -256,23 +256,17 @@ async def _cleanup_agent_run(agent_run_id: str):
logger.warning(f"Failed to clean up Redis keys for agent run {agent_run_id}: {str(e)}") logger.warning(f"Failed to clean up Redis keys for agent run {agent_run_id}: {str(e)}")
# Non-fatal error, can continue # Non-fatal error, can continue
async def get_or_create_project_sandbox(client, project_id: str, sandbox_cache={}): async def get_or_create_project_sandbox(client, project_id: str):
""" """
Safely get or create a sandbox for a project using distributed locking to avoid race conditions. Get or create a sandbox for a project without distributed locking.
Args: Args:
client: The Supabase client client: The Supabase client
project_id: The project ID to get or create a sandbox for project_id: The project ID to get or create a sandbox for
sandbox_cache: Optional in-memory cache to avoid repeated lookups in the same process
Returns: Returns:
Tuple of (sandbox object, sandbox_id, sandbox_pass) Tuple of (sandbox object, sandbox_id, sandbox_pass)
""" """
# Check in-memory cache first (optimization for repeated calls within same process)
if project_id in sandbox_cache:
logger.debug(f"Using cached sandbox for project {project_id}")
return sandbox_cache[project_id]
# First get the current project data to check if a sandbox already exists # First get the current project data to check if a sandbox already exists
project = await client.table('projects').select('*').eq('project_id', project_id).execute() project = await client.table('projects').select('*').eq('project_id', project_id).execute()
if not project.data or len(project.data) == 0: if not project.data or len(project.data) == 0:
@ -288,133 +282,55 @@ async def get_or_create_project_sandbox(client, project_id: str, sandbox_cache={
try: try:
sandbox = await get_or_start_sandbox(sandbox_id) sandbox = await get_or_start_sandbox(sandbox_id)
# Cache the result
sandbox_cache[project_id] = (sandbox, sandbox_id, sandbox_pass)
return (sandbox, sandbox_id, sandbox_pass) return (sandbox, sandbox_id, sandbox_pass)
except Exception as e: except Exception as e:
logger.error(f"Failed to retrieve existing sandbox {sandbox_id} for project {project_id}: {str(e)}") logger.error(f"Failed to retrieve existing sandbox {sandbox_id} for project {project_id}: {str(e)}")
# Fall through to create a new sandbox if retrieval fails # Fall through to create a new sandbox if retrieval fails
# Need to create a new sandbox - use Redis for distributed locking # Create a new sandbox
lock_key = f"project_sandbox_lock:{project_id}"
lock_value = str(uuid.uuid4()) # Unique identifier for this lock acquisition
lock_timeout = 60 # seconds
# Try to acquire a lock
try: try:
# Attempt to get a lock with a timeout - this is atomic in Redis logger.info(f"Creating new sandbox for project {project_id}")
acquired = await redis.set( sandbox_pass = str(uuid.uuid4())
lock_key, sandbox = create_sandbox(sandbox_pass)
lock_value, sandbox_id = sandbox.id
nx=True, # Only set if key doesn't exist (NX = not exists)
ex=lock_timeout # Auto-expire the lock
)
if not acquired: logger.info(f"Created new sandbox {sandbox_id} with preview: {sandbox.get_preview_link(6080)}/vnc_lite.html?password={sandbox_pass}")
# Someone else is creating a sandbox for this project
logger.info(f"Waiting for another process to create sandbox for project {project_id}")
# Wait and retry a few times
max_retries = 5
retry_delay = 2 # seconds
for retry in range(max_retries):
await asyncio.sleep(retry_delay)
# Check if the other process completed
fresh_project = await client.table('projects').select('*').eq('project_id', project_id).execute()
if fresh_project.data and fresh_project.data[0].get('sandbox', {}).get('id'):
sandbox_id = fresh_project.data[0]['sandbox']['id']
sandbox_pass = fresh_project.data[0]['sandbox']['pass']
logger.info(f"Another process created sandbox {sandbox_id} for project {project_id}")
sandbox = await get_or_start_sandbox(sandbox_id)
# Cache the result
sandbox_cache[project_id] = (sandbox, sandbox_id, sandbox_pass)
return (sandbox, sandbox_id, sandbox_pass)
# If we got here, the other process didn't complete in time
# Force-acquire the lock by deleting and recreating it
logger.warning(f"Timeout waiting for sandbox creation for project {project_id}, acquiring lock forcefully")
await redis.delete(lock_key)
await redis.set(lock_key, lock_value, ex=lock_timeout)
# We have the lock now - check one more time to avoid race conditions # Get preview links
fresh_project = await client.table('projects').select('*').eq('project_id', project_id).execute() vnc_link = sandbox.get_preview_link(6080)
if fresh_project.data and fresh_project.data[0].get('sandbox', {}).get('id'): website_link = sandbox.get_preview_link(8080)
sandbox_id = fresh_project.data[0]['sandbox']['id']
sandbox_pass = fresh_project.data[0]['sandbox']['pass'] # Extract the actual URLs and token from the preview link objects
logger.info(f"Sandbox {sandbox_id} was created by another process while acquiring lock for project {project_id}") vnc_url = vnc_link.url if hasattr(vnc_link, 'url') else str(vnc_link).split("url='")[1].split("'")[0]
website_url = website_link.url if hasattr(website_link, 'url') else str(website_link).split("url='")[1].split("'")[0]
# Release the lock
await redis.delete(lock_key) # Extract token if available
token = None
sandbox = await get_or_start_sandbox(sandbox_id) if hasattr(vnc_link, 'token'):
# Cache the result token = vnc_link.token
sandbox_cache[project_id] = (sandbox, sandbox_id, sandbox_pass) elif "token='" in str(vnc_link):
return (sandbox, sandbox_id, sandbox_pass) token = str(vnc_link).split("token='")[1].split("'")[0]
# Update the project with the new sandbox info
update_result = await client.table('projects').update({
'sandbox': {
'id': sandbox_id,
'pass': sandbox_pass,
'vnc_preview': vnc_url,
'sandbox_url': website_url,
'token': token
}
}).eq('project_id', project_id).execute()
if not update_result.data:
logger.error(f"Failed to update project {project_id} with new sandbox {sandbox_id}")
raise Exception("Database update failed")
return (sandbox, sandbox_id, sandbox_pass)
# Create a new sandbox
try:
logger.info(f"Creating new sandbox for project {project_id}")
sandbox_pass = str(uuid.uuid4())
sandbox = create_sandbox(sandbox_pass)
sandbox_id = sandbox.id
logger.info(f"Created new sandbox {sandbox_id} with preview: {sandbox.get_preview_link(6080)}/vnc_lite.html?password={sandbox_pass}")
# Get preview links
vnc_link = sandbox.get_preview_link(6080)
website_link = sandbox.get_preview_link(8080)
# Extract the actual URLs and token from the preview link objects
vnc_url = vnc_link.url if hasattr(vnc_link, 'url') else str(vnc_link).split("url='")[1].split("'")[0]
website_url = website_link.url if hasattr(website_link, 'url') else str(website_link).split("url='")[1].split("'")[0]
# Extract token if available
token = None
if hasattr(vnc_link, 'token'):
token = vnc_link.token
elif "token='" in str(vnc_link):
token = str(vnc_link).split("token='")[1].split("'")[0]
# Update the project with the new sandbox info
update_result = await client.table('projects').update({
'sandbox': {
'id': sandbox_id,
'pass': sandbox_pass,
'vnc_preview': vnc_url,
'sandbox_url': website_url,
'token': token
}
}).eq('project_id', project_id).execute()
if not update_result.data:
logger.error(f"Failed to update project {project_id} with new sandbox {sandbox_id}")
raise Exception("Database update failed")
# Cache the result
sandbox_cache[project_id] = (sandbox, sandbox_id, sandbox_pass)
return (sandbox, sandbox_id, sandbox_pass)
except Exception as e:
logger.error(f"Error creating sandbox for project {project_id}: {str(e)}")
raise e
except Exception as e: except Exception as e:
logger.error(f"Error in sandbox creation process for project {project_id}: {str(e)}") logger.error(f"Error creating sandbox for project {project_id}: {str(e)}")
raise e raise e
finally:
# Always try to release the lock if we have it
try:
# Only delete the lock if it's still ours
current_value = await redis.get(lock_key)
if current_value == lock_value:
await redis.delete(lock_key)
logger.debug(f"Released lock for project {project_id} sandbox creation")
except Exception as lock_error:
logger.warning(f"Error releasing sandbox creation lock for project {project_id}: {str(lock_error)}")
@router.post("/thread/{thread_id}/agent/start") @router.post("/thread/{thread_id}/agent/start")
async def start_agent( async def start_agent(
@ -894,7 +810,6 @@ async def run_agent_background(
logger.info(f"Agent run background task fully completed for: {agent_run_id} (instance: {instance_id})") logger.info(f"Agent run background task fully completed for: {agent_run_id} (instance: {instance_id})")
# New background task function
async def generate_and_update_project_name(project_id: str, prompt: str): async def generate_and_update_project_name(project_id: str, prompt: str):
"""Generates a project name using an LLM and updates the database.""" """Generates a project name using an LLM and updates the database."""
logger.info(f"Starting background task to generate name for project: {project_id}") logger.info(f"Starting background task to generate name for project: {project_id}")
@ -938,7 +853,6 @@ async def generate_and_update_project_name(project_id: str, prompt: str):
else: else:
logger.warning(f"Failed to get valid response from LLM for project {project_id} naming. Response: {response}") logger.warning(f"Failed to get valid response from LLM for project {project_id} naming. Response: {response}")
print(f"\n\n\nGenerated name: {generated_name}\n\n\n")
# Update database if name was generated # Update database if name was generated
if generated_name: if generated_name:
update_result = await client.table('projects') \ update_result = await client.table('projects') \
@ -1026,7 +940,7 @@ async def initiate_agent_with_files(
) )
# ----------------------------------------- # -----------------------------------------
# 3. Create Sandbox - Using safe method with distributed locking # 3. Create Sandbox
try: try:
sandbox, sandbox_id, sandbox_pass = await get_or_create_project_sandbox(client, project_id) sandbox, sandbox_id, sandbox_pass = await get_or_create_project_sandbox(client, project_id)
logger.info(f"Using sandbox {sandbox_id} for new project {project_id}") logger.info(f"Using sandbox {sandbox_id} for new project {project_id}")

View File

@ -247,9 +247,9 @@ class WebSearchTool(Tool):
) )
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
print(f"--- Raw Tavily Response ---") # print(f"--- Raw Tavily Response ---")
print(data) # print(data)
print(f"--------------------------") # print(f"--------------------------")
# Normalise Tavily extract output to a list of dicts # Normalise Tavily extract output to a list of dicts
extracted = [] extracted = []

View File

@ -46,9 +46,9 @@ async def lifespan(app: FastAPI):
# Initialize the sandbox API with shared resources # Initialize the sandbox API with shared resources
sandbox_api.initialize(db) sandbox_api.initialize(db)
# Initialize Redis before restoring agent runs # Redis is no longer needed for a single-server setup
from services import redis # from services import redis
await redis.initialize_async() # await redis.initialize_async()
asyncio.create_task(agent_api.restore_running_agent_runs()) asyncio.create_task(agent_api.restore_running_agent_runs())

View File

@ -74,7 +74,6 @@ async def verify_sandbox_access(client, sandbox_id: str, user_id: Optional[str]
async def get_sandbox_by_id_safely(client, sandbox_id: str): async def get_sandbox_by_id_safely(client, sandbox_id: str):
""" """
Safely retrieve a sandbox object by its ID, using the project that owns it. Safely retrieve a sandbox object by its ID, using the project that owns it.
This prevents race conditions by leveraging the distributed locking mechanism.
Args: Args:
client: The Supabase client client: The Supabase client
@ -97,7 +96,7 @@ async def get_sandbox_by_id_safely(client, sandbox_id: str):
logger.debug(f"Found project {project_id} for sandbox {sandbox_id}") logger.debug(f"Found project {project_id} for sandbox {sandbox_id}")
try: try:
# Use the race-condition-safe function to get the sandbox # Get the sandbox
sandbox, retrieved_sandbox_id, sandbox_pass = await get_or_create_project_sandbox(client, project_id) sandbox, retrieved_sandbox_id, sandbox_pass = await get_or_create_project_sandbox(client, project_id)
# Verify we got the right sandbox # Verify we got the right sandbox
@ -259,7 +258,6 @@ async def ensure_project_sandbox_active(
""" """
Ensure that a project's sandbox is active and running. Ensure that a project's sandbox is active and running.
Checks the sandbox status and starts it if it's not running. Checks the sandbox status and starts it if it's not running.
Uses distributed locking to prevent race conditions.
""" """
client = await db.client client = await db.client
@ -286,8 +284,8 @@ async def ensure_project_sandbox_active(
raise HTTPException(status_code=403, detail="Not authorized to access this project") raise HTTPException(status_code=403, detail="Not authorized to access this project")
try: try:
# Use the safer function that handles race conditions with distributed locking # Get or create the sandbox
logger.info(f"Ensuring sandbox is active for project {project_id} using distributed locking") logger.info(f"Ensuring sandbox is active for project {project_id}")
sandbox, sandbox_id, sandbox_pass = await get_or_create_project_sandbox(client, project_id) sandbox, sandbox_id, sandbox_pass = await get_or_create_project_sandbox(client, project_id)
logger.info(f"Successfully ensured sandbox {sandbox_id} is active for project {project_id}") logger.info(f"Successfully ensured sandbox {sandbox_id} is active for project {project_id}")

View File

@ -3,7 +3,7 @@ from typing import Dict, Optional, Tuple
# Define subscription tiers and their monthly limits (in minutes) # Define subscription tiers and their monthly limits (in minutes)
SUBSCRIPTION_TIERS = { SUBSCRIPTION_TIERS = {
'price_1RGJ9GG6l1KZGqIroxSqgphC': {'name': 'free', 'minutes': 0}, 'price_1RGJ9GG6l1KZGqIroxSqgphC': {'name': 'free', 'minutes': 100000},
'price_1RGJ9LG6l1KZGqIrd9pwzeNW': {'name': 'base', 'minutes': 300}, # 100 hours = 6000 minutes 'price_1RGJ9LG6l1KZGqIrd9pwzeNW': {'name': 'base', 'minutes': 300}, # 100 hours = 6000 minutes
'price_1RGJ9JG6l1KZGqIrVUU4ZRv6': {'name': 'extra', 'minutes': 2400} # 100 hours = 6000 minutes 'price_1RGJ9JG6l1KZGqIrVUU4ZRv6': {'name': 'extra', 'minutes': 2400} # 100 hours = 6000 minutes
} }

View File

@ -243,7 +243,7 @@ export default function ThreadPage({ params }: { params: Promise<ThreadParams> }
const messagesLoadedRef = useRef(false); const messagesLoadedRef = useRef(false);
const agentRunsCheckedRef = useRef(false); const agentRunsCheckedRef = useRef(false);
const previousAgentStatus = useRef<typeof agentStatus>('idle'); const previousAgentStatus = useRef<typeof agentStatus>('idle');
const pollingIntervalRef = useRef<NodeJS.Timeout | null>(null); const pollingIntervalRef = useRef<NodeJS.Timeout | null>(null); // POLLING FOR MESSAGES
const handleProjectRenamed = useCallback((newName: string) => { const handleProjectRenamed = useCallback((newName: string) => {
setProjectName(newName); setProjectName(newName);
@ -969,6 +969,7 @@ export default function ThreadPage({ params }: { params: Promise<ThreadParams> }
} }
}, [projectName]); }, [projectName]);
// POLLING FOR MESSAGES
// Set up polling for messages // Set up polling for messages
useEffect(() => { useEffect(() => {
// Function to fetch messages // Function to fetch messages
@ -1024,6 +1025,7 @@ export default function ThreadPage({ params }: { params: Promise<ThreadParams> }
} }
}; };
}, [threadId, userHasScrolled, initialLoadCompleted]); }, [threadId, userHasScrolled, initialLoadCompleted]);
// POLLING FOR MESSAGES
// Add another useEffect to ensure messages are refreshed when agent status changes to idle // Add another useEffect to ensure messages are refreshed when agent status changes to idle
useEffect(() => { useEffect(() => {