mirror of https://github.com/kortix-ai/suna.git
wip
This commit is contained in:
parent
4d5d93e943
commit
602b65791f
|
@ -10,7 +10,6 @@ import jwt
|
|||
from pydantic import BaseModel
|
||||
import tempfile
|
||||
import os
|
||||
import threading
|
||||
|
||||
from agentpress.thread_manager import ThreadManager
|
||||
from services.supabase import DBConnection
|
||||
|
@ -30,9 +29,6 @@ db = None
|
|||
# In-memory storage for active agent runs and their responses
|
||||
active_agent_runs: Dict[str, List[Any]] = {}
|
||||
|
||||
# Add this near the top of the file with other global variables
|
||||
sandbox_locks: Dict[str, threading.Lock] = {}
|
||||
|
||||
MODEL_NAME_ALIASES = {
|
||||
"sonnet-3.7": "anthropic/claude-3-7-sonnet-latest",
|
||||
"gpt-4.1": "openai/gpt-4.1-2025-04-14",
|
||||
|
@ -94,7 +90,8 @@ async def update_agent_run_status(
|
|||
client,
|
||||
agent_run_id: str,
|
||||
status: str,
|
||||
error: Optional[str] = None
|
||||
error: Optional[str] = None,
|
||||
responses: Optional[List[Any]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Centralized function to update agent run status.
|
||||
|
@ -109,6 +106,9 @@ async def update_agent_run_status(
|
|||
if error:
|
||||
update_data["error"] = error
|
||||
|
||||
if responses:
|
||||
update_data["responses"] = responses
|
||||
|
||||
# Retry up to 3 times
|
||||
for retry in range(3):
|
||||
try:
|
||||
|
@ -129,16 +129,18 @@ async def update_agent_run_status(
|
|||
if retry == 2: # Last retry
|
||||
logger.error(f"Failed to update agent run status after all retries: {agent_run_id}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating agent run status on retry {retry}: {str(e)}")
|
||||
if retry == 2: # Last retry
|
||||
raise
|
||||
|
||||
except Exception as db_error:
|
||||
logger.error(f"Database error on retry {retry} updating status: {str(db_error)}")
|
||||
if retry < 2: # Not the last retry yet
|
||||
await asyncio.sleep(0.5 * (2 ** retry)) # Exponential backoff
|
||||
else:
|
||||
logger.error(f"Failed to update agent run status after all retries: {agent_run_id}", exc_info=True)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update agent run status: {str(e)}")
|
||||
logger.error(f"Unexpected error updating agent run status: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
async def stop_agent_run(agent_run_id: str, error_message: Optional[str] = None):
|
||||
"""Update database and publish stop signal to Redis."""
|
||||
|
@ -256,8 +258,7 @@ async def _cleanup_agent_run(agent_run_id: str):
|
|||
|
||||
async def get_or_create_project_sandbox(client, project_id: str, sandbox_cache={}):
|
||||
"""
|
||||
Get or create a sandbox for a project using in-memory locking to avoid race conditions
|
||||
within a single instance deployment.
|
||||
Safely get or create a sandbox for a project using distributed locking to avoid race conditions.
|
||||
|
||||
Args:
|
||||
client: The Supabase client
|
||||
|
@ -294,43 +295,59 @@ async def get_or_create_project_sandbox(client, project_id: str, sandbox_cache={
|
|||
logger.error(f"Failed to retrieve existing sandbox {sandbox_id} for project {project_id}: {str(e)}")
|
||||
# Fall through to create a new sandbox if retrieval fails
|
||||
|
||||
# Need to create a new sandbox - use simple in-memory locking
|
||||
# Make sure we have a lock for this project
|
||||
if project_id not in sandbox_locks:
|
||||
sandbox_locks[project_id] = threading.Lock()
|
||||
|
||||
# Try to acquire the lock
|
||||
lock_acquired = sandbox_locks[project_id].acquire(blocking=False)
|
||||
# Need to create a new sandbox - use Redis for distributed locking
|
||||
lock_key = f"project_sandbox_lock:{project_id}"
|
||||
lock_value = str(uuid.uuid4()) # Unique identifier for this lock acquisition
|
||||
lock_timeout = 60 # seconds
|
||||
|
||||
# Try to acquire a lock
|
||||
try:
|
||||
if not lock_acquired:
|
||||
# Someone else is creating a sandbox for this project
|
||||
logger.info(f"Waiting for another thread to create sandbox for project {project_id}")
|
||||
|
||||
# Wait and retry a few times with a blocking acquire
|
||||
sandbox_locks[project_id].acquire(blocking=True, timeout=60)
|
||||
|
||||
# Check if the other thread completed while we were waiting
|
||||
fresh_project = await client.table('projects').select('*').eq('project_id', project_id).execute()
|
||||
if fresh_project.data and fresh_project.data[0].get('sandbox', {}).get('id'):
|
||||
sandbox_id = fresh_project.data[0]['sandbox']['id']
|
||||
sandbox_pass = fresh_project.data[0]['sandbox']['pass']
|
||||
logger.info(f"Another thread created sandbox {sandbox_id} for project {project_id}")
|
||||
|
||||
sandbox = await get_or_start_sandbox(sandbox_id)
|
||||
# Cache the result
|
||||
sandbox_cache[project_id] = (sandbox, sandbox_id, sandbox_pass)
|
||||
return (sandbox, sandbox_id, sandbox_pass)
|
||||
|
||||
# If we got here and still don't have a sandbox, we'll create one now
|
||||
# (the other thread must have failed or timed out)
|
||||
# Attempt to get a lock with a timeout - this is atomic in Redis
|
||||
acquired = await redis.set(
|
||||
lock_key,
|
||||
lock_value,
|
||||
nx=True, # Only set if key doesn't exist (NX = not exists)
|
||||
ex=lock_timeout # Auto-expire the lock
|
||||
)
|
||||
|
||||
# Double-check the project data once more to avoid race conditions
|
||||
if not acquired:
|
||||
# Someone else is creating a sandbox for this project
|
||||
logger.info(f"Waiting for another process to create sandbox for project {project_id}")
|
||||
|
||||
# Wait and retry a few times
|
||||
max_retries = 5
|
||||
retry_delay = 2 # seconds
|
||||
|
||||
for retry in range(max_retries):
|
||||
await asyncio.sleep(retry_delay)
|
||||
|
||||
# Check if the other process completed
|
||||
fresh_project = await client.table('projects').select('*').eq('project_id', project_id).execute()
|
||||
if fresh_project.data and fresh_project.data[0].get('sandbox', {}).get('id'):
|
||||
sandbox_id = fresh_project.data[0]['sandbox']['id']
|
||||
sandbox_pass = fresh_project.data[0]['sandbox']['pass']
|
||||
logger.info(f"Another process created sandbox {sandbox_id} for project {project_id}")
|
||||
|
||||
sandbox = await get_or_start_sandbox(sandbox_id)
|
||||
# Cache the result
|
||||
sandbox_cache[project_id] = (sandbox, sandbox_id, sandbox_pass)
|
||||
return (sandbox, sandbox_id, sandbox_pass)
|
||||
|
||||
# If we got here, the other process didn't complete in time
|
||||
# Force-acquire the lock by deleting and recreating it
|
||||
logger.warning(f"Timeout waiting for sandbox creation for project {project_id}, acquiring lock forcefully")
|
||||
await redis.delete(lock_key)
|
||||
await redis.set(lock_key, lock_value, ex=lock_timeout)
|
||||
|
||||
# We have the lock now - check one more time to avoid race conditions
|
||||
fresh_project = await client.table('projects').select('*').eq('project_id', project_id).execute()
|
||||
if fresh_project.data and fresh_project.data[0].get('sandbox', {}).get('id'):
|
||||
sandbox_id = fresh_project.data[0]['sandbox']['id']
|
||||
sandbox_pass = fresh_project.data[0]['sandbox']['pass']
|
||||
logger.info(f"Sandbox {sandbox_id} was created by another thread while waiting for project {project_id}")
|
||||
logger.info(f"Sandbox {sandbox_id} was created by another process while acquiring lock for project {project_id}")
|
||||
|
||||
# Release the lock
|
||||
await redis.delete(lock_key)
|
||||
|
||||
sandbox = await get_or_start_sandbox(sandbox_id)
|
||||
# Cache the result
|
||||
|
@ -341,7 +358,7 @@ async def get_or_create_project_sandbox(client, project_id: str, sandbox_cache={
|
|||
try:
|
||||
logger.info(f"Creating new sandbox for project {project_id}")
|
||||
sandbox_pass = str(uuid.uuid4())
|
||||
sandbox = create_sandbox(sandbox_pass, sandbox_id=project_id)
|
||||
sandbox = create_sandbox(sandbox_pass)
|
||||
sandbox_id = sandbox.id
|
||||
|
||||
logger.info(f"Created new sandbox {sandbox_id} with preview: {sandbox.get_preview_link(6080)}/vnc_lite.html?password={sandbox_pass}")
|
||||
|
@ -389,10 +406,15 @@ async def get_or_create_project_sandbox(client, project_id: str, sandbox_cache={
|
|||
raise e
|
||||
|
||||
finally:
|
||||
# Always release the lock if we acquired it
|
||||
if lock_acquired:
|
||||
sandbox_locks[project_id].release()
|
||||
logger.debug(f"Released lock for project {project_id} sandbox creation")
|
||||
# Always try to release the lock if we have it
|
||||
try:
|
||||
# Only delete the lock if it's still ours
|
||||
current_value = await redis.get(lock_key)
|
||||
if current_value == lock_value:
|
||||
await redis.delete(lock_key)
|
||||
logger.debug(f"Released lock for project {project_id} sandbox creation")
|
||||
except Exception as lock_error:
|
||||
logger.warning(f"Error releasing sandbox creation lock for project {project_id}: {str(lock_error)}")
|
||||
|
||||
@router.post("/thread/{thread_id}/agent/start")
|
||||
async def start_agent(
|
||||
|
@ -468,6 +490,7 @@ async def start_agent(
|
|||
thread_id=thread_id,
|
||||
instance_id=instance_id,
|
||||
project_id=project_id,
|
||||
sandbox=sandbox,
|
||||
model_name=MODEL_NAME_ALIASES.get(body.model_name, body.model_name),
|
||||
enable_thinking=body.enable_thinking,
|
||||
reasoning_effort=body.reasoning_effort,
|
||||
|
@ -545,6 +568,22 @@ async def stream_agent_run(
|
|||
# Verify user has access to the agent run and get run data
|
||||
agent_run_data = await get_agent_run_with_access_check(client, agent_run_id, user_id)
|
||||
|
||||
# Initialize response storage if not already in memory
|
||||
if agent_run_id not in active_agent_runs:
|
||||
active_agent_runs[agent_run_id] = []
|
||||
logger.info(f"Initialized missing response storage for agent run: {agent_run_id}")
|
||||
|
||||
# If run is completed/failed, try to load responses from database
|
||||
if agent_run_data['status'] in ['completed', 'failed', 'stopped']:
|
||||
try:
|
||||
# Get responses from database
|
||||
run_with_responses = await client.table('agent_runs').select('responses').eq("id", agent_run_id).execute()
|
||||
if run_with_responses.data and run_with_responses.data[0].get('responses'):
|
||||
active_agent_runs[agent_run_id] = run_with_responses.data[0]['responses']
|
||||
logger.info(f"Loaded {len(active_agent_runs[agent_run_id])} responses from database for agent run: {agent_run_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load responses from database for agent run {agent_run_id}: {str(e)}")
|
||||
|
||||
# Define a streaming generator that uses in-memory responses
|
||||
async def stream_generator():
|
||||
logger.debug(f"Streaming responses for agent run: {agent_run_id}")
|
||||
|
@ -578,9 +617,9 @@ async def stream_agent_run(
|
|||
# Brief pause before checking again
|
||||
await asyncio.sleep(0.1)
|
||||
else:
|
||||
# If the run is not active or we don't have stored responses,
|
||||
# send a message indicating the run is not available for streaming
|
||||
logger.warning(f"Agent run {agent_run_id} not found in active runs")
|
||||
# This should not happen now, but just in case
|
||||
|
||||
logger.warning(f"Agent run {agent_run_id} not found in active runs even after initialization")
|
||||
yield f"data: {json.dumps({'type': 'status', 'status': agent_run_data['status'], 'message': 'Run data not available for streaming'})}\n\n"
|
||||
|
||||
# Always send a completion status at the end
|
||||
|
@ -605,6 +644,7 @@ async def run_agent_background(
|
|||
thread_id: str,
|
||||
instance_id: str,
|
||||
project_id: str,
|
||||
sandbox,
|
||||
model_name: str,
|
||||
enable_thinking: Optional[bool],
|
||||
reasoning_effort: Optional[str],
|
||||
|
@ -741,23 +781,27 @@ async def run_agent_background(
|
|||
enable_context_manager=enable_context_manager
|
||||
)
|
||||
|
||||
# Collect all responses to save to database
|
||||
all_responses = []
|
||||
|
||||
async for response in agent_gen:
|
||||
# Check if stop signal received
|
||||
if stop_signal_received:
|
||||
logger.info(f"Agent run stopped due to stop signal: {agent_run_id} (instance: {instance_id})")
|
||||
await update_agent_run_status(client, agent_run_id, "stopped")
|
||||
await update_agent_run_status(client, agent_run_id, "stopped", responses=all_responses)
|
||||
break
|
||||
|
||||
# Check for billing error status
|
||||
if response.get('type') == 'status' and response.get('status') == 'error':
|
||||
error_msg = response.get('message', '')
|
||||
logger.info(f"Agent run failed with error: {error_msg} (instance: {instance_id})")
|
||||
await update_agent_run_status(client, agent_run_id, "failed", error=error_msg)
|
||||
await update_agent_run_status(client, agent_run_id, "failed", error=error_msg, responses=all_responses)
|
||||
break
|
||||
|
||||
# Store response in memory
|
||||
if agent_run_id in active_agent_runs:
|
||||
active_agent_runs[agent_run_id].append(response)
|
||||
all_responses.append(response)
|
||||
total_responses += 1
|
||||
|
||||
# Signal all done if we weren't stopped
|
||||
|
@ -773,9 +817,10 @@ async def run_agent_background(
|
|||
}
|
||||
if agent_run_id in active_agent_runs:
|
||||
active_agent_runs[agent_run_id].append(completion_message)
|
||||
all_responses.append(completion_message)
|
||||
|
||||
# Update the agent run status
|
||||
await update_agent_run_status(client, agent_run_id, "completed")
|
||||
await update_agent_run_status(client, agent_run_id, "completed", responses=all_responses)
|
||||
|
||||
# Notify any clients monitoring the control channels that we're done
|
||||
try:
|
||||
|
@ -801,13 +846,18 @@ async def run_agent_background(
|
|||
}
|
||||
if agent_run_id in active_agent_runs:
|
||||
active_agent_runs[agent_run_id].append(error_response)
|
||||
if 'all_responses' in locals():
|
||||
all_responses.append(error_response)
|
||||
else:
|
||||
all_responses = [error_response]
|
||||
|
||||
# Update the agent run with the error
|
||||
await update_agent_run_status(
|
||||
client,
|
||||
agent_run_id,
|
||||
"failed",
|
||||
error=f"{error_message}\n{traceback_str}"
|
||||
error=f"{error_message}\n{traceback_str}",
|
||||
responses=all_responses
|
||||
)
|
||||
|
||||
# Notify any clients of the error
|
||||
|
@ -888,6 +938,8 @@ async def generate_and_update_project_name(project_id: str, prompt: str):
|
|||
else:
|
||||
logger.warning(f"Failed to get valid response from LLM for project {project_id} naming. Response: {response}")
|
||||
|
||||
print(f"\n\n\nGenerated name: {generated_name}\n\n\n")
|
||||
# Update database if name was generated
|
||||
if generated_name:
|
||||
update_result = await client.table('projects') \
|
||||
.update({"name": generated_name}) \
|
||||
|
@ -1115,6 +1167,7 @@ async def initiate_agent_with_files(
|
|||
thread_id=thread_id,
|
||||
instance_id=instance_id,
|
||||
project_id=project_id,
|
||||
sandbox=sandbox,
|
||||
model_name=MODEL_NAME_ALIASES.get(model_name, model_name),
|
||||
enable_thinking=enable_thinking,
|
||||
reasoning_effort=reasoning_effort,
|
||||
|
|
|
@ -68,7 +68,7 @@ async def run_agent(
|
|||
if os.getenv("TAVILY_API_KEY"):
|
||||
thread_manager.add_tool(WebSearchTool)
|
||||
else:
|
||||
print("TAVILY_API_KEY not found, WebSearchTool will not be available.")
|
||||
logger.warning("TAVILY_API_KEY not found, WebSearchTool will not be available.")
|
||||
|
||||
if os.getenv("RAPID_API_KEY"):
|
||||
thread_manager.add_tool(DataProvidersTool)
|
||||
|
@ -80,7 +80,7 @@ async def run_agent(
|
|||
|
||||
while continue_execution and iteration_count < max_iterations:
|
||||
iteration_count += 1
|
||||
print(f"Running iteration {iteration_count}...")
|
||||
# logger.debug(f"Running iteration {iteration_count}...")
|
||||
|
||||
# Billing check on each iteration - still needed within the iterations
|
||||
can_run, message, subscription = await check_billing_status(client, account_id)
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
from typing import Dict
|
||||
import logging
|
||||
|
||||
from agent.tools.data_providers.RapidDataProviderBase import RapidDataProviderBase, EndpointSchema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ZillowProvider(RapidDataProviderBase):
|
||||
def __init__(self):
|
||||
|
@ -112,10 +115,10 @@ if __name__ == "__main__":
|
|||
"doz": "any"
|
||||
}
|
||||
)
|
||||
print("Search Result:", search_result)
|
||||
print("***")
|
||||
print("***")
|
||||
print("***")
|
||||
logger.debug("Search Result: %s", search_result)
|
||||
logger.debug("***")
|
||||
logger.debug("***")
|
||||
logger.debug("***")
|
||||
sleep(1)
|
||||
# Example for searching by address
|
||||
address_result = tool.call_endpoint(
|
||||
|
@ -124,10 +127,10 @@ if __name__ == "__main__":
|
|||
"address": "1161 Natchez Dr College Station Texas 77845"
|
||||
}
|
||||
)
|
||||
print("Address Search Result:", address_result)
|
||||
print("***")
|
||||
print("***")
|
||||
print("***")
|
||||
logger.debug("Address Search Result: %s", address_result)
|
||||
logger.debug("***")
|
||||
logger.debug("***")
|
||||
logger.debug("***")
|
||||
sleep(1)
|
||||
# Example for getting property details
|
||||
property_result = tool.call_endpoint(
|
||||
|
@ -136,11 +139,11 @@ if __name__ == "__main__":
|
|||
"zpid": "7594920"
|
||||
}
|
||||
)
|
||||
print("Property Details Result:", property_result)
|
||||
logger.debug("Property Details Result: %s", property_result)
|
||||
sleep(1)
|
||||
print("***")
|
||||
print("***")
|
||||
print("***")
|
||||
logger.debug("***")
|
||||
logger.debug("***")
|
||||
logger.debug("***")
|
||||
|
||||
# Example for getting zestimate history
|
||||
zestimate_result = tool.call_endpoint(
|
||||
|
@ -149,11 +152,11 @@ if __name__ == "__main__":
|
|||
"zpid": "20476226"
|
||||
}
|
||||
)
|
||||
print("Zestimate History Result:", zestimate_result)
|
||||
logger.debug("Zestimate History Result: %s", zestimate_result)
|
||||
sleep(1)
|
||||
print("***")
|
||||
print("***")
|
||||
print("***")
|
||||
logger.debug("***")
|
||||
logger.debug("***")
|
||||
logger.debug("***")
|
||||
# Example for getting similar properties
|
||||
similar_result = tool.call_endpoint(
|
||||
route="similar_properties",
|
||||
|
@ -161,11 +164,11 @@ if __name__ == "__main__":
|
|||
"zpid": "28253016"
|
||||
}
|
||||
)
|
||||
print("Similar Properties Result:", similar_result)
|
||||
logger.debug("Similar Properties Result: %s", similar_result)
|
||||
sleep(1)
|
||||
print("***")
|
||||
print("***")
|
||||
print("***")
|
||||
logger.debug("***")
|
||||
logger.debug("***")
|
||||
logger.debug("***")
|
||||
# Example for getting mortgage rates
|
||||
mortgage_result = tool.call_endpoint(
|
||||
route="mortgage_rates",
|
||||
|
@ -180,5 +183,5 @@ if __name__ == "__main__":
|
|||
"duration": "30"
|
||||
}
|
||||
)
|
||||
print("Mortgage Rates Result:", mortgage_result)
|
||||
logger.debug("Mortgage Rates Result: %s", mortgage_result)
|
||||
|
|
@ -42,8 +42,8 @@ class SandboxBrowserTool(SandboxToolsBase):
|
|||
json_data = json.dumps(params)
|
||||
curl_cmd += f" -d '{json_data}'"
|
||||
|
||||
print(f"\033[95mExecuting curl command:\033[0m")
|
||||
print(f"{curl_cmd}")
|
||||
logger.debug("\033[95mExecuting curl command:\033[0m")
|
||||
logger.debug(f"{curl_cmd}")
|
||||
|
||||
response = self.sandbox.process.exec(curl_cmd, timeout=30)
|
||||
|
||||
|
@ -101,7 +101,7 @@ class SandboxBrowserTool(SandboxToolsBase):
|
|||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing browser action: {e}")
|
||||
print(traceback.format_exc())
|
||||
logger.debug(traceback.format_exc())
|
||||
return self.fail_response(f"Error executing browser action: {e}")
|
||||
|
||||
@openapi_schema({
|
||||
|
@ -141,7 +141,6 @@ class SandboxBrowserTool(SandboxToolsBase):
|
|||
Returns:
|
||||
dict: Result of the execution
|
||||
"""
|
||||
print(f"\033[95mNavigating to: {url}\033[0m")
|
||||
return await self._execute_browser_action("navigate_to", {"url": url})
|
||||
|
||||
# @openapi_schema({
|
||||
|
@ -181,7 +180,7 @@ class SandboxBrowserTool(SandboxToolsBase):
|
|||
# Returns:
|
||||
# dict: Result of the execution
|
||||
# """
|
||||
# print(f"\033[95mSearching Google for: {query}\033[0m")
|
||||
# logger.debug(f"\033[95mSearching Google for: {query}\033[0m")
|
||||
# return await self._execute_browser_action("search_google", {"query": query})
|
||||
|
||||
@openapi_schema({
|
||||
|
@ -208,7 +207,7 @@ class SandboxBrowserTool(SandboxToolsBase):
|
|||
Returns:
|
||||
dict: Result of the execution
|
||||
"""
|
||||
print(f"\033[95mNavigating back in browser history\033[0m")
|
||||
logger.debug(f"\033[95mNavigating back in browser history\033[0m")
|
||||
return await self._execute_browser_action("go_back", {})
|
||||
|
||||
@openapi_schema({
|
||||
|
@ -247,7 +246,7 @@ class SandboxBrowserTool(SandboxToolsBase):
|
|||
Returns:
|
||||
dict: Result of the execution
|
||||
"""
|
||||
print(f"\033[95mWaiting for {seconds} seconds\033[0m")
|
||||
logger.debug(f"\033[95mWaiting for {seconds} seconds\033[0m")
|
||||
return await self._execute_browser_action("wait", {"seconds": seconds})
|
||||
|
||||
@openapi_schema({
|
||||
|
@ -287,7 +286,7 @@ class SandboxBrowserTool(SandboxToolsBase):
|
|||
Returns:
|
||||
dict: Result of the execution
|
||||
"""
|
||||
print(f"\033[95mClicking element with index: {index}\033[0m")
|
||||
logger.debug(f"\033[95mClicking element with index: {index}\033[0m")
|
||||
return await self._execute_browser_action("click_element", {"index": index})
|
||||
|
||||
@openapi_schema({
|
||||
|
@ -333,7 +332,7 @@ class SandboxBrowserTool(SandboxToolsBase):
|
|||
Returns:
|
||||
dict: Result of the execution
|
||||
"""
|
||||
print(f"\033[95mInputting text into element {index}: {text}\033[0m")
|
||||
logger.debug(f"\033[95mInputting text into element {index}: {text}\033[0m")
|
||||
return await self._execute_browser_action("input_text", {"index": index, "text": text})
|
||||
|
||||
@openapi_schema({
|
||||
|
@ -373,7 +372,7 @@ class SandboxBrowserTool(SandboxToolsBase):
|
|||
Returns:
|
||||
dict: Result of the execution
|
||||
"""
|
||||
print(f"\033[95mSending keys: {keys}\033[0m")
|
||||
logger.debug(f"\033[95mSending keys: {keys}\033[0m")
|
||||
return await self._execute_browser_action("send_keys", {"keys": keys})
|
||||
|
||||
@openapi_schema({
|
||||
|
@ -413,7 +412,7 @@ class SandboxBrowserTool(SandboxToolsBase):
|
|||
Returns:
|
||||
dict: Result of the execution
|
||||
"""
|
||||
print(f"\033[95mSwitching to tab: {page_id}\033[0m")
|
||||
logger.debug(f"\033[95mSwitching to tab: {page_id}\033[0m")
|
||||
return await self._execute_browser_action("switch_tab", {"page_id": page_id})
|
||||
|
||||
# @openapi_schema({
|
||||
|
@ -453,7 +452,7 @@ class SandboxBrowserTool(SandboxToolsBase):
|
|||
# Returns:
|
||||
# dict: Result of the execution
|
||||
# """
|
||||
# print(f"\033[95mOpening new tab with URL: {url}\033[0m")
|
||||
# logger.debug(f"\033[95mOpening new tab with URL: {url}\033[0m")
|
||||
# return await self._execute_browser_action("open_tab", {"url": url})
|
||||
|
||||
@openapi_schema({
|
||||
|
@ -493,7 +492,7 @@ class SandboxBrowserTool(SandboxToolsBase):
|
|||
Returns:
|
||||
dict: Result of the execution
|
||||
"""
|
||||
print(f"\033[95mClosing tab: {page_id}\033[0m")
|
||||
logger.debug(f"\033[95mClosing tab: {page_id}\033[0m")
|
||||
return await self._execute_browser_action("close_tab", {"page_id": page_id})
|
||||
|
||||
# @openapi_schema({
|
||||
|
@ -533,25 +532,25 @@ class SandboxBrowserTool(SandboxToolsBase):
|
|||
# Returns:
|
||||
# dict: Result of the execution
|
||||
# """
|
||||
# print(f"\033[95mExtracting content with goal: {goal}\033[0m")
|
||||
# logger.debug(f"\033[95mExtracting content with goal: {goal}\033[0m")
|
||||
# result = await self._execute_browser_action("extract_content", {"goal": goal})
|
||||
|
||||
# # Format content for better readability
|
||||
# if result.get("success"):
|
||||
# print(f"\033[92mContent extraction successful\033[0m")
|
||||
# logger.debug(f"\033[92mContent extraction successful\033[0m")
|
||||
# content = result.data.get("content", "")
|
||||
# url = result.data.get("url", "")
|
||||
# title = result.data.get("title", "")
|
||||
|
||||
# if content:
|
||||
# content_preview = content[:200] + "..." if len(content) > 200 else content
|
||||
# print(f"\033[95mExtracted content from {title} ({url}):\033[0m")
|
||||
# print(f"\033[96m{content_preview}\033[0m")
|
||||
# print(f"\033[95mTotal content length: {len(content)} characters\033[0m")
|
||||
# logger.debug(f"\033[95mExtracted content from {title} ({url}):\033[0m")
|
||||
# logger.debug(f"\033[96m{content_preview}\033[0m")
|
||||
# logger.debug(f"\033[95mTotal content length: {len(content)} characters\033[0m")
|
||||
# else:
|
||||
# print(f"\033[93mNo content extracted from {url}\033[0m")
|
||||
# logger.debug(f"\033[93mNo content extracted from {url}\033[0m")
|
||||
# else:
|
||||
# print(f"\033[91mFailed to extract content: {result.data.get('error', 'Unknown error')}\033[0m")
|
||||
# logger.debug(f"\033[91mFailed to extract content: {result.data.get('error', 'Unknown error')}\033[0m")
|
||||
|
||||
# return result
|
||||
|
||||
|
@ -594,9 +593,9 @@ class SandboxBrowserTool(SandboxToolsBase):
|
|||
params = {}
|
||||
if amount is not None:
|
||||
params["amount"] = amount
|
||||
print(f"\033[95mScrolling down by {amount} pixels\033[0m")
|
||||
logger.debug(f"\033[95mScrolling down by {amount} pixels\033[0m")
|
||||
else:
|
||||
print(f"\033[95mScrolling down one page\033[0m")
|
||||
logger.debug(f"\033[95mScrolling down one page\033[0m")
|
||||
|
||||
return await self._execute_browser_action("scroll_down", params)
|
||||
|
||||
|
@ -639,9 +638,9 @@ class SandboxBrowserTool(SandboxToolsBase):
|
|||
params = {}
|
||||
if amount is not None:
|
||||
params["amount"] = amount
|
||||
print(f"\033[95mScrolling up by {amount} pixels\033[0m")
|
||||
logger.debug(f"\033[95mScrolling up by {amount} pixels\033[0m")
|
||||
else:
|
||||
print(f"\033[95mScrolling up one page\033[0m")
|
||||
logger.debug(f"\033[95mScrolling up one page\033[0m")
|
||||
|
||||
return await self._execute_browser_action("scroll_up", params)
|
||||
|
||||
|
@ -682,7 +681,7 @@ class SandboxBrowserTool(SandboxToolsBase):
|
|||
Returns:
|
||||
dict: Result of the execution
|
||||
"""
|
||||
print(f"\033[95mScrolling to text: {text}\033[0m")
|
||||
logger.debug(f"\033[95mScrolling to text: {text}\033[0m")
|
||||
return await self._execute_browser_action("scroll_to_text", {"text": text})
|
||||
|
||||
@openapi_schema({
|
||||
|
@ -722,7 +721,7 @@ class SandboxBrowserTool(SandboxToolsBase):
|
|||
Returns:
|
||||
dict: Result of the execution with the dropdown options
|
||||
"""
|
||||
print(f"\033[95mGetting options from dropdown with index: {index}\033[0m")
|
||||
logger.debug(f"\033[95mGetting options from dropdown with index: {index}\033[0m")
|
||||
return await self._execute_browser_action("get_dropdown_options", {"index": index})
|
||||
|
||||
@openapi_schema({
|
||||
|
@ -768,7 +767,7 @@ class SandboxBrowserTool(SandboxToolsBase):
|
|||
Returns:
|
||||
dict: Result of the execution
|
||||
"""
|
||||
print(f"\033[95mSelecting option '{text}' from dropdown with index: {index}\033[0m")
|
||||
logger.debug(f"\033[95mSelecting option '{text}' from dropdown with index: {index}\033[0m")
|
||||
return await self._execute_browser_action("select_dropdown_option", {"index": index, "text": text})
|
||||
|
||||
@openapi_schema({
|
||||
|
@ -842,13 +841,13 @@ class SandboxBrowserTool(SandboxToolsBase):
|
|||
if element_source and element_target:
|
||||
params["element_source"] = element_source
|
||||
params["element_target"] = element_target
|
||||
print(f"\033[95mDragging from element '{element_source}' to '{element_target}'\033[0m")
|
||||
logger.debug(f"\033[95mDragging from element '{element_source}' to '{element_target}'\033[0m")
|
||||
elif all(coord is not None for coord in [coord_source_x, coord_source_y, coord_target_x, coord_target_y]):
|
||||
params["coord_source_x"] = coord_source_x
|
||||
params["coord_source_y"] = coord_source_y
|
||||
params["coord_target_x"] = coord_target_x
|
||||
params["coord_target_y"] = coord_target_y
|
||||
print(f"\033[95mDragging from coordinates ({coord_source_x}, {coord_source_y}) to ({coord_target_x}, {coord_target_y})\033[0m")
|
||||
logger.debug(f"\033[95mDragging from coordinates ({coord_source_x}, {coord_source_y}) to ({coord_target_x}, {coord_target_y})\033[0m")
|
||||
else:
|
||||
return self.fail_response("Must provide either element selectors or coordinates for drag and drop")
|
||||
|
||||
|
@ -895,5 +894,5 @@ class SandboxBrowserTool(SandboxToolsBase):
|
|||
Returns:
|
||||
dict: Result of the execution
|
||||
"""
|
||||
print(f"\033[95mClicking at coordinates: ({x}, {y})\033[0m")
|
||||
logger.debug(f"\033[95mClicking at coordinates: ({x}, {y})\033[0m")
|
||||
return await self._execute_browser_action("click_coordinates", {"x": x, "y": y})
|
|
@ -113,7 +113,7 @@ app = FastAPI(lifespan=lifespan)
|
|||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["https://www.suna.so", "https://suna.so", "https://staging.suna.so"], #http://localhost:3000
|
||||
allow_origins=["https://www.suna.so", "https://suna.so", "https://staging.suna.so", "http://localhost:3000"], #http://localhost:3000
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
allow_headers=["Content-Type", "Authorization"],
|
||||
|
|
|
@ -179,19 +179,19 @@ class SandboxToolsBase(Tool):
|
|||
# Get or start the sandbox
|
||||
self._sandbox = await get_or_start_sandbox(self._sandbox_id)
|
||||
|
||||
# Log URLs if not already printed
|
||||
if not SandboxToolsBase._urls_printed:
|
||||
vnc_link = self._sandbox.get_preview_link(6080)
|
||||
website_link = self._sandbox.get_preview_link(8080)
|
||||
# # Log URLs if not already printed
|
||||
# if not SandboxToolsBase._urls_printed:
|
||||
# vnc_link = self._sandbox.get_preview_link(6080)
|
||||
# website_link = self._sandbox.get_preview_link(8080)
|
||||
|
||||
vnc_url = vnc_link.url if hasattr(vnc_link, 'url') else str(vnc_link)
|
||||
website_url = website_link.url if hasattr(website_link, 'url') else str(website_link)
|
||||
# vnc_url = vnc_link.url if hasattr(vnc_link, 'url') else str(vnc_link)
|
||||
# website_url = website_link.url if hasattr(website_link, 'url') else str(website_link)
|
||||
|
||||
print("\033[95m***")
|
||||
print(f"VNC URL: {vnc_url}")
|
||||
print(f"Website URL: {website_url}")
|
||||
print("***\033[0m")
|
||||
SandboxToolsBase._urls_printed = True
|
||||
# print("\033[95m***")
|
||||
# print(f"VNC URL: {vnc_url}")
|
||||
# print(f"Website URL: {website_url}")
|
||||
# print("***\033[0m")
|
||||
# SandboxToolsBase._urls_printed = True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving sandbox for project {self.project_id}: {str(e)}", exc_info=True)
|
||||
|
|
|
@ -80,7 +80,7 @@ def initialize():
|
|||
socket_connect_timeout=5.0, # Connection timeout
|
||||
retry_on_timeout=True, # Auto-retry on timeout
|
||||
health_check_interval=30, # Check connection health every 30 seconds
|
||||
max_connections=10 # Limit connections to prevent overloading
|
||||
max_connections=100 # Limit connections to prevent overloading
|
||||
)
|
||||
|
||||
return client
|
||||
|
|
Loading…
Reference in New Issue