suna/backend/sandbox/api.py

268 lines
10 KiB
Python
Raw Normal View History

2025-04-11 10:45:32 +08:00
import os
2025-05-07 21:24:00 +08:00
from typing import Optional
2025-04-11 10:45:32 +08:00
2025-04-18 22:04:11 +08:00
from fastapi import FastAPI, UploadFile, File, HTTPException, APIRouter, Form, Depends, Request
2025-05-07 21:24:00 +08:00
from fastapi.responses import Response
2025-04-11 10:45:32 +08:00
from pydantic import BaseModel
2025-05-09 12:06:51 +08:00
from sandbox.sandbox import get_or_start_sandbox
2025-04-11 10:45:32 +08:00
from utils.logger import logger
2025-05-07 21:24:00 +08:00
from utils.auth_utils import get_optional_user_id
2025-04-18 22:04:11 +08:00
from services.supabase import DBConnection
2025-04-23 13:20:38 +08:00
from agent.api import get_or_create_project_sandbox
2025-04-11 10:45:32 +08:00
2025-04-18 22:04:11 +08:00
# Initialize shared resources
router = APIRouter(tags=["sandbox"])
db = None
def initialize(_db: DBConnection):
"""Initialize the sandbox API with resources from the main API."""
global db
db = _db
logger.info("Initialized sandbox API with database connection")
2025-04-11 10:45:32 +08:00
class FileInfo(BaseModel):
"""Model for file information"""
name: str
path: str
is_dir: bool
size: int
mod_time: str
permissions: Optional[str] = None
2025-04-21 22:08:28 +08:00
async def verify_sandbox_access(client, sandbox_id: str, user_id: Optional[str] = None):
2025-04-18 22:04:11 +08:00
"""
Verify that a user has access to a specific sandbox based on account membership.
Args:
client: The Supabase client
sandbox_id: The sandbox ID to check access for
2025-04-21 22:08:28 +08:00
user_id: The user ID to check permissions for. Can be None for public resource access.
2025-04-18 22:04:11 +08:00
Returns:
dict: Project data containing sandbox information
Raises:
HTTPException: If the user doesn't have access to the sandbox or sandbox doesn't exist
"""
# Find the project that owns this sandbox
project_result = await client.table('projects').select('*').filter('sandbox->>id', 'eq', sandbox_id).execute()
if not project_result.data or len(project_result.data) == 0:
raise HTTPException(status_code=404, detail="Sandbox not found")
project_data = project_result.data[0]
2025-04-21 21:58:58 +08:00
if project_data.get('is_public'):
return project_data
2025-04-21 22:08:28 +08:00
# For private projects, we must have a user_id
if not user_id:
raise HTTPException(status_code=401, detail="Authentication required for this resource")
2025-04-18 22:04:11 +08:00
account_id = project_data.get('account_id')
# Verify account membership
if account_id:
account_user_result = await client.schema('basejump').from_('account_user').select('account_role').eq('user_id', user_id).eq('account_id', account_id).execute()
if account_user_result.data and len(account_user_result.data) > 0:
return project_data
raise HTTPException(status_code=403, detail="Not authorized to access this sandbox")
2025-04-11 10:45:32 +08:00
2025-04-23 13:20:38 +08:00
async def get_sandbox_by_id_safely(client, sandbox_id: str):
"""
Safely retrieve a sandbox object by its ID, using the project that owns it.
Args:
client: The Supabase client
sandbox_id: The sandbox ID to retrieve
Returns:
Sandbox: The sandbox object
Raises:
HTTPException: If the sandbox doesn't exist or can't be retrieved
"""
# Find the project that owns this sandbox
project_result = await client.table('projects').select('project_id').filter('sandbox->>id', 'eq', sandbox_id).execute()
if not project_result.data or len(project_result.data) == 0:
logger.error(f"No project found for sandbox ID: {sandbox_id}")
raise HTTPException(status_code=404, detail="Sandbox not found - no project owns this sandbox ID")
2025-05-09 12:06:51 +08:00
# project_id = project_result.data[0]['project_id']
# logger.debug(f"Found project {project_id} for sandbox {sandbox_id}")
2025-04-23 13:20:38 +08:00
try:
2025-04-24 02:46:22 +08:00
# Get the sandbox
2025-05-09 12:06:51 +08:00
sandbox = await get_or_start_sandbox(sandbox_id)
# Extract just the sandbox object from the tuple (sandbox, sandbox_id, sandbox_pass)
# sandbox = sandbox_tuple[0]
2025-04-23 13:20:38 +08:00
return sandbox
except Exception as e:
logger.error(f"Error retrieving sandbox {sandbox_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to retrieve sandbox: {str(e)}")
2025-04-11 10:45:32 +08:00
@router.post("/sandboxes/{sandbox_id}/files")
2025-04-11 20:56:50 +08:00
async def create_file(
sandbox_id: str,
path: str = Form(...),
2025-04-18 22:04:11 +08:00
file: UploadFile = File(...),
2025-04-21 22:08:28 +08:00
request: Request = None,
user_id: Optional[str] = Depends(get_optional_user_id)
2025-04-11 20:56:50 +08:00
):
"""Create a file in the sandbox using direct file upload"""
2025-04-25 06:04:59 +08:00
logger.info(f"Received file upload request for sandbox {sandbox_id}, path: {path}, user_id: {user_id}")
2025-04-18 22:04:11 +08:00
client = await db.client
# Verify the user has access to this sandbox
await verify_sandbox_access(client, sandbox_id, user_id)
2025-04-11 10:45:32 +08:00
try:
2025-04-23 13:20:38 +08:00
# Get sandbox using the safer method
sandbox = await get_sandbox_by_id_safely(client, sandbox_id)
2025-04-11 20:56:50 +08:00
# Read file content directly from the uploaded file
content = await file.read()
# Create file using raw binary content
sandbox.fs.upload_file(path, content)
logger.info(f"File created at {path} in sandbox {sandbox_id}")
return {"status": "success", "created": True, "path": path}
except Exception as e:
logger.error(f"Error creating file in sandbox {sandbox_id}: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
2025-04-11 10:45:32 +08:00
@router.get("/sandboxes/{sandbox_id}/files")
2025-04-18 22:04:11 +08:00
async def list_files(
sandbox_id: str,
path: str,
2025-04-21 22:08:28 +08:00
request: Request = None,
user_id: Optional[str] = Depends(get_optional_user_id)
2025-04-18 22:04:11 +08:00
):
2025-04-11 10:45:32 +08:00
"""List files and directories at the specified path"""
2025-04-25 06:04:59 +08:00
logger.info(f"Received list files request for sandbox {sandbox_id}, path: {path}, user_id: {user_id}")
2025-04-18 22:04:11 +08:00
client = await db.client
# Verify the user has access to this sandbox
await verify_sandbox_access(client, sandbox_id, user_id)
2025-04-11 10:45:32 +08:00
try:
2025-04-23 13:20:38 +08:00
# Get sandbox using the safer method
sandbox = await get_sandbox_by_id_safely(client, sandbox_id)
2025-04-11 10:45:32 +08:00
# List files
files = sandbox.fs.list_files(path)
result = []
for file in files:
# Convert file information to our model
2025-04-19 00:02:22 +08:00
# Ensure forward slashes are used for paths, regardless of OS
full_path = f"{path.rstrip('/')}/{file.name}" if path != '/' else f"/{file.name}"
2025-04-11 10:45:32 +08:00
file_info = FileInfo(
name=file.name,
2025-04-19 00:02:22 +08:00
path=full_path, # Use the constructed path
2025-04-11 10:45:32 +08:00
is_dir=file.is_dir,
size=file.size,
mod_time=str(file.mod_time),
permissions=getattr(file, 'permissions', None)
)
result.append(file_info)
2025-04-25 06:04:59 +08:00
logger.info(f"Successfully listed {len(result)} files in sandbox {sandbox_id}")
2025-04-11 10:45:32 +08:00
return {"files": [file.dict() for file in result]}
except Exception as e:
logger.error(f"Error listing files in sandbox {sandbox_id}: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/sandboxes/{sandbox_id}/files/content")
2025-04-18 22:04:11 +08:00
async def read_file(
sandbox_id: str,
path: str,
2025-04-21 22:08:28 +08:00
request: Request = None,
user_id: Optional[str] = Depends(get_optional_user_id)
2025-04-18 22:04:11 +08:00
):
2025-04-11 10:45:32 +08:00
"""Read a file from the sandbox"""
2025-04-25 06:04:59 +08:00
logger.info(f"Received file read request for sandbox {sandbox_id}, path: {path}, user_id: {user_id}")
2025-04-18 22:04:11 +08:00
client = await db.client
# Verify the user has access to this sandbox
await verify_sandbox_access(client, sandbox_id, user_id)
2025-04-11 10:45:32 +08:00
try:
2025-04-23 13:20:38 +08:00
# Get sandbox using the safer method
sandbox = await get_sandbox_by_id_safely(client, sandbox_id)
2025-04-11 10:45:32 +08:00
# Read file
content = sandbox.fs.download_file(path)
2025-04-11 20:56:50 +08:00
# Return a Response object with the content directly
2025-04-11 10:45:32 +08:00
filename = os.path.basename(path)
2025-04-25 06:04:59 +08:00
logger.info(f"Successfully read file {filename} from sandbox {sandbox_id}")
2025-04-11 10:45:32 +08:00
return Response(
content=content,
media_type="application/octet-stream",
headers={"Content-Disposition": f"attachment; filename={filename}"}
)
except Exception as e:
logger.error(f"Error reading file in sandbox {sandbox_id}: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
2025-05-09 12:06:51 +08:00
# Should happen on server-side fully
# @router.post("/project/{project_id}/sandbox/ensure-active")
# async def ensure_project_sandbox_active(
# project_id: str,
# request: Request = None,
# user_id: Optional[str] = Depends(get_optional_user_id)
# ):
# """
# Ensure that a project's sandbox is active and running.
# Checks the sandbox status and starts it if it's not running.
# """
# logger.info(f"Received ensure sandbox active request for project {project_id}, user_id: {user_id}")
# client = await db.client
2025-05-09 12:06:51 +08:00
# # Find the project and sandbox information
# project_result = await client.table('projects').select('*').eq('project_id', project_id).execute()
2025-05-09 12:06:51 +08:00
# if not project_result.data or len(project_result.data) == 0:
# logger.error(f"Project not found: {project_id}")
# raise HTTPException(status_code=404, detail="Project not found")
2025-05-09 12:06:51 +08:00
# project_data = project_result.data[0]
2025-05-09 12:06:51 +08:00
# # For public projects, no authentication is needed
# if not project_data.get('is_public'):
# # For private projects, we must have a user_id
# if not user_id:
# logger.error(f"Authentication required for private project {project_id}")
# raise HTTPException(status_code=401, detail="Authentication required for this resource")
2025-04-21 22:08:28 +08:00
2025-05-09 12:06:51 +08:00
# account_id = project_data.get('account_id')
2025-04-21 22:08:28 +08:00
2025-05-09 12:06:51 +08:00
# # Verify account membership
# if account_id:
# account_user_result = await client.schema('basejump').from_('account_user').select('account_role').eq('user_id', user_id).eq('account_id', account_id).execute()
# if not (account_user_result.data and len(account_user_result.data) > 0):
# logger.error(f"User {user_id} not authorized to access project {project_id}")
# raise HTTPException(status_code=403, detail="Not authorized to access this project")
2025-05-09 12:06:51 +08:00
# try:
# # Get or create the sandbox
# logger.info(f"Ensuring sandbox is active for project {project_id}")
# sandbox, sandbox_id, sandbox_pass = await get_or_create_project_sandbox(client, project_id)
2025-04-23 13:20:38 +08:00
2025-05-09 12:06:51 +08:00
# logger.info(f"Successfully ensured sandbox {sandbox_id} is active for project {project_id}")
2025-05-09 12:06:51 +08:00
# return {
# "status": "success",
# "sandbox_id": sandbox_id,
# "message": "Sandbox is active"
# }
# except Exception as e:
# logger.error(f"Error ensuring sandbox is active for project {project_id}: {str(e)}")
# raise HTTPException(status_code=500, detail=str(e))