From f3746d9dbf2b954fe3dc5c08ceeac8643613b4ea Mon Sep 17 00:00:00 2001 From: LE Quoc Dat Date: Mon, 28 Apr 2025 01:28:21 +0100 Subject: [PATCH] support images --- backend/agent/prompt.py | 10 +- backend/agent/run.py | 82 ++++++++++++----- backend/agent/tools/sb_vision_tool.py | 128 ++++++++++++++++++++++++++ 3 files changed, 197 insertions(+), 23 deletions(-) create mode 100644 backend/agent/tools/sb_vision_tool.py diff --git a/backend/agent/prompt.py b/backend/agent/prompt.py index a0aabb35..3abe07dc 100644 --- a/backend/agent/prompt.py +++ b/backend/agent/prompt.py @@ -74,7 +74,15 @@ You have the ability to execute operations using both Python and CLI tools: * YOU CAN DO ANYTHING ON THE BROWSER - including clicking on elements, filling forms, submitting data, etc. * The browser is in a sandboxed environment, so nothing to worry about. -### 2.2.6 DATA PROVIDERS +### 2.2.6 VISUAL INPUT +- You MUST use the 'see-image' tool to see image files. There is NO other way to access visual information. + * Provide the relative path to the image in the `/workspace` directory. + * Example: `` + * ALWAYS use this tool when visual information from a file is necessary for your task. + * Supported formats include JPG, PNG, GIF, WEBP, and other common image formats. + * Maximum file size limit is 10 MB. + +### 2.2.7 DATA PROVIDERS - You have access to a variety of data providers that you can use to get data for your tasks. - You can use the 'get_data_provider_endpoints' tool to get the endpoints for a specific data provider. - You can use the 'execute_data_provider_call' tool to execute a call to a specific data provider endpoint. diff --git a/backend/agent/run.py b/backend/agent/run.py index e4ec43d2..7e1dc6ce 100644 --- a/backend/agent/run.py +++ b/backend/agent/run.py @@ -22,6 +22,7 @@ from agent.prompt import get_system_prompt from utils import logger from utils.auth_utils import get_account_id_from_thread from services.billing import check_billing_status +from agent.tools.sb_vision_tool import SandboxVisionTool load_dotenv() @@ -66,8 +67,8 @@ async def run_agent( thread_manager.add_tool(SandboxDeployTool, project_id=project_id, thread_manager=thread_manager) thread_manager.add_tool(SandboxExposeTool, project_id=project_id, thread_manager=thread_manager) thread_manager.add_tool(MessageTool) # we are just doing this via prompt as there is no need to call it as a tool - thread_manager.add_tool(WebSearchTool) + thread_manager.add_tool(SandboxVisionTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager) # Add data providers tool if RapidAPI key is available if config.RAPID_API_KEY: @@ -102,37 +103,74 @@ async def run_agent( continue_execution = False break - # Get the latest message from messages table that its type is browser_state - latest_browser_state = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'browser_state').order('created_at', desc=True).limit(1).execute() + # ---- Temporary Message Handling (Browser State & Image Context) ---- temporary_message = None - if latest_browser_state.data and len(latest_browser_state.data) > 0: + temp_message_content_list = [] # List to hold text/image blocks + + # Get the latest browser_state message + latest_browser_state_msg = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'browser_state').order('created_at', desc=True).limit(1).execute() + if latest_browser_state_msg.data and len(latest_browser_state_msg.data) > 0: try: - content = json.loads(latest_browser_state.data[0]["content"]) - screenshot_base64 = content["screenshot_base64"] + browser_content = json.loads(latest_browser_state_msg.data[0]["content"]) + screenshot_base64 = browser_content.get("screenshot_base64") # Create a copy of the browser state without screenshot - browser_state = content.copy() - browser_state.pop('screenshot_base64', None) - browser_state.pop('screenshot_url', None) - browser_state.pop('screenshot_url_base64', None) - temporary_message = { "role": "user", "content": [] } - if browser_state: - temporary_message["content"].append({ + browser_state_text = browser_content.copy() + browser_state_text.pop('screenshot_base64', None) + browser_state_text.pop('screenshot_url', None) + browser_state_text.pop('screenshot_url_base64', None) + + if browser_state_text: + temp_message_content_list.append({ "type": "text", - "text": f"The following is the current state of the browser:\n{browser_state}" + "text": f"The following is the current state of the browser:\n{json.dumps(browser_state_text, indent=2)}" }) if screenshot_base64: - temporary_message["content"].append({ + temp_message_content_list.append({ "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{screenshot_base64}", - } + "image_url": { + "url": f"data:image/jpeg;base64,{screenshot_base64}", + } }) else: - print("@@@@@ THIS TIME NO SCREENSHOT!!") + logger.warning("Browser state found but no screenshot base64 data.") + + await client.table('messages').delete().eq('message_id', latest_browser_state_msg.data[0]["message_id"]).execute() except Exception as e: - print(f"Error parsing browser state: {e}") - # print(latest_browser_state.data[0]) - + logger.error(f"Error parsing browser state: {e}") + + # Get the latest image_context message (NEW) + latest_image_context_msg = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'image_context').order('created_at', desc=True).limit(1).execute() + if latest_image_context_msg.data and len(latest_image_context_msg.data) > 0: + try: + image_context_content = json.loads(latest_image_context_msg.data[0]["content"]) + base64_image = image_context_content.get("base64") + mime_type = image_context_content.get("mime_type") + file_path = image_context_content.get("file_path", "unknown file") + + if base64_image and mime_type: + temp_message_content_list.append({ + "type": "text", + "text": f"Here is the image you requested to see: '{file_path}'" + }) + temp_message_content_list.append({ + "type": "image_url", + "image_url": { + "url": f"data:{mime_type};base64,{base64_image}", + } + }) + else: + logger.warning(f"Image context found for '{file_path}' but missing base64 or mime_type.") + + await client.table('messages').delete().eq('message_id', latest_image_context_msg.data[0]["message_id"]).execute() + except Exception as e: + logger.error(f"Error parsing image context: {e}") + + # If we have any content, construct the temporary_message + if temp_message_content_list: + temporary_message = {"role": "user", "content": temp_message_content_list} + # logger.debug(f"Constructed temporary message with {len(temp_message_content_list)} content blocks.") + # ---- End Temporary Message Handling ---- + max_tokens = 64000 if "sonnet" in model_name.lower() else None response = await thread_manager.run_thread( diff --git a/backend/agent/tools/sb_vision_tool.py b/backend/agent/tools/sb_vision_tool.py new file mode 100644 index 00000000..a1e0abad --- /dev/null +++ b/backend/agent/tools/sb_vision_tool.py @@ -0,0 +1,128 @@ +import os +import base64 +import mimetypes +from typing import Optional + +from agentpress.tool import ToolResult, openapi_schema, xml_schema +from sandbox.sandbox import SandboxToolsBase, Sandbox +from agentpress.thread_manager import ThreadManager +from utils.logger import logger +import json + +# Add common image MIME types if mimetypes module is limited +mimetypes.add_type("image/webp", ".webp") +mimetypes.add_type("image/jpeg", ".jpg") +mimetypes.add_type("image/jpeg", ".jpeg") +mimetypes.add_type("image/png", ".png") +mimetypes.add_type("image/gif", ".gif") + +# Maximum file size in bytes (e.g., 5MB) +MAX_IMAGE_SIZE = 10 * 1024 * 1024 + +class SandboxVisionTool(SandboxToolsBase): + """Tool for allowing the agent to 'see' images within the sandbox.""" + + def __init__(self, project_id: str, thread_id: str, thread_manager: ThreadManager): + super().__init__(project_id, thread_manager) + self.thread_id = thread_id + # Make thread_manager accessible within the tool instance + self.thread_manager = thread_manager + + @openapi_schema({ + "type": "function", + "function": { + "name": "see_image", + "description": "Allows the agent to 'see' an image file located in the /workspace directory. Provide the relative path to the image. The image content will be made available in the next turn's context.", + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "The relative path to the image file within the /workspace directory (e.g., 'screenshots/image.png'). Supported formats: JPG, PNG, GIF, WEBP. Max size: 5MB." + } + }, + "required": ["file_path"] + } + } + }) + @xml_schema( + tag_name="see-image", + mappings=[ + {"param_name": "file_path", "node_type": "attribute", "path": "."} + ], + example=''' + + + ''' + ) + async def see_image(self, file_path: str) -> ToolResult: + """Reads an image file, converts it to base64, and adds it as a temporary message.""" + try: + # Ensure sandbox is initialized + await self._ensure_sandbox() + + # Clean and construct full path + cleaned_path = self.clean_path(file_path) + full_path = f"{self.workspace_path}/{cleaned_path}" + logger.info(f"Attempting to see image: {full_path} (original: {file_path})") + + # Check if file exists and get info + try: + file_info = self.sandbox.fs.get_file_info(full_path) + if file_info.is_dir: + return self.fail_response(f"Path '{cleaned_path}' is a directory, not an image file.") + except Exception as e: + logger.warning(f"File not found at {full_path}: {e}") + return self.fail_response(f"Image file not found at path: '{cleaned_path}'") + + # Check file size + if file_info.size > MAX_IMAGE_SIZE: + return self.fail_response(f"Image file '{cleaned_path}' is too large ({file_info.size / (1024*1024):.2f}MB). Maximum size is {MAX_IMAGE_SIZE / (1024*1024)}MB.") + + # Read image file content + try: + image_bytes = self.sandbox.fs.download_file(full_path) + except Exception as e: + logger.error(f"Error reading image file {full_path}: {e}") + return self.fail_response(f"Could not read image file: {cleaned_path}") + + # Convert to base64 + base64_image = base64.b64encode(image_bytes).decode('utf-8') + + # Determine MIME type + mime_type, _ = mimetypes.guess_type(full_path) + if not mime_type or not mime_type.startswith('image/'): + # Basic fallback based on extension if mimetypes fails + ext = os.path.splitext(cleaned_path)[1].lower() + if ext == '.jpg' or ext == '.jpeg': mime_type = 'image/jpeg' + elif ext == '.png': mime_type = 'image/png' + elif ext == '.gif': mime_type = 'image/gif' + elif ext == '.webp': mime_type = 'image/webp' + else: + return self.fail_response(f"Unsupported or unknown image format for file: '{cleaned_path}'. Supported: JPG, PNG, GIF, WEBP.") + + logger.info(f"Successfully read and encoded image '{cleaned_path}' as {mime_type}") + + # Prepare the temporary message content + image_context_data = { + "mime_type": mime_type, + "base64": base64_image, + "file_path": cleaned_path # Include path for context + } + + # Add the temporary message using the thread_manager callback + # Use a distinct type like 'image_context' + await self.thread_manager.add_message( + thread_id=self.thread_id, + type="image_context", # Use a specific type for this + content=image_context_data, # Store the dict directly + is_llm_message=False # This is context generated by a tool + ) + logger.info(f"Added image context message for '{cleaned_path}' to thread {self.thread_id}") + + # Inform the agent the image will be available next turn + return self.success_response(f"Successfully loaded the image '{cleaned_path}'.") + + except Exception as e: + logger.error(f"Error processing see_image for {file_path}: {e}", exc_info=True) + return self.fail_response(f"An unexpected error occurred while trying to see the image: {str(e)}") \ No newline at end of file