mirror of https://github.com/kortix-ai/suna.git
Merge pull request #171 from kortix-ai/fix_upload
Fix upload & support images
This commit is contained in:
commit
dea49030f8
|
@ -76,7 +76,15 @@ You have the ability to execute operations using both Python and CLI tools:
|
|||
* YOU CAN DO ANYTHING ON THE BROWSER - including clicking on elements, filling forms, submitting data, etc.
|
||||
* The browser is in a sandboxed environment, so nothing to worry about.
|
||||
|
||||
### 2.2.6 DATA PROVIDERS
|
||||
### 2.2.6 VISUAL INPUT
|
||||
- You MUST use the 'see-image' tool to see image files. There is NO other way to access visual information.
|
||||
* Provide the relative path to the image in the `/workspace` directory.
|
||||
* Example: `<see-image file_path="path/to/your/image.png"></see-image>`
|
||||
* ALWAYS use this tool when visual information from a file is necessary for your task.
|
||||
* Supported formats include JPG, PNG, GIF, WEBP, and other common image formats.
|
||||
* Maximum file size limit is 10 MB.
|
||||
|
||||
### 2.2.7 DATA PROVIDERS
|
||||
- You have access to a variety of data providers that you can use to get data for your tasks.
|
||||
- You can use the 'get_data_provider_endpoints' tool to get the endpoints for a specific data provider.
|
||||
- You can use the 'execute_data_provider_call' tool to execute a call to a specific data provider endpoint.
|
||||
|
|
|
@ -22,6 +22,7 @@ from agent.prompt import get_system_prompt
|
|||
from utils import logger
|
||||
from utils.auth_utils import get_account_id_from_thread
|
||||
from services.billing import check_billing_status
|
||||
from agent.tools.sb_vision_tool import SandboxVisionTool
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
@ -66,8 +67,8 @@ async def run_agent(
|
|||
thread_manager.add_tool(SandboxDeployTool, project_id=project_id, thread_manager=thread_manager)
|
||||
thread_manager.add_tool(SandboxExposeTool, project_id=project_id, thread_manager=thread_manager)
|
||||
thread_manager.add_tool(MessageTool) # we are just doing this via prompt as there is no need to call it as a tool
|
||||
|
||||
thread_manager.add_tool(WebSearchTool)
|
||||
thread_manager.add_tool(SandboxVisionTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager)
|
||||
|
||||
# Add data providers tool if RapidAPI key is available
|
||||
if config.RAPID_API_KEY:
|
||||
|
@ -102,36 +103,73 @@ async def run_agent(
|
|||
continue_execution = False
|
||||
break
|
||||
|
||||
# Get the latest message from messages table that its type is browser_state
|
||||
latest_browser_state = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'browser_state').order('created_at', desc=True).limit(1).execute()
|
||||
# ---- Temporary Message Handling (Browser State & Image Context) ----
|
||||
temporary_message = None
|
||||
if latest_browser_state.data and len(latest_browser_state.data) > 0:
|
||||
temp_message_content_list = [] # List to hold text/image blocks
|
||||
|
||||
# Get the latest browser_state message
|
||||
latest_browser_state_msg = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'browser_state').order('created_at', desc=True).limit(1).execute()
|
||||
if latest_browser_state_msg.data and len(latest_browser_state_msg.data) > 0:
|
||||
try:
|
||||
content = json.loads(latest_browser_state.data[0]["content"])
|
||||
screenshot_base64 = content["screenshot_base64"]
|
||||
browser_content = json.loads(latest_browser_state_msg.data[0]["content"])
|
||||
screenshot_base64 = browser_content.get("screenshot_base64")
|
||||
# Create a copy of the browser state without screenshot
|
||||
browser_state = content.copy()
|
||||
browser_state.pop('screenshot_base64', None)
|
||||
browser_state.pop('screenshot_url', None)
|
||||
browser_state.pop('screenshot_url_base64', None)
|
||||
temporary_message = { "role": "user", "content": [] }
|
||||
if browser_state:
|
||||
temporary_message["content"].append({
|
||||
browser_state_text = browser_content.copy()
|
||||
browser_state_text.pop('screenshot_base64', None)
|
||||
browser_state_text.pop('screenshot_url', None)
|
||||
browser_state_text.pop('screenshot_url_base64', None)
|
||||
|
||||
if browser_state_text:
|
||||
temp_message_content_list.append({
|
||||
"type": "text",
|
||||
"text": f"The following is the current state of the browser:\n{browser_state}"
|
||||
"text": f"The following is the current state of the browser:\n{json.dumps(browser_state_text, indent=2)}"
|
||||
})
|
||||
if screenshot_base64:
|
||||
temporary_message["content"].append({
|
||||
temp_message_content_list.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{screenshot_base64}",
|
||||
}
|
||||
})
|
||||
else:
|
||||
print("@@@@@ THIS TIME NO SCREENSHOT!!")
|
||||
logger.warning("Browser state found but no screenshot base64 data.")
|
||||
|
||||
await client.table('messages').delete().eq('message_id', latest_browser_state_msg.data[0]["message_id"]).execute()
|
||||
except Exception as e:
|
||||
print(f"Error parsing browser state: {e}")
|
||||
# print(latest_browser_state.data[0])
|
||||
logger.error(f"Error parsing browser state: {e}")
|
||||
|
||||
# Get the latest image_context message (NEW)
|
||||
latest_image_context_msg = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'image_context').order('created_at', desc=True).limit(1).execute()
|
||||
if latest_image_context_msg.data and len(latest_image_context_msg.data) > 0:
|
||||
try:
|
||||
image_context_content = json.loads(latest_image_context_msg.data[0]["content"])
|
||||
base64_image = image_context_content.get("base64")
|
||||
mime_type = image_context_content.get("mime_type")
|
||||
file_path = image_context_content.get("file_path", "unknown file")
|
||||
|
||||
if base64_image and mime_type:
|
||||
temp_message_content_list.append({
|
||||
"type": "text",
|
||||
"text": f"Here is the image you requested to see: '{file_path}'"
|
||||
})
|
||||
temp_message_content_list.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{mime_type};base64,{base64_image}",
|
||||
}
|
||||
})
|
||||
else:
|
||||
logger.warning(f"Image context found for '{file_path}' but missing base64 or mime_type.")
|
||||
|
||||
await client.table('messages').delete().eq('message_id', latest_image_context_msg.data[0]["message_id"]).execute()
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing image context: {e}")
|
||||
|
||||
# If we have any content, construct the temporary_message
|
||||
if temp_message_content_list:
|
||||
temporary_message = {"role": "user", "content": temp_message_content_list}
|
||||
# logger.debug(f"Constructed temporary message with {len(temp_message_content_list)} content blocks.")
|
||||
# ---- End Temporary Message Handling ----
|
||||
|
||||
max_tokens = 64000 if "sonnet" in model_name.lower() else None
|
||||
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
import os
|
||||
import base64
|
||||
import mimetypes
|
||||
from typing import Optional
|
||||
|
||||
from agentpress.tool import ToolResult, openapi_schema, xml_schema
|
||||
from sandbox.sandbox import SandboxToolsBase, Sandbox
|
||||
from agentpress.thread_manager import ThreadManager
|
||||
from utils.logger import logger
|
||||
import json
|
||||
|
||||
# Add common image MIME types if mimetypes module is limited
|
||||
mimetypes.add_type("image/webp", ".webp")
|
||||
mimetypes.add_type("image/jpeg", ".jpg")
|
||||
mimetypes.add_type("image/jpeg", ".jpeg")
|
||||
mimetypes.add_type("image/png", ".png")
|
||||
mimetypes.add_type("image/gif", ".gif")
|
||||
|
||||
# Maximum file size in bytes (e.g., 5MB)
|
||||
MAX_IMAGE_SIZE = 10 * 1024 * 1024
|
||||
|
||||
class SandboxVisionTool(SandboxToolsBase):
|
||||
"""Tool for allowing the agent to 'see' images within the sandbox."""
|
||||
|
||||
def __init__(self, project_id: str, thread_id: str, thread_manager: ThreadManager):
|
||||
super().__init__(project_id, thread_manager)
|
||||
self.thread_id = thread_id
|
||||
# Make thread_manager accessible within the tool instance
|
||||
self.thread_manager = thread_manager
|
||||
|
||||
@openapi_schema({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "see_image",
|
||||
"description": "Allows the agent to 'see' an image file located in the /workspace directory. Provide the relative path to the image. The image content will be made available in the next turn's context.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "The relative path to the image file within the /workspace directory (e.g., 'screenshots/image.png'). Supported formats: JPG, PNG, GIF, WEBP. Max size: 5MB."
|
||||
}
|
||||
},
|
||||
"required": ["file_path"]
|
||||
}
|
||||
}
|
||||
})
|
||||
@xml_schema(
|
||||
tag_name="see-image",
|
||||
mappings=[
|
||||
{"param_name": "file_path", "node_type": "attribute", "path": "."}
|
||||
],
|
||||
example='''
|
||||
<!-- Example: Request to see an image named 'diagram.png' inside the 'docs' folder -->
|
||||
<see-image file_path="docs/diagram.png"></see-image>
|
||||
'''
|
||||
)
|
||||
async def see_image(self, file_path: str) -> ToolResult:
|
||||
"""Reads an image file, converts it to base64, and adds it as a temporary message."""
|
||||
try:
|
||||
# Ensure sandbox is initialized
|
||||
await self._ensure_sandbox()
|
||||
|
||||
# Clean and construct full path
|
||||
cleaned_path = self.clean_path(file_path)
|
||||
full_path = f"{self.workspace_path}/{cleaned_path}"
|
||||
logger.info(f"Attempting to see image: {full_path} (original: {file_path})")
|
||||
|
||||
# Check if file exists and get info
|
||||
try:
|
||||
file_info = self.sandbox.fs.get_file_info(full_path)
|
||||
if file_info.is_dir:
|
||||
return self.fail_response(f"Path '{cleaned_path}' is a directory, not an image file.")
|
||||
except Exception as e:
|
||||
logger.warning(f"File not found at {full_path}: {e}")
|
||||
return self.fail_response(f"Image file not found at path: '{cleaned_path}'")
|
||||
|
||||
# Check file size
|
||||
if file_info.size > MAX_IMAGE_SIZE:
|
||||
return self.fail_response(f"Image file '{cleaned_path}' is too large ({file_info.size / (1024*1024):.2f}MB). Maximum size is {MAX_IMAGE_SIZE / (1024*1024)}MB.")
|
||||
|
||||
# Read image file content
|
||||
try:
|
||||
image_bytes = self.sandbox.fs.download_file(full_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading image file {full_path}: {e}")
|
||||
return self.fail_response(f"Could not read image file: {cleaned_path}")
|
||||
|
||||
# Convert to base64
|
||||
base64_image = base64.b64encode(image_bytes).decode('utf-8')
|
||||
|
||||
# Determine MIME type
|
||||
mime_type, _ = mimetypes.guess_type(full_path)
|
||||
if not mime_type or not mime_type.startswith('image/'):
|
||||
# Basic fallback based on extension if mimetypes fails
|
||||
ext = os.path.splitext(cleaned_path)[1].lower()
|
||||
if ext == '.jpg' or ext == '.jpeg': mime_type = 'image/jpeg'
|
||||
elif ext == '.png': mime_type = 'image/png'
|
||||
elif ext == '.gif': mime_type = 'image/gif'
|
||||
elif ext == '.webp': mime_type = 'image/webp'
|
||||
else:
|
||||
return self.fail_response(f"Unsupported or unknown image format for file: '{cleaned_path}'. Supported: JPG, PNG, GIF, WEBP.")
|
||||
|
||||
logger.info(f"Successfully read and encoded image '{cleaned_path}' as {mime_type}")
|
||||
|
||||
# Prepare the temporary message content
|
||||
image_context_data = {
|
||||
"mime_type": mime_type,
|
||||
"base64": base64_image,
|
||||
"file_path": cleaned_path # Include path for context
|
||||
}
|
||||
|
||||
# Add the temporary message using the thread_manager callback
|
||||
# Use a distinct type like 'image_context'
|
||||
await self.thread_manager.add_message(
|
||||
thread_id=self.thread_id,
|
||||
type="image_context", # Use a specific type for this
|
||||
content=image_context_data, # Store the dict directly
|
||||
is_llm_message=False # This is context generated by a tool
|
||||
)
|
||||
logger.info(f"Added image context message for '{cleaned_path}' to thread {self.thread_id}")
|
||||
|
||||
# Inform the agent the image will be available next turn
|
||||
return self.success_response(f"Successfully loaded the image '{cleaned_path}'.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing see_image for {file_path}: {e}", exc_info=True)
|
||||
return self.fail_response(f"An unexpected error occurred while trying to see the image: {str(e)}")
|
|
@ -32,7 +32,7 @@ function DashboardContent() {
|
|||
const personalAccount = accounts?.find(account => account.personal_account);
|
||||
const chatInputRef = useRef<ChatInputHandles>(null);
|
||||
|
||||
const handleSubmit = async (message: string, options?: { model_name?: string; enable_thinking?: boolean }) => {
|
||||
const handleSubmit = async (message: string, options?: { model_name?: string; enable_thinking?: boolean; reasoning_effort?: string; stream?: boolean; enable_context_manager?: boolean }) => {
|
||||
if ((!message.trim() && !(chatInputRef.current?.getPendingFiles().length)) || isSubmitting) return;
|
||||
|
||||
setIsSubmitting(true);
|
||||
|
@ -42,85 +42,73 @@ function DashboardContent() {
|
|||
localStorage.removeItem(PENDING_PROMPT_KEY);
|
||||
|
||||
if (files.length > 0) {
|
||||
// Create a FormData instance
|
||||
// ---- Handle submission WITH files ----
|
||||
console.log(`Submitting with message: "${message}" and ${files.length} files.`);
|
||||
const formData = new FormData();
|
||||
|
||||
// Append the message
|
||||
formData.append('message', message);
|
||||
// Use 'prompt' key instead of 'message'
|
||||
formData.append('prompt', message);
|
||||
|
||||
// Append all files
|
||||
files.forEach(file => {
|
||||
formData.append('files', file);
|
||||
// Append files
|
||||
files.forEach((file, index) => {
|
||||
formData.append('files', file, file.name);
|
||||
});
|
||||
|
||||
// Add any additional options
|
||||
if (options) {
|
||||
formData.append('options', JSON.stringify(options));
|
||||
}
|
||||
// Append options individually instead of bundled 'options' field
|
||||
if (options?.model_name) formData.append('model_name', options.model_name);
|
||||
// Default values from backend signature if not provided in options:
|
||||
formData.append('enable_thinking', String(options?.enable_thinking ?? false));
|
||||
formData.append('reasoning_effort', options?.reasoning_effort ?? 'low');
|
||||
formData.append('stream', String(options?.stream ?? true));
|
||||
formData.append('enable_context_manager', String(options?.enable_context_manager ?? false));
|
||||
|
||||
console.log('FormData content:', Array.from(formData.entries()));
|
||||
|
||||
// Call initiateAgent API
|
||||
const result = await initiateAgent(formData);
|
||||
console.log('Agent initiated with files:', result);
|
||||
|
||||
// Navigate to the thread
|
||||
if (result.thread_id) {
|
||||
router.push(`/agents/${result.thread_id}`);
|
||||
}
|
||||
} else {
|
||||
// ---- Text-only messages ----
|
||||
// 1. Generate a project name
|
||||
throw new Error("Agent initiation did not return a thread_id.");
|
||||
}
|
||||
chatInputRef.current?.clearPendingFiles();
|
||||
|
||||
} else {
|
||||
// ---- Handle text-only messages (NO CHANGES NEEDED HERE) ----
|
||||
console.log(`Submitting text-only message: "${message}"`);
|
||||
const projectName = await generateThreadName(message);
|
||||
|
||||
// 2. Create the project
|
||||
// Assuming createProject gets the account_id from the logged-in user
|
||||
const newProject = await createProject({
|
||||
name: projectName,
|
||||
description: "", // Or derive a description if desired
|
||||
});
|
||||
|
||||
// 3. Create the thread using the new project ID
|
||||
const thread = await createThread(newProject.id); // <-- Pass the actual project ID
|
||||
|
||||
// 4. Then add the user message
|
||||
const newProject = await createProject({ name: projectName, description: "" });
|
||||
const thread = await createThread(newProject.id);
|
||||
await addUserMessage(thread.thread_id, message);
|
||||
|
||||
// 5. Start the agent on this thread with the options
|
||||
await startAgent(thread.thread_id, options);
|
||||
|
||||
// 6. Navigate to thread
|
||||
await startAgent(thread.thread_id, options); // Pass original options here
|
||||
router.push(`/agents/${thread.thread_id}`);
|
||||
}
|
||||
} catch (error: any) {
|
||||
console.error('Error during submission process:', error);
|
||||
|
||||
// Check specifically for BillingError (402)
|
||||
if (error instanceof BillingError) {
|
||||
// Delegate billing error handling
|
||||
console.log("Handling BillingError:", error.detail);
|
||||
handleBillingError({
|
||||
// Pass details from the BillingError instance
|
||||
message: error.detail.message || 'Monthly usage limit reached. Please upgrade your plan.',
|
||||
currentUsage: error.detail.currentUsage as number | undefined, // Attempt to get usage/limit if backend adds them
|
||||
currentUsage: error.detail.currentUsage as number | undefined,
|
||||
limit: error.detail.limit as number | undefined,
|
||||
// Include subscription details if available in the error, otherwise provide defaults
|
||||
subscription: error.detail.subscription || {
|
||||
price_id: config.SUBSCRIPTION_TIERS.FREE.priceId, // Default to Free tier
|
||||
price_id: config.SUBSCRIPTION_TIERS.FREE.priceId,
|
||||
plan_name: "Free"
|
||||
}
|
||||
});
|
||||
// Don't show toast for billing errors, the modal handles it
|
||||
setIsSubmitting(false);
|
||||
return; // Stop execution
|
||||
return; // Stop further processing for billing errors
|
||||
}
|
||||
|
||||
// Handle other types of errors (e.g., network, other API errors)
|
||||
// Skip toast in local mode unless it's a connection error
|
||||
// Handle other errors
|
||||
const isConnectionError = error instanceof TypeError && error.message.includes('Failed to fetch');
|
||||
if (!isLocalMode() || isConnectionError) {
|
||||
toast.error(error.message || "An unexpected error occurred");
|
||||
}
|
||||
setIsSubmitting(false); // Reset submitting state on other errors too
|
||||
setIsSubmitting(false); // Reset submitting state on all errors
|
||||
}
|
||||
// No finally block needed, state is reset in catch blocks
|
||||
};
|
||||
|
||||
// Check for pending prompt in localStorage on mount
|
||||
|
|
|
@ -29,6 +29,7 @@ const API_URL = process.env.NEXT_PUBLIC_BACKEND_URL || '';
|
|||
|
||||
// Local storage keys
|
||||
const STORAGE_KEY_MODEL = 'suna-preferred-model';
|
||||
const DEFAULT_MODEL_ID = "sonnet-3.7"; // Define default model ID
|
||||
|
||||
interface ChatInputProps {
|
||||
onSubmit: (message: string, options?: { model_name?: string; enable_thinking?: boolean }) => void;
|
||||
|
@ -76,7 +77,16 @@ export const ChatInput = forwardRef<ChatInputHandles, ChatInputProps>(({
|
|||
const [uncontrolledValue, setUncontrolledValue] = useState('');
|
||||
const value = isControlled ? controlledValue : uncontrolledValue;
|
||||
|
||||
const [selectedModel, setSelectedModel] = useState("sonnet-3.7");
|
||||
// Define model options array earlier so it can be used in useEffect
|
||||
const modelOptions = [
|
||||
{ id: "sonnet-3.7", label: "Sonnet 3.7" },
|
||||
{ id: "sonnet-3.7-thinking", label: "Sonnet 3.7 (Thinking)" },
|
||||
{ id: "gpt-4.1", label: "GPT-4.1" },
|
||||
{ id: "gemini-flash-2.5", label: "Gemini Flash 2.5" }
|
||||
];
|
||||
|
||||
// Initialize state with the default model
|
||||
const [selectedModel, setSelectedModel] = useState(DEFAULT_MODEL_ID);
|
||||
const textareaRef = useRef<HTMLTextAreaElement | null>(null);
|
||||
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||
const [uploadedFiles, setUploadedFiles] = useState<UploadedFile[]>([]);
|
||||
|
@ -94,8 +104,13 @@ export const ChatInput = forwardRef<ChatInputHandles, ChatInputProps>(({
|
|||
if (typeof window !== 'undefined') {
|
||||
try {
|
||||
const savedModel = localStorage.getItem(STORAGE_KEY_MODEL);
|
||||
if (savedModel) {
|
||||
// Check if the saved model exists and is one of the valid options
|
||||
if (savedModel && modelOptions.some(option => option.id === savedModel)) {
|
||||
setSelectedModel(savedModel);
|
||||
} else if (savedModel) {
|
||||
// If invalid model found in storage, clear it
|
||||
localStorage.removeItem(STORAGE_KEY_MODEL);
|
||||
console.log(`Removed invalid model '${savedModel}' from localStorage. Using default: ${DEFAULT_MODEL_ID}`);
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn('Failed to load preferences from localStorage:', error);
|
||||
|
@ -154,8 +169,8 @@ export const ChatInput = forwardRef<ChatInputHandles, ChatInputProps>(({
|
|||
|
||||
let baseModelName = selectedModel;
|
||||
let thinkingEnabled = false;
|
||||
if (selectedModel === "sonnet-3.7-thinking") {
|
||||
baseModelName = "sonnet-3.7";
|
||||
if (selectedModel.endsWith("-thinking")) {
|
||||
baseModelName = selectedModel.replace(/-thinking$/, "");
|
||||
thinkingEnabled = true;
|
||||
}
|
||||
|
||||
|
@ -333,13 +348,6 @@ export const ChatInput = forwardRef<ChatInputHandles, ChatInputProps>(({
|
|||
setUploadedFiles(prev => prev.filter((_, i) => i !== index));
|
||||
};
|
||||
|
||||
const modelOptions = [
|
||||
{ id: "sonnet-3.7", label: "Sonnet 3.7" },
|
||||
{ id: "sonnet-3.7-thinking", label: "Sonnet 3.7 (Thinking)" },
|
||||
{ id: "gpt-4.1", label: "GPT-4.1" },
|
||||
{ id: "gemini-flash-2.5", label: "Gemini Flash 2.5" }
|
||||
];
|
||||
|
||||
return (
|
||||
<div className="mx-auto w-full max-w-3xl px-4 py-4">
|
||||
<AnimatePresence>
|
||||
|
|
Loading…
Reference in New Issue