diff --git a/backend/agent/api.py b/backend/agent/api.py index 97c89713..0364a604 100644 --- a/backend/agent/api.py +++ b/backend/agent/api.py @@ -24,10 +24,11 @@ from services.llm import make_llm_api_call # Initialize shared resources router = APIRouter() thread_manager = None -db = None +db = None +instance_id = None # Global instance ID for this backend instance -# In-memory storage for active agent runs and their responses -active_agent_runs: Dict[str, List[Any]] = {} +# TTL for Redis response lists (24 hours) +REDIS_RESPONSE_LIST_TTL = 3600 * 24 MODEL_NAME_ALIASES = { "sonnet-3.7": "anthropic/claude-3-7-sonnet-latest", @@ -55,33 +56,42 @@ def initialize( global thread_manager, db, instance_id thread_manager = _thread_manager db = _db - + # Use provided instance_id or generate a new one if _instance_id: instance_id = _instance_id else: # Generate instance ID instance_id = str(uuid.uuid4())[:8] - + logger.info(f"Initialized agent API with instance ID: {instance_id}") - + # Note: Redis will be initialized in the lifespan function in api.py async def cleanup(): """Clean up resources and stop running agents on shutdown.""" logger.info("Starting cleanup of agent API resources") - + # Use the instance_id to find and clean up this instance's keys try: - running_keys = await redis.keys(f"active_run:{instance_id}:*") - logger.info(f"Found {len(running_keys)} running agent runs to clean up") - - for key in running_keys: - agent_run_id = key.split(":")[-1] - await stop_agent_run(agent_run_id) + if instance_id: # Ensure instance_id is set + running_keys = await redis.keys(f"active_run:{instance_id}:*") + logger.info(f"Found {len(running_keys)} running agent runs for instance {instance_id} to clean up") + + for key in running_keys: + # Key format: active_run:{instance_id}:{agent_run_id} + parts = key.split(":") + if len(parts) == 3: + agent_run_id = parts[2] + await stop_agent_run(agent_run_id, error_message=f"Instance {instance_id} shutting down") + else: + logger.warning(f"Unexpected key format found: {key}") + else: + logger.warning("Instance ID not set, cannot clean up instance-specific agent runs.") + except Exception as e: logger.error(f"Failed to clean up running agent runs: {str(e)}") - + # Close Redis connection await redis.close() logger.info("Completed cleanup of agent API resources") @@ -91,7 +101,7 @@ async def update_agent_run_status( agent_run_id: str, status: str, error: Optional[str] = None, - responses: Optional[List[Any]] = None + responses: Optional[List[Any]] = None # Expects parsed list of dicts ) -> bool: """ Centralized function to update agent run status. @@ -102,21 +112,22 @@ async def update_agent_run_status( "status": status, "completed_at": datetime.now(timezone.utc).isoformat() } - + if error: update_data["error"] = error - + if responses: + # Ensure responses are stored correctly as JSONB update_data["responses"] = responses - + # Retry up to 3 times for retry in range(3): try: update_result = await client.table('agent_runs').update(update_data).eq("id", agent_run_id).execute() - + if hasattr(update_result, 'data') and update_result.data: - logger.info(f"Successfully updated agent run status to '{status}' (retry {retry}): {agent_run_id}") - + logger.info(f"Successfully updated agent run {agent_run_id} status to '{status}' (retry {retry})") + # Verify the update verify_result = await client.table('agent_runs').select('status', 'completed_at').eq("id", agent_run_id).execute() if verify_result.data: @@ -125,317 +136,264 @@ async def update_agent_run_status( logger.info(f"Verified agent run update: status={actual_status}, completed_at={completed_at}") return True else: - logger.warning(f"Database update returned no data on retry {retry}: {update_result}") + logger.warning(f"Database update returned no data for agent run {agent_run_id} on retry {retry}: {update_result}") if retry == 2: # Last retry logger.error(f"Failed to update agent run status after all retries: {agent_run_id}") return False except Exception as db_error: - logger.error(f"Database error on retry {retry} updating status: {str(db_error)}") + logger.error(f"Database error on retry {retry} updating status for {agent_run_id}: {str(db_error)}") if retry < 2: # Not the last retry yet await asyncio.sleep(0.5 * (2 ** retry)) # Exponential backoff else: logger.error(f"Failed to update agent run status after all retries: {agent_run_id}", exc_info=True) return False except Exception as e: - logger.error(f"Unexpected error updating agent run status: {str(e)}", exc_info=True) + logger.error(f"Unexpected error updating agent run status for {agent_run_id}: {str(e)}", exc_info=True) return False - + return False async def stop_agent_run(agent_run_id: str, error_message: Optional[str] = None): """Update database and publish stop signal to Redis.""" logger.info(f"Stopping agent run: {agent_run_id}") client = await db.client - - # Update the agent run status - status = "failed" if error_message else "stopped" - await update_agent_run_status(client, agent_run_id, status, error=error_message) - - # Send stop signal to global channel + final_status = "failed" if error_message else "stopped" + + # Attempt to fetch final responses from Redis + response_list_key = f"agent_run:{agent_run_id}:responses" + all_responses = [] try: - await redis.publish(f"agent_run:{agent_run_id}:control", "STOP") - logger.debug(f"Published STOP signal to global channel for agent run {agent_run_id}") + all_responses_json = await redis.lrange(response_list_key, 0, -1) + all_responses = [json.loads(r) for r in all_responses_json] + logger.info(f"Fetched {len(all_responses)} responses from Redis for DB update on stop/fail: {agent_run_id}") except Exception as e: - logger.error(f"Failed to publish STOP signal to global channel: {str(e)}") - - # Find all instances handling this agent run + logger.error(f"Failed to fetch responses from Redis for {agent_run_id} during stop/fail: {e}") + # Try fetching from DB as a fallback? Or proceed without responses? Proceeding without for now. + + # Update the agent run status in the database + update_success = await update_agent_run_status( + client, agent_run_id, final_status, error=error_message, responses=all_responses + ) + + if not update_success: + logger.error(f"Failed to update database status for stopped/failed run {agent_run_id}") + + # Send STOP signal to the global control channel + global_control_channel = f"agent_run:{agent_run_id}:control" + try: + await redis.publish(global_control_channel, "STOP") + logger.debug(f"Published STOP signal to global channel {global_control_channel}") + except Exception as e: + logger.error(f"Failed to publish STOP signal to global channel {global_control_channel}: {str(e)}") + + # Find all instances handling this agent run and send STOP to instance-specific channels try: instance_keys = await redis.keys(f"active_run:*:{agent_run_id}") - logger.debug(f"Found {len(instance_keys)} active instances for agent run {agent_run_id}") - + logger.debug(f"Found {len(instance_keys)} active instance keys for agent run {agent_run_id}") + for key in instance_keys: - # Extract instance ID from the key pattern: active_run:{instance_id}:{agent_run_id} + # Key format: active_run:{instance_id}:{agent_run_id} parts = key.split(":") - if len(parts) >= 3: - instance_id = parts[1] + if len(parts) == 3: + instance_id_from_key = parts[1] + instance_control_channel = f"agent_run:{agent_run_id}:control:{instance_id_from_key}" try: - # Send stop signal to instance-specific channel - await redis.publish(f"agent_run:{agent_run_id}:control:{instance_id}", "STOP") - logger.debug(f"Published STOP signal to instance {instance_id} for agent run {agent_run_id}") + await redis.publish(instance_control_channel, "STOP") + logger.debug(f"Published STOP signal to instance channel {instance_control_channel}") except Exception as e: - logger.warning(f"Failed to publish STOP signal to instance {instance_id}: {str(e)}") + logger.warning(f"Failed to publish STOP signal to instance channel {instance_control_channel}: {str(e)}") + else: + logger.warning(f"Unexpected key format found: {key}") + + # Clean up the response list immediately on stop/fail + await _cleanup_redis_response_list(agent_run_id) + except Exception as e: - logger.error(f"Failed to find or signal active instances: {str(e)}") - + logger.error(f"Failed to find or signal active instances for {agent_run_id}: {str(e)}") + logger.info(f"Successfully initiated stop process for agent run: {agent_run_id}") + +async def _cleanup_redis_response_list(agent_run_id: str): + """Set TTL on the Redis response list.""" + response_list_key = f"agent_run:{agent_run_id}:responses" + try: + await redis.expire(response_list_key, REDIS_RESPONSE_LIST_TTL) + logger.debug(f"Set TTL ({REDIS_RESPONSE_LIST_TTL}s) on response list: {response_list_key}") + except Exception as e: + logger.warning(f"Failed to set TTL on response list {response_list_key}: {str(e)}") + async def restore_running_agent_runs(): - """Restore any agent runs that were still marked as running in the database.""" + """Mark agent runs that were still 'running' in the database as failed.""" logger.info("Restoring running agent runs after server restart") client = await db.client - running_agent_runs = await client.table('agent_runs').select('*').eq("status", "running").execute() + running_agent_runs = await client.table('agent_runs').select('id').eq("status", "running").execute() for run in running_agent_runs.data: - logger.warning(f"Found running agent run {run['id']} from before server restart") - await client.table('agent_runs').update({ - "status": "failed", - "error": "Server restarted while agent was running", - "completed_at": datetime.now(timezone.utc).isoformat() - }).eq("id", run['id']).execute() + agent_run_id = run['id'] + logger.warning(f"Found running agent run {agent_run_id} from before server restart") + # Call stop_agent_run to handle status update and cleanup + await stop_agent_run(agent_run_id, error_message="Server restarted while agent was running") async def check_for_active_project_agent_run(client, project_id: str): """ Check if there is an active agent run for any thread in the given project. If found, returns the ID of the active run, otherwise returns None. - - Args: - client: The Supabase client - project_id: The project ID to check - - Returns: - str or None: The ID of the active agent run if found, None otherwise """ - # Get all threads from this project project_threads = await client.table('threads').select('thread_id').eq('project_id', project_id).execute() project_thread_ids = [t['thread_id'] for t in project_threads.data] - - # Check if there are any active agent runs for any thread in this project + if project_thread_ids: active_runs = await client.table('agent_runs').select('id').in_('thread_id', project_thread_ids).eq('status', 'running').execute() - if active_runs.data and len(active_runs.data) > 0: return active_runs.data[0]['id'] - return None async def get_agent_run_with_access_check(client, agent_run_id: str, user_id: str): - """ - Get an agent run's data after verifying the user has access to it through account membership. - - Args: - client: The Supabase client - agent_run_id: The agent run ID to check access for - user_id: The user ID to check permissions for - - Returns: - dict: The agent run data if access is granted - - Raises: - HTTPException: If the user doesn't have access or the agent run doesn't exist - """ + """Get agent run data after verifying user access.""" agent_run = await client.table('agent_runs').select('*').eq('id', agent_run_id).execute() - - if not agent_run.data or len(agent_run.data) == 0: + if not agent_run.data: raise HTTPException(status_code=404, detail="Agent run not found") - + agent_run_data = agent_run.data[0] thread_id = agent_run_data['thread_id'] - - # Verify user has access to this thread using the updated verify_thread_access function await verify_thread_access(client, thread_id, user_id) - return agent_run_data -async def _cleanup_agent_run(agent_run_id: str): - """Clean up Redis keys when an agent run is done.""" - logger.debug(f"Cleaning up Redis keys for agent run: {agent_run_id}") +async def _cleanup_redis_instance_key(agent_run_id: str): + """Clean up the instance-specific Redis key for an agent run.""" + if not instance_id: + logger.warning("Instance ID not set, cannot clean up instance key.") + return + key = f"active_run:{instance_id}:{agent_run_id}" + logger.debug(f"Cleaning up Redis instance key: {key}") try: - await redis.delete(f"active_run:{instance_id}:{agent_run_id}") - logger.debug(f"Successfully cleaned up Redis keys for agent run: {agent_run_id}") + await redis.delete(key) + logger.debug(f"Successfully cleaned up Redis key: {key}") except Exception as e: - logger.warning(f"Failed to clean up Redis keys for agent run {agent_run_id}: {str(e)}") - # Non-fatal error, can continue + logger.warning(f"Failed to clean up Redis key {key}: {str(e)}") + async def get_or_create_project_sandbox(client, project_id: str): - """ - Get or create a sandbox for a project without distributed locking. - - Args: - client: The Supabase client - project_id: The project ID to get or create a sandbox for - - Returns: - Tuple of (sandbox object, sandbox_id, sandbox_pass) - """ - # First get the current project data to check if a sandbox already exists + """Get or create a sandbox for a project.""" project = await client.table('projects').select('*').eq('project_id', project_id).execute() - if not project.data or len(project.data) == 0: + if not project.data: raise ValueError(f"Project {project_id} not found") - project_data = project.data[0] - - # If project already has a sandbox, just use it + if project_data.get('sandbox', {}).get('id'): sandbox_id = project_data['sandbox']['id'] sandbox_pass = project_data['sandbox']['pass'] logger.info(f"Project {project_id} already has sandbox {sandbox_id}, retrieving it") - try: sandbox = await get_or_start_sandbox(sandbox_id) - return (sandbox, sandbox_id, sandbox_pass) + return sandbox, sandbox_id, sandbox_pass except Exception as e: - logger.error(f"Failed to retrieve existing sandbox {sandbox_id} for project {project_id}: {str(e)}") - # Fall through to create a new sandbox if retrieval fails - - # Create a new sandbox - try: - logger.info(f"Creating new sandbox for project {project_id}") - sandbox_pass = str(uuid.uuid4()) - sandbox = create_sandbox(sandbox_pass) - sandbox_id = sandbox.id - - logger.info(f"Created new sandbox {sandbox_id} with preview: {sandbox.get_preview_link(6080)}/vnc_lite.html?password={sandbox_pass}") - - # Get preview links - vnc_link = sandbox.get_preview_link(6080) - website_link = sandbox.get_preview_link(8080) - - # Extract the actual URLs and token from the preview link objects - vnc_url = vnc_link.url if hasattr(vnc_link, 'url') else str(vnc_link).split("url='")[1].split("'")[0] - website_url = website_link.url if hasattr(website_link, 'url') else str(website_link).split("url='")[1].split("'")[0] - - # Extract token if available - token = None - if hasattr(vnc_link, 'token'): - token = vnc_link.token - elif "token='" in str(vnc_link): - token = str(vnc_link).split("token='")[1].split("'")[0] - - # Update the project with the new sandbox info - update_result = await client.table('projects').update({ - 'sandbox': { - 'id': sandbox_id, - 'pass': sandbox_pass, - 'vnc_preview': vnc_url, - 'sandbox_url': website_url, - 'token': token - } - }).eq('project_id', project_id).execute() - - if not update_result.data: - logger.error(f"Failed to update project {project_id} with new sandbox {sandbox_id}") - raise Exception("Database update failed") - - return (sandbox, sandbox_id, sandbox_pass) - - except Exception as e: - logger.error(f"Error creating sandbox for project {project_id}: {str(e)}") - raise e + logger.error(f"Failed to retrieve existing sandbox {sandbox_id}: {str(e)}. Creating a new one.") + + logger.info(f"Creating new sandbox for project {project_id}") + sandbox_pass = str(uuid.uuid4()) + sandbox = create_sandbox(sandbox_pass) + sandbox_id = sandbox.id + logger.info(f"Created new sandbox {sandbox_id}") + + vnc_link = sandbox.get_preview_link(6080) + website_link = sandbox.get_preview_link(8080) + vnc_url = vnc_link.url if hasattr(vnc_link, 'url') else str(vnc_link).split("url='")[1].split("'")[0] + website_url = website_link.url if hasattr(website_link, 'url') else str(website_link).split("url='")[1].split("'")[0] + token = None + if hasattr(vnc_link, 'token'): + token = vnc_link.token + elif "token='" in str(vnc_link): + token = str(vnc_link).split("token='")[1].split("'")[0] + + update_result = await client.table('projects').update({ + 'sandbox': { + 'id': sandbox_id, 'pass': sandbox_pass, 'vnc_preview': vnc_url, + 'sandbox_url': website_url, 'token': token + } + }).eq('project_id', project_id).execute() + + if not update_result.data: + logger.error(f"Failed to update project {project_id} with new sandbox {sandbox_id}") + raise Exception("Database update failed") + + return sandbox, sandbox_id, sandbox_pass @router.post("/thread/{thread_id}/agent/start") async def start_agent( thread_id: str, - body: AgentStartRequest = Body(...), # Accept request body + body: AgentStartRequest = Body(...), user_id: str = Depends(get_current_user_id) ): """Start an agent for a specific thread in the background.""" - logger.info(f"Starting new agent for thread: {thread_id} with config: model={body.model_name}, thinking={body.enable_thinking}, effort={body.reasoning_effort}, stream={body.stream}, context_manager={body.enable_context_manager}") + global instance_id # Ensure instance_id is accessible + if not instance_id: + raise HTTPException(status_code=500, detail="Agent API not initialized with instance ID") + + logger.info(f"Starting new agent for thread: {thread_id} with config: model={body.model_name}, thinking={body.enable_thinking}, effort={body.reasoning_effort}, stream={body.stream}, context_manager={body.enable_context_manager} (Instance: {instance_id})") client = await db.client - - # Verify user has access to this thread + await verify_thread_access(client, thread_id, user_id) - - # Get the project_id and account_id for this thread thread_result = await client.table('threads').select('project_id', 'account_id').eq('thread_id', thread_id).execute() if not thread_result.data: raise HTTPException(status_code=404, detail="Thread not found") - thread_data = thread_result.data[0] project_id = thread_data.get('project_id') account_id = thread_data.get('account_id') - - # Check billing status + can_run, message, subscription = await check_billing_status(client, account_id) if not can_run: - raise HTTPException(status_code=402, detail={ - "message": message, - "subscription": subscription - }) - - # Check if there is already an active agent run for this project + raise HTTPException(status_code=402, detail={"message": message, "subscription": subscription}) + active_run_id = await check_for_active_project_agent_run(client, project_id) - - # If there's an active run, stop it first if active_run_id: - logger.info(f"Stopping existing agent run {active_run_id} before starting new one") + logger.info(f"Stopping existing agent run {active_run_id} for project {project_id}") await stop_agent_run(active_run_id) - # Get or create a sandbox for this project using the safe function try: sandbox, sandbox_id, sandbox_pass = await get_or_create_project_sandbox(client, project_id) except Exception as e: - logger.error(f"Failed to get or create sandbox for project {project_id}: {str(e)}") + logger.error(f"Failed to get/create sandbox for project {project_id}: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to initialize sandbox: {str(e)}") - + agent_run = await client.table('agent_runs').insert({ - "thread_id": thread_id, - "status": "running", + "thread_id": thread_id, "status": "running", "started_at": datetime.now(timezone.utc).isoformat() }).execute() - agent_run_id = agent_run.data[0]['id'] logger.info(f"Created new agent run: {agent_run_id}") - - # Initialize in-memory storage for this agent run - active_agent_runs[agent_run_id] = [] - - # Register this run in Redis with TTL + + # Register this run in Redis with TTL using instance ID + instance_key = f"active_run:{instance_id}:{agent_run_id}" try: - await redis.set( - f"active_run:{instance_id}:{agent_run_id}", - "running", - ex=redis.REDIS_KEY_TTL - ) + await redis.set(instance_key, "running", ex=redis.REDIS_KEY_TTL) except Exception as e: - logger.warning(f"Failed to register agent run in Redis, continuing without Redis tracking: {str(e)}") - + logger.warning(f"Failed to register agent run in Redis ({instance_key}): {str(e)}") + # Run the agent in the background task = asyncio.create_task( run_agent_background( - agent_run_id=agent_run_id, - thread_id=thread_id, - instance_id=instance_id, - project_id=project_id, - sandbox=sandbox, - model_name=MODEL_NAME_ALIASES.get(body.model_name, body.model_name), - enable_thinking=body.enable_thinking, - reasoning_effort=body.reasoning_effort, - stream=body.stream, - enable_context_manager=body.enable_context_manager + agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id, + project_id=project_id, sandbox=sandbox, + model_name=MODEL_NAME_ALIASES.get(body.model_name, body.model_name), + enable_thinking=body.enable_thinking, reasoning_effort=body.reasoning_effort, + stream=body.stream, enable_context_manager=body.enable_context_manager ) ) - - # Set a callback to clean up when task is done - task.add_done_callback( - lambda _: asyncio.create_task( - _cleanup_agent_run(agent_run_id) - ) - ) - + + # Set a callback to clean up Redis instance key when task is done + task.add_done_callback(lambda _: asyncio.create_task(_cleanup_redis_instance_key(agent_run_id))) + return {"agent_run_id": agent_run_id, "status": "running"} @router.post("/agent-run/{agent_run_id}/stop") async def stop_agent(agent_run_id: str, user_id: str = Depends(get_current_user_id)): """Stop a running agent.""" - logger.info(f"Stopping agent run: {agent_run_id}") + logger.info(f"Received request to stop agent run: {agent_run_id}") client = await db.client - - # Verify user has access to the agent run await get_agent_run_with_access_check(client, agent_run_id, user_id) - - # Stop the agent run await stop_agent_run(agent_run_id) - return {"status": "stopped"} @router.get("/thread/{thread_id}/agent-runs") @@ -443,11 +401,8 @@ async def get_agent_runs(thread_id: str, user_id: str = Depends(get_current_user """Get all agent runs for a thread.""" logger.info(f"Fetching agent runs for thread: {thread_id}") client = await db.client - - # Verify user has access to this thread await verify_thread_access(client, thread_id, user_id) - - agent_runs = await client.table('agent_runs').select('*').eq("thread_id", thread_id).execute() + agent_runs = await client.table('agent_runs').select('*').eq("thread_id", thread_id).order('created_at', desc=True).execute() logger.debug(f"Found {len(agent_runs.data)} agent runs for thread: {thread_id}") return {"agent_runs": agent_runs.data} @@ -456,9 +411,8 @@ async def get_agent_run(agent_run_id: str, user_id: str = Depends(get_current_us """Get agent run status and responses.""" logger.info(f"Fetching agent run details: {agent_run_id}") client = await db.client - agent_run_data = await get_agent_run_with_access_check(client, agent_run_id, user_id) - + # Note: Responses are not included here by default, they are in the stream or DB return { "id": agent_run_data['id'], "threadId": agent_run_data['thread_id'], @@ -470,95 +424,191 @@ async def get_agent_run(agent_run_id: str, user_id: str = Depends(get_current_us @router.get("/agent-run/{agent_run_id}/stream") async def stream_agent_run( - agent_run_id: str, + agent_run_id: str, token: Optional[str] = None, request: Request = None ): - """Stream the responses of an agent run from in-memory storage or reconnect to ongoing run.""" + """Stream the responses of an agent run using Redis Lists and Pub/Sub.""" logger.info(f"Starting stream for agent run: {agent_run_id}") client = await db.client - - # Get user ID using the streaming auth function + user_id = await get_user_id_from_stream_auth(request, token) - - # Verify user has access to the agent run and get run data agent_run_data = await get_agent_run_with_access_check(client, agent_run_id, user_id) - - # Initialize response storage if not already in memory - if agent_run_id not in active_agent_runs: - active_agent_runs[agent_run_id] = [] - logger.info(f"Initialized missing response storage for agent run: {agent_run_id}") - - # If run is completed/failed, try to load responses from database - if agent_run_data['status'] in ['completed', 'failed', 'stopped']: - try: - # Get responses from database - run_with_responses = await client.table('agent_runs').select('responses').eq("id", agent_run_id).execute() - if run_with_responses.data and run_with_responses.data[0].get('responses'): - active_agent_runs[agent_run_id] = run_with_responses.data[0]['responses'] - logger.info(f"Loaded {len(active_agent_runs[agent_run_id])} responses from database for agent run: {agent_run_id}") - except Exception as e: - logger.error(f"Failed to load responses from database for agent run {agent_run_id}: {str(e)}") - - # Define a streaming generator that uses in-memory responses + + response_list_key = f"agent_run:{agent_run_id}:responses" + response_channel = f"agent_run:{agent_run_id}:new_response" + control_channel = f"agent_run:{agent_run_id}:control" # Global control channel + async def stream_generator(): - logger.debug(f"Streaming responses for agent run: {agent_run_id}") - - # Check if this is an active run with stored responses - if agent_run_id in active_agent_runs: - # First, send all existing responses - stored_responses = active_agent_runs[agent_run_id] - logger.debug(f"Sending {len(stored_responses)} existing responses for agent run: {agent_run_id}") - - for response in stored_responses: - yield f"data: {json.dumps(response)}\n\n" - - # If the run is still active (status is running), set up to stream new responses - if agent_run_data['status'] == 'running': - # Get the current length to know where to start watching for new responses - current_length = len(stored_responses) - - # Keep checking for new responses - while agent_run_id in active_agent_runs: - # Check if there are new responses - if len(active_agent_runs[agent_run_id]) > current_length: - # Send all new responses - for i in range(current_length, len(active_agent_runs[agent_run_id])): - response = active_agent_runs[agent_run_id][i] - yield f"data: {json.dumps(response)}\n\n" - - # Update current length - current_length = len(active_agent_runs[agent_run_id]) - - # Brief pause before checking again - await asyncio.sleep(0.1) - else: - # This should not happen now, but just in case - - logger.warning(f"Agent run {agent_run_id} not found in active runs even after initialization") - yield f"data: {json.dumps({'type': 'status', 'status': agent_run_data['status'], 'message': 'Run data not available for streaming'})}\n\n" - - # Always send a completion status at the end - yield f"data: {json.dumps({'type': 'status', 'status': 'completed'})}\n\n" - logger.debug(f"Streaming complete for agent run: {agent_run_id}") - - # Return a streaming response - return StreamingResponse( - stream_generator(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache, no-transform", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - "Content-Type": "text/event-stream", - "Access-Control-Allow-Origin": "*" - } - ) + logger.debug(f"Streaming responses for {agent_run_id} using Redis list {response_list_key} and channel {response_channel}") + last_processed_index = -1 + pubsub_response = None + pubsub_control = None + listener_task = None + terminate_stream = False + initial_yield_complete = False + + try: + # 1. Fetch and yield initial responses from Redis list + initial_responses_json = await redis.lrange(response_list_key, 0, -1) + initial_responses = [] + if initial_responses_json: + initial_responses = [json.loads(r) for r in initial_responses_json] + logger.debug(f"Sending {len(initial_responses)} initial responses for {agent_run_id}") + for response in initial_responses: + yield f"data: {json.dumps(response)}\n\n" + last_processed_index = len(initial_responses) - 1 + initial_yield_complete = True + + # 2. Check run status *after* yielding initial data + run_status = await client.table('agent_runs').select('status').eq("id", agent_run_id).maybe_single().execute() + current_status = run_status.data.get('status') if run_status.data else None + + if current_status != 'running': + logger.info(f"Agent run {agent_run_id} is not running (status: {current_status}). Ending stream.") + yield f"data: {json.dumps({'type': 'status', 'status': 'completed'})}\n\n" + return + + # 3. Set up Pub/Sub listeners for new responses and control signals + pubsub_response = await redis.create_pubsub() + await pubsub_response.subscribe(response_channel) + logger.debug(f"Subscribed to response channel: {response_channel}") + + pubsub_control = await redis.create_pubsub() + await pubsub_control.subscribe(control_channel) + logger.debug(f"Subscribed to control channel: {control_channel}") + + # Queue to communicate between listeners and the main generator loop + message_queue = asyncio.Queue() + + async def listen_messages(): + response_reader = pubsub_response.listen() + control_reader = pubsub_control.listen() + tasks = [asyncio.create_task(response_reader.__anext__()), asyncio.create_task(control_reader.__anext__())] + + while not terminate_stream: + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + for task in done: + try: + message = task.result() + if message and isinstance(message, dict) and message.get("type") == "message": + channel = message.get("channel") + data = message.get("data") + if isinstance(data, bytes): data = data.decode('utf-8') + + if channel == response_channel and data == "new": + await message_queue.put({"type": "new_response"}) + elif channel == control_channel and data in ["STOP", "END_STREAM", "ERROR"]: + logger.info(f"Received control signal '{data}' for {agent_run_id}") + await message_queue.put({"type": "control", "data": data}) + return # Stop listening on control signal + + except StopAsyncIteration: + logger.warning(f"Listener {task} stopped.") + # Decide how to handle listener stopping, maybe terminate? + await message_queue.put({"type": "error", "data": "Listener stopped unexpectedly"}) + return + except Exception as e: + logger.error(f"Error in listener for {agent_run_id}: {e}") + await message_queue.put({"type": "error", "data": "Listener failed"}) + return + finally: + # Reschedule the completed listener task + if task in tasks: + tasks.remove(task) + if message and isinstance(message, dict) and message.get("channel") == response_channel: + tasks.append(asyncio.create_task(response_reader.__anext__())) + elif message and isinstance(message, dict) and message.get("channel") == control_channel: + tasks.append(asyncio.create_task(control_reader.__anext__())) + + # Cancel pending listener tasks on exit + for p_task in pending: p_task.cancel() + for task in tasks: task.cancel() + + + listener_task = asyncio.create_task(listen_messages()) + + # 4. Main loop to process messages from the queue + while not terminate_stream: + try: + queue_item = await message_queue.get() + + if queue_item["type"] == "new_response": + # Fetch new responses from Redis list starting after the last processed index + new_start_index = last_processed_index + 1 + new_responses_json = await redis.lrange(response_list_key, new_start_index, -1) + + if new_responses_json: + new_responses = [json.loads(r) for r in new_responses_json] + num_new = len(new_responses) + logger.debug(f"Received {num_new} new responses for {agent_run_id} (index {new_start_index} onwards)") + for response in new_responses: + yield f"data: {json.dumps(response)}\n\n" + # Check if this response signals completion + if response.get('type') == 'status' and response.get('status') in ['completed', 'failed', 'stopped']: + logger.info(f"Detected run completion via status message in stream: {response.get('status')}") + terminate_stream = True + break # Stop processing further new responses + last_processed_index += num_new + if terminate_stream: break + + elif queue_item["type"] == "control": + control_signal = queue_item["data"] + terminate_stream = True # Stop the stream on any control signal + yield f"data: {json.dumps({'type': 'status', 'status': control_signal})}\n\n" + break + + elif queue_item["type"] == "error": + logger.error(f"Listener error for {agent_run_id}: {queue_item['data']}") + terminate_stream = True + yield f"data: {json.dumps({'type': 'status', 'status': 'error'})}\n\n" + break + + except asyncio.CancelledError: + logger.info(f"Stream generator main loop cancelled for {agent_run_id}") + terminate_stream = True + break + except Exception as loop_err: + logger.error(f"Error in stream generator main loop for {agent_run_id}: {loop_err}", exc_info=True) + terminate_stream = True + yield f"data: {json.dumps({'type': 'status', 'status': 'error', 'message': f'Stream failed: {loop_err}'})}\n\n" + break + + except Exception as e: + logger.error(f"Error setting up stream for agent run {agent_run_id}: {e}", exc_info=True) + # Only yield error if initial yield didn't happen + if not initial_yield_complete: + yield f"data: {json.dumps({'type': 'status', 'status': 'error', 'message': f'Failed to start stream: {e}'})}\n\n" + finally: + terminate_stream = True + # Graceful shutdown order: unsubscribe → close → cancel + if pubsub_response: await pubsub_response.unsubscribe(response_channel) + if pubsub_control: await pubsub_control.unsubscribe(control_channel) + if pubsub_response: await pubsub_response.close() + if pubsub_control: await pubsub_control.close() + + if listener_task: + listener_task.cancel() + try: + await listener_task # Reap inner tasks & swallow their errors + except asyncio.CancelledError: + pass + except Exception as e: + logger.debug(f"listener_task ended with: {e}") + # Wait briefly for tasks to cancel + await asyncio.sleep(0.1) + logger.debug(f"Streaming cleanup complete for agent run: {agent_run_id}") + + return StreamingResponse(stream_generator(), media_type="text/event-stream", headers={ + "Cache-Control": "no-cache, no-transform", "Connection": "keep-alive", + "X-Accel-Buffering": "no", "Content-Type": "text/event-stream", + "Access-Control-Allow-Origin": "*" + }) async def run_agent_background( agent_run_id: str, thread_id: str, - instance_id: str, + instance_id: str, # Use the global instance ID passed during initialization project_id: str, sandbox, model_name: str, @@ -567,283 +617,191 @@ async def run_agent_background( stream: bool, enable_context_manager: bool ): - """Run the agent in the background and handle status updates.""" - logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id}) with model={model_name}, thinking={enable_thinking}, effort={reasoning_effort}, stream={stream}, context_manager={enable_context_manager}") + """Run the agent in the background using Redis for state.""" + logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (Instance: {instance_id})") client = await db.client - - # Tracking variables - total_responses = 0 start_time = datetime.now(timezone.utc) - - # Create a pubsub to listen for control messages + total_responses = 0 pubsub = None - try: - pubsub = await redis.create_pubsub() - - # Use instance-specific control channel to avoid cross-talk between instances - control_channel = f"agent_run:{agent_run_id}:control:{instance_id}" - # Use backoff retry pattern for pubsub connection - retry_count = 0 - while retry_count < 3: - try: - await pubsub.subscribe(control_channel) - logger.debug(f"Subscribed to control channel: {control_channel}") - break - except Exception as e: - retry_count += 1 - if retry_count >= 3: - logger.error(f"Failed to subscribe to control channel after 3 attempts: {str(e)}") - raise - wait_time = 0.5 * (2 ** (retry_count - 1)) - logger.warning(f"Failed to subscribe to control channel (attempt {retry_count}/3): {str(e)}. Retrying in {wait_time}s...") - await asyncio.sleep(wait_time) - - # Also subscribe to the global control channel for cross-instance control - global_control_channel = f"agent_run:{agent_run_id}:control" - retry_count = 0 - while retry_count < 3: - try: - await pubsub.subscribe(global_control_channel) - logger.debug(f"Subscribed to global control channel: {global_control_channel}") - break - except Exception as e: - retry_count += 1 - if retry_count >= 3: - logger.error(f"Failed to subscribe to global control channel after 3 attempts: {str(e)}") - # We can continue with just the instance-specific channel - break - wait_time = 0.5 * (2 ** (retry_count - 1)) - logger.warning(f"Failed to subscribe to global control channel (attempt {retry_count}/3): {str(e)}. Retrying in {wait_time}s...") - await asyncio.sleep(wait_time) - except Exception as e: - logger.error(f"Failed to initialize Redis pubsub: {str(e)}") - pubsub = None - - # Keep Redis key up-to-date with TTL refresh - try: - # Extend TTL on the active run key to prevent expiration during long runs - await redis.set( - f"active_run:{instance_id}:{agent_run_id}", - "running", - ex=redis.REDIS_KEY_TTL - ) - except Exception as e: - logger.warning(f"Failed to refresh active run key TTL: {str(e)}") - - # Start a background task to check for stop signals - stop_signal_received = False stop_checker = None - + stop_signal_received = False + + # Define Redis keys and channels + response_list_key = f"agent_run:{agent_run_id}:responses" + response_channel = f"agent_run:{agent_run_id}:new_response" + instance_control_channel = f"agent_run:{agent_run_id}:control:{instance_id}" + global_control_channel = f"agent_run:{agent_run_id}:control" + instance_active_key = f"active_run:{instance_id}:{agent_run_id}" + async def check_for_stop_signal(): nonlocal stop_signal_received - if not pubsub: - logger.warning("Stop signal checker not started - pubsub not available") - return - + if not pubsub: return try: - while True: - try: - message = await pubsub.get_message(timeout=0.5) - if message and message["type"] == "message": - stop_signal = "STOP" - if message["data"] == stop_signal or message["data"] == stop_signal.encode('utf-8'): - logger.info(f"Received stop signal for agent run: {agent_run_id} (instance: {instance_id})") - stop_signal_received = True - break - except Exception as e: - logger.warning(f"Error checking for stop signals: {str(e)}") - # Brief pause before retry - await asyncio.sleep(1) - - # Check if we should stop naturally - if stop_signal_received: - break - - # Periodically refresh the active run key's TTL - try: - if total_responses % 100 == 0: - await redis.set( - f"active_run:{instance_id}:{agent_run_id}", - "running", - ex=redis.REDIS_KEY_TTL - ) - except Exception as e: - logger.warning(f"Failed to refresh active run key TTL: {str(e)}") - - await asyncio.sleep(0.1) + while not stop_signal_received: + message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.5) + if message and message.get("type") == "message": + data = message.get("data") + if isinstance(data, bytes): data = data.decode('utf-8') + if data == "STOP": + logger.info(f"Received STOP signal for agent run {agent_run_id} (Instance: {instance_id})") + stop_signal_received = True + break + # Periodically refresh the active run key TTL + if total_responses % 50 == 0: # Refresh every 50 responses or so + try: await redis.expire(instance_active_key, redis.REDIS_KEY_TTL) + except Exception as ttl_err: logger.warning(f"Failed to refresh TTL for {instance_active_key}: {ttl_err}") + await asyncio.sleep(0.1) # Short sleep to prevent tight loop except asyncio.CancelledError: - logger.info(f"Stop signal checker task cancelled (instance: {instance_id})") + logger.info(f"Stop signal checker cancelled for {agent_run_id} (Instance: {instance_id})") except Exception as e: - logger.error(f"Unexpected error in stop signal checker: {str(e)}", exc_info=True) - - # Start the stop signal checker if pubsub is available - if pubsub: - stop_checker = asyncio.create_task(check_for_stop_signal()) - logger.debug(f"Started stop signal checker for agent run: {agent_run_id} (instance: {instance_id})") - else: - logger.warning(f"No stop signal checker for agent run: {agent_run_id} - pubsub unavailable") - + logger.error(f"Error in stop signal checker for {agent_run_id}: {e}", exc_info=True) + stop_signal_received = True # Stop the run if the checker fails + try: - # Run the agent - logger.debug(f"Initializing agent generator for thread: {thread_id} (instance: {instance_id})") + # Setup Pub/Sub listener for control signals + pubsub = await redis.create_pubsub() + await pubsub.subscribe(instance_control_channel, global_control_channel) + logger.debug(f"Subscribed to control channels: {instance_control_channel}, {global_control_channel}") + stop_checker = asyncio.create_task(check_for_stop_signal()) + + # Ensure active run key exists and has TTL + await redis.set(instance_active_key, "running", ex=redis.REDIS_KEY_TTL) + + # Initialize agent generator agent_gen = run_agent( - thread_id=thread_id, - project_id=project_id, - stream=stream, - thread_manager=thread_manager, - model_name=model_name, - enable_thinking=enable_thinking, - reasoning_effort=reasoning_effort, + thread_id=thread_id, project_id=project_id, stream=stream, + thread_manager=thread_manager, model_name=model_name, + enable_thinking=enable_thinking, reasoning_effort=reasoning_effort, enable_context_manager=enable_context_manager ) - - # Collect all responses to save to database - all_responses = [] - + + final_status = "running" + error_message = None + async for response in agent_gen: - # Check if stop signal received if stop_signal_received: - logger.info(f"Agent run stopped due to stop signal: {agent_run_id} (instance: {instance_id})") - await update_agent_run_status(client, agent_run_id, "stopped", responses=all_responses) + logger.info(f"Agent run {agent_run_id} stopped by signal.") + final_status = "stopped" break - - # Check for billing error status - if response.get('type') == 'status' and response.get('status') == 'error': - error_msg = response.get('message', '') - logger.info(f"Agent run failed with error: {error_msg} (instance: {instance_id})") - await update_agent_run_status(client, agent_run_id, "failed", error=error_msg, responses=all_responses) - break - - # Store response in memory - if agent_run_id in active_agent_runs: - active_agent_runs[agent_run_id].append(response) - all_responses.append(response) - total_responses += 1 - - # Signal all done if we weren't stopped - if not stop_signal_received: - duration = (datetime.now(timezone.utc) - start_time).total_seconds() - logger.info(f"Thread Run Response completed successfully: {agent_run_id} (duration: {duration:.2f}s, total responses: {total_responses}, instance: {instance_id})") - - # Add completion message to the stream - completion_message = { - "type": "status", - "status": "completed", - "message": "Agent run completed successfully" - } - if agent_run_id in active_agent_runs: - active_agent_runs[agent_run_id].append(completion_message) - all_responses.append(completion_message) - - # Update the agent run status - await update_agent_run_status(client, agent_run_id, "completed", responses=all_responses) - - # Notify any clients monitoring the control channels that we're done - try: - if pubsub: - await redis.publish(f"agent_run:{agent_run_id}:control:{instance_id}", "END_STREAM") - await redis.publish(f"agent_run:{agent_run_id}:control", "END_STREAM") - logger.debug(f"Sent END_STREAM signals for agent run: {agent_run_id} (instance: {instance_id})") - except Exception as e: - logger.warning(f"Failed to publish END_STREAM signals: {str(e)}") - + + # Store response in Redis list and publish notification + response_json = json.dumps(response) + await redis.rpush(response_list_key, response_json) + await redis.publish(response_channel, "new") + total_responses += 1 + + # Check for agent-signaled completion or error + if response.get('type') == 'status': + status_val = response.get('status') + if status_val in ['completed', 'failed', 'stopped']: + logger.info(f"Agent run {agent_run_id} finished via status message: {status_val}") + final_status = status_val + if status_val == 'failed' or status_val == 'stopped': + error_message = response.get('message', f"Run ended with status: {status_val}") + break + + # If loop finished without explicit completion/error/stop signal, mark as completed + if final_status == "running": + final_status = "completed" + duration = (datetime.now(timezone.utc) - start_time).total_seconds() + logger.info(f"Agent run {agent_run_id} completed normally (duration: {duration:.2f}s, responses: {total_responses})") + completion_message = {"type": "status", "status": "completed", "message": "Agent run completed successfully"} + await redis.rpush(response_list_key, json.dumps(completion_message)) + await redis.publish(response_channel, "new") # Notify about the completion message + + # Fetch final responses from Redis for DB update + all_responses_json = await redis.lrange(response_list_key, 0, -1) + all_responses = [json.loads(r) for r in all_responses_json] + + # Update DB status + await update_agent_run_status(client, agent_run_id, final_status, error=error_message, responses=all_responses) + + # Publish final control signal (END_STREAM or ERROR) + control_signal = "END_STREAM" if final_status == "completed" else "ERROR" if final_status == "failed" else "STOP" + try: + await redis.publish(global_control_channel, control_signal) + # No need to publish to instance channel as the run is ending on this instance + logger.debug(f"Published final control signal '{control_signal}' to {global_control_channel}") + except Exception as e: + logger.warning(f"Failed to publish final control signal {control_signal}: {str(e)}") + except Exception as e: - # Log the error and update the agent run error_message = str(e) traceback_str = traceback.format_exc() duration = (datetime.now(timezone.utc) - start_time).total_seconds() - logger.error(f"Error in agent run {agent_run_id} after {duration:.2f}s: {error_message}\n{traceback_str} (instance: {instance_id})") - - # Add error message to the stream - error_response = { - "type": "status", - "status": "error", - "message": error_message - } - if agent_run_id in active_agent_runs: - active_agent_runs[agent_run_id].append(error_response) - if 'all_responses' in locals(): - all_responses.append(error_response) - else: - all_responses = [error_response] - - # Update the agent run with the error - await update_agent_run_status( - client, - agent_run_id, - "failed", - error=f"{error_message}\n{traceback_str}", - responses=all_responses - ) - - # Notify any clients of the error + logger.error(f"Error in agent run {agent_run_id} after {duration:.2f}s: {error_message}\n{traceback_str} (Instance: {instance_id})") + final_status = "failed" + + # Push error message to Redis list + error_response = {"type": "status", "status": "error", "message": error_message} try: - if pubsub: - await redis.publish(f"agent_run:{agent_run_id}:control:{instance_id}", "ERROR") - await redis.publish(f"agent_run:{agent_run_id}:control", "ERROR") - logger.debug(f"Sent ERROR signals for agent run: {agent_run_id} (instance: {instance_id})") + await redis.rpush(response_list_key, json.dumps(error_response)) + await redis.publish(response_channel, "new") + except Exception as redis_err: + logger.error(f"Failed to push error response to Redis for {agent_run_id}: {redis_err}") + + # Fetch final responses (including the error) + all_responses = [] + try: + all_responses_json = await redis.lrange(response_list_key, 0, -1) + all_responses = [json.loads(r) for r in all_responses_json] + except Exception as fetch_err: + logger.error(f"Failed to fetch responses from Redis after error for {agent_run_id}: {fetch_err}") + all_responses = [error_response] # Use the error message we tried to push + + # Update DB status + await update_agent_run_status(client, agent_run_id, "failed", error=f"{error_message}\n{traceback_str}", responses=all_responses) + + # Publish ERROR signal + try: + await redis.publish(global_control_channel, "ERROR") + logger.debug(f"Published ERROR signal to {global_control_channel}") except Exception as e: - logger.warning(f"Failed to publish ERROR signals: {str(e)}") - + logger.warning(f"Failed to publish ERROR signal: {str(e)}") + finally: - # Ensure we always clean up the pubsub and stop checker - if stop_checker: - try: - stop_checker.cancel() - logger.debug(f"Cancelled stop signal checker task for agent run: {agent_run_id} (instance: {instance_id})") - except Exception as e: - logger.warning(f"Error cancelling stop checker: {str(e)}") - + # Cleanup stop checker task + if stop_checker and not stop_checker.done(): + stop_checker.cancel() + try: await stop_checker + except asyncio.CancelledError: pass + except Exception as e: logger.warning(f"Error during stop_checker cancellation: {e}") + + # Close pubsub connection if pubsub: try: await pubsub.unsubscribe() - logger.debug(f"Successfully unsubscribed from pubsub for agent run: {agent_run_id} (instance: {instance_id})") + await pubsub.close() + logger.debug(f"Closed pubsub connection for {agent_run_id}") except Exception as e: - logger.warning(f"Error unsubscribing from pubsub: {str(e)}") - - # Clean up the Redis key - try: - await redis.delete(f"active_run:{instance_id}:{agent_run_id}") - logger.debug(f"Deleted active run key for agent run: {agent_run_id} (instance: {instance_id})") - except Exception as e: - logger.warning(f"Error deleting active run key: {str(e)}") - - logger.info(f"Agent run background task fully completed for: {agent_run_id} (instance: {instance_id})") + logger.warning(f"Error closing pubsub for {agent_run_id}: {str(e)}") + + # Set TTL on the response list in Redis + await _cleanup_redis_response_list(agent_run_id) + + # Remove the instance-specific active run key + await _cleanup_redis_instance_key(agent_run_id) + + logger.info(f"Agent run background task fully completed for: {agent_run_id} (Instance: {instance_id}) with final status: {final_status}") async def generate_and_update_project_name(project_id: str, prompt: str): """Generates a project name using an LLM and updates the database.""" logger.info(f"Starting background task to generate name for project: {project_id}") try: - # Ensure db client is ready (may need re-initialization in background task context) - # Getting a fresh connection within the task might be safer db_conn = DBConnection() client = await db_conn.client - - # Prepare LLM call - model_name = "openai/gpt-4o-mini" # Or claude-3-haiku + + model_name = "openai/gpt-4o-mini" system_prompt = "You are a helpful assistant that generates extremely concise titles (2-4 words maximum) for chat threads based on the user's message. Respond with only the title, no other text or punctuation." user_message = f"Generate an extremely brief title (2-4 words only) for a chat thread that starts with this message: \"{prompt}\"" - - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_message} - ] - + messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_message}] + logger.debug(f"Calling LLM ({model_name}) for project {project_id} naming.") - - # Use make_llm_api_call (ensure it's compatible with background task context) - response = await make_llm_api_call( - messages=messages, - model_name=model_name, - max_tokens=20, - temperature=0.7 - ) - - # Extract and clean the name + response = await make_llm_api_call(messages=messages, model_name=model_name, max_tokens=20, temperature=0.7) + generated_name = None if response and response.get('choices') and response['choices'][0].get('message'): raw_name = response['choices'][0]['message'].get('content', '').strip() - # Simple cleaning: remove quotes and extra whitespace cleaned_name = raw_name.strip('\'" \n\t') if cleaned_name: generated_name = cleaned_name @@ -853,13 +811,8 @@ async def generate_and_update_project_name(project_id: str, prompt: str): else: logger.warning(f"Failed to get valid response from LLM for project {project_id} naming. Response: {response}") - # Update database if name was generated if generated_name: - update_result = await client.table('projects') \ - .update({"name": generated_name}) \ - .eq("project_id", project_id) \ - .execute() - + update_result = await client.table('projects').update({"name": generated_name}).eq("project_id", project_id).execute() if hasattr(update_result, 'data') and update_result.data: logger.info(f"Successfully updated project {project_id} name to '{generated_name}'") else: @@ -870,8 +823,7 @@ async def generate_and_update_project_name(project_id: str, prompt: str): except Exception as e: logger.error(f"Error in background naming task for project {project_id}: {str(e)}\n{traceback.format_exc()}") finally: - if 'db_conn' in locals(): - pass + # No need to disconnect DBConnection singleton instance here logger.info(f"Finished background naming task for project: {project_id}") @router.post("/agent/initiate", response_model=InitiateAgentResponse) @@ -885,124 +837,78 @@ async def initiate_agent_with_files( files: List[UploadFile] = File(default=[]), user_id: str = Depends(get_current_user_id) ): - """ - Initiate a new agent session with optional file attachments. - Creates project, thread, and sandbox, then uploads files and starts the agent run. - """ - logger.info(f"Initiating new agent with prompt and {len(files)} files") + """Initiate a new agent session with optional file attachments.""" + global instance_id # Ensure instance_id is accessible + if not instance_id: + raise HTTPException(status_code=500, detail="Agent API not initialized with instance ID") + + logger.info(f"[\033[91mDEBUG\033[0m] Initiating new agent with prompt and {len(files)} files (Instance: {instance_id}), model: {model_name}, enable_thinking: {enable_thinking}") client = await db.client - - # In Basejump, personal account_id is the same as user_id - account_id = user_id - - # Check billing status + account_id = user_id # In Basejump, personal account_id is the same as user_id + can_run, message, subscription = await check_billing_status(client, account_id) if not can_run: - raise HTTPException(status_code=402, detail={ - "message": message, - "subscription": subscription - }) - + raise HTTPException(status_code=402, detail={"message": message, "subscription": subscription}) + try: # 1. Create Project - # Use prompt for placeholder name placeholder_name = f"{prompt[:30]}..." if len(prompt) > 30 else prompt - logger.info(f"Using placeholder name: '{placeholder_name}'") - project = await client.table('projects').insert({ - "project_id": str(uuid.uuid4()), - "account_id": account_id, - "name": placeholder_name, # Use placeholder + "project_id": str(uuid.uuid4()), "account_id": account_id, "name": placeholder_name, "created_at": datetime.now(timezone.utc).isoformat() }).execute() - project_id = project.data[0]['project_id'] logger.info(f"Created new project: {project_id}") - + # 2. Create Thread thread = await client.table('threads').insert({ - "thread_id": str(uuid.uuid4()), - "project_id": project_id, - "account_id": account_id, + "thread_id": str(uuid.uuid4()), "project_id": project_id, "account_id": account_id, "created_at": datetime.now(timezone.utc).isoformat() }).execute() - thread_id = thread.data[0]['thread_id'] logger.info(f"Created new thread: {thread_id}") - - # ---- Trigger Background Naming Task ---- - logger.info(f"Scheduling background task to generate name for project {project_id}") - asyncio.create_task( - generate_and_update_project_name( - project_id=project_id, - prompt=prompt - ) - ) - # ----------------------------------------- - # 3. Create Sandbox - try: - sandbox, sandbox_id, sandbox_pass = await get_or_create_project_sandbox(client, project_id) - logger.info(f"Using sandbox {sandbox_id} for new project {project_id}") - except Exception as e: - logger.error(f"Failed to create or get sandbox for project {project_id}: {str(e)}") - raise HTTPException(status_code=500, detail=f"Failed to initialize sandbox: {str(e)}") - + # Trigger Background Naming Task + asyncio.create_task(generate_and_update_project_name(project_id=project_id, prompt=prompt)) + + # 3. Create Sandbox + sandbox, sandbox_id, sandbox_pass = await get_or_create_project_sandbox(client, project_id) + logger.info(f"Using sandbox {sandbox_id} for new project {project_id}") + # 4. Upload Files to Sandbox (if any) + # ... (File upload logic remains the same as in original file) ... message_content = prompt if files: successful_uploads = [] failed_uploads = [] - for file in files: if file.filename: try: - # Sanitize filename safe_filename = file.filename.replace('/', '_').replace('\\', '_') - # Ensure path is relative to /workspace and cleaned - # Using clean_path from utils.files_utils might be safer here - # target_path = clean_path(safe_filename, "/workspace") # Assuming clean_path handles prepending /workspace if needed by upload - # For now, stick to the original target path format used by tools target_path = f"/workspace/{safe_filename}" - logger.info(f"Attempting to upload {safe_filename} to {target_path} in sandbox {sandbox_id}") - - # Read file content content = await file.read() - - # --- Simplified Upload Attempt --- upload_successful = False try: - # Assume sandbox.fs.upload_file is the primary method - # Ensure it's awaited if it's async (Daytona SDK methods often are) if hasattr(sandbox, 'fs') and hasattr(sandbox.fs, 'upload_file'): import inspect if inspect.iscoroutinefunction(sandbox.fs.upload_file): await sandbox.fs.upload_file(target_path, content) else: - # If sync, run in executor? For now, assume async is standard - # Or handle potential blocking if called directly - sandbox.fs.upload_file(target_path, content) # Or run_in_executor if needed + sandbox.fs.upload_file(target_path, content) logger.debug(f"Called sandbox.fs.upload_file for {target_path}") - upload_successful = True # Mark as attempted + upload_successful = True else: - logger.error(f"Sandbox object missing 'fs.upload_file' method for {sandbox_id}") raise NotImplementedError("Suitable upload method not found on sandbox object.") - except Exception as upload_error: logger.error(f"Error during sandbox upload call for {safe_filename}: {str(upload_error)}", exc_info=True) - # Keep upload_successful as False - # --- Verification Step --- if upload_successful: try: - # Short delay to allow filesystem changes to propagate if needed await asyncio.sleep(0.2) - # Verify by listing the directory containing the file parent_dir = os.path.dirname(target_path) files_in_dir = sandbox.fs.list_files(parent_dir) file_names_in_dir = [f.name for f in files_in_dir] - if safe_filename in file_names_in_dir: successful_uploads.append(target_path) logger.info(f"Successfully uploaded and verified file {safe_filename} to sandbox path {target_path}") @@ -1013,100 +919,61 @@ async def initiate_agent_with_files( logger.error(f"Error verifying file {safe_filename} after upload: {str(verify_error)}", exc_info=True) failed_uploads.append(safe_filename) else: - # If the upload call itself failed failed_uploads.append(safe_filename) - except Exception as file_error: logger.error(f"Error processing file {file.filename}: {str(file_error)}", exc_info=True) failed_uploads.append(file.filename) finally: - # Ensure file is closed await file.close() - # Append file references to message content if successful_uploads: message_content += "\n\n" if message_content else "" - for file_path in successful_uploads: - message_content += f"[Uploaded File: {file_path}]\n" - - # Also mention failed uploads if any + for file_path in successful_uploads: message_content += f"[Uploaded File: {file_path}]\n" if failed_uploads: message_content += "\n\nThe following files failed to upload:\n" - for failed_file in failed_uploads: - message_content += f"- {failed_file}\n" - + for failed_file in failed_uploads: message_content += f"- {failed_file}\n" + # ... (End of file upload logic) ... + + # 5. Add initial user message to thread message_id = str(uuid.uuid4()) - # Prepare the message content in the standard format expected by the LLM/database - message_payload = { - "role": "user", - "content": message_content # This already contains the prompt + file references - } + message_payload = {"role": "user", "content": message_content} await client.table('messages').insert({ - "message_id": message_id, - "thread_id": thread_id, - "type": "user", # Use the 'type' column - "is_llm_message": True, # Indicate it's part of the conversation flow - "content": json.dumps(message_payload), # Store the structured message in the content column + "message_id": message_id, "thread_id": thread_id, "type": "user", + "is_llm_message": True, "content": json.dumps(message_payload), "created_at": datetime.now(timezone.utc).isoformat() }).execute() - + # 6. Start Agent Run agent_run = await client.table('agent_runs').insert({ - "thread_id": thread_id, - "status": "running", + "thread_id": thread_id, "status": "running", "started_at": datetime.now(timezone.utc).isoformat() }).execute() - agent_run_id = agent_run.data[0]['id'] logger.info(f"Created new agent run: {agent_run_id}") - - # Initialize in-memory storage for this agent run - active_agent_runs[agent_run_id] = [] - - # Register this run in Redis with TTL + + # Register run in Redis + instance_key = f"active_run:{instance_id}:{agent_run_id}" try: - await redis.set( - f"active_run:{instance_id}:{agent_run_id}", - "running", - ex=redis.REDIS_KEY_TTL - ) + await redis.set(instance_key, "running", ex=redis.REDIS_KEY_TTL) except Exception as e: - logger.warning(f"Failed to register agent run in Redis, continuing without Redis tracking: {str(e)}") - - # Run the agent in the background + logger.warning(f"Failed to register agent run in Redis ({instance_key}): {str(e)}") + + # Run agent in background task = asyncio.create_task( run_agent_background( - agent_run_id=agent_run_id, - thread_id=thread_id, - instance_id=instance_id, - project_id=project_id, - sandbox=sandbox, - model_name=MODEL_NAME_ALIASES.get(model_name, model_name), - enable_thinking=enable_thinking, - reasoning_effort=reasoning_effort, - stream=stream, - enable_context_manager=enable_context_manager + agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id, + project_id=project_id, sandbox=sandbox, + model_name=MODEL_NAME_ALIASES.get(model_name, model_name), + enable_thinking=enable_thinking, reasoning_effort=reasoning_effort, + stream=stream, enable_context_manager=enable_context_manager ) ) - - # Set a callback to clean up when task is done - task.add_done_callback( - lambda _: asyncio.create_task( - _cleanup_agent_run(agent_run_id) - ) - ) - - # Return immediately without waiting for the naming task + task.add_done_callback(lambda _: asyncio.create_task(_cleanup_redis_instance_key(agent_run_id))) + return {"thread_id": thread_id, "agent_run_id": agent_run_id} - + except Exception as e: - # Log the error logger.error(f"Error in agent initiation: {str(e)}\n{traceback.format_exc()}") - - # Todo: Clean up resources if needed (project, thread, sandbox) - - raise HTTPException( - status_code=500, - detail=f"Failed to initiate agent session: {str(e)}" - ) \ No newline at end of file + # TODO: Clean up created project/thread if initiation fails mid-way + raise HTTPException(status_code=500, detail=f"Failed to initiate agent session: {str(e)}") \ No newline at end of file diff --git a/backend/services/redis.py b/backend/services/redis.py index cfb0ef2f..e9a139bf 100644 --- a/backend/services/redis.py +++ b/backend/services/redis.py @@ -7,6 +7,7 @@ import ssl from utils.logger import logger import random from functools import wraps +from typing import List # Added for type hinting # Redis client client = None @@ -76,11 +77,11 @@ def initialize(): ssl=os.getenv('REDIS_SSL', 'True').lower() == 'true', ssl_ca_certs=certifi.where(), decode_responses=True, - socket_timeout=5.0, # Socket timeout + socket_timeout=None, # Changed from 5.0 to None to let listen() block indefinitely socket_connect_timeout=5.0, # Connection timeout retry_on_timeout=True, # Auto-retry on timeout health_check_interval=30, # Check connection health every 30 seconds - max_connections=10 # Limit connections to prevent overloading + max_connections=1000 # Limit connections to prevent overloading ) return client @@ -167,7 +168,31 @@ async def keys(pattern): redis_client = await get_client() return await with_retry(redis_client.keys, pattern) +async def rpush(key, *values): + """Append one or more values to a list with automatic retry.""" + redis_client = await get_client() + return await with_retry(redis_client.rpush, key, *values) + +async def lrange(key, start, end): + """Get a range of elements from a list with automatic retry.""" + redis_client = await get_client() + # Note: lrange returns bytes if decode_responses=False, but we set it to True + # Ensure the return type is List[str] + result: List[str] = await with_retry(redis_client.lrange, key, start, end) + return result + +async def llen(key): + """Get the length of a list with automatic retry.""" + redis_client = await get_client() + return await with_retry(redis_client.llen, key) + +async def expire(key, time): + """Set a key's time to live in seconds with automatic retry.""" + redis_client = await get_client() + return await with_retry(redis_client.expire, key, time) + async def create_pubsub(): """Create a Redis pubsub object.""" redis_client = await get_client() + # decode_responses=True in client init applies to pubsub messages too return redis_client.pubsub() \ No newline at end of file