mirror of https://github.com/kortix-ai/suna.git
feat(image-editing): introduce image generation and editing tool with updated documentation
This commit is contained in:
parent
d40e2e3829
commit
30f88aed99
|
@ -89,7 +89,31 @@ You have the ability to execute operations using both Python and CLI tools:
|
||||||
* Supported formats include JPG, PNG, GIF, WEBP, and other common image formats.
|
* Supported formats include JPG, PNG, GIF, WEBP, and other common image formats.
|
||||||
* Maximum file size limit is 10 MB.
|
* Maximum file size limit is 10 MB.
|
||||||
|
|
||||||
### 2.2.7 DATA PROVIDERS
|
### 2.2.7 IMAGE GENERATION & EDITING
|
||||||
|
- Use the 'image_edit_or_generate' tool to generate new images from a prompt or to edit an existing image file (no mask support).
|
||||||
|
* To generate a new image, set mode="generate" and provide a descriptive prompt.
|
||||||
|
* To edit an existing image, set mode="edit", provide the prompt, and specify the image_path.
|
||||||
|
* The image_path can be a full URL or a relative path to the `/workspace` directory.
|
||||||
|
* Example (generate):
|
||||||
|
<function_calls>
|
||||||
|
<invoke name="image_edit_or_generate">
|
||||||
|
<parameter name="mode">generate</parameter>
|
||||||
|
<parameter name="prompt">A futuristic cityscape at sunset</parameter>
|
||||||
|
</invoke>
|
||||||
|
</function_calls>
|
||||||
|
* Example (edit):
|
||||||
|
<function_calls>
|
||||||
|
<invoke name="image_edit_or_generate">
|
||||||
|
<parameter name="mode">edit</parameter>
|
||||||
|
<parameter name="prompt">Add a red hat to the person in the image</parameter>
|
||||||
|
<parameter name="image_path">http://example.com/images/person.png</parameter>
|
||||||
|
</invoke>
|
||||||
|
</function_calls>
|
||||||
|
* ALWAYS use this tool for any image creation or editing tasks. Do not attempt to generate or edit images by any other means.
|
||||||
|
* You must use edit mode when the user asks you to edit an image or change an existing image in any way.
|
||||||
|
* Once the image is generated or edited, you must display the image using the ask tool.
|
||||||
|
|
||||||
|
### 2.2.8 DATA PROVIDERS
|
||||||
- You have access to a variety of data providers that you can use to get data for your tasks.
|
- You have access to a variety of data providers that you can use to get data for your tasks.
|
||||||
- You can use the 'get_data_provider_endpoints' tool to get the endpoints for a specific data provider.
|
- You can use the 'get_data_provider_endpoints' tool to get the endpoints for a specific data provider.
|
||||||
- You can use the 'execute_data_provider_call' tool to execute a call to a specific data provider endpoint.
|
- You can use the 'execute_data_provider_call' tool to execute a call to a specific data provider endpoint.
|
||||||
|
|
|
@ -25,6 +25,7 @@ from utils.logger import logger
|
||||||
from utils.auth_utils import get_account_id_from_thread
|
from utils.auth_utils import get_account_id_from_thread
|
||||||
from services.billing import check_billing_status
|
from services.billing import check_billing_status
|
||||||
from agent.tools.sb_vision_tool import SandboxVisionTool
|
from agent.tools.sb_vision_tool import SandboxVisionTool
|
||||||
|
from agent.tools.sb_image_edit_tool import SandboxImageEditTool
|
||||||
from services.langfuse import langfuse
|
from services.langfuse import langfuse
|
||||||
from langfuse.client import StatefulTraceClient
|
from langfuse.client import StatefulTraceClient
|
||||||
from services.langfuse import langfuse
|
from services.langfuse import langfuse
|
||||||
|
@ -107,6 +108,7 @@ async def run_agent(
|
||||||
thread_manager.add_tool(MessageTool)
|
thread_manager.add_tool(MessageTool)
|
||||||
thread_manager.add_tool(SandboxWebSearchTool, project_id=project_id, thread_manager=thread_manager)
|
thread_manager.add_tool(SandboxWebSearchTool, project_id=project_id, thread_manager=thread_manager)
|
||||||
thread_manager.add_tool(SandboxVisionTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager)
|
thread_manager.add_tool(SandboxVisionTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager)
|
||||||
|
thread_manager.add_tool(SandboxImageEditTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager)
|
||||||
if config.RAPID_API_KEY:
|
if config.RAPID_API_KEY:
|
||||||
thread_manager.add_tool(DataProvidersTool)
|
thread_manager.add_tool(DataProvidersTool)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -0,0 +1,167 @@
|
||||||
|
from typing import Optional
|
||||||
|
from agentpress.tool import ToolResult, openapi_schema, xml_schema
|
||||||
|
from sandbox.tool_base import SandboxToolsBase
|
||||||
|
from agentpress.thread_manager import ThreadManager
|
||||||
|
from openai import OpenAI
|
||||||
|
import httpx
|
||||||
|
import os
|
||||||
|
from io import BytesIO
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
class SandboxImageEditTool(SandboxToolsBase):
|
||||||
|
"""Tool for generating or editing images using OpenAI DALL-E via OpenAI SDK (no mask support)."""
|
||||||
|
|
||||||
|
def __init__(self, project_id: str, thread_id: str, thread_manager: ThreadManager):
|
||||||
|
super().__init__(project_id, thread_manager)
|
||||||
|
self.thread_id = thread_id
|
||||||
|
self.thread_manager = thread_manager
|
||||||
|
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
||||||
|
|
||||||
|
@openapi_schema(
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "image_edit_or_generate",
|
||||||
|
"description": "Generate a new image from a prompt, or edit an existing image (no mask support) using OpenAI DALL-E via OpenAI SDK. Stores the result in the thread context.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"mode": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["generate", "edit"],
|
||||||
|
"description": "'generate' to create a new image from a prompt, 'edit' to edit an existing image.",
|
||||||
|
},
|
||||||
|
"prompt": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Text prompt describing the desired image or edit.",
|
||||||
|
},
|
||||||
|
"image_path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "(edit mode only) Path to the image file to edit, relative to /workspace. Required for 'edit'.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["mode", "prompt"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@xml_schema(
|
||||||
|
tag_name="image-edit-or-generate",
|
||||||
|
mappings=[
|
||||||
|
{"param_name": "mode", "node_type": "attribute", "path": "."},
|
||||||
|
{"param_name": "prompt", "node_type": "attribute", "path": "."},
|
||||||
|
{"param_name": "image_path", "node_type": "attribute", "path": "."},
|
||||||
|
],
|
||||||
|
example="""
|
||||||
|
<function_calls>
|
||||||
|
<invoke name="image_edit_or_generate">
|
||||||
|
<parameter name="mode">generate</parameter>
|
||||||
|
<parameter name="prompt">A futuristic cityscape at sunset</parameter>
|
||||||
|
</invoke>
|
||||||
|
</function_calls>
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
async def image_edit_or_generate(
|
||||||
|
self,
|
||||||
|
mode: str,
|
||||||
|
prompt: str,
|
||||||
|
image_path: Optional[str] = None,
|
||||||
|
) -> ToolResult:
|
||||||
|
"""Generate or edit images using OpenAI DALL-E via OpenAI SDK (no mask support)."""
|
||||||
|
try:
|
||||||
|
await self._ensure_sandbox()
|
||||||
|
|
||||||
|
if mode == "generate":
|
||||||
|
response = self.client.images.generate(
|
||||||
|
prompt=prompt, n=1, size="1024x1024"
|
||||||
|
)
|
||||||
|
elif mode == "edit":
|
||||||
|
if not image_path:
|
||||||
|
return self.fail_response("'image_path' is required for edit mode.")
|
||||||
|
|
||||||
|
image_bytes = await self._get_image_bytes(image_path)
|
||||||
|
if isinstance(image_bytes, ToolResult): # Error occurred
|
||||||
|
return image_bytes
|
||||||
|
|
||||||
|
# Create BytesIO object with proper filename to set MIME type
|
||||||
|
image_io = BytesIO(image_bytes)
|
||||||
|
image_io.name = "image.png" # Set filename to ensure proper MIME type detection
|
||||||
|
|
||||||
|
response = self.client.images.edit(
|
||||||
|
image=image_io, prompt=prompt, n=1, size="1024x1024"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.fail_response("Invalid mode. Use 'generate' or 'edit'.")
|
||||||
|
|
||||||
|
# Download and save the generated image to sandbox
|
||||||
|
image_filename = await self._process_image_response(response)
|
||||||
|
if isinstance(image_filename, ToolResult): # Error occurred
|
||||||
|
return image_filename
|
||||||
|
|
||||||
|
return self.success_response(
|
||||||
|
f"Successfully generated image using mode '{mode}'. Image saved as: {image_filename}. You can use the ask tool to display the image."
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return self.fail_response(
|
||||||
|
f"An error occurred during image generation/editing: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_image_bytes(self, image_path: str) -> bytes | ToolResult:
|
||||||
|
"""Get image bytes from URL or local file path."""
|
||||||
|
if image_path.startswith(("http://", "https://")):
|
||||||
|
return await self._download_image_from_url(image_path)
|
||||||
|
else:
|
||||||
|
return await self._read_image_from_sandbox(image_path)
|
||||||
|
|
||||||
|
async def _download_image_from_url(self, url: str) -> bytes | ToolResult:
|
||||||
|
"""Download image from URL."""
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(url)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.content
|
||||||
|
except Exception:
|
||||||
|
return self.fail_response(f"Could not download image from URL: {url}")
|
||||||
|
|
||||||
|
async def _read_image_from_sandbox(self, image_path: str) -> bytes | ToolResult:
|
||||||
|
"""Read image from sandbox filesystem."""
|
||||||
|
try:
|
||||||
|
cleaned_path = self.clean_path(image_path)
|
||||||
|
full_path = f"{self.workspace_path}/{cleaned_path}"
|
||||||
|
|
||||||
|
# Check if file exists and is not a directory
|
||||||
|
file_info = self.sandbox.fs.get_file_info(full_path)
|
||||||
|
if file_info.is_dir:
|
||||||
|
return self.fail_response(
|
||||||
|
f"Path '{cleaned_path}' is a directory, not an image file."
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.sandbox.fs.download_file(full_path)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return self.fail_response(
|
||||||
|
f"Could not read image file from sandbox: {image_path} - {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _process_image_response(self, response) -> str | ToolResult:
|
||||||
|
"""Download generated image and save to sandbox with random name."""
|
||||||
|
try:
|
||||||
|
original_url = response.data[0].url
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
img_response = await client.get(original_url)
|
||||||
|
img_response.raise_for_status()
|
||||||
|
|
||||||
|
# Generate random filename
|
||||||
|
|
||||||
|
random_filename = f"generated_image_{uuid.uuid4().hex[:8]}.png"
|
||||||
|
sandbox_path = f"{self.workspace_path}/{random_filename}"
|
||||||
|
|
||||||
|
# Save image to sandbox
|
||||||
|
self.sandbox.fs.upload_file(sandbox_path, img_response.content)
|
||||||
|
return random_filename
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return self.fail_response(f"Failed to download and save image: {str(e)}")
|
|
@ -1428,14 +1428,14 @@ openai = ["openai (>=0.27.8)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "1.66.1"
|
version = "1.72.2"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8"
|
python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "litellm-1.66.1-py3-none-any.whl", hash = "sha256:1f601fea3f086c1d2d91be60b9db115082a2f3a697e4e0def72f8b9c777c7232"},
|
{file = "litellm-1.72.2-py3-none-any.whl", hash = "sha256:51e70f5cd98748a603d725ef29ede0ecad3d55e1a89cbbcec8d12d6fff55bff4"},
|
||||||
{file = "litellm-1.66.1.tar.gz", hash = "sha256:98f7add913e5eae2131dd412ee27532d9a309defd9dbb64f6c6c42ea8a2af068"},
|
{file = "litellm-1.72.2.tar.gz", hash = "sha256:b50c7f7a0df67117889479264a12b0dea9c566a02173d4c3159540a13760d38b"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -1453,7 +1453,8 @@ tokenizers = "*"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
extra-proxy = ["azure-identity (>=1.15.0,<2.0.0)", "azure-keyvault-secrets (>=4.8.0,<5.0.0)", "google-cloud-kms (>=2.21.3,<3.0.0)", "prisma (==0.11.0)", "redisvl (>=0.4.1,<0.5.0) ; python_version >= \"3.9\" and python_version < \"3.14\"", "resend (>=0.8.0,<0.9.0)"]
|
extra-proxy = ["azure-identity (>=1.15.0,<2.0.0)", "azure-keyvault-secrets (>=4.8.0,<5.0.0)", "google-cloud-kms (>=2.21.3,<3.0.0)", "prisma (==0.11.0)", "redisvl (>=0.4.1,<0.5.0) ; python_version >= \"3.9\" and python_version < \"3.14\"", "resend (>=0.8.0,<0.9.0)"]
|
||||||
proxy = ["PyJWT (>=2.8.0,<3.0.0)", "apscheduler (>=3.10.4,<4.0.0)", "backoff", "boto3 (==1.34.34)", "cryptography (>=43.0.1,<44.0.0)", "fastapi (>=0.115.5,<0.116.0)", "fastapi-sso (>=0.16.0,<0.17.0)", "gunicorn (>=23.0.0,<24.0.0)", "litellm-proxy-extras (==0.1.7)", "mcp (==1.5.0) ; python_version >= \"3.10\"", "orjson (>=3.9.7,<4.0.0)", "pynacl (>=1.5.0,<2.0.0)", "python-multipart (>=0.0.18,<0.0.19)", "pyyaml (>=6.0.1,<7.0.0)", "rq", "uvicorn (>=0.29.0,<0.30.0)", "uvloop (>=0.21.0,<0.22.0)", "websockets (>=13.1.0,<14.0.0)"]
|
proxy = ["PyJWT (>=2.8.0,<3.0.0)", "apscheduler (>=3.10.4,<4.0.0)", "backoff", "boto3 (==1.34.34)", "cryptography (>=43.0.1,<44.0.0)", "fastapi (>=0.115.5,<0.116.0)", "fastapi-sso (>=0.16.0,<0.17.0)", "gunicorn (>=23.0.0,<24.0.0)", "litellm-enterprise (==0.1.7)", "litellm-proxy-extras (==0.2.3)", "mcp (==1.5.0) ; python_version >= \"3.10\"", "orjson (>=3.9.7,<4.0.0)", "pynacl (>=1.5.0,<2.0.0)", "python-multipart (>=0.0.18,<0.0.19)", "pyyaml (>=6.0.1,<7.0.0)", "rich (==13.7.1)", "rq", "uvicorn (>=0.29.0,<0.30.0)", "uvloop (>=0.21.0,<0.22.0) ; sys_platform != \"win32\"", "websockets (>=13.1.0,<14.0.0)"]
|
||||||
|
utils = ["numpydoc"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mailtrap"
|
name = "mailtrap"
|
||||||
|
@ -3904,4 +3905,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.1"
|
lock-version = "2.1"
|
||||||
python-versions = "^3.11"
|
python-versions = "^3.11"
|
||||||
content-hash = "09a851f3db2d0b1f130405a69c1661c453f82ce23e078256bc6749662af897a7"
|
content-hash = "3b983fbe8614f4e59280b2087fa4bcc574502d58fc75aa73a44426279f99e3d2"
|
||||||
|
|
|
@ -19,7 +19,7 @@ classifiers = [
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.11"
|
python = "^3.11"
|
||||||
python-dotenv = "1.0.1"
|
python-dotenv = "1.0.1"
|
||||||
litellm = "1.66.1"
|
litellm = "1.72.2"
|
||||||
click = "8.1.7"
|
click = "8.1.7"
|
||||||
questionary = "2.0.1"
|
questionary = "2.0.1"
|
||||||
requests = "^2.31.0"
|
requests = "^2.31.0"
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
litellm==1.66.1
|
litellm==1.72.2
|
||||||
click==8.1.7
|
click==8.1.7
|
||||||
questionary==2.0.1
|
questionary==2.0.1
|
||||||
requests>=2.31.0
|
requests>=2.31.0
|
||||||
|
|
Loading…
Reference in New Issue