suna/backend/agent/tools/sb_upload_file_tool.py

237 lines
9.5 KiB
Python

import os
import uuid
import mimetypes
import structlog
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
from pathlib import Path
from agentpress.tool import ToolResult, openapi_schema, usage_example
from sandbox.tool_base import SandboxToolsBase
from agentpress.thread_manager import ThreadManager
from services.supabase import DBConnection
from utils.logger import logger
from utils.config import config
class SandboxUploadFileTool(SandboxToolsBase):
def __init__(self, project_id: str, thread_manager: ThreadManager):
super().__init__(project_id, thread_manager)
self.workspace_path = "/workspace"
self.db = DBConnection()
@openapi_schema({
"type": "function",
"function": {
"name": "upload_file",
"description": "Securely upload a file from the sandbox workspace to private cloud storage (Supabase S3). Returns a secure signed URL that expires after 24 hours for access control and security.",
"parameters": {
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "Path to the file in the sandbox, relative to /workspace (e.g., 'output.pdf', 'data/results.csv')"
},
"bucket_name": {
"type": "string",
"description": "Target storage bucket. Options: 'file-uploads' (default - secure private storage), 'browser-screenshots' (browser automation only). Default: 'file-uploads'",
"default": "file-uploads"
},
"custom_filename": {
"type": "string",
"description": "Optional custom filename for the uploaded file. If not provided, uses original filename with timestamp"
},
},
"required": ["file_path"]
}
}
})
@usage_example('''
<function_calls>
<invoke name="upload_file">
<parameter name="file_path">report.pdf</parameter>
</invoke>
</function_calls>
''')
async def upload_file(
self,
file_path: str,
bucket_name: str = "file-uploads",
custom_filename: Optional[str] = None
) -> ToolResult:
try:
await self._ensure_sandbox()
file_path = self.clean_path(file_path)
full_path = f"{self.workspace_path}/{file_path}"
try:
file_info = await self.sandbox.fs.get_file_info(full_path)
if file_info.size > 50 * 1024 * 1024: # 50MB limit
return self.fail_response(f"File '{file_path}' is too large (>50MB). Please reduce file size before uploading.")
except Exception:
return self.fail_response(f"File '{file_path}' not found in workspace.")
try:
file_content = await self.sandbox.fs.download_file(full_path)
except Exception as e:
return self.fail_response(f"Failed to read file '{file_path}': {str(e)}")
account_id = await self._get_current_account_id()
original_filename = os.path.basename(file_path)
file_extension = Path(original_filename).suffix.lower()
content_type, _ = mimetypes.guess_type(original_filename)
if not content_type:
content_type = "application/octet-stream"
if custom_filename:
storage_filename = custom_filename
else:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
unique_id = str(uuid.uuid4())[:8]
name_base = Path(original_filename).stem
storage_filename = f"{name_base}_{timestamp}_{unique_id}{file_extension}"
storage_path = f"{account_id}/{storage_filename}"
try:
client = await self.db.client
storage_response = await client.storage.from_(bucket_name).upload(
storage_path,
file_content,
{"content-type": content_type}
)
expires_in = 24 * 60 * 60
signed_url_response = await client.storage.from_(bucket_name).create_signed_url(
storage_path,
expires_in
)
signed_url = signed_url_response.get('signedURL')
if not signed_url:
return self.fail_response("Failed to generate secure access URL.")
url_expires_at = datetime.now() + timedelta(seconds=expires_in)
await self._track_upload(
client,
account_id,
storage_path,
bucket_name,
original_filename,
file_info.size,
content_type,
signed_url,
url_expires_at
)
message = f"🔒 File '{original_filename}' uploaded securely!\n"
message += f"📁 Storage: {bucket_name}/{storage_path}\n"
message += f"📏 Size: {self._format_file_size(file_info.size)}\n"
message += f"🔗 Secure Access URL: {signed_url}\n"
message += f"⏰ URL expires: {url_expires_at.strftime('%Y-%m-%d %H:%M:%S UTC')}\n"
message += f"\n🔐 This file is stored in private, secure storage with account isolation."
return self.success_response(message)
except Exception as e:
logger.error(f"Failed to upload file to Supabase: {str(e)}")
return self.fail_response(f"Failed to upload file to secure storage: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error in upload_file: {str(e)}")
return self.fail_response(f"Unexpected error during secure file upload: {str(e)}")
async def _get_current_account_id(self) -> str:
try:
context_vars = structlog.contextvars.get_contextvars()
thread_id = context_vars.get('thread_id')
if not thread_id:
raise ValueError("No thread_id available from execution context")
client = await self.db.client
thread_result = await client.table('threads').select('account_id').eq('thread_id', thread_id).limit(1).execute()
if not thread_result.data:
raise ValueError(f"Could not find thread with ID: {thread_id}")
account_id = thread_result.data[0]['account_id']
if not account_id:
raise ValueError("Thread has no associated account_id")
return account_id
except Exception as e:
logger.error(f"Error getting current account_id: {e}")
raise
async def _track_upload(
self,
client,
account_id: str,
storage_path: str,
bucket_name: str,
original_filename: str,
file_size: int,
content_type: str,
signed_url: str,
url_expires_at: datetime
):
try:
thread_id = None
agent_id = None
try:
context_vars = structlog.contextvars.get_contextvars()
thread_id = context_vars.get('thread_id')
except Exception:
pass
if thread_id:
thread_result = await client.table('threads').select('agent_id').eq('thread_id', thread_id).execute()
if thread_result.data:
thread_data = thread_result.data[0]
agent_id = thread_data.get('agent_id')
user_id = None
try:
account_result = await client.table('basejump.account_user').select('user_id').eq('account_id', account_id).limit(1).execute()
if account_result.data:
user_id = account_result.data[0].get('user_id')
except Exception:
pass
upload_data = {
'project_id': self.project_id,
'thread_id': thread_id,
'agent_id': agent_id,
'account_id': account_id,
'user_id': user_id,
'bucket_name': bucket_name,
'storage_path': storage_path,
'original_filename': original_filename,
'file_size': file_size,
'content_type': content_type,
'signed_url': signed_url,
'url_expires_at': url_expires_at.isoformat(),
'metadata': {
'uploaded_from': 'sandbox',
'tool': 'upload_file',
'secure_upload': True
}
}
await client.table('file_uploads').insert(upload_data).execute()
except Exception as e:
logger.warning(f"Failed to track file upload in database: {str(e)}")
def _format_file_size(self, size_bytes: int) -> str:
for unit in ['B', 'KB', 'MB', 'GB']:
if size_bytes < 1024.0:
return f"{size_bytes:.1f} {unit}"
size_bytes /= 1024.0
return f"{size_bytes:.1f} TB"