mirror of https://github.com/kortix-ai/suna.git
wip
This commit is contained in:
parent
7bffa72056
commit
e985cbdc2b
|
@ -256,23 +256,17 @@ async def _cleanup_agent_run(agent_run_id: str):
|
|||
logger.warning(f"Failed to clean up Redis keys for agent run {agent_run_id}: {str(e)}")
|
||||
# Non-fatal error, can continue
|
||||
|
||||
async def get_or_create_project_sandbox(client, project_id: str, sandbox_cache={}):
|
||||
async def get_or_create_project_sandbox(client, project_id: str):
|
||||
"""
|
||||
Safely get or create a sandbox for a project using distributed locking to avoid race conditions.
|
||||
Get or create a sandbox for a project without distributed locking.
|
||||
|
||||
Args:
|
||||
client: The Supabase client
|
||||
project_id: The project ID to get or create a sandbox for
|
||||
sandbox_cache: Optional in-memory cache to avoid repeated lookups in the same process
|
||||
|
||||
Returns:
|
||||
Tuple of (sandbox object, sandbox_id, sandbox_pass)
|
||||
"""
|
||||
# Check in-memory cache first (optimization for repeated calls within same process)
|
||||
if project_id in sandbox_cache:
|
||||
logger.debug(f"Using cached sandbox for project {project_id}")
|
||||
return sandbox_cache[project_id]
|
||||
|
||||
# First get the current project data to check if a sandbox already exists
|
||||
project = await client.table('projects').select('*').eq('project_id', project_id).execute()
|
||||
if not project.data or len(project.data) == 0:
|
||||
|
@ -288,133 +282,55 @@ async def get_or_create_project_sandbox(client, project_id: str, sandbox_cache={
|
|||
|
||||
try:
|
||||
sandbox = await get_or_start_sandbox(sandbox_id)
|
||||
# Cache the result
|
||||
sandbox_cache[project_id] = (sandbox, sandbox_id, sandbox_pass)
|
||||
return (sandbox, sandbox_id, sandbox_pass)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve existing sandbox {sandbox_id} for project {project_id}: {str(e)}")
|
||||
# Fall through to create a new sandbox if retrieval fails
|
||||
|
||||
# Need to create a new sandbox - use Redis for distributed locking
|
||||
lock_key = f"project_sandbox_lock:{project_id}"
|
||||
lock_value = str(uuid.uuid4()) # Unique identifier for this lock acquisition
|
||||
lock_timeout = 60 # seconds
|
||||
|
||||
# Try to acquire a lock
|
||||
# Create a new sandbox
|
||||
try:
|
||||
# Attempt to get a lock with a timeout - this is atomic in Redis
|
||||
acquired = await redis.set(
|
||||
lock_key,
|
||||
lock_value,
|
||||
nx=True, # Only set if key doesn't exist (NX = not exists)
|
||||
ex=lock_timeout # Auto-expire the lock
|
||||
)
|
||||
logger.info(f"Creating new sandbox for project {project_id}")
|
||||
sandbox_pass = str(uuid.uuid4())
|
||||
sandbox = create_sandbox(sandbox_pass)
|
||||
sandbox_id = sandbox.id
|
||||
|
||||
if not acquired:
|
||||
# Someone else is creating a sandbox for this project
|
||||
logger.info(f"Waiting for another process to create sandbox for project {project_id}")
|
||||
|
||||
# Wait and retry a few times
|
||||
max_retries = 5
|
||||
retry_delay = 2 # seconds
|
||||
|
||||
for retry in range(max_retries):
|
||||
await asyncio.sleep(retry_delay)
|
||||
|
||||
# Check if the other process completed
|
||||
fresh_project = await client.table('projects').select('*').eq('project_id', project_id).execute()
|
||||
if fresh_project.data and fresh_project.data[0].get('sandbox', {}).get('id'):
|
||||
sandbox_id = fresh_project.data[0]['sandbox']['id']
|
||||
sandbox_pass = fresh_project.data[0]['sandbox']['pass']
|
||||
logger.info(f"Another process created sandbox {sandbox_id} for project {project_id}")
|
||||
|
||||
sandbox = await get_or_start_sandbox(sandbox_id)
|
||||
# Cache the result
|
||||
sandbox_cache[project_id] = (sandbox, sandbox_id, sandbox_pass)
|
||||
return (sandbox, sandbox_id, sandbox_pass)
|
||||
|
||||
# If we got here, the other process didn't complete in time
|
||||
# Force-acquire the lock by deleting and recreating it
|
||||
logger.warning(f"Timeout waiting for sandbox creation for project {project_id}, acquiring lock forcefully")
|
||||
await redis.delete(lock_key)
|
||||
await redis.set(lock_key, lock_value, ex=lock_timeout)
|
||||
logger.info(f"Created new sandbox {sandbox_id} with preview: {sandbox.get_preview_link(6080)}/vnc_lite.html?password={sandbox_pass}")
|
||||
|
||||
# We have the lock now - check one more time to avoid race conditions
|
||||
fresh_project = await client.table('projects').select('*').eq('project_id', project_id).execute()
|
||||
if fresh_project.data and fresh_project.data[0].get('sandbox', {}).get('id'):
|
||||
sandbox_id = fresh_project.data[0]['sandbox']['id']
|
||||
sandbox_pass = fresh_project.data[0]['sandbox']['pass']
|
||||
logger.info(f"Sandbox {sandbox_id} was created by another process while acquiring lock for project {project_id}")
|
||||
|
||||
# Release the lock
|
||||
await redis.delete(lock_key)
|
||||
|
||||
sandbox = await get_or_start_sandbox(sandbox_id)
|
||||
# Cache the result
|
||||
sandbox_cache[project_id] = (sandbox, sandbox_id, sandbox_pass)
|
||||
return (sandbox, sandbox_id, sandbox_pass)
|
||||
# Get preview links
|
||||
vnc_link = sandbox.get_preview_link(6080)
|
||||
website_link = sandbox.get_preview_link(8080)
|
||||
|
||||
# Extract the actual URLs and token from the preview link objects
|
||||
vnc_url = vnc_link.url if hasattr(vnc_link, 'url') else str(vnc_link).split("url='")[1].split("'")[0]
|
||||
website_url = website_link.url if hasattr(website_link, 'url') else str(website_link).split("url='")[1].split("'")[0]
|
||||
|
||||
# Extract token if available
|
||||
token = None
|
||||
if hasattr(vnc_link, 'token'):
|
||||
token = vnc_link.token
|
||||
elif "token='" in str(vnc_link):
|
||||
token = str(vnc_link).split("token='")[1].split("'")[0]
|
||||
|
||||
# Update the project with the new sandbox info
|
||||
update_result = await client.table('projects').update({
|
||||
'sandbox': {
|
||||
'id': sandbox_id,
|
||||
'pass': sandbox_pass,
|
||||
'vnc_preview': vnc_url,
|
||||
'sandbox_url': website_url,
|
||||
'token': token
|
||||
}
|
||||
}).eq('project_id', project_id).execute()
|
||||
|
||||
if not update_result.data:
|
||||
logger.error(f"Failed to update project {project_id} with new sandbox {sandbox_id}")
|
||||
raise Exception("Database update failed")
|
||||
|
||||
return (sandbox, sandbox_id, sandbox_pass)
|
||||
|
||||
# Create a new sandbox
|
||||
try:
|
||||
logger.info(f"Creating new sandbox for project {project_id}")
|
||||
sandbox_pass = str(uuid.uuid4())
|
||||
sandbox = create_sandbox(sandbox_pass)
|
||||
sandbox_id = sandbox.id
|
||||
|
||||
logger.info(f"Created new sandbox {sandbox_id} with preview: {sandbox.get_preview_link(6080)}/vnc_lite.html?password={sandbox_pass}")
|
||||
|
||||
# Get preview links
|
||||
vnc_link = sandbox.get_preview_link(6080)
|
||||
website_link = sandbox.get_preview_link(8080)
|
||||
|
||||
# Extract the actual URLs and token from the preview link objects
|
||||
vnc_url = vnc_link.url if hasattr(vnc_link, 'url') else str(vnc_link).split("url='")[1].split("'")[0]
|
||||
website_url = website_link.url if hasattr(website_link, 'url') else str(website_link).split("url='")[1].split("'")[0]
|
||||
|
||||
# Extract token if available
|
||||
token = None
|
||||
if hasattr(vnc_link, 'token'):
|
||||
token = vnc_link.token
|
||||
elif "token='" in str(vnc_link):
|
||||
token = str(vnc_link).split("token='")[1].split("'")[0]
|
||||
|
||||
# Update the project with the new sandbox info
|
||||
update_result = await client.table('projects').update({
|
||||
'sandbox': {
|
||||
'id': sandbox_id,
|
||||
'pass': sandbox_pass,
|
||||
'vnc_preview': vnc_url,
|
||||
'sandbox_url': website_url,
|
||||
'token': token
|
||||
}
|
||||
}).eq('project_id', project_id).execute()
|
||||
|
||||
if not update_result.data:
|
||||
logger.error(f"Failed to update project {project_id} with new sandbox {sandbox_id}")
|
||||
raise Exception("Database update failed")
|
||||
|
||||
# Cache the result
|
||||
sandbox_cache[project_id] = (sandbox, sandbox_id, sandbox_pass)
|
||||
return (sandbox, sandbox_id, sandbox_pass)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating sandbox for project {project_id}: {str(e)}")
|
||||
raise e
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in sandbox creation process for project {project_id}: {str(e)}")
|
||||
logger.error(f"Error creating sandbox for project {project_id}: {str(e)}")
|
||||
raise e
|
||||
|
||||
finally:
|
||||
# Always try to release the lock if we have it
|
||||
try:
|
||||
# Only delete the lock if it's still ours
|
||||
current_value = await redis.get(lock_key)
|
||||
if current_value == lock_value:
|
||||
await redis.delete(lock_key)
|
||||
logger.debug(f"Released lock for project {project_id} sandbox creation")
|
||||
except Exception as lock_error:
|
||||
logger.warning(f"Error releasing sandbox creation lock for project {project_id}: {str(lock_error)}")
|
||||
|
||||
@router.post("/thread/{thread_id}/agent/start")
|
||||
async def start_agent(
|
||||
|
@ -894,7 +810,6 @@ async def run_agent_background(
|
|||
|
||||
logger.info(f"Agent run background task fully completed for: {agent_run_id} (instance: {instance_id})")
|
||||
|
||||
# New background task function
|
||||
async def generate_and_update_project_name(project_id: str, prompt: str):
|
||||
"""Generates a project name using an LLM and updates the database."""
|
||||
logger.info(f"Starting background task to generate name for project: {project_id}")
|
||||
|
@ -938,7 +853,6 @@ async def generate_and_update_project_name(project_id: str, prompt: str):
|
|||
else:
|
||||
logger.warning(f"Failed to get valid response from LLM for project {project_id} naming. Response: {response}")
|
||||
|
||||
print(f"\n\n\nGenerated name: {generated_name}\n\n\n")
|
||||
# Update database if name was generated
|
||||
if generated_name:
|
||||
update_result = await client.table('projects') \
|
||||
|
@ -1026,7 +940,7 @@ async def initiate_agent_with_files(
|
|||
)
|
||||
# -----------------------------------------
|
||||
|
||||
# 3. Create Sandbox - Using safe method with distributed locking
|
||||
# 3. Create Sandbox
|
||||
try:
|
||||
sandbox, sandbox_id, sandbox_pass = await get_or_create_project_sandbox(client, project_id)
|
||||
logger.info(f"Using sandbox {sandbox_id} for new project {project_id}")
|
||||
|
|
|
@ -247,9 +247,9 @@ class WebSearchTool(Tool):
|
|||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
print(f"--- Raw Tavily Response ---")
|
||||
print(data)
|
||||
print(f"--------------------------")
|
||||
# print(f"--- Raw Tavily Response ---")
|
||||
# print(data)
|
||||
# print(f"--------------------------")
|
||||
|
||||
# Normalise Tavily extract output to a list of dicts
|
||||
extracted = []
|
||||
|
|
|
@ -46,9 +46,9 @@ async def lifespan(app: FastAPI):
|
|||
# Initialize the sandbox API with shared resources
|
||||
sandbox_api.initialize(db)
|
||||
|
||||
# Initialize Redis before restoring agent runs
|
||||
from services import redis
|
||||
await redis.initialize_async()
|
||||
# Redis is no longer needed for a single-server setup
|
||||
# from services import redis
|
||||
# await redis.initialize_async()
|
||||
|
||||
asyncio.create_task(agent_api.restore_running_agent_runs())
|
||||
|
||||
|
|
|
@ -74,7 +74,6 @@ async def verify_sandbox_access(client, sandbox_id: str, user_id: Optional[str]
|
|||
async def get_sandbox_by_id_safely(client, sandbox_id: str):
|
||||
"""
|
||||
Safely retrieve a sandbox object by its ID, using the project that owns it.
|
||||
This prevents race conditions by leveraging the distributed locking mechanism.
|
||||
|
||||
Args:
|
||||
client: The Supabase client
|
||||
|
@ -97,7 +96,7 @@ async def get_sandbox_by_id_safely(client, sandbox_id: str):
|
|||
logger.debug(f"Found project {project_id} for sandbox {sandbox_id}")
|
||||
|
||||
try:
|
||||
# Use the race-condition-safe function to get the sandbox
|
||||
# Get the sandbox
|
||||
sandbox, retrieved_sandbox_id, sandbox_pass = await get_or_create_project_sandbox(client, project_id)
|
||||
|
||||
# Verify we got the right sandbox
|
||||
|
@ -259,7 +258,6 @@ async def ensure_project_sandbox_active(
|
|||
"""
|
||||
Ensure that a project's sandbox is active and running.
|
||||
Checks the sandbox status and starts it if it's not running.
|
||||
Uses distributed locking to prevent race conditions.
|
||||
"""
|
||||
client = await db.client
|
||||
|
||||
|
@ -286,8 +284,8 @@ async def ensure_project_sandbox_active(
|
|||
raise HTTPException(status_code=403, detail="Not authorized to access this project")
|
||||
|
||||
try:
|
||||
# Use the safer function that handles race conditions with distributed locking
|
||||
logger.info(f"Ensuring sandbox is active for project {project_id} using distributed locking")
|
||||
# Get or create the sandbox
|
||||
logger.info(f"Ensuring sandbox is active for project {project_id}")
|
||||
sandbox, sandbox_id, sandbox_pass = await get_or_create_project_sandbox(client, project_id)
|
||||
|
||||
logger.info(f"Successfully ensured sandbox {sandbox_id} is active for project {project_id}")
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Dict, Optional, Tuple
|
|||
|
||||
# Define subscription tiers and their monthly limits (in minutes)
|
||||
SUBSCRIPTION_TIERS = {
|
||||
'price_1RGJ9GG6l1KZGqIroxSqgphC': {'name': 'free', 'minutes': 0},
|
||||
'price_1RGJ9GG6l1KZGqIroxSqgphC': {'name': 'free', 'minutes': 100000},
|
||||
'price_1RGJ9LG6l1KZGqIrd9pwzeNW': {'name': 'base', 'minutes': 300}, # 100 hours = 6000 minutes
|
||||
'price_1RGJ9JG6l1KZGqIrVUU4ZRv6': {'name': 'extra', 'minutes': 2400} # 100 hours = 6000 minutes
|
||||
}
|
||||
|
|
|
@ -243,7 +243,7 @@ export default function ThreadPage({ params }: { params: Promise<ThreadParams> }
|
|||
const messagesLoadedRef = useRef(false);
|
||||
const agentRunsCheckedRef = useRef(false);
|
||||
const previousAgentStatus = useRef<typeof agentStatus>('idle');
|
||||
const pollingIntervalRef = useRef<NodeJS.Timeout | null>(null);
|
||||
const pollingIntervalRef = useRef<NodeJS.Timeout | null>(null); // POLLING FOR MESSAGES
|
||||
|
||||
const handleProjectRenamed = useCallback((newName: string) => {
|
||||
setProjectName(newName);
|
||||
|
@ -969,6 +969,7 @@ export default function ThreadPage({ params }: { params: Promise<ThreadParams> }
|
|||
}
|
||||
}, [projectName]);
|
||||
|
||||
// POLLING FOR MESSAGES
|
||||
// Set up polling for messages
|
||||
useEffect(() => {
|
||||
// Function to fetch messages
|
||||
|
@ -1024,6 +1025,7 @@ export default function ThreadPage({ params }: { params: Promise<ThreadParams> }
|
|||
}
|
||||
};
|
||||
}, [threadId, userHasScrolled, initialLoadCompleted]);
|
||||
// POLLING FOR MESSAGES
|
||||
|
||||
// Add another useEffect to ensure messages are refreshed when agent status changes to idle
|
||||
useEffect(() => {
|
||||
|
|
Loading…
Reference in New Issue