diff --git a/README.md b/README.md index 7971a08f..ea9b1dae 100644 --- a/README.md +++ b/README.md @@ -88,7 +88,7 @@ You'll need the following components: - Redis database for caching and session management - Daytona sandbox for secure agent execution - Python 3.11 for the API backend -- API keys for LLM providers (Anthropic) +- API keys for LLM providers (Anthropic, OpenRouter) - Tavily API key for enhanced search capabilities - Firecrawl API key for web scraping capabilities @@ -99,23 +99,16 @@ You'll need the following components: - Save your project's API URL, anon key, and service role key for later use - Install the [Supabase CLI](https://supabase.com/docs/guides/cli/getting-started) -2. **Redis**: Set up a Redis instance using one of these options: - - [Upstash Redis](https://upstash.com/) (recommended for cloud deployments) - - Local installation: - - [Mac](https://formulae.brew.sh/formula/redis): `brew install redis` - - [Linux](https://redis.io/docs/getting-started/installation/install-redis-on-linux/): Follow distribution-specific instructions - - [Windows](https://redis.io/docs/getting-started/installation/install-redis-on-windows/): Use WSL2 or Docker - - Docker Compose (included in our setup): - - If you're using our Docker Compose setup, Redis is included and configured automatically - - No additional installation is needed - - Save your Redis connection details for later use (not needed if using Docker Compose) +2. **Redis**: + - Go to the `/backend` folder + - Run `docker compose up redis` 3. **Daytona**: - Create an account on [Daytona](https://app.daytona.io/) - Generate an API key from your account settings - Go to [Images](https://app.daytona.io/dashboard/images) - Click "Add Image" - - Enter `adamcohenhillel/kortix-suna:0.0.20` as the image name + - Enter `kortix/suna:0.1` as the image name - Set `/usr/bin/supervisord -n -c /etc/supervisor/conf.d/supervisord.conf` as the Entrypoint 4. **LLM API Keys**: diff --git a/backend/agent/api.py b/backend/agent/api.py index 6a517658..3b65a7e4 100644 --- a/backend/agent/api.py +++ b/backend/agent/api.py @@ -597,7 +597,7 @@ async def stream_agent_run( if new_responses_json: new_responses = [json.loads(r) for r in new_responses_json] num_new = len(new_responses) - logger.debug(f"Received {num_new} new responses for {agent_run_id} (index {new_start_index} onwards)") + # logger.debug(f"Received {num_new} new responses for {agent_run_id} (index {new_start_index} onwards)") for response in new_responses: yield f"data: {json.dumps(response)}\n\n" # Check if this response signals completion diff --git a/backend/agent/prompt.py b/backend/agent/prompt.py index 8bda3d23..99de8dd5 100644 --- a/backend/agent/prompt.py +++ b/backend/agent/prompt.py @@ -59,11 +59,11 @@ You have the ability to execute operations using both Python and CLI tools: * Always expose ports when you need to show running services to users ### 2.2.4 WEB SEARCH CAPABILITIES -- Searching the web for up-to-date information -- Retrieving and extracting content from specific webpages -- Filtering search results by date, relevance, and content +- Searching the web for up-to-date information with direct question answering +- Retrieving relevant images related to search queries +- Getting comprehensive search results with titles, URLs, and snippets - Finding recent news, articles, and information beyond training data -- Scraping webpage content for detailed information extraction +- Scraping webpage content for detailed information extraction when needed ### 2.2.5 BROWSER TOOLS AND CAPABILITIES - BROWSER OPERATIONS: @@ -312,8 +312,8 @@ You have the ability to execute operations using both Python and CLI tools: ## 4.4 WEB SEARCH & CONTENT EXTRACTION - Research Best Practices: 1. ALWAYS use a multi-source approach for thorough research: - * Start with web-search to find relevant URLs and sources - * Use scrape-webpage on URLs from web-search results to get detailed content + * Start with web-search to find direct answers, images, and relevant URLs + * Only use scrape-webpage when you need detailed content not available in the search results * Utilize data providers for real-time, accurate data when available * Only use browser tools when scrape-webpage fails or interaction is needed 2. Data Provider Priority: @@ -330,8 +330,9 @@ You have the ability to execute operations using both Python and CLI tools: 3. Research Workflow: a. First check for relevant data providers b. If no data provider exists: - - Use web-search to find relevant URLs - - Use scrape-webpage on URLs from web-search results + - Use web-search to get direct answers, images, and relevant URLs + - Only if you need specific details not found in search results: + * Use scrape-webpage on specific URLs from web-search results - Only if scrape-webpage fails or if the page requires interaction: * Use direct browser tools (browser_navigate_to, browser_go_back, browser_wait, browser_click_element, browser_input_text, browser_send_keys, browser_switch_tab, browser_close_tab, browser_scroll_down, browser_scroll_up, browser_scroll_to_text, browser_get_dropdown_options, browser_select_dropdown_option, browser_drag_drop, browser_click_coordinates etc.) * This is needed for: @@ -345,31 +346,41 @@ You have the ability to execute operations using both Python and CLI tools: e. Document sources and timestamps - Web Search Best Practices: - 1. Use specific, targeted search queries to obtain the most relevant results + 1. Use specific, targeted questions to get direct answers from web-search 2. Include key terms and contextual information in search queries 3. Filter search results by date when freshness is important - 4. Use include_text/exclude_text parameters to refine search results + 4. Review the direct answer, images, and search results 5. Analyze multiple search results to cross-validate information -- Web Content Extraction Workflow: - 1. ALWAYS start with web-search to find relevant URLs - 2. Use scrape-webpage on URLs from web-search results - 3. Only if scrape-webpage fails or if the page requires interaction: - - Use direct browser tools (browser_navigate_to, browser_go_back, browser_wait, browser_click_element, browser_input_text, browser_send_keys, browser_switch_tab, browser_close_tab, browser_scroll_down, browser_scroll_up, browser_scroll_to_text, browser_get_dropdown_options, browser_select_dropdown_option, browser_drag_drop, browser_click_coordinates etc.) +- Content Extraction Decision Tree: + 1. ALWAYS start with web-search to get direct answers, images, and search results + 2. Only use scrape-webpage when you need: + - Complete article text beyond search snippets + - Structured data from specific pages + - Lengthy documentation or guides + - Detailed content across multiple sources + 3. Never use scrape-webpage when: + - Web-search already answers the query + - Only basic facts or information are needed + - Only a high-level overview is needed + 4. Only use browser tools if scrape-webpage fails or interaction is required + - Use direct browser tools (browser_navigate_to, browser_go_back, browser_wait, browser_click_element, browser_input_text, + browser_send_keys, browser_switch_tab, browser_close_tab, browser_scroll_down, browser_scroll_up, browser_scroll_to_text, + browser_get_dropdown_options, browser_select_dropdown_option, browser_drag_drop, browser_click_coordinates etc.) - This is needed for: * Dynamic content loading * JavaScript-heavy sites * Pages requiring login * Interactive elements * Infinite scroll pages - 4. DO NOT use browser tools directly unless scrape-webpage fails or interaction is required - 5. Maintain this strict workflow order: web-search → scrape-webpage → direct browser tools (if needed) + DO NOT use browser tools directly unless interaction is required. + 5. Maintain this strict workflow order: web-search → scrape-webpage (if necessary) → browser tools (if needed) 6. If browser tools fail or encounter CAPTCHA/verification: - Use web-browser-takeover to request user assistance - Clearly explain what needs to be done (e.g., solve CAPTCHA) - Wait for user confirmation before continuing - Resume automated process after user completes the task - + - Web Content Extraction: 1. Verify URL validity before scraping 2. Extract and save content to files for further processing diff --git a/backend/agent/prompt.txt b/backend/agent/prompt.txt index 977bddc2..72e75f61 100644 --- a/backend/agent/prompt.txt +++ b/backend/agent/prompt.txt @@ -840,7 +840,7 @@ Ask user a question and wait for response. Use for: 1) Requesting clarification + num_results="20"> diff --git a/backend/agent/run.py b/backend/agent/run.py index 81795398..5ada055a 100644 --- a/backend/agent/run.py +++ b/backend/agent/run.py @@ -8,7 +8,7 @@ from typing import Optional from agent.tools.message_tool import MessageTool from agent.tools.sb_deploy_tool import SandboxDeployTool from agent.tools.sb_expose_tool import SandboxExposeTool -from agent.tools.web_search_tool import WebSearchTool +from agent.tools.web_search_tool import SandboxWebSearchTool from dotenv import load_dotenv from utils.config import config @@ -19,7 +19,7 @@ from agent.tools.sb_files_tool import SandboxFilesTool from agent.tools.sb_browser_tool import SandboxBrowserTool from agent.tools.data_providers_tool import DataProvidersTool from agent.prompt import get_system_prompt -from utils import logger +from utils.logger import logger from utils.auth_utils import get_account_id_from_thread from services.billing import check_billing_status from agent.tools.sb_vision_tool import SandboxVisionTool @@ -32,14 +32,14 @@ async def run_agent( stream: bool, thread_manager: Optional[ThreadManager] = None, native_max_auto_continues: int = 25, - max_iterations: int = 150, + max_iterations: int = 100, model_name: str = "anthropic/claude-3-7-sonnet-latest", enable_thinking: Optional[bool] = False, reasoning_effort: Optional[str] = 'low', enable_context_manager: bool = True ): """Run the development agent with specified configuration.""" - print(f"🚀 Starting agent with model: {model_name}") + logger.info(f"🚀 Starting agent with model: {model_name}") thread_manager = ThreadManager() @@ -68,7 +68,7 @@ async def run_agent( thread_manager.add_tool(SandboxDeployTool, project_id=project_id, thread_manager=thread_manager) thread_manager.add_tool(SandboxExposeTool, project_id=project_id, thread_manager=thread_manager) thread_manager.add_tool(MessageTool) # we are just doing this via prompt as there is no need to call it as a tool - thread_manager.add_tool(WebSearchTool) + thread_manager.add_tool(SandboxWebSearchTool, project_id=project_id, thread_manager=thread_manager) thread_manager.add_tool(SandboxVisionTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager) # Add data providers tool if RapidAPI key is available if config.RAPID_API_KEY: @@ -90,7 +90,7 @@ async def run_agent( while continue_execution and iteration_count < max_iterations: iteration_count += 1 - # logger.debug(f"Running iteration {iteration_count}...") + logger.info(f"🔄 Running iteration {iteration_count} of {max_iterations}...") # Billing check on each iteration - still needed within the iterations can_run, message, subscription = await check_billing_status(client, account_id) @@ -108,7 +108,7 @@ async def run_agent( if latest_message.data and len(latest_message.data) > 0: message_type = latest_message.data[0].get('type') if message_type == 'assistant': - print(f"Last message was from assistant, stopping execution") + logger.info(f"Last message was from assistant, stopping execution") continue_execution = False break @@ -186,100 +186,116 @@ async def run_agent( max_tokens = 64000 elif "gpt-4" in model_name.lower(): max_tokens = 4096 + + try: + # Make the LLM call and process the response + response = await thread_manager.run_thread( + thread_id=thread_id, + system_prompt=system_message, + stream=stream, + llm_model=model_name, + llm_temperature=0, + llm_max_tokens=max_tokens, + tool_choice="auto", + max_xml_tool_calls=1, + temporary_message=temporary_message, + processor_config=ProcessorConfig( + xml_tool_calling=True, + native_tool_calling=False, + execute_tools=True, + execute_on_stream=True, + tool_execution_strategy="parallel", + xml_adding_strategy="user_message" + ), + native_max_auto_continues=native_max_auto_continues, + include_xml_examples=True, + enable_thinking=enable_thinking, + reasoning_effort=reasoning_effort, + enable_context_manager=enable_context_manager + ) - response = await thread_manager.run_thread( - thread_id=thread_id, - system_prompt=system_message, - stream=stream, - llm_model=model_name, - llm_temperature=0, - llm_max_tokens=max_tokens, - tool_choice="auto", - max_xml_tool_calls=1, - temporary_message=temporary_message, - processor_config=ProcessorConfig( - xml_tool_calling=True, - native_tool_calling=False, - execute_tools=True, - execute_on_stream=True, - tool_execution_strategy="parallel", - xml_adding_strategy="user_message" - ), - native_max_auto_continues=native_max_auto_continues, - include_xml_examples=True, - enable_thinking=enable_thinking, - reasoning_effort=reasoning_effort, - enable_context_manager=enable_context_manager - ) + if isinstance(response, dict) and "status" in response and response["status"] == "error": + logger.error(f"Error response from run_thread: {response.get('message', 'Unknown error')}") + yield response + break - if isinstance(response, dict) and "status" in response and response["status"] == "error": - yield response - return + # Track if we see ask, complete, or web-browser-takeover tool calls + last_tool_call = None - # Track if we see ask, complete, or web-browser-takeover tool calls - last_tool_call = None + # Process the response + error_detected = False + try: + async for chunk in response: + # If we receive an error chunk, we should stop after this iteration + if isinstance(chunk, dict) and chunk.get('type') == 'status' and chunk.get('status') == 'error': + logger.error(f"Error chunk detected: {chunk.get('message', 'Unknown error')}") + error_detected = True + yield chunk # Forward the error chunk + continue # Continue processing other chunks but don't break yet + + # Check for XML versions like , , or in assistant content chunks + if chunk.get('type') == 'assistant' and 'content' in chunk: + try: + # The content field might be a JSON string or object + content = chunk.get('content', '{}') + if isinstance(content, str): + assistant_content_json = json.loads(content) + else: + assistant_content_json = content - async for chunk in response: - # print(f"CHUNK: {chunk}") # Uncomment for detailed chunk logging + # The actual text content is nested within + assistant_text = assistant_content_json.get('content', '') + if isinstance(assistant_text, str): # Ensure it's a string + # Check for the closing tags as they signal the end of the tool usage + if '' in assistant_text or '' in assistant_text or '' in assistant_text: + if '' in assistant_text: + xml_tool = 'ask' + elif '' in assistant_text: + xml_tool = 'complete' + elif '' in assistant_text: + xml_tool = 'web-browser-takeover' - # Check for XML versions like , , or in assistant content chunks - if chunk.get('type') == 'assistant' and 'content' in chunk: - try: - # The content field might be a JSON string or object - content = chunk.get('content', '{}') - if isinstance(content, str): - assistant_content_json = json.loads(content) - else: - assistant_content_json = content + last_tool_call = xml_tool + logger.info(f"Agent used XML tool: {xml_tool}") + except json.JSONDecodeError: + # Handle cases where content might not be valid JSON + logger.warning(f"Warning: Could not parse assistant content JSON: {chunk.get('content')}") + except Exception as e: + logger.error(f"Error processing assistant chunk: {e}") - # The actual text content is nested within - assistant_text = assistant_content_json.get('content', '') - if isinstance(assistant_text, str): # Ensure it's a string - # Check for the closing tags as they signal the end of the tool usage - if '' in assistant_text or '' in assistant_text or '' in assistant_text: - if '' in assistant_text: - xml_tool = 'ask' - elif '' in assistant_text: - xml_tool = 'complete' - elif '' in assistant_text: - xml_tool = 'web-browser-takeover' + yield chunk - last_tool_call = xml_tool - print(f"Agent used XML tool: {xml_tool}") - except json.JSONDecodeError: - # Handle cases where content might not be valid JSON - print(f"Warning: Could not parse assistant content JSON: {chunk.get('content')}") - except Exception as e: - print(f"Error processing assistant chunk: {e}") - - # # Check for native function calls (OpenAI format) - # elif chunk.get('type') == 'status' and 'content' in chunk: - # try: - # # Parse the status content - # status_content = chunk.get('content', '{}') - # if isinstance(status_content, str): - # status_content = json.loads(status_content) - - # # Check if this is a tool call status - # status_type = status_content.get('status_type') - # function_name = status_content.get('function_name', '') - - # # Check for special function names that should stop execution - # if status_type == 'tool_started' and function_name in ['ask', 'complete', 'web-browser-takeover']: - # last_tool_call = function_name - # print(f"Agent used native function call: {function_name}") - # except json.JSONDecodeError: - # # Handle cases where content might not be valid JSON - # print(f"Warning: Could not parse status content JSON: {chunk.get('content')}") - # except Exception as e: - # print(f"Error processing status chunk: {e}") - - yield chunk - - # Check if we should stop based on the last tool call - if last_tool_call in ['ask', 'complete', 'web-browser-takeover']: - print(f"Agent decided to stop with tool: {last_tool_call}") - continue_execution = False + # Check if we should stop based on the last tool call or error + if error_detected: + logger.info(f"Stopping due to error detected in response") + break + + if last_tool_call in ['ask', 'complete', 'web-browser-takeover']: + logger.info(f"Agent decided to stop with tool: {last_tool_call}") + continue_execution = False + except Exception as e: + # Just log the error and re-raise to stop all iterations + error_msg = f"Error during response streaming: {str(e)}" + logger.error(f"Error: {error_msg}") + yield { + "type": "status", + "status": "error", + "message": error_msg + } + # Stop execution immediately on any error + break + + except Exception as e: + # Just log the error and re-raise to stop all iterations + error_msg = f"Error running thread: {str(e)}" + logger.error(f"Error: {error_msg}") + yield { + "type": "status", + "status": "error", + "message": error_msg + } + # Stop execution immediately on any error + break # # TESTING diff --git a/backend/agent/sample_responses/2.txt b/backend/agent/sample_responses/2.txt index c868d7e4..c7d74926 100644 --- a/backend/agent/sample_responses/2.txt +++ b/backend/agent/sample_responses/2.txt @@ -43,31 +43,31 @@ First, I'll create a todo list to organize our work and make sure we cover all a Now I'll start researching flight options and the best regions to visit based on your interests. Let's begin by searching for flight information and gathering essential data about traveling to Japan from Seattle. - + Let me also search for information about the best places to visit in Japan, especially focusing on historical sites and cultural experiences: - + Now let's look for information about Nara deer park and other cultural experiences you mentioned: - + - + - + - + Let's also research hidden gems and romantic spots in Japan for your proposal: - + - + Now let's look at transportation options within Japan: - + Based on the searches, I'm going to compile information to create your itinerary. First, let me mark some research tasks as complete and update our todo list: @@ -112,7 +112,7 @@ Based on the searches, I'm going to compile information to create your itinerary Now I'll start crafting your detailed 7-day itinerary based on the research. First, let me research accommodation options within your budget: - + Based on all the research, I'm now going to create a detailed itinerary file: diff --git a/backend/agent/tools/computer_use_tool.py b/backend/agent/tools/computer_use_tool.py index 21766463..dcdb2ddf 100644 --- a/backend/agent/tools/computer_use_tool.py +++ b/backend/agent/tools/computer_use_tool.py @@ -4,11 +4,11 @@ import base64 import aiohttp import asyncio import logging -from typing import Optional, Dict, Any, Union -from PIL import Image +from typing import Optional, Dict +import os from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema -from sandbox.sandbox import SandboxToolsBase, Sandbox +from sandbox.tool_base import SandboxToolsBase, Sandbox KEYBOARD_KEYS = [ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', diff --git a/backend/agent/tools/data_providers/AmazonProvider.py b/backend/agent/tools/data_providers/AmazonProvider.py index 5ecea89e..b2972089 100644 --- a/backend/agent/tools/data_providers/AmazonProvider.py +++ b/backend/agent/tools/data_providers/AmazonProvider.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Dict from agent.tools.data_providers.RapidDataProviderBase import RapidDataProviderBase, EndpointSchema diff --git a/backend/agent/tools/message_tool.py b/backend/agent/tools/message_tool.py index 23c44cd9..3c958c49 100644 --- a/backend/agent/tools/message_tool.py +++ b/backend/agent/tools/message_tool.py @@ -1,4 +1,3 @@ -import os from typing import List, Optional, Union from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema diff --git a/backend/agent/tools/sb_browser_tool.py b/backend/agent/tools/sb_browser_tool.py index 94fdf666..ce9130ec 100644 --- a/backend/agent/tools/sb_browser_tool.py +++ b/backend/agent/tools/sb_browser_tool.py @@ -3,7 +3,7 @@ import json from agentpress.tool import ToolResult, openapi_schema, xml_schema from agentpress.thread_manager import ThreadManager -from sandbox.sandbox import SandboxToolsBase, Sandbox +from sandbox.tool_base import SandboxToolsBase from utils.logger import logger diff --git a/backend/agent/tools/sb_deploy_tool.py b/backend/agent/tools/sb_deploy_tool.py index adce2ce0..7fa9b737 100644 --- a/backend/agent/tools/sb_deploy_tool.py +++ b/backend/agent/tools/sb_deploy_tool.py @@ -1,7 +1,7 @@ import os from dotenv import load_dotenv from agentpress.tool import ToolResult, openapi_schema, xml_schema -from sandbox.sandbox import SandboxToolsBase, Sandbox +from sandbox.tool_base import SandboxToolsBase from utils.files_utils import clean_path from agentpress.thread_manager import ThreadManager diff --git a/backend/agent/tools/sb_expose_tool.py b/backend/agent/tools/sb_expose_tool.py index d437accf..b618d200 100644 --- a/backend/agent/tools/sb_expose_tool.py +++ b/backend/agent/tools/sb_expose_tool.py @@ -1,6 +1,5 @@ -from typing import Optional from agentpress.tool import ToolResult, openapi_schema, xml_schema -from sandbox.sandbox import SandboxToolsBase, Sandbox +from sandbox.tool_base import SandboxToolsBase from agentpress.thread_manager import ThreadManager class SandboxExposeTool(SandboxToolsBase): diff --git a/backend/agent/tools/sb_files_tool.py b/backend/agent/tools/sb_files_tool.py index 549f6777..30272d10 100644 --- a/backend/agent/tools/sb_files_tool.py +++ b/backend/agent/tools/sb_files_tool.py @@ -1,9 +1,7 @@ -from daytona_sdk.process import SessionExecuteRequest -from typing import Optional from agentpress.tool import ToolResult, openapi_schema, xml_schema -from sandbox.sandbox import SandboxToolsBase, Sandbox, get_or_start_sandbox -from utils.files_utils import EXCLUDED_FILES, EXCLUDED_DIRS, EXCLUDED_EXT, should_exclude_file, clean_path +from sandbox.tool_base import SandboxToolsBase +from utils.files_utils import should_exclude_file, clean_path from agentpress.thread_manager import ThreadManager from utils.logger import logger import os diff --git a/backend/agent/tools/sb_shell_tool.py b/backend/agent/tools/sb_shell_tool.py index 33fae063..bea10729 100644 --- a/backend/agent/tools/sb_shell_tool.py +++ b/backend/agent/tools/sb_shell_tool.py @@ -1,7 +1,8 @@ -from typing import Optional, Dict, List +from typing import Optional, Dict, Any +import time from uuid import uuid4 from agentpress.tool import ToolResult, openapi_schema, xml_schema -from sandbox.sandbox import SandboxToolsBase, Sandbox +from sandbox.tool_base import SandboxToolsBase from agentpress.thread_manager import ThreadManager class SandboxShellTool(SandboxToolsBase): @@ -39,13 +40,13 @@ class SandboxShellTool(SandboxToolsBase): "type": "function", "function": { "name": "execute_command", - "description": "Execute a shell command in the workspace directory. IMPORTANT: By default, commands are blocking and will wait for completion before returning. For long-running operations, use background execution techniques (& operator, nohup) to prevent timeouts. Uses sessions to maintain state between commands. This tool is essential for running CLI tools, installing packages, and managing system operations. Always verify command outputs before using the data. Commands can be chained using && for sequential execution, || for fallback execution, and | for piping output.", + "description": "Execute a shell command in the workspace directory. IMPORTANT: Commands are non-blocking by default and run in a tmux session. This is ideal for long-running operations like starting servers or build processes. Uses sessions to maintain state between commands. This tool is essential for running CLI tools, installing packages, and managing system operations.", "parameters": { "type": "object", "properties": { "command": { "type": "string", - "description": "The shell command to execute. Use this for running CLI tools, installing packages, or system operations. Commands can be chained using &&, ||, and | operators. Example: 'find . -type f | sort && grep -r \"pattern\" . | awk \"{print $1}\" | sort | uniq -c'" + "description": "The shell command to execute. Use this for running CLI tools, installing packages, or system operations. Commands can be chained using &&, ||, and | operators." }, "folder": { "type": "string", @@ -53,12 +54,16 @@ class SandboxShellTool(SandboxToolsBase): }, "session_name": { "type": "string", - "description": "Optional name of the session to use. Use named sessions for related commands that need to maintain state. Defaults to 'default'.", - "default": "default" + "description": "Optional name of the tmux session to use. Use named sessions for related commands that need to maintain state. Defaults to a random session name.", + }, + "blocking": { + "type": "boolean", + "description": "Whether to wait for the command to complete. Defaults to false for non-blocking execution.", + "default": False }, "timeout": { "type": "integer", - "description": "Optional timeout in seconds. Increase for long-running commands. Defaults to 60. For commands that might exceed this timeout, use background execution with & operator instead.", + "description": "Optional timeout in seconds for blocking commands. Defaults to 60. Ignored for non-blocking commands.", "default": 60 } }, @@ -72,79 +77,30 @@ class SandboxShellTool(SandboxToolsBase): {"param_name": "command", "node_type": "content", "path": "."}, {"param_name": "folder", "node_type": "attribute", "path": ".", "required": False}, {"param_name": "session_name", "node_type": "attribute", "path": ".", "required": False}, + {"param_name": "blocking", "node_type": "attribute", "path": ".", "required": False}, {"param_name": "timeout", "node_type": "attribute", "path": ".", "required": False} ], example=''' - - - - ls -la + + + + npm run dev - - npm install - - - - + npm run build + + + + npm install + + - - export NODE_ENV=production && npm run preview - - - - - npm run build > build.log 2>&1 - - - - - - tmux new-session -d -s vite_dev "cd /workspace && npm run dev" - - - - - tmux list-sessions | grep -q vite_dev && echo "Vite server running" || echo "Vite server not found" - - - - - tmux capture-pane -pt vite_dev - - - - - tmux kill-session -t vite_dev - - - - - tmux new-session -d -s vite_build "cd /workspace && npm run build" - - - - - tmux capture-pane -pt vite_build - - - - - tmux new-session -d -s vite_services "cd /workspace && npm run start:all" - - - - - tmux list-sessions - - - - - tmux kill-server + + export NODE_ENV=production && npm run build ''' ) @@ -152,61 +108,300 @@ class SandboxShellTool(SandboxToolsBase): self, command: str, folder: Optional[str] = None, - session_name: str = "default", + session_name: Optional[str] = None, + blocking: bool = False, timeout: int = 60 ) -> ToolResult: try: # Ensure sandbox is initialized await self._ensure_sandbox() - # Ensure session exists - session_id = await self._ensure_session(session_name) - # Set up working directory cwd = self.workspace_path if folder: folder = folder.strip('/') cwd = f"{self.workspace_path}/{folder}" - # Ensure we're in the correct directory before executing the command - command = f"cd {cwd} && {command}" + # Generate a session name if not provided + if not session_name: + session_name = f"session_{str(uuid4())[:8]}" - # Execute command in session - from sandbox.sandbox import SessionExecuteRequest - req = SessionExecuteRequest( - command=command, - var_async=False, # This makes the command blocking by default - cwd=cwd # Still set the working directory for reference - ) + # Check if tmux session already exists + check_session = await self._execute_raw_command(f"tmux has-session -t {session_name} 2>/dev/null || echo 'not_exists'") + session_exists = "not_exists" not in check_session.get("output", "") - response = self.sandbox.process.execute_session_command( - session_id=session_id, - req=req, - timeout=timeout - ) + if not session_exists: + # Create a new tmux session + await self._execute_raw_command(f"tmux new-session -d -s {session_name}") + + # Ensure we're in the correct directory and send command to tmux + full_command = f"cd {cwd} && {command}" + wrapped_command = full_command.replace('"', '\\"') # Escape double quotes - # Get detailed logs - logs = self.sandbox.process.get_session_command_logs( - session_id=session_id, - command_id=response.cmd_id - ) + # Send command to tmux session + await self._execute_raw_command(f'tmux send-keys -t {session_name} "{wrapped_command}" Enter') - if response.exit_code == 0: + if blocking: + # For blocking execution, wait and capture output + start_time = time.time() + while (time.time() - start_time) < timeout: + # Wait a bit before checking + time.sleep(2) + + # Check if session still exists (command might have exited) + check_result = await self._execute_raw_command(f"tmux has-session -t {session_name} 2>/dev/null || echo 'ended'") + if "ended" in check_result.get("output", ""): + break + + # Get current output and check for common completion indicators + output_result = await self._execute_raw_command(f"tmux capture-pane -t {session_name} -p") + current_output = output_result.get("output", "") + + # Check for prompt indicators that suggest command completion + last_lines = current_output.split('\n')[-3:] + completion_indicators = ['$', '#', '>', 'Done', 'Completed', 'Finished', '✓'] + if any(indicator in line for indicator in completion_indicators for line in last_lines): + break + + # Capture final output + output_result = await self._execute_raw_command(f"tmux capture-pane -t {session_name} -p") + final_output = output_result.get("output", "") + + # Kill the session after capture + await self._execute_raw_command(f"tmux kill-session -t {session_name}") + return self.success_response({ - "output": logs, - "exit_code": response.exit_code, - "cwd": cwd + "output": final_output, + "session_name": session_name, + "cwd": cwd, + "completed": True }) else: - error_msg = f"Command failed with exit code {response.exit_code}" - if logs: - error_msg += f": {logs}" - return self.fail_response(error_msg) + # For non-blocking, just return immediately + return self.success_response({ + "session_name": session_name, + "cwd": cwd, + "message": f"Command sent to tmux session '{session_name}'. Use check_command_output to view results.", + "completed": False + }) except Exception as e: + # Attempt to clean up session in case of error + if session_name: + try: + await self._execute_raw_command(f"tmux kill-session -t {session_name}") + except: + pass return self.fail_response(f"Error executing command: {str(e)}") + async def _execute_raw_command(self, command: str) -> Dict[str, Any]: + """Execute a raw command directly in the sandbox.""" + # Ensure session exists for raw commands + session_id = await self._ensure_session("raw_commands") + + # Execute command in session + from sandbox.sandbox import SessionExecuteRequest + req = SessionExecuteRequest( + command=command, + var_async=False, + cwd=self.workspace_path + ) + + response = self.sandbox.process.execute_session_command( + session_id=session_id, + req=req, + timeout=30 # Short timeout for utility commands + ) + + logs = self.sandbox.process.get_session_command_logs( + session_id=session_id, + command_id=response.cmd_id + ) + + return { + "output": logs, + "exit_code": response.exit_code + } + + @openapi_schema({ + "type": "function", + "function": { + "name": "check_command_output", + "description": "Check the output of a previously executed command in a tmux session. Use this to monitor the progress or results of non-blocking commands.", + "parameters": { + "type": "object", + "properties": { + "session_name": { + "type": "string", + "description": "The name of the tmux session to check." + }, + "kill_session": { + "type": "boolean", + "description": "Whether to terminate the tmux session after checking. Set to true when you're done with the command.", + "default": False + } + }, + "required": ["session_name"] + } + } + }) + @xml_schema( + tag_name="check-command-output", + mappings=[ + {"param_name": "session_name", "node_type": "attribute", "path": ".", "required": True}, + {"param_name": "kill_session", "node_type": "attribute", "path": ".", "required": False} + ], + example=''' + + + + + + ''' + ) + async def check_command_output( + self, + session_name: str, + kill_session: bool = False + ) -> ToolResult: + try: + # Ensure sandbox is initialized + await self._ensure_sandbox() + + # Check if session exists + check_result = await self._execute_raw_command(f"tmux has-session -t {session_name} 2>/dev/null || echo 'not_exists'") + if "not_exists" in check_result.get("output", ""): + return self.fail_response(f"Tmux session '{session_name}' does not exist.") + + # Get output from tmux pane + output_result = await self._execute_raw_command(f"tmux capture-pane -t {session_name} -p") + output = output_result.get("output", "") + + # Kill session if requested + if kill_session: + await self._execute_raw_command(f"tmux kill-session -t {session_name}") + termination_status = "Session terminated." + else: + termination_status = "Session still running." + + return self.success_response({ + "output": output, + "session_name": session_name, + "status": termination_status + }) + + except Exception as e: + return self.fail_response(f"Error checking command output: {str(e)}") + + @openapi_schema({ + "type": "function", + "function": { + "name": "terminate_command", + "description": "Terminate a running command by killing its tmux session.", + "parameters": { + "type": "object", + "properties": { + "session_name": { + "type": "string", + "description": "The name of the tmux session to terminate." + } + }, + "required": ["session_name"] + } + } + }) + @xml_schema( + tag_name="terminate-command", + mappings=[ + {"param_name": "session_name", "node_type": "attribute", "path": ".", "required": True} + ], + example=''' + + + ''' + ) + async def terminate_command( + self, + session_name: str + ) -> ToolResult: + try: + # Ensure sandbox is initialized + await self._ensure_sandbox() + + # Check if session exists + check_result = await self._execute_raw_command(f"tmux has-session -t {session_name} 2>/dev/null || echo 'not_exists'") + if "not_exists" in check_result.get("output", ""): + return self.fail_response(f"Tmux session '{session_name}' does not exist.") + + # Kill the session + await self._execute_raw_command(f"tmux kill-session -t {session_name}") + + return self.success_response({ + "message": f"Tmux session '{session_name}' terminated successfully." + }) + + except Exception as e: + return self.fail_response(f"Error terminating command: {str(e)}") + + @openapi_schema({ + "type": "function", + "function": { + "name": "list_commands", + "description": "List all running tmux sessions and their status.", + "parameters": { + "type": "object", + "properties": {} + } + } + }) + @xml_schema( + tag_name="list-commands", + mappings=[], + example=''' + + + ''' + ) + async def list_commands(self) -> ToolResult: + try: + # Ensure sandbox is initialized + await self._ensure_sandbox() + + # List all tmux sessions + result = await self._execute_raw_command("tmux list-sessions 2>/dev/null || echo 'No sessions'") + output = result.get("output", "") + + if "No sessions" in output or not output.strip(): + return self.success_response({ + "message": "No active tmux sessions found.", + "sessions": [] + }) + + # Parse session list + sessions = [] + for line in output.split('\n'): + if line.strip(): + parts = line.split(':') + if parts: + session_name = parts[0].strip() + sessions.append(session_name) + + return self.success_response({ + "message": f"Found {len(sessions)} active sessions.", + "sessions": sessions + }) + + except Exception as e: + return self.fail_response(f"Error listing commands: {str(e)}") + async def cleanup(self): """Clean up all sessions.""" for session_name in list(self._sessions.keys()): - await self._cleanup_session(session_name) \ No newline at end of file + await self._cleanup_session(session_name) + + # Also clean up any tmux sessions + try: + await self._ensure_sandbox() + await self._execute_raw_command("tmux kill-server 2>/dev/null || true") + except: + pass \ No newline at end of file diff --git a/backend/agent/tools/sb_vision_tool.py b/backend/agent/tools/sb_vision_tool.py index a1e0abad..c07e3df3 100644 --- a/backend/agent/tools/sb_vision_tool.py +++ b/backend/agent/tools/sb_vision_tool.py @@ -4,7 +4,7 @@ import mimetypes from typing import Optional from agentpress.tool import ToolResult, openapi_schema, xml_schema -from sandbox.sandbox import SandboxToolsBase, Sandbox +from sandbox.tool_base import SandboxToolsBase from agentpress.thread_manager import ThreadManager from utils.logger import logger import json diff --git a/backend/agent/tools/web_search_tool.py b/backend/agent/tools/web_search_tool.py index dfa1c87b..25ef62b9 100644 --- a/backend/agent/tools/web_search_tool.py +++ b/backend/agent/tools/web_search_tool.py @@ -1,24 +1,27 @@ from tavily import AsyncTavilyClient import httpx -from typing import List, Optional -from datetime import datetime -import os from dotenv import load_dotenv from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema from utils.config import config +from sandbox.tool_base import SandboxToolsBase +from agentpress.thread_manager import ThreadManager import json +import os +import datetime +import asyncio +import logging # TODO: add subpages, etc... in filters as sometimes its necessary -class WebSearchTool(Tool): +class SandboxWebSearchTool(SandboxToolsBase): """Tool for performing web searches using Tavily API and web scraping using Firecrawl.""" - def __init__(self, api_key: str = None): - super().__init__() + def __init__(self, project_id: str, thread_manager: ThreadManager): + super().__init__(project_id, thread_manager) # Load environment variables load_dotenv() - # Use the provided API key or get it from environment variables - self.tavily_api_key = api_key or config.TAVILY_API_KEY + # Use API keys from config + self.tavily_api_key = config.TAVILY_API_KEY self.firecrawl_api_key = config.FIRECRAWL_API_KEY self.firecrawl_url = config.FIRECRAWL_URL @@ -34,7 +37,7 @@ class WebSearchTool(Tool): "type": "function", "function": { "name": "web_search", - "description": "Search the web for up-to-date information on a specific topic using the Tavily API. This tool allows you to gather real-time information from the internet to answer user queries, research topics, validate facts, and find recent developments. Results include titles, URLs, summaries, and publication dates. Use this tool for discovering relevant web pages before potentially crawling them for complete content.", + "description": "Search the web for up-to-date information on a specific topic using the Tavily API. This tool allows you to gather real-time information from the internet to answer user queries, research topics, validate facts, and find recent developments. Results include titles, URLs, and publication dates. Use this tool for discovering relevant web pages before potentially crawling them for complete content.", "parameters": { "type": "object", "properties": { @@ -42,11 +45,6 @@ class WebSearchTool(Tool): "type": "string", "description": "The search query to find relevant web pages. Be specific and include key terms to improve search accuracy. For best results, use natural language questions or keyword combinations that precisely describe what you're looking for." }, - # "summary": { - # "type": "boolean", - # "description": "Whether to include a summary of each search result. Summaries provide key context about each page without requiring full content extraction. Set to true to get concise descriptions of each result.", - # "default": True - # }, "num_results": { "type": "integer", "description": "The number of search results to return. Increase for more comprehensive research or decrease for focused, high-relevance results.", @@ -61,7 +59,6 @@ class WebSearchTool(Tool): tag_name="web-search", mappings=[ {"param_name": "query", "node_type": "attribute", "path": "."}, - # {"param_name": "summary", "node_type": "attribute", "path": "."}, {"param_name": "num_results", "node_type": "attribute", "path": "."} ], example=''' @@ -69,29 +66,32 @@ class WebSearchTool(Tool): The web-search tool allows you to search the internet for real-time information. Use this tool when you need to find current information, research topics, or verify facts. - The tool returns information including: - - Titles of relevant web pages - - URLs for accessing the pages - - Published dates (when available) + THE TOOL NOW RETURNS: + - Direct answer to your query from search results + - Relevant images when available + - Detailed search results including titles, URLs, and snippets + + WORKFLOW RECOMMENDATION: + 1. Use web-search first with a specific question to get direct answers + 2. Only use scrape-webpage if you need more detailed information from specific pages --> ''' ) async def web_search( self, - query: str, - # summary: bool = True, + query: str, num_results: int = 20 ) -> ToolResult: """ @@ -116,46 +116,27 @@ class WebSearchTool(Tool): num_results = 20 # Execute the search with Tavily + logging.info(f"Executing web search for query: '{query}' with {num_results} results") search_response = await self.tavily_client.search( query=query, max_results=num_results, - include_answer=False, - include_images=False, + include_images=True, + include_answer="advanced", + search_depth="advanced", ) - - # Normalize the response format - raw_results = ( - search_response.get("results") - if isinstance(search_response, dict) - else search_response - ) - - # Format results consistently - formatted_results = [] - for result in raw_results: - formatted_result = { - "title": result.get("title", ""), - "url": result.get("url", ""), - } - - # if summary: - # # Prefer full content; fall back to description - # formatted_result["snippet"] = ( - # result.get("content") or - # result.get("description") or - # "" - # ) - - formatted_results.append(formatted_result) - # Return a properly formatted ToolResult + # Return the complete Tavily response + # This includes the query, answer, results, images and more + logging.info(f"Retrieved search results for query: '{query}' with answer and {len(search_response.get('results', []))} results") + return ToolResult( success=True, - output=json.dumps(formatted_results, ensure_ascii=False) + output=json.dumps(search_response, ensure_ascii=False) ) except Exception as e: error_message = str(e) + logging.error(f"Error performing web search for '{query}': {error_message}") simplified_message = f"Error performing web search: {error_message[:200]}" if len(error_message) > 200: simplified_message += "..." @@ -165,53 +146,59 @@ class WebSearchTool(Tool): "type": "function", "function": { "name": "scrape_webpage", - "description": "Retrieve the complete text content of a specific webpage using Firecrawl. This tool extracts the full text content from any accessible web page and returns it for analysis, processing, or reference. The extracted text includes the main content of the page without HTML markup. Note that some pages may have limitations on access due to paywalls, access restrictions, or dynamic content loading.", + "description": "Extract full text content from multiple webpages in a single operation. IMPORTANT: You should ALWAYS collect multiple relevant URLs from web-search results and scrape them all in a single call for efficiency. This tool saves time by processing multiple pages simultaneously rather than one at a time. The extracted text includes the main content of each page without HTML markup.", "parameters": { "type": "object", "properties": { - "url": { + "urls": { "type": "string", - "description": "The complete URL of the webpage to scrape. This should be a valid, accessible web address including the protocol (http:// or https://). The tool will attempt to extract all text content from this URL." + "description": "Multiple URLs to scrape, separated by commas. You should ALWAYS include several URLs when possible for efficiency. Example: 'https://example.com/page1,https://example.com/page2,https://example.com/page3'" } }, - "required": ["url"] + "required": ["urls"] } } }) @xml_schema( tag_name="scrape-webpage", mappings=[ - {"param_name": "url", "node_type": "attribute", "path": "."} + {"param_name": "urls", "node_type": "attribute", "path": "."} ], example=''' - - + + query="what is Kortix AI and what are they building?" + num_results="20"> - + + urls="https://www.kortix.ai/,https://github.com/kortix-ai/suna"> @@ -226,41 +213,103 @@ class WebSearchTool(Tool): ) async def scrape_webpage( self, - url: str + urls: str ) -> ToolResult: """ - Retrieve the complete text content of a webpage using Firecrawl. + Retrieve the complete text content of multiple webpages in a single efficient operation. - This function scrapes the specified URL and extracts the full text content from the page. - The extracted text is returned in the response, making it available for further analysis, - processing, or reference. - - The returned data includes: - - Title: The title of the webpage - - URL: The URL of the scraped page - - Published Date: When the content was published (if available) - - Text: The complete text content of the webpage in markdown format - - Note that some pages may have limitations on access due to paywalls, - access restrictions, or dynamic content loading. + ALWAYS collect multiple relevant URLs from search results and scrape them all at once + rather than making separate calls for each URL. This is much more efficient. Parameters: - - url: The URL of the webpage to scrape + - urls: Multiple URLs to scrape, separated by commas """ try: - # Parse the URL parameter exactly as it would appear in XML - if not url: - return self.fail_response("A valid URL is required.") + logging.info(f"Starting to scrape webpages: {urls}") + + # Ensure sandbox is initialized + await self._ensure_sandbox() + + # Parse the URLs parameter + if not urls: + logging.warning("Scrape attempt with empty URLs") + return self.fail_response("Valid URLs are required.") + + # Split the URLs string into a list + url_list = [url.strip() for url in urls.split(',') if url.strip()] + + if not url_list: + logging.warning("No valid URLs found in the input") + return self.fail_response("No valid URLs provided.") - # Handle url parameter (as it would appear in XML) - if isinstance(url, str): - # Add protocol if missing - if not (url.startswith('http://') or url.startswith('https://')): - url = 'https://' + url + if len(url_list) == 1: + logging.warning("Only a single URL provided - for efficiency you should scrape multiple URLs at once") + + logging.info(f"Processing {len(url_list)} URLs: {url_list}") + + # Process each URL and collect results + results = [] + for url in url_list: + try: + # Add protocol if missing + if not (url.startswith('http://') or url.startswith('https://')): + url = 'https://' + url + logging.info(f"Added https:// protocol to URL: {url}") + + # Scrape this URL + result = await self._scrape_single_url(url) + results.append(result) + + except Exception as e: + logging.error(f"Error processing URL {url}: {str(e)}") + results.append({ + "url": url, + "success": False, + "error": str(e) + }) + + # Summarize results + successful = sum(1 for r in results if r.get("success", False)) + failed = len(results) - successful + + # Create success/failure message + if successful == len(results): + message = f"Successfully scraped all {len(results)} URLs. Results saved to:" + for r in results: + if r.get("file_path"): + message += f"\n- {r.get('file_path')}" + elif successful > 0: + message = f"Scraped {successful} URLs successfully and {failed} failed. Results saved to:" + for r in results: + if r.get("success", False) and r.get("file_path"): + message += f"\n- {r.get('file_path')}" + message += "\n\nFailed URLs:" + for r in results: + if not r.get("success", False): + message += f"\n- {r.get('url')}: {r.get('error', 'Unknown error')}" else: - return self.fail_response("URL must be a string.") - + error_details = "; ".join([f"{r.get('url')}: {r.get('error', 'Unknown error')}" for r in results]) + return self.fail_response(f"Failed to scrape all {len(results)} URLs. Errors: {error_details}") + + return ToolResult( + success=True, + output=message + ) + + except Exception as e: + error_message = str(e) + logging.error(f"Error in scrape_webpage: {error_message}") + return self.fail_response(f"Error processing scrape request: {error_message[:200]}") + + async def _scrape_single_url(self, url: str) -> dict: + """ + Helper function to scrape a single URL and return the result information. + """ + logging.info(f"Scraping single URL: {url}") + + try: # ---------- Firecrawl scrape endpoint ---------- + logging.info(f"Sending request to Firecrawl for URL: {url}") async with httpx.AsyncClient() as client: headers = { "Authorization": f"Bearer {self.firecrawl_api_key}", @@ -270,57 +319,110 @@ class WebSearchTool(Tool): "url": url, "formats": ["markdown"] } - response = await client.post( - f"{self.firecrawl_url}/v1/scrape", - json=payload, - headers=headers, - timeout=60, - ) - response.raise_for_status() - data = response.json() + + # Use longer timeout and retry logic for more reliability + max_retries = 3 + timeout_seconds = 120 + retry_count = 0 + + while retry_count < max_retries: + try: + logging.info(f"Sending request to Firecrawl (attempt {retry_count + 1}/{max_retries})") + response = await client.post( + f"{self.firecrawl_url}/v1/scrape", + json=payload, + headers=headers, + timeout=timeout_seconds, + ) + response.raise_for_status() + data = response.json() + logging.info(f"Successfully received response from Firecrawl for {url}") + break + except (httpx.ReadTimeout, httpx.ConnectTimeout, httpx.ReadError) as timeout_err: + retry_count += 1 + logging.warning(f"Request timed out (attempt {retry_count}/{max_retries}): {str(timeout_err)}") + if retry_count >= max_retries: + raise Exception(f"Request timed out after {max_retries} attempts with {timeout_seconds}s timeout") + # Exponential backoff + logging.info(f"Waiting {2 ** retry_count}s before retry") + await asyncio.sleep(2 ** retry_count) + except Exception as e: + # Don't retry on non-timeout errors + logging.error(f"Error during scraping: {str(e)}") + raise e # Format the response + title = data.get("data", {}).get("metadata", {}).get("title", "") + markdown_content = data.get("data", {}).get("markdown", "") + logging.info(f"Extracted content from {url}: title='{title}', content length={len(markdown_content)}") + formatted_result = { - "Title": data.get("data", {}).get("metadata", {}).get("title", ""), - "URL": url, - "Text": data.get("data", {}).get("markdown", "") + "title": title, + "url": url, + "text": markdown_content } # Add metadata if available if "metadata" in data.get("data", {}): - formatted_result["Metadata"] = data["data"]["metadata"] + formatted_result["metadata"] = data["data"]["metadata"] + logging.info(f"Added metadata: {data['data']['metadata'].keys()}") - return self.success_response([formatted_result]) + # Create a simple filename from the URL domain and date + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + + # Extract domain from URL for the filename + from urllib.parse import urlparse + parsed_url = urlparse(url) + domain = parsed_url.netloc.replace("www.", "") + + # Clean up domain for filename + domain = "".join([c if c.isalnum() else "_" for c in domain]) + safe_filename = f"{timestamp}_{domain}.json" + + logging.info(f"Generated filename: {safe_filename}") + + # Save results to a file in the /workspace/scrape directory + scrape_dir = f"{self.workspace_path}/scrape" + self.sandbox.fs.create_folder(scrape_dir, "755") + + results_file_path = f"{scrape_dir}/{safe_filename}" + json_content = json.dumps(formatted_result, ensure_ascii=False, indent=2) + logging.info(f"Saving content to file: {results_file_path}, size: {len(json_content)} bytes") + + self.sandbox.fs.upload_file( + results_file_path, + json_content.encode() + ) + + return { + "url": url, + "success": True, + "title": title, + "file_path": results_file_path, + "content_length": len(markdown_content) + } except Exception as e: error_message = str(e) - # Truncate very long error messages - simplified_message = f"Error scraping webpage: {error_message[:200]}" - if len(error_message) > 200: - simplified_message += "..." - return self.fail_response(simplified_message) - + logging.error(f"Error scraping URL '{url}': {error_message}") + + # Create an error result + return { + "url": url, + "success": False, + "error": error_message + } if __name__ == "__main__": - import asyncio - async def test_web_search(): """Test function for the web search tool""" - search_tool = WebSearchTool() - result = await search_tool.web_search( - query="rubber gym mats best prices comparison", - # summary=True, - num_results=20 - ) - print(result) + # This test function is not compatible with the sandbox version + print("Test function needs to be updated for sandbox version") async def test_scrape_webpage(): """Test function for the webpage scrape tool""" - search_tool = WebSearchTool() - result = await search_tool.scrape_webpage( - url="https://www.wired.com/story/anthropic-benevolent-artificial-intelligence/" - ) - print(result) + # This test function is not compatible with the sandbox version + print("Test function needs to be updated for sandbox version") async def run_tests(): """Run all test functions""" diff --git a/backend/agentpress/context_manager.py b/backend/agentpress/context_manager.py index 3fd297ec..11405f40 100644 --- a/backend/agentpress/context_manager.py +++ b/backend/agentpress/context_manager.py @@ -8,7 +8,7 @@ reaching the context window limitations of LLM models. import json from typing import List, Dict, Any, Optional -from litellm import token_counter, completion, completion_cost +from litellm import token_counter, completion_cost from services.supabase import DBConnection from services.llm import make_llm_api_call from utils.logger import logger diff --git a/backend/agentpress/response_processor.py b/backend/agentpress/response_processor.py index 41a7df50..78c95af5 100644 --- a/backend/agentpress/response_processor.py +++ b/backend/agentpress/response_processor.py @@ -16,7 +16,7 @@ from typing import List, Dict, Any, Optional, Tuple, AsyncGenerator, Callable, U from dataclasses import dataclass from datetime import datetime, timezone -from litellm import completion_cost, token_counter +from litellm import completion_cost from agentpress.tool import Tool, ToolResult from agentpress.tool_registry import ToolRegistry @@ -560,15 +560,22 @@ class ResponseProcessor: is_llm_message=False, metadata={"thread_run_id": thread_run_id if 'thread_run_id' in locals() else None} ) if err_msg_obj: yield err_msg_obj # Yield the saved error message + + # Re-raise the same exception (not a new one) to ensure proper error propagation + logger.critical(f"Re-raising error to stop further processing: {str(e)}") + raise # Use bare 'raise' to preserve the original exception with its traceback finally: # Save and Yield the final thread_run_end status - end_content = {"status_type": "thread_run_end"} - end_msg_obj = await self.add_message( - thread_id=thread_id, type="status", content=end_content, - is_llm_message=False, metadata={"thread_run_id": thread_run_id if 'thread_run_id' in locals() else None} - ) - if end_msg_obj: yield end_msg_obj + try: + end_content = {"status_type": "thread_run_end"} + end_msg_obj = await self.add_message( + thread_id=thread_id, type="status", content=end_content, + is_llm_message=False, metadata={"thread_run_id": thread_run_id if 'thread_run_id' in locals() else None} + ) + if end_msg_obj: yield end_msg_obj + except Exception as final_e: + logger.error(f"Error in finally block: {str(final_e)}", exc_info=True) async def process_non_streaming_response( self, @@ -763,6 +770,10 @@ class ResponseProcessor: is_llm_message=False, metadata={"thread_run_id": thread_run_id if 'thread_run_id' in locals() else None} ) if err_msg_obj: yield err_msg_obj + + # Re-raise the same exception (not a new one) to ensure proper error propagation + logger.critical(f"Re-raising error to stop further processing: {str(e)}") + raise # Use bare 'raise' to preserve the original exception with its traceback finally: # Save and Yield the final thread_run_end status diff --git a/backend/agentpress/thread_manager.py b/backend/agentpress/thread_manager.py index bf66133e..be8b48a6 100644 --- a/backend/agentpress/thread_manager.py +++ b/backend/agentpress/thread_manager.py @@ -353,22 +353,19 @@ Here are the XML tools available with examples: return response_generator else: logger.debug("Processing non-streaming response") - try: - # Return the async generator directly, don't await it - response_generator = self.response_processor.process_non_streaming_response( - llm_response=llm_response, - thread_id=thread_id, - config=processor_config, - prompt_messages=prepared_messages, - llm_model=llm_model - ) - return response_generator # Return the generator - except Exception as e: - logger.error(f"Error setting up non-streaming response: {str(e)}", exc_info=True) - raise # Re-raise the exception to be caught by the outer handler + # Pass through the response generator without try/except to let errors propagate up + response_generator = self.response_processor.process_non_streaming_response( + llm_response=llm_response, + thread_id=thread_id, + config=processor_config, + prompt_messages=prepared_messages, + llm_model=llm_model + ) + return response_generator # Return the generator except Exception as e: logger.error(f"Error in run_thread: {str(e)}", exc_info=True) + # Return the error as a dict to be handled by the caller return { "status": "error", "message": str(e) @@ -384,37 +381,58 @@ Here are the XML tools available with examples: # Run the thread once, passing the potentially modified system prompt # Pass temp_msg only on the first iteration - response_gen = await _run_once(temporary_message if auto_continue_count == 0 else None) + try: + response_gen = await _run_once(temporary_message if auto_continue_count == 0 else None) - # Handle error responses - if isinstance(response_gen, dict) and "status" in response_gen and response_gen["status"] == "error": - yield response_gen - return + # Handle error responses + if isinstance(response_gen, dict) and "status" in response_gen and response_gen["status"] == "error": + logger.error(f"Error in auto_continue_wrapper: {response_gen.get('message', 'Unknown error')}") + yield response_gen + return # Exit the generator on error - # Process each chunk - async for chunk in response_gen: - # Check if this is a finish reason chunk with tool_calls or xml_tool_limit_reached - if chunk.get('type') == 'finish': - if chunk.get('finish_reason') == 'tool_calls': - # Only auto-continue if enabled (max > 0) - if native_max_auto_continues > 0: - logger.info(f"Detected finish_reason='tool_calls', auto-continuing ({auto_continue_count + 1}/{native_max_auto_continues})") - auto_continue = True - auto_continue_count += 1 - # Don't yield the finish chunk to avoid confusing the client - continue - elif chunk.get('finish_reason') == 'xml_tool_limit_reached': - # Don't auto-continue if XML tool limit was reached - logger.info(f"Detected finish_reason='xml_tool_limit_reached', stopping auto-continue") - auto_continue = False - # Still yield the chunk to inform the client + # Process each chunk + try: + async for chunk in response_gen: + # Check if this is a finish reason chunk with tool_calls or xml_tool_limit_reached + if chunk.get('type') == 'finish': + if chunk.get('finish_reason') == 'tool_calls': + # Only auto-continue if enabled (max > 0) + if native_max_auto_continues > 0: + logger.info(f"Detected finish_reason='tool_calls', auto-continuing ({auto_continue_count + 1}/{native_max_auto_continues})") + auto_continue = True + auto_continue_count += 1 + # Don't yield the finish chunk to avoid confusing the client + continue + elif chunk.get('finish_reason') == 'xml_tool_limit_reached': + # Don't auto-continue if XML tool limit was reached + logger.info(f"Detected finish_reason='xml_tool_limit_reached', stopping auto-continue") + auto_continue = False + # Still yield the chunk to inform the client - # Otherwise just yield the chunk normally - yield chunk + # Otherwise just yield the chunk normally + yield chunk - # If not auto-continuing, we're done - if not auto_continue: - break + # If not auto-continuing, we're done + if not auto_continue: + break + except Exception as e: + # If there's an exception, log it, yield an error status, and stop execution + logger.error(f"Error in auto_continue_wrapper generator: {str(e)}", exc_info=True) + yield { + "type": "status", + "status": "error", + "message": f"Error in thread processing: {str(e)}" + } + return # Exit the generator on any error + except Exception as outer_e: + # Catch exceptions from _run_once itself + logger.error(f"Error executing thread: {str(outer_e)}", exc_info=True) + yield { + "type": "status", + "status": "error", + "message": f"Error executing thread: {str(outer_e)}" + } + return # Exit immediately on exception from _run_once # If we've reached the max auto-continues, log a warning if auto_continue and auto_continue_count >= native_max_auto_continues: diff --git a/backend/agentpress/tool.py b/backend/agentpress/tool.py index c804602e..de7a5045 100644 --- a/backend/agentpress/tool.py +++ b/backend/agentpress/tool.py @@ -7,7 +7,7 @@ This module defines the base classes and decorators for creating tools in AgentP - Result containers for standardized tool outputs """ -from typing import Dict, Any, Union, Optional, List, Type +from typing import Dict, Any, Union, Optional, List from dataclasses import dataclass, field from abc import ABC import json diff --git a/backend/agentpress/tool_registry.py b/backend/agentpress/tool_registry.py index 238b7b33..b50438a1 100644 --- a/backend/agentpress/tool_registry.py +++ b/backend/agentpress/tool_registry.py @@ -1,5 +1,5 @@ from typing import Dict, Type, Any, List, Optional, Callable -from agentpress.tool import Tool, SchemaType, ToolSchema +from agentpress.tool import Tool, SchemaType from utils.logger import logger diff --git a/backend/api.py b/backend/api.py index bc1e08e4..7de122bc 100644 --- a/backend/api.py +++ b/backend/api.py @@ -157,5 +157,5 @@ if __name__ == "__main__": host="0.0.0.0", port=8000, workers=workers, - reload=True + # reload=True ) \ No newline at end of file diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 4bd8921a..3992795e 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -18,7 +18,6 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.11" -streamlit-quill = "0.0.3" python-dotenv = "1.0.1" litellm = "1.66.1" click = "8.1.7" @@ -45,7 +44,6 @@ python-ripgrep = "0.0.6" daytona_sdk = "^0.14.0" boto3 = "^1.34.0" openai = "^1.72.0" -streamlit = "^1.44.1" nest-asyncio = "^1.6.0" vncdotool = "^1.2.0" tavily-python = "^0.5.4" @@ -63,4 +61,4 @@ daytona-sdk = "^0.14.0" [build-system] requires = ["poetry-core"] -build-backend = "poetry.core.masonry.api" \ No newline at end of file +build-backend = "poetry.core.masonry.api" diff --git a/backend/sandbox/README.md b/backend/sandbox/README.md new file mode 100644 index 00000000..cafa4b7b --- /dev/null +++ b/backend/sandbox/README.md @@ -0,0 +1,32 @@ +# Agent Sandbox + +This directory contains the agent sandbox implementation - a Docker-based virtual environment that agents use as their own computer to execute tasks, access the web, and manipulate files. + +## Overview + +The sandbox provides a complete containerized Linux environment with: +- Chrome browser for web interactions +- VNC server for accessing the Web User +- Web server for serving content (port 8080) -> loading html files from the /workspace directory +- Full file system access +- Full sudo access + +## Customizing the Sandbox + +You can modify the sandbox environment for development or to add new capabilities: + +1. Edit files in the `docker/` directory +2. Build a custom image: + ``` + cd backend/sandbox/docker + docker-compose build + ``` +3. Test your changes locally using docker-compose + +## Using a Custom Image + +To use your custom sandbox image: + +1. Change the `image` parameter in `docker-compose.yml` (that defines the image name `kortix/suna:___`) +2. Update the same image name in `backend/sandbox/sandbox.py` in the `create_sandbox` function +3. If using Daytona for deployment, update the image reference there as well diff --git a/backend/sandbox/api.py b/backend/sandbox/api.py index b1fa3677..63fde35e 100644 --- a/backend/sandbox/api.py +++ b/backend/sandbox/api.py @@ -1,17 +1,16 @@ import os -from typing import List, Optional +from typing import Optional from fastapi import FastAPI, UploadFile, File, HTTPException, APIRouter, Form, Depends, Request -from fastapi.responses import Response, JSONResponse +from fastapi.responses import Response from pydantic import BaseModel -from utils.logger import logger -from utils.auth_utils import get_current_user_id_from_jwt, get_user_id_from_stream_auth, get_optional_user_id from sandbox.sandbox import get_or_start_sandbox +from utils.logger import logger +from utils.auth_utils import get_optional_user_id from services.supabase import DBConnection from agent.api import get_or_create_project_sandbox - # Initialize shared resources router = APIRouter(tags=["sandbox"]) db = None @@ -92,19 +91,15 @@ async def get_sandbox_by_id_safely(client, sandbox_id: str): logger.error(f"No project found for sandbox ID: {sandbox_id}") raise HTTPException(status_code=404, detail="Sandbox not found - no project owns this sandbox ID") - project_id = project_result.data[0]['project_id'] - logger.debug(f"Found project {project_id} for sandbox {sandbox_id}") + # project_id = project_result.data[0]['project_id'] + # logger.debug(f"Found project {project_id} for sandbox {sandbox_id}") try: # Get the sandbox - sandbox, retrieved_sandbox_id, sandbox_pass = await get_or_create_project_sandbox(client, project_id) - - # Verify we got the right sandbox - if retrieved_sandbox_id != sandbox_id: - logger.warning(f"Retrieved sandbox ID {retrieved_sandbox_id} doesn't match requested ID {sandbox_id} for project {project_id}") - # Fall back to the direct method if IDs don't match (shouldn't happen but just in case) - sandbox = await get_or_start_sandbox(sandbox_id) - + sandbox = await get_or_start_sandbox(sandbox_id) + # Extract just the sandbox object from the tuple (sandbox, sandbox_id, sandbox_pass) + # sandbox = sandbox_tuple[0] + return sandbox except Exception as e: logger.error(f"Error retrieving sandbox {sandbox_id}: {str(e)}") @@ -141,46 +136,6 @@ async def create_file( logger.error(f"Error creating file in sandbox {sandbox_id}: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) -# For backward compatibility, keep the JSON version too -@router.post("/sandboxes/{sandbox_id}/files/json") -async def create_file_json( - sandbox_id: str, - file_request: dict, - request: Request = None, - user_id: Optional[str] = Depends(get_optional_user_id) -): - """Create a file in the sandbox using JSON (legacy support)""" - logger.info(f"Received JSON file creation request for sandbox {sandbox_id}, user_id: {user_id}") - client = await db.client - - # Verify the user has access to this sandbox - await verify_sandbox_access(client, sandbox_id, user_id) - - try: - # Get sandbox using the safer method - sandbox = await get_sandbox_by_id_safely(client, sandbox_id) - - # Get file path and content - path = file_request.get("path") - content = file_request.get("content", "") - - if not path: - logger.error(f"Missing file path in request for sandbox {sandbox_id}") - raise HTTPException(status_code=400, detail="File path is required") - - # Convert string content to bytes - if isinstance(content, str): - content = content.encode('utf-8') - - # Create file - sandbox.fs.upload_file(path, content) - logger.info(f"File created at {path} in sandbox {sandbox_id}") - - return {"status": "success", "created": True, "path": path} - except Exception as e: - logger.error(f"Error creating file in sandbox {sandbox_id}: {str(e)}") - raise HTTPException(status_code=500, detail=str(e)) - @router.get("/sandboxes/{sandbox_id}/files") async def list_files( sandbox_id: str, @@ -256,56 +211,57 @@ async def read_file( logger.error(f"Error reading file in sandbox {sandbox_id}: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) -@router.post("/project/{project_id}/sandbox/ensure-active") -async def ensure_project_sandbox_active( - project_id: str, - request: Request = None, - user_id: Optional[str] = Depends(get_optional_user_id) -): - """ - Ensure that a project's sandbox is active and running. - Checks the sandbox status and starts it if it's not running. - """ - logger.info(f"Received ensure sandbox active request for project {project_id}, user_id: {user_id}") - client = await db.client +# Should happen on server-side fully +# @router.post("/project/{project_id}/sandbox/ensure-active") +# async def ensure_project_sandbox_active( +# project_id: str, +# request: Request = None, +# user_id: Optional[str] = Depends(get_optional_user_id) +# ): +# """ +# Ensure that a project's sandbox is active and running. +# Checks the sandbox status and starts it if it's not running. +# """ +# logger.info(f"Received ensure sandbox active request for project {project_id}, user_id: {user_id}") +# client = await db.client - # Find the project and sandbox information - project_result = await client.table('projects').select('*').eq('project_id', project_id).execute() +# # Find the project and sandbox information +# project_result = await client.table('projects').select('*').eq('project_id', project_id).execute() - if not project_result.data or len(project_result.data) == 0: - logger.error(f"Project not found: {project_id}") - raise HTTPException(status_code=404, detail="Project not found") +# if not project_result.data or len(project_result.data) == 0: +# logger.error(f"Project not found: {project_id}") +# raise HTTPException(status_code=404, detail="Project not found") - project_data = project_result.data[0] +# project_data = project_result.data[0] - # For public projects, no authentication is needed - if not project_data.get('is_public'): - # For private projects, we must have a user_id - if not user_id: - logger.error(f"Authentication required for private project {project_id}") - raise HTTPException(status_code=401, detail="Authentication required for this resource") +# # For public projects, no authentication is needed +# if not project_data.get('is_public'): +# # For private projects, we must have a user_id +# if not user_id: +# logger.error(f"Authentication required for private project {project_id}") +# raise HTTPException(status_code=401, detail="Authentication required for this resource") - account_id = project_data.get('account_id') +# account_id = project_data.get('account_id') - # Verify account membership - if account_id: - account_user_result = await client.schema('basejump').from_('account_user').select('account_role').eq('user_id', user_id).eq('account_id', account_id).execute() - if not (account_user_result.data and len(account_user_result.data) > 0): - logger.error(f"User {user_id} not authorized to access project {project_id}") - raise HTTPException(status_code=403, detail="Not authorized to access this project") +# # Verify account membership +# if account_id: +# account_user_result = await client.schema('basejump').from_('account_user').select('account_role').eq('user_id', user_id).eq('account_id', account_id).execute() +# if not (account_user_result.data and len(account_user_result.data) > 0): +# logger.error(f"User {user_id} not authorized to access project {project_id}") +# raise HTTPException(status_code=403, detail="Not authorized to access this project") - try: - # Get or create the sandbox - logger.info(f"Ensuring sandbox is active for project {project_id}") - sandbox, sandbox_id, sandbox_pass = await get_or_create_project_sandbox(client, project_id) +# try: +# # Get or create the sandbox +# logger.info(f"Ensuring sandbox is active for project {project_id}") +# sandbox, sandbox_id, sandbox_pass = await get_or_create_project_sandbox(client, project_id) - logger.info(f"Successfully ensured sandbox {sandbox_id} is active for project {project_id}") +# logger.info(f"Successfully ensured sandbox {sandbox_id} is active for project {project_id}") - return { - "status": "success", - "sandbox_id": sandbox_id, - "message": "Sandbox is active" - } - except Exception as e: - logger.error(f"Error ensuring sandbox is active for project {project_id}: {str(e)}") - raise HTTPException(status_code=500, detail=str(e)) +# return { +# "status": "success", +# "sandbox_id": sandbox_id, +# "message": "Sandbox is active" +# } +# except Exception as e: +# logger.error(f"Error ensuring sandbox is active for project {project_id}: {str(e)}") +# raise HTTPException(status_code=500, detail=str(e)) diff --git a/backend/sandbox/docker/Dockerfile b/backend/sandbox/docker/Dockerfile index 79fe5b5d..418fe524 100644 --- a/backend/sandbox/docker/Dockerfile +++ b/backend/sandbox/docker/Dockerfile @@ -125,4 +125,4 @@ COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf EXPOSE 7788 6080 5901 8000 8080 -CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"] +CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"] \ No newline at end of file diff --git a/backend/sandbox/docker/browser_api.py b/backend/sandbox/docker/browser_api.py index 579c8458..471fc6b0 100644 --- a/backend/sandbox/docker/browser_api.py +++ b/backend/sandbox/docker/browser_api.py @@ -1,11 +1,10 @@ from fastapi import FastAPI, APIRouter, HTTPException, Body -from playwright.async_api import async_playwright, Browser, Page, ElementHandle +from playwright.async_api import async_playwright, Browser, Page from pydantic import BaseModel -from typing import Optional, List, Dict, Any, Union +from typing import Optional, List, Dict, Any import asyncio import json import logging -import re import base64 from dataclasses import dataclass, field from datetime import datetime diff --git a/backend/sandbox/docker/docker-compose.yml b/backend/sandbox/docker/docker-compose.yml index 64298126..27432984 100644 --- a/backend/sandbox/docker/docker-compose.yml +++ b/backend/sandbox/docker/docker-compose.yml @@ -6,7 +6,7 @@ services: dockerfile: ${DOCKERFILE:-Dockerfile} args: TARGETPLATFORM: ${TARGETPLATFORM:-linux/amd64} - image: adamcohenhillel/kortix-suna:0.0.20 + image: kortix/suna:0.1.2 ports: - "6080:6080" # noVNC web interface - "5901:5901" # VNC port diff --git a/backend/sandbox/sandbox.py b/backend/sandbox/sandbox.py index 0dc365a5..ec432e8d 100644 --- a/backend/sandbox/sandbox.py +++ b/backend/sandbox/sandbox.py @@ -1,15 +1,8 @@ -import os -from typing import Optional - from daytona_sdk import Daytona, DaytonaConfig, CreateSandboxParams, Sandbox, SessionExecuteRequest from daytona_api_client.models.workspace_state import WorkspaceState from dotenv import load_dotenv - -from agentpress.tool import Tool from utils.logger import logger from utils.config import config -from utils.files_utils import clean_path -from agentpress.thread_manager import ThreadManager load_dotenv() @@ -98,7 +91,7 @@ def create_sandbox(password: str, project_id: str = None): labels = {'id': project_id} params = CreateSandboxParams( - image="adamcohenhillel/kortix-suna:0.0.20", + image="kortix/suna:0.1.2", public=True, labels=labels, env_vars={ @@ -131,83 +124,3 @@ def create_sandbox(password: str, project_id: str = None): logger.debug(f"Sandbox environment successfully initialized") return sandbox - -class SandboxToolsBase(Tool): - """Base class for all sandbox tools that provides project-based sandbox access.""" - - # Class variable to track if sandbox URLs have been printed - _urls_printed = False - - def __init__(self, project_id: str, thread_manager: Optional[ThreadManager] = None): - super().__init__() - self.project_id = project_id - self.thread_manager = thread_manager - self.workspace_path = "/workspace" - self._sandbox = None - self._sandbox_id = None - self._sandbox_pass = None - - async def _ensure_sandbox(self) -> Sandbox: - """Ensure we have a valid sandbox instance, retrieving it from the project if needed.""" - if self._sandbox is None: - try: - # Get database client - client = await self.thread_manager.db.client - - # Get project data - project = await client.table('projects').select('*').eq('project_id', self.project_id).execute() - if not project.data or len(project.data) == 0: - raise ValueError(f"Project {self.project_id} not found") - - project_data = project.data[0] - sandbox_info = project_data.get('sandbox', {}) - - if not sandbox_info.get('id'): - raise ValueError(f"No sandbox found for project {self.project_id}") - - # Store sandbox info - self._sandbox_id = sandbox_info['id'] - self._sandbox_pass = sandbox_info.get('pass') - - # Get or start the sandbox - self._sandbox = await get_or_start_sandbox(self._sandbox_id) - - # # Log URLs if not already printed - # if not SandboxToolsBase._urls_printed: - # vnc_link = self._sandbox.get_preview_link(6080) - # website_link = self._sandbox.get_preview_link(8080) - - # vnc_url = vnc_link.url if hasattr(vnc_link, 'url') else str(vnc_link) - # website_url = website_link.url if hasattr(website_link, 'url') else str(website_link) - - # print("\033[95m***") - # print(f"VNC URL: {vnc_url}") - # print(f"Website URL: {website_url}") - # print("***\033[0m") - # SandboxToolsBase._urls_printed = True - - except Exception as e: - logger.error(f"Error retrieving sandbox for project {self.project_id}: {str(e)}", exc_info=True) - raise e - - return self._sandbox - - @property - def sandbox(self) -> Sandbox: - """Get the sandbox instance, ensuring it exists.""" - if self._sandbox is None: - raise RuntimeError("Sandbox not initialized. Call _ensure_sandbox() first.") - return self._sandbox - - @property - def sandbox_id(self) -> str: - """Get the sandbox ID, ensuring it exists.""" - if self._sandbox_id is None: - raise RuntimeError("Sandbox ID not initialized. Call _ensure_sandbox() first.") - return self._sandbox_id - - def clean_path(self, path: str) -> str: - """Clean and normalize a path to be relative to /workspace.""" - cleaned_path = clean_path(path, self.workspace_path) - logger.debug(f"Cleaned path: {path} -> {cleaned_path}") - return cleaned_path \ No newline at end of file diff --git a/backend/sandbox/tool_base.py b/backend/sandbox/tool_base.py new file mode 100644 index 00000000..4e8359a9 --- /dev/null +++ b/backend/sandbox/tool_base.py @@ -0,0 +1,90 @@ + +from typing import Optional + +from agentpress.thread_manager import ThreadManager +from agentpress.tool import Tool +from daytona_sdk import Sandbox +from sandbox.sandbox import get_or_start_sandbox +from utils import logger +from utils.files_utils import clean_path + + +class SandboxToolsBase(Tool): + """Base class for all sandbox tools that provides project-based sandbox access.""" + + # Class variable to track if sandbox URLs have been printed + _urls_printed = False + + def __init__(self, project_id: str, thread_manager: Optional[ThreadManager] = None): + super().__init__() + self.project_id = project_id + self.thread_manager = thread_manager + self.workspace_path = "/workspace" + self._sandbox = None + self._sandbox_id = None + self._sandbox_pass = None + + async def _ensure_sandbox(self) -> Sandbox: + """Ensure we have a valid sandbox instance, retrieving it from the project if needed.""" + if self._sandbox is None: + try: + # Get database client + client = await self.thread_manager.db.client + + # Get project data + project = await client.table('projects').select('*').eq('project_id', self.project_id).execute() + if not project.data or len(project.data) == 0: + raise ValueError(f"Project {self.project_id} not found") + + project_data = project.data[0] + sandbox_info = project_data.get('sandbox', {}) + + if not sandbox_info.get('id'): + raise ValueError(f"No sandbox found for project {self.project_id}") + + # Store sandbox info + self._sandbox_id = sandbox_info['id'] + self._sandbox_pass = sandbox_info.get('pass') + + # Get or start the sandbox + self._sandbox = await get_or_start_sandbox(self._sandbox_id) + + # # Log URLs if not already printed + # if not SandboxToolsBase._urls_printed: + # vnc_link = self._sandbox.get_preview_link(6080) + # website_link = self._sandbox.get_preview_link(8080) + + # vnc_url = vnc_link.url if hasattr(vnc_link, 'url') else str(vnc_link) + # website_url = website_link.url if hasattr(website_link, 'url') else str(website_link) + + # print("\033[95m***") + # print(f"VNC URL: {vnc_url}") + # print(f"Website URL: {website_url}") + # print("***\033[0m") + # SandboxToolsBase._urls_printed = True + + except Exception as e: + logger.error(f"Error retrieving sandbox for project {self.project_id}: {str(e)}", exc_info=True) + raise e + + return self._sandbox + + @property + def sandbox(self) -> Sandbox: + """Get the sandbox instance, ensuring it exists.""" + if self._sandbox is None: + raise RuntimeError("Sandbox not initialized. Call _ensure_sandbox() first.") + return self._sandbox + + @property + def sandbox_id(self) -> str: + """Get the sandbox ID, ensuring it exists.""" + if self._sandbox_id is None: + raise RuntimeError("Sandbox ID not initialized. Call _ensure_sandbox() first.") + return self._sandbox_id + + def clean_path(self, path: str) -> str: + """Clean and normalize a path to be relative to /workspace.""" + cleaned_path = clean_path(path, self.workspace_path) + logger.debug(f"Cleaned path: {path} -> {cleaned_path}") + return cleaned_path \ No newline at end of file diff --git a/backend/services/billing.py b/backend/services/billing.py index 109237dc..b2438bf4 100644 --- a/backend/services/billing.py +++ b/backend/services/billing.py @@ -5,14 +5,14 @@ stripe listen --forward-to localhost:8000/api/billing/webhook """ from fastapi import APIRouter, HTTPException, Depends, Request -from typing import Optional, Dict, Any, List, Tuple +from typing import Optional, Dict, Tuple import stripe from datetime import datetime, timezone from utils.logger import logger from utils.config import config, EnvMode from services.supabase import DBConnection from utils.auth_utils import get_current_user_id_from_jwt -from pydantic import BaseModel, Field +from pydantic import BaseModel # Initialize Stripe stripe.api_key = config.STRIPE_SECRET_KEY diff --git a/backend/services/llm.py b/backend/services/llm.py index b7a83251..06d92d7e 100644 --- a/backend/services/llm.py +++ b/backend/services/llm.py @@ -18,16 +18,14 @@ from openai import OpenAIError import litellm from utils.logger import logger from utils.config import config -from datetime import datetime -import traceback # litellm.set_verbose=True litellm.modify_params=True # Constants -MAX_RETRIES = 3 +MAX_RETRIES = 2 RATE_LIMIT_DELAY = 30 -RETRY_DELAY = 5 +RETRY_DELAY = 0.1 class LLMError(Exception): """Base exception for LLM-related errors.""" diff --git a/backend/services/supabase.py b/backend/services/supabase.py index e2930075..0bb1419a 100644 --- a/backend/services/supabase.py +++ b/backend/services/supabase.py @@ -2,7 +2,6 @@ Centralized database connection management for AgentPress using Supabase. """ -import os from typing import Optional from supabase import create_async_client, AsyncClient from utils.logger import logger diff --git a/backend/utils/auth_utils.py b/backend/utils/auth_utils.py index f051815b..e2a090b8 100644 --- a/backend/utils/auth_utils.py +++ b/backend/utils/auth_utils.py @@ -1,8 +1,7 @@ -from fastapi import HTTPException, Request, Depends -from typing import Optional, List, Dict, Any +from fastapi import HTTPException, Request +from typing import Optional import jwt from jwt.exceptions import PyJWTError -from utils.logger import logger # This function extracts the user ID from Supabase JWT async def get_current_user_id_from_jwt(request: Request) -> str: diff --git a/backend/utils/logger.py b/backend/utils/logger.py index db7dca53..6f107ee6 100644 --- a/backend/utils/logger.py +++ b/backend/utils/logger.py @@ -13,7 +13,6 @@ import json import sys import os from datetime import datetime -from typing import Any, Dict, Optional from contextvars import ContextVar from functools import wraps import traceback diff --git a/backend/utils/scripts/archive_old_sandboxes.py b/backend/utils/scripts/archive_old_sandboxes.py index 00573943..c0ebce9d 100644 --- a/backend/utils/scripts/archive_old_sandboxes.py +++ b/backend/utils/scripts/archive_old_sandboxes.py @@ -16,6 +16,8 @@ Make sure your environment variables are properly set: - DAYTONA_SERVER_URL """ +# TODO: SAVE THE LATEST SANDBOX STATE SOMEWHERE OR LIKE MASS CHECK THE STATE BEFORE STARTING TO ARCHIVE - AS ITS GOING TO GO OVER A BUNCH THAT ARE ALREADY ARCHIVED – MAYBE BEST TO GET ALL FROM DAYTONA AND THEN RUN THE ARCHIVE ONLY ON THE ONES THAT MEET THE CRITERIA (STOPPED STATE) + import asyncio import sys import os @@ -81,7 +83,7 @@ async def get_old_projects(days_threshold: int = 1) -> List[Dict[str, Any]]: 'created_at', 'account_id', 'sandbox' - ).range(start_range, end_range).execute() + ).order('created_at', desc=True).range(start_range, end_range).execute() # Debug info - print raw response print(f"Response data length: {len(result.data)}") diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 1be86af1..8f21ab22 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -23,6 +23,7 @@ "@radix-ui/react-label": "^2.1.4", "@radix-ui/react-navigation-menu": "^1.2.5", "@radix-ui/react-popover": "^1.1.7", + "@radix-ui/react-progress": "^1.1.6", "@radix-ui/react-radio-group": "^1.3.3", "@radix-ui/react-scroll-area": "^1.2.4", "@radix-ui/react-select": "^2.1.7", @@ -3321,6 +3322,68 @@ } } }, + "node_modules/@radix-ui/react-progress": { + "version": "1.1.6", + "resolved": "https://registry.npmjs.org/@radix-ui/react-progress/-/react-progress-1.1.6.tgz", + "integrity": "sha512-QzN9a36nKk2eZKMf9EBCia35x3TT+SOgZuzQBVIHyRrmYYi73VYBRK3zKwdJ6az/F5IZ6QlacGJBg7zfB85liA==", + "dependencies": { + "@radix-ui/react-context": "1.1.2", + "@radix-ui/react-primitive": "2.1.2" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-progress/node_modules/@radix-ui/react-primitive": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-2.1.2.tgz", + "integrity": "sha512-uHa+l/lKfxuDD2zjN/0peM/RhhSmRjr5YWdk/37EnSv1nJ88uvG85DPexSm8HdFQROd2VdERJ6ynXbkCFi+APw==", + "dependencies": { + "@radix-ui/react-slot": "1.2.2" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-progress/node_modules/@radix-ui/react-slot": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.2.2.tgz", + "integrity": "sha512-y7TBO4xN4Y94FvcWIOIh18fM4R1A8S4q1jhoz4PNzOoHsFcN8pogcFmZrTYAm4F9VRUrWP/Mw7xSKybIeRI+CQ==", + "dependencies": { + "@radix-ui/react-compose-refs": "1.1.2" + }, + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-radio-group": { "version": "1.3.3", "resolved": "https://registry.npmjs.org/@radix-ui/react-radio-group/-/react-radio-group-1.3.3.tgz", diff --git a/frontend/package.json b/frontend/package.json index 50f4651d..39ff87c2 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -26,6 +26,7 @@ "@radix-ui/react-label": "^2.1.4", "@radix-ui/react-navigation-menu": "^1.2.5", "@radix-ui/react-popover": "^1.1.7", + "@radix-ui/react-progress": "^1.1.6", "@radix-ui/react-radio-group": "^1.3.3", "@radix-ui/react-scroll-area": "^1.2.4", "@radix-ui/react-select": "^2.1.7", diff --git a/frontend/src/app/(dashboard)/agents/[threadId]/page.tsx b/frontend/src/app/(dashboard)/agents/[threadId]/page.tsx index beddc22f..0ed147e8 100644 --- a/frontend/src/app/(dashboard)/agents/[threadId]/page.tsx +++ b/frontend/src/app/(dashboard)/agents/[threadId]/page.tsx @@ -8,7 +8,7 @@ import React, { useState, } from 'react'; import Image from 'next/image'; -import { useRouter } from 'next/navigation'; +import { useRouter, useSearchParams } from 'next/navigation'; import { ArrowDown, CheckCircle, @@ -87,6 +87,7 @@ export default function ThreadPage({ const unwrappedParams = React.use(params); const threadId = unwrappedParams.threadId; const isMobile = useIsMobile(); + const searchParams = useSearchParams(); const router = useRouter(); const [messages, setMessages] = useState([]); @@ -132,6 +133,9 @@ export default function ThreadPage({ const agentRunsCheckedRef = useRef(false); const previousAgentStatus = useRef('idle'); + // Add debug mode state - check for debug=true in URL + const [debugMode, setDebugMode] = useState(false); + const handleProjectRenamed = useCallback((newName: string) => { setProjectName(newName); }, []); @@ -1040,6 +1044,12 @@ export default function ThreadPage({ isLoading, ]); + // Check for debug mode in URL on initial load and when URL changes + useEffect(() => { + const debugParam = searchParams.get('debug'); + setDebugMode(debugParam === 'true'); + }, [searchParams]); + // Main rendering function for the thread page if (!initialLoadCompleted.current || isLoading) { // Use the new ThreadSkeleton component instead of inline skeleton @@ -1058,6 +1068,7 @@ export default function ThreadPage({ onViewFiles={handleOpenFileViewer} onToggleSidePanel={toggleSidePanel} isMobileView={isMobile} + debugMode={debugMode} />
@@ -1122,6 +1133,12 @@ export default function ThreadPage({ } else { return (
+ {/* Render debug mode indicator when active */} + {debugMode && ( +
+ Debug Mode +
+ )}
@@ -1133,9 +1150,10 @@ export default function ThreadPage({ onToggleSidePanel={toggleSidePanel} onProjectRenamed={handleProjectRenamed} isMobileView={isMobile} + debugMode={debugMode} /> - {/* Use ThreadContent component instead of custom message rendering */} + {/* Pass debugMode to ThreadContent component */}
); } -} +} \ No newline at end of file diff --git a/frontend/src/app/providers.tsx b/frontend/src/app/providers.tsx index 5ad9285e..b88f159a 100644 --- a/frontend/src/app/providers.tsx +++ b/frontend/src/app/providers.tsx @@ -4,8 +4,7 @@ import { ThemeProvider } from 'next-themes'; import { useState, createContext, useEffect } from 'react'; import { AuthProvider } from '@/components/AuthProvider'; import { ReactQueryProvider } from '@/providers/react-query-provider'; -import { dehydrate, QueryClient, QueryClientProvider } from '@tanstack/react-query'; -import { initializeCacheSystem } from '@/lib/cache-init'; +import { dehydrate, QueryClient } from '@tanstack/react-query'; export interface ParsedTag { tagName: string; @@ -38,35 +37,16 @@ export const ToolCallsContext = createContext<{ export function Providers({ children }: { children: React.ReactNode }) { // Shared state for tool calls across the app const [toolCalls, setToolCalls] = useState([]); - const queryClient = new QueryClient({ - defaultOptions: { - queries: { - refetchOnWindowFocus: false, - }, - }, - }); + const queryClient = new QueryClient(); const dehydratedState = dehydrate(queryClient); - // Initialize the file caching system when the app starts - useEffect(() => { - // Start the cache maintenance system - const { stopCacheSystem } = initializeCacheSystem(); - - // Clean up when the component unmounts - return () => { - stopCacheSystem(); - }; - }, []); - return ( - - - {children} - - + + {children} + diff --git a/frontend/src/app/share/[threadId]/page.tsx b/frontend/src/app/share/[threadId]/page.tsx index 4d0f7a5d..cec47582 100644 --- a/frontend/src/app/share/[threadId]/page.tsx +++ b/frontend/src/app/share/[threadId]/page.tsx @@ -186,10 +186,7 @@ export default function ThreadPage({ const handleStreamError = useCallback((errorMessage: string) => { console.error(`[PAGE] Stream hook error: ${errorMessage}`); - if (!errorMessage.toLowerCase().includes('not found') && - !errorMessage.toLowerCase().includes('agent run is not running')) { - toast.error(`Stream Error: ${errorMessage}`); - } + toast.error(errorMessage, { duration: 15000 }); }, []); const handleStreamClose = useCallback(() => { diff --git a/frontend/src/components/billing/usage-limit-alert.tsx b/frontend/src/components/billing/usage-limit-alert.tsx index 78d88670..e831121e 100644 --- a/frontend/src/components/billing/usage-limit-alert.tsx +++ b/frontend/src/components/billing/usage-limit-alert.tsx @@ -1,8 +1,9 @@ 'use client'; -import { AlertTriangle } from 'lucide-react'; +import { AlertTriangle, X } from 'lucide-react'; import { Button } from '@/components/ui/button'; import { useRouter } from 'next/navigation'; +import { cn } from '@/lib/utils'; interface BillingErrorAlertProps { message?: string; @@ -26,17 +27,28 @@ export function BillingErrorAlert({ if (!isOpen) return null; return ( -
-
-
-
+
+
+
+
-

- Usage Limit Reached -

+
+

+ Usage Limit Reached +

+ +

{message}

+
diff --git a/frontend/src/components/thread/chat-input.tsx b/frontend/src/components/thread/chat-input.tsx deleted file mode 100644 index cb4c021d..00000000 --- a/frontend/src/components/thread/chat-input.tsx +++ /dev/null @@ -1,619 +0,0 @@ -'use client'; - -import React, { - useState, - useRef, - useEffect, - forwardRef, - useImperativeHandle, -} from 'react'; -import { Textarea } from '@/components/ui/textarea'; -import { Button } from '@/components/ui/button'; -import { - Send, - Square, - Loader2, - X, - Paperclip, - Settings, - ChevronDown, - AlertTriangle, - Info, - ArrowUp, -} from 'lucide-react'; -import { createClient } from '@/lib/supabase/client'; -import { toast } from 'sonner'; -import { AnimatePresence, motion } from 'framer-motion'; -import { - Tooltip, - TooltipContent, - TooltipProvider, - TooltipTrigger, -} from '@/components/ui/tooltip'; -import { - Dialog, - DialogContent, - DialogHeader, - DialogTitle, - DialogTrigger, - DialogFooter, - DialogDescription, -} from '@/components/ui/dialog'; -import { - Select, - SelectContent, - SelectItem, - SelectTrigger, - SelectValue, -} from '@/components/ui/select'; -import { Badge } from '@/components/ui/badge'; -import { Label } from '@/components/ui/label'; -import { RadioGroup, RadioGroupItem } from '@/components/ui/radio-group'; -import { Card, CardContent } from '@/components/ui/card'; -import { cn } from '@/lib/utils'; -import { FileAttachment } from './file-attachment'; -import { AttachmentGroup } from './attachment-group'; - -// Define API_URL -const API_URL = process.env.NEXT_PUBLIC_BACKEND_URL || ''; - -// Local storage keys -const STORAGE_KEY_MODEL = 'suna-preferred-model'; -const DEFAULT_MODEL_ID = 'deepseek'; // Define default model ID - -interface ChatInputProps { - onSubmit: ( - message: string, - options?: { model_name?: string; enable_thinking?: boolean }, - ) => void; - placeholder?: string; - loading?: boolean; - disabled?: boolean; - isAgentRunning?: boolean; - onStopAgent?: () => void; - autoFocus?: boolean; - value?: string; - onChange?: (value: string) => void; - onFileBrowse?: () => void; - sandboxId?: string; - hideAttachments?: boolean; -} - -interface UploadedFile { - name: string; - path: string; - size: number; - localUrl?: string; -} - -// Define interface for the ref -export interface ChatInputHandles { - getPendingFiles: () => File[]; - clearPendingFiles: () => void; -} - -export const ChatInput = forwardRef( - ( - { - onSubmit, - placeholder = 'Describe what you need help with...', - loading = false, - disabled = false, - isAgentRunning = false, - onStopAgent, - autoFocus = true, - value: controlledValue, - onChange: controlledOnChange, - onFileBrowse, - sandboxId, - hideAttachments = false, - }, - ref, - ) => { - const isControlled = - controlledValue !== undefined && controlledOnChange !== undefined; - - const [uncontrolledValue, setUncontrolledValue] = useState(''); - const value = isControlled ? controlledValue : uncontrolledValue; - - // Define model options array earlier so it can be used in useEffect - const modelOptions = [ - { id: 'sonnet-3.7', label: 'Sonnet 3.7' }, - { id: 'sonnet-3.7-thinking', label: 'Sonnet 3.7 (Thinking)' }, - { id: 'gpt-4.1', label: 'GPT-4.1' }, - { id: 'gemini-flash-2.5', label: 'Gemini Flash 2.5' }, - ]; - - // Initialize state with the default model - const [selectedModel, setSelectedModel] = useState(DEFAULT_MODEL_ID); - const [showModelDialog, setShowModelDialog] = useState(false); - const textareaRef = useRef(null); - const fileInputRef = useRef(null); - const [uploadedFiles, setUploadedFiles] = useState([]); - const [pendingFiles, setPendingFiles] = useState([]); - const [isUploading, setIsUploading] = useState(false); - const [isDraggingOver, setIsDraggingOver] = useState(false); - - // Expose methods through the ref - useImperativeHandle(ref, () => ({ - getPendingFiles: () => pendingFiles, - clearPendingFiles: () => setPendingFiles([]), - })); - - useEffect(() => { - if (typeof window !== 'undefined') { - try { - const savedModel = localStorage.getItem(STORAGE_KEY_MODEL); - // Check if the saved model exists and is one of the valid options - if ( - savedModel && - modelOptions.some((option) => option.id === savedModel) - ) { - setSelectedModel(savedModel); - } else if (savedModel) { - // If invalid model found in storage, clear it - localStorage.removeItem(STORAGE_KEY_MODEL); - console.log( - `Removed invalid model '${savedModel}' from localStorage. Using default: ${DEFAULT_MODEL_ID}`, - ); - } - } catch (error) { - console.warn('Failed to load preferences from localStorage:', error); - } - } - }, []); - - useEffect(() => { - if (autoFocus && textareaRef.current) { - textareaRef.current.focus(); - } - }, [autoFocus]); - - useEffect(() => { - const textarea = textareaRef.current; - if (!textarea) return; - - const adjustHeight = () => { - textarea.style.height = 'auto'; - const newHeight = Math.min(Math.max(textarea.scrollHeight, 24), 200); - textarea.style.height = `${newHeight}px`; - }; - - adjustHeight(); - - adjustHeight(); - - window.addEventListener('resize', adjustHeight); - return () => window.removeEventListener('resize', adjustHeight); - }, [value]); - - const handleModelChange = (value: string) => { - setSelectedModel(value); - - // Save to localStorage - try { - localStorage.setItem(STORAGE_KEY_MODEL, value); - } catch (error) { - console.warn('Failed to save model preference to localStorage:', error); - } - }; - - const handleSubmit = async (e: React.FormEvent) => { - e.preventDefault(); - if ( - (!value.trim() && uploadedFiles.length === 0) || - loading || - (disabled && !isAgentRunning) - ) - return; - - if (isAgentRunning && onStopAgent) { - onStopAgent(); - return; - } - - let message = value; - - if (uploadedFiles.length > 0) { - const fileInfo = uploadedFiles - .map((file) => `[Uploaded File: ${file.path}]`) - .join('\n'); - message = message ? `${message}\n\n${fileInfo}` : fileInfo; - } - - let baseModelName = selectedModel; - let thinkingEnabled = false; - if (selectedModel.endsWith('-thinking')) { - baseModelName = selectedModel.replace(/-thinking$/, ''); - thinkingEnabled = true; - } - - onSubmit(message, { - model_name: baseModelName, - enable_thinking: thinkingEnabled, - }); - - if (!isControlled) { - setUncontrolledValue(''); - } - - setUploadedFiles([]); - }; - - const handleChange = (e: React.ChangeEvent) => { - const newValue = e.target.value; - if (isControlled) { - controlledOnChange(newValue); - } else { - setUncontrolledValue(newValue); - } - }; - - const handleKeyDown = (e: React.KeyboardEvent) => { - if (e.key === 'Enter' && !e.shiftKey) { - e.preventDefault(); - if ( - (value.trim() || uploadedFiles.length > 0) && - !loading && - (!disabled || isAgentRunning) - ) { - handleSubmit(e as React.FormEvent); - } - } - }; - - const handleFileUpload = () => { - if (fileInputRef.current) { - fileInputRef.current.click(); - } - }; - - const handleDragOver = (e: React.DragEvent) => { - e.preventDefault(); - e.stopPropagation(); - setIsDraggingOver(true); - }; - - const handleDragLeave = (e: React.DragEvent) => { - e.preventDefault(); - e.stopPropagation(); - setIsDraggingOver(false); - }; - - const handleDrop = async (e: React.DragEvent) => { - e.preventDefault(); - e.stopPropagation(); - setIsDraggingOver(false); - - if (!e.dataTransfer.files || e.dataTransfer.files.length === 0) return; - - const files = Array.from(e.dataTransfer.files); - - if (sandboxId) { - // If we have a sandboxId, upload files directly - await uploadFiles(files); - } else { - // Otherwise, store files locally - handleLocalFiles(files); - } - }; - - const processFileUpload = async ( - event: React.ChangeEvent, - ) => { - if (!event.target.files || event.target.files.length === 0) return; - - const files = Array.from(event.target.files); - - if (sandboxId) { - // If we have a sandboxId, upload files directly - await uploadFiles(files); - } else { - // Otherwise, store files locally - handleLocalFiles(files); - } - - event.target.value = ''; - }; - - // New function to handle files locally when there's no sandboxId - const handleLocalFiles = (files: File[]) => { - const filteredFiles = files.filter(file => { - if (file.size > 50 * 1024 * 1024) { - toast.error(`File size exceeds 50MB limit: ${file.name}`); - return false; - } - return true; - }); - - // Store the files in pendingFiles state - setPendingFiles(prevFiles => [...prevFiles, ...filteredFiles]); - - // Create object URLs for the files and add to uploadedFiles for UI display - const newUploadedFiles: UploadedFile[] = filteredFiles.map(file => ({ - name: file.name, - path: `/workspace/${file.name}`, // This is just for display purposes - size: file.size, - localUrl: URL.createObjectURL(file) // Add local preview URL - })); - - setUploadedFiles(prev => [...prev, ...newUploadedFiles]); - filteredFiles.forEach(file => { - toast.success(`File attached: ${file.name}`); - }); - }; - - // Clean up object URLs when component unmounts or files are removed - useEffect(() => { - return () => { - // Clean up any object URLs to avoid memory leaks - uploadedFiles.forEach(file => { - if (file.localUrl) { - URL.revokeObjectURL(file.localUrl); - } - }); - }; - }, []); - - // // Add a function to clean up URL when removing a file - // const removeUploadedFile = (index: number) => { - // const file = uploadedFiles[index]; - // if (file?.localUrl) { - // URL.revokeObjectURL(file.localUrl); - // } - // setUploadedFiles(prev => prev.filter((_, i) => i !== index)); - // }; - - const uploadFiles = async (files: File[]) => { - try { - setIsUploading(true); - - const newUploadedFiles: UploadedFile[] = []; - - for (const file of files) { - if (file.size > 50 * 1024 * 1024) { - toast.error(`File size exceeds 50MB limit: ${file.name}`); - continue; - } - - const formData = new FormData(); - formData.append('file', file); - - const uploadPath = `/workspace/${file.name}`; - formData.append('path', uploadPath); - - const supabase = createClient(); - const { - data: { session }, - } = await supabase.auth.getSession(); - - if (!session?.access_token) { - throw new Error('No access token available'); - } - - const response = await fetch( - `${API_URL}/sandboxes/${sandboxId}/files`, - { - method: 'POST', - headers: { - Authorization: `Bearer ${session.access_token}`, - }, - body: formData, - }, - ); - - if (!response.ok) { - throw new Error(`Upload failed: ${response.statusText}`); - } - - newUploadedFiles.push({ - name: file.name, - path: uploadPath, - size: file.size, - }); - - toast.success(`File uploaded: ${file.name}`); - } - - setUploadedFiles((prev) => [...prev, ...newUploadedFiles]); - } catch (error) { - console.error('File upload failed:', error); - toast.error( - typeof error === 'string' - ? error - : error instanceof Error - ? error.message - : 'Failed to upload file', - ); - } finally { - setIsUploading(false); - } - }; - - const formatFileSize = (bytes: number): string => { - if (bytes < 1024) return `${bytes} B`; - if (bytes < 1024 * 1024) return `${(bytes / 1024).toFixed(1)} KB`; - return `${(bytes / (1024 * 1024)).toFixed(1)} MB`; - }; - - const removeUploadedFile = (index: number) => { - setUploadedFiles((prev) => prev.filter((_, i) => i !== index)); - // Also remove from pendingFiles if needed - if (!sandboxId && pendingFiles.length > index) { - setPendingFiles((prev) => prev.filter((_, i) => i !== index)); - } - }; - - return ( -
- -
- - { - removeUploadedFile(index); - // Also remove from pendingFiles if needed - if (!sandboxId && pendingFiles.length > index) { - setPendingFiles(prev => prev.filter((_, i) => i !== index)); - } - }} - layout="inline" - maxHeight="216px" - showPreviews={true} - /> - -
-