mirror of https://github.com/kortix-ai/suna.git
Merge pull request #689 from tnfssc/sharath/suna-283-generate-image-tool
This commit is contained in:
commit
9a4d9305e0
|
@ -89,7 +89,31 @@ You have the ability to execute operations using both Python and CLI tools:
|
|||
* Supported formats include JPG, PNG, GIF, WEBP, and other common image formats.
|
||||
* Maximum file size limit is 10 MB.
|
||||
|
||||
### 2.2.7 DATA PROVIDERS
|
||||
### 2.2.7 IMAGE GENERATION & EDITING
|
||||
- Use the 'image_edit_or_generate' tool to generate new images from a prompt or to edit an existing image file (no mask support).
|
||||
* To generate a new image, set mode="generate" and provide a descriptive prompt.
|
||||
* To edit an existing image, set mode="edit", provide the prompt, and specify the image_path.
|
||||
* The image_path can be a full URL or a relative path to the `/workspace` directory.
|
||||
* Example (generate):
|
||||
<function_calls>
|
||||
<invoke name="image_edit_or_generate">
|
||||
<parameter name="mode">generate</parameter>
|
||||
<parameter name="prompt">A futuristic cityscape at sunset</parameter>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
* Example (edit):
|
||||
<function_calls>
|
||||
<invoke name="image_edit_or_generate">
|
||||
<parameter name="mode">edit</parameter>
|
||||
<parameter name="prompt">Add a red hat to the person in the image</parameter>
|
||||
<parameter name="image_path">http://example.com/images/person.png</parameter>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
* ALWAYS use this tool for any image creation or editing tasks. Do not attempt to generate or edit images by any other means.
|
||||
* You must use edit mode when the user asks you to edit an image or change an existing image in any way.
|
||||
* Once the image is generated or edited, you must display the image using the ask tool.
|
||||
|
||||
### 2.2.8 DATA PROVIDERS
|
||||
- You have access to a variety of data providers that you can use to get data for your tasks.
|
||||
- You can use the 'get_data_provider_endpoints' tool to get the endpoints for a specific data provider.
|
||||
- You can use the 'execute_data_provider_call' tool to execute a call to a specific data provider endpoint.
|
||||
|
|
|
@ -25,6 +25,7 @@ from utils.logger import logger
|
|||
from utils.auth_utils import get_account_id_from_thread
|
||||
from services.billing import check_billing_status
|
||||
from agent.tools.sb_vision_tool import SandboxVisionTool
|
||||
from agent.tools.sb_image_edit_tool import SandboxImageEditTool
|
||||
from services.langfuse import langfuse
|
||||
from langfuse.client import StatefulTraceClient
|
||||
from services.langfuse import langfuse
|
||||
|
@ -107,6 +108,7 @@ async def run_agent(
|
|||
thread_manager.add_tool(MessageTool)
|
||||
thread_manager.add_tool(SandboxWebSearchTool, project_id=project_id, thread_manager=thread_manager)
|
||||
thread_manager.add_tool(SandboxVisionTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager)
|
||||
thread_manager.add_tool(SandboxImageEditTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager)
|
||||
if config.RAPID_API_KEY:
|
||||
thread_manager.add_tool(DataProvidersTool)
|
||||
else:
|
||||
|
|
|
@ -0,0 +1,172 @@
|
|||
from typing import Optional
|
||||
from agentpress.tool import ToolResult, openapi_schema, xml_schema
|
||||
from sandbox.tool_base import SandboxToolsBase
|
||||
from agentpress.thread_manager import ThreadManager
|
||||
import httpx
|
||||
from io import BytesIO
|
||||
import uuid
|
||||
from litellm import aimage_generation, aimage_edit
|
||||
import base64
|
||||
|
||||
|
||||
class SandboxImageEditTool(SandboxToolsBase):
|
||||
"""Tool for generating or editing images using OpenAI GPT Image 1 via OpenAI SDK (no mask support)."""
|
||||
|
||||
def __init__(self, project_id: str, thread_id: str, thread_manager: ThreadManager):
|
||||
super().__init__(project_id, thread_manager)
|
||||
self.thread_id = thread_id
|
||||
self.thread_manager = thread_manager
|
||||
|
||||
@openapi_schema(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "image_edit_or_generate",
|
||||
"description": "Generate a new image from a prompt, or edit an existing image (no mask support) using OpenAI GPT Image 1 via OpenAI SDK. Stores the result in the thread context.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"enum": ["generate", "edit"],
|
||||
"description": "'generate' to create a new image from a prompt, 'edit' to edit an existing image.",
|
||||
},
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "Text prompt describing the desired image or edit.",
|
||||
},
|
||||
"image_path": {
|
||||
"type": "string",
|
||||
"description": "(edit mode only) Path to the image file to edit, relative to /workspace. Required for 'edit'.",
|
||||
},
|
||||
},
|
||||
"required": ["mode", "prompt"],
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
@xml_schema(
|
||||
tag_name="image-edit-or-generate",
|
||||
mappings=[
|
||||
{"param_name": "mode", "node_type": "attribute", "path": "."},
|
||||
{"param_name": "prompt", "node_type": "attribute", "path": "."},
|
||||
{"param_name": "image_path", "node_type": "attribute", "path": "."},
|
||||
],
|
||||
example="""
|
||||
<function_calls>
|
||||
<invoke name="image_edit_or_generate">
|
||||
<parameter name="mode">generate</parameter>
|
||||
<parameter name="prompt">A futuristic cityscape at sunset</parameter>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
""",
|
||||
)
|
||||
async def image_edit_or_generate(
|
||||
self,
|
||||
mode: str,
|
||||
prompt: str,
|
||||
image_path: Optional[str] = None,
|
||||
) -> ToolResult:
|
||||
"""Generate or edit images using OpenAI GPT Image 1 via OpenAI SDK (no mask support)."""
|
||||
try:
|
||||
await self._ensure_sandbox()
|
||||
|
||||
if mode == "generate":
|
||||
response = await aimage_generation(
|
||||
model="gpt-image-1",
|
||||
prompt=prompt,
|
||||
n=1,
|
||||
size="1024x1024",
|
||||
)
|
||||
elif mode == "edit":
|
||||
if not image_path:
|
||||
return self.fail_response("'image_path' is required for edit mode.")
|
||||
|
||||
image_bytes = await self._get_image_bytes(image_path)
|
||||
if isinstance(image_bytes, ToolResult): # Error occurred
|
||||
return image_bytes
|
||||
|
||||
# Create BytesIO object with proper filename to set MIME type
|
||||
image_io = BytesIO(image_bytes)
|
||||
image_io.name = (
|
||||
"image.png" # Set filename to ensure proper MIME type detection
|
||||
)
|
||||
|
||||
response = await aimage_edit(
|
||||
image=[image_io], # Type in the LiteLLM SDK is wrong
|
||||
prompt=prompt,
|
||||
model="gpt-image-1",
|
||||
n=1,
|
||||
size="1024x1024",
|
||||
)
|
||||
else:
|
||||
return self.fail_response("Invalid mode. Use 'generate' or 'edit'.")
|
||||
|
||||
# Download and save the generated image to sandbox
|
||||
image_filename = await self._process_image_response(response)
|
||||
if isinstance(image_filename, ToolResult): # Error occurred
|
||||
return image_filename
|
||||
|
||||
return self.success_response(
|
||||
f"Successfully generated image using mode '{mode}'. Image saved as: {image_filename}. You can use the ask tool to display the image."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return self.fail_response(
|
||||
f"An error occurred during image generation/editing: {str(e)}"
|
||||
)
|
||||
|
||||
async def _get_image_bytes(self, image_path: str) -> bytes | ToolResult:
|
||||
"""Get image bytes from URL or local file path."""
|
||||
if image_path.startswith(("http://", "https://")):
|
||||
return await self._download_image_from_url(image_path)
|
||||
else:
|
||||
return await self._read_image_from_sandbox(image_path)
|
||||
|
||||
async def _download_image_from_url(self, url: str) -> bytes | ToolResult:
|
||||
"""Download image from URL."""
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
except Exception:
|
||||
return self.fail_response(f"Could not download image from URL: {url}")
|
||||
|
||||
async def _read_image_from_sandbox(self, image_path: str) -> bytes | ToolResult:
|
||||
"""Read image from sandbox filesystem."""
|
||||
try:
|
||||
cleaned_path = self.clean_path(image_path)
|
||||
full_path = f"{self.workspace_path}/{cleaned_path}"
|
||||
|
||||
# Check if file exists and is not a directory
|
||||
file_info = await self.sandbox.fs.get_file_info(full_path)
|
||||
if file_info.is_dir:
|
||||
return self.fail_response(
|
||||
f"Path '{cleaned_path}' is a directory, not an image file."
|
||||
)
|
||||
|
||||
return await self.sandbox.fs.download_file(full_path)
|
||||
|
||||
except Exception as e:
|
||||
return self.fail_response(
|
||||
f"Could not read image file from sandbox: {image_path} - {str(e)}"
|
||||
)
|
||||
|
||||
async def _process_image_response(self, response) -> str | ToolResult:
|
||||
"""Download generated image and save to sandbox with random name."""
|
||||
try:
|
||||
original_b64_str = response.data[0].b64_json
|
||||
# Decode base64 image data
|
||||
image_data = base64.b64decode(original_b64_str)
|
||||
|
||||
# Generate random filename
|
||||
random_filename = f"generated_image_{uuid.uuid4().hex[:8]}.png"
|
||||
sandbox_path = f"{self.workspace_path}/{random_filename}"
|
||||
|
||||
# Save image to sandbox
|
||||
await self.sandbox.fs.upload_file(image_data, sandbox_path)
|
||||
return random_filename
|
||||
|
||||
except Exception as e:
|
||||
return self.fail_response(f"Failed to download and save image: {str(e)}")
|
|
@ -12,7 +12,7 @@ classifiers = [
|
|||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"python-dotenv==1.0.1",
|
||||
"litellm==1.66.1",
|
||||
"litellm==1.72.2",
|
||||
"click==8.1.7",
|
||||
"questionary==2.0.1",
|
||||
"requests==2.32.3",
|
||||
|
@ -38,7 +38,7 @@ dependencies = [
|
|||
"daytona-api-client==0.21.0",
|
||||
"daytona-api-client-async==0.21.0",
|
||||
"boto3==1.37.3",
|
||||
"openai==1.72.0",
|
||||
"openai==1.90.0",
|
||||
"nest-asyncio==1.6.0",
|
||||
"vncdotool==1.2.0",
|
||||
"tavily-python==0.5.4",
|
||||
|
|
2700
backend/uv.lock
2700
backend/uv.lock
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue