mirror of https://github.com/kortix-ai/suna.git
Merge pull request #1659 from KrishavRajSingh/main
img search via serper
This commit is contained in:
commit
61c1f267dd
|
@ -10,11 +10,12 @@ import os
|
|||
import datetime
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Union, List
|
||||
|
||||
# TODO: add subpages, etc... in filters as sometimes its necessary
|
||||
|
||||
class SandboxWebSearchTool(SandboxToolsBase):
|
||||
"""Tool for performing web searches using Tavily API and web scraping using Firecrawl."""
|
||||
"""Tool for performing web searches using Tavily API, image searches using SERPER API, and web scraping using Firecrawl."""
|
||||
|
||||
def __init__(self, project_id: str, thread_manager: ThreadManager):
|
||||
super().__init__(project_id, thread_manager)
|
||||
|
@ -24,6 +25,7 @@ class SandboxWebSearchTool(SandboxToolsBase):
|
|||
self.tavily_api_key = config.TAVILY_API_KEY
|
||||
self.firecrawl_api_key = config.FIRECRAWL_API_KEY
|
||||
self.firecrawl_url = config.FIRECRAWL_URL
|
||||
self.serper_api_key = config.SERPER_API_KEY
|
||||
|
||||
if not self.tavily_api_key:
|
||||
raise ValueError("TAVILY_API_KEY not found in configuration")
|
||||
|
@ -399,6 +401,206 @@ class SandboxWebSearchTool(SandboxToolsBase):
|
|||
"error": error_message
|
||||
}
|
||||
|
||||
@openapi_schema({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "image_search",
|
||||
"description": "Search for images using SERPER API. Supports both single and batch searches. Returns image URLs for the given search query(s). Perfect for finding visual content, illustrations, photos, or any images related to your search terms.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Single search query. Be specific about what kind of images you're looking for (e.g., 'cats playing', 'mountain landscape', 'modern architecture')"
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Multiple search queries for batch processing. More efficient for multiple searches (e.g., ['cats', 'dogs', 'birds'])"
|
||||
}
|
||||
],
|
||||
"description": "Search query or queries. Single string for one search, array of strings for batch search."
|
||||
},
|
||||
"num_results": {
|
||||
"type": "integer",
|
||||
"description": "The number of image results to return per query. Default is 12, maximum is 100.",
|
||||
"default": 12,
|
||||
"minimum": 1,
|
||||
"maximum": 100
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
})
|
||||
@usage_example('''
|
||||
<!-- Single search -->
|
||||
<function_calls>
|
||||
<invoke name="image_search">
|
||||
<parameter name="query">cute cats playing</parameter>
|
||||
<parameter name="num_results">20</parameter>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
|
||||
<!-- Batch search (more efficient for multiple queries) -->
|
||||
<function_calls>
|
||||
<invoke name="image_search">
|
||||
<parameter name="query">["cats", "dogs", "birds"]</parameter>
|
||||
<parameter name="num_results">15</parameter>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
''')
|
||||
async def image_search(
|
||||
self,
|
||||
query: Union[str, List[str]],
|
||||
num_results: int = 12
|
||||
) -> ToolResult:
|
||||
"""
|
||||
Search for images using SERPER API and return image URLs.
|
||||
|
||||
Supports both single and batch searches:
|
||||
- Single: query="cats" returns {"images": [...]}
|
||||
- Batch: query=["cats", "dogs"] returns {"batch_results": [...]}
|
||||
"""
|
||||
# Initialize variables for error handling
|
||||
is_batch = False
|
||||
queries = []
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
if isinstance(query, str):
|
||||
if not query or not query.strip():
|
||||
return self.fail_response("A valid search query is required.")
|
||||
is_batch = False
|
||||
queries = [query]
|
||||
elif isinstance(query, list):
|
||||
if not query or not all(isinstance(q, str) and q.strip() for q in query):
|
||||
return self.fail_response("All queries must be valid non-empty strings.")
|
||||
is_batch = True
|
||||
queries = query
|
||||
else:
|
||||
return self.fail_response("Query must be either a string or list of strings.")
|
||||
|
||||
# Check if SERPER API key is available
|
||||
if not self.serper_api_key:
|
||||
return self.fail_response("SERPER_API_KEY not configured. Image search is not available.")
|
||||
|
||||
# Normalize num_results
|
||||
if num_results is None:
|
||||
num_results = 12
|
||||
elif isinstance(num_results, str):
|
||||
try:
|
||||
num_results = int(num_results)
|
||||
except ValueError:
|
||||
num_results = 12
|
||||
|
||||
# Clamp num_results to valid range
|
||||
num_results = max(1, min(num_results, 100))
|
||||
|
||||
if is_batch:
|
||||
logging.info(f"Executing batch image search for {len(queries)} queries with {num_results} results each")
|
||||
# Batch API request
|
||||
payload = [{"q": q, "num": num_results} for q in queries]
|
||||
else:
|
||||
logging.info(f"Executing image search for query: '{queries[0]}' with {num_results} results")
|
||||
# Single API request
|
||||
payload = {"q": queries[0], "num": num_results}
|
||||
|
||||
# SERPER API request
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = {
|
||||
"X-API-KEY": self.serper_api_key,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
"https://google.serper.dev/images",
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if is_batch:
|
||||
# Handle batch response
|
||||
if not isinstance(data, list):
|
||||
return self.fail_response("Unexpected batch response format from SERPER API.")
|
||||
|
||||
batch_results = []
|
||||
for i, (q, result_data) in enumerate(zip(queries, data)):
|
||||
images = result_data.get("images", []) if isinstance(result_data, dict) else []
|
||||
|
||||
# Extract image URLs
|
||||
image_urls = []
|
||||
for img in images:
|
||||
img_url = img.get("imageUrl")
|
||||
if img_url:
|
||||
image_urls.append(img_url)
|
||||
|
||||
batch_results.append({
|
||||
"query": q,
|
||||
"total_found": len(image_urls),
|
||||
"images": image_urls
|
||||
})
|
||||
|
||||
logging.info(f"Found {len(image_urls)} image URLs for query: '{q}'")
|
||||
|
||||
result = {
|
||||
"batch_results": batch_results,
|
||||
"total_queries": len(queries)
|
||||
}
|
||||
else:
|
||||
# Handle single response
|
||||
images = data.get("images", [])
|
||||
|
||||
if not images:
|
||||
logging.warning(f"No images found for query: '{queries[0]}'")
|
||||
return self.fail_response(f"No images found for query: '{queries[0]}'")
|
||||
|
||||
# Extract just the image URLs - keep it simple
|
||||
image_urls = []
|
||||
for img in images:
|
||||
img_url = img.get("imageUrl")
|
||||
if img_url:
|
||||
image_urls.append(img_url)
|
||||
|
||||
logging.info(f"Found {len(image_urls)} image URLs for query: '{queries[0]}'")
|
||||
|
||||
result = {
|
||||
"query": queries[0],
|
||||
"total_found": len(image_urls),
|
||||
"images": image_urls
|
||||
}
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
output=json.dumps(result, ensure_ascii=False)
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_message = f"SERPER API error: {e.response.status_code}"
|
||||
if e.response.status_code == 429:
|
||||
error_message = "SERPER API rate limit exceeded. Please try again later."
|
||||
elif e.response.status_code == 401:
|
||||
error_message = "Invalid SERPER API key."
|
||||
|
||||
query_desc = f"batch queries {queries}" if is_batch else f"query '{queries[0]}'"
|
||||
logging.error(f"SERPER API error for {query_desc}: {error_message}")
|
||||
return self.fail_response(error_message)
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
query_desc = f"batch queries {queries}" if is_batch else f"query '{queries[0]}'"
|
||||
logging.error(f"Error performing image search for {query_desc}: {error_message}")
|
||||
simplified_message = f"Error performing image search: {error_message[:200]}"
|
||||
if len(error_message) > 200:
|
||||
simplified_message += "..."
|
||||
return self.fail_response(simplified_message)
|
||||
|
||||
if __name__ == "__main__":
|
||||
async def test_web_search():
|
||||
"""Test function for the web search tool"""
|
||||
|
|
|
@ -295,6 +295,7 @@ class Configuration:
|
|||
# Search and other API keys
|
||||
TAVILY_API_KEY: str
|
||||
RAPID_API_KEY: str
|
||||
SERPER_API_KEY: Optional[str] = None
|
||||
CLOUDFLARE_API_TOKEN: Optional[str] = None
|
||||
FIRECRAWL_API_KEY: str
|
||||
FIRECRAWL_URL: Optional[str] = "https://api.firecrawl.dev"
|
||||
|
|
|
@ -41,6 +41,7 @@ export function getToolTitle(toolName: string): string {
|
|||
'full-file-rewrite': 'Rewrite File',
|
||||
'delete-file': 'Delete File',
|
||||
'web-search': 'Web Search',
|
||||
'image-search': 'Image Search',
|
||||
'crawl-webpage': 'Web Crawl',
|
||||
'scrape-webpage': 'Web Scrape',
|
||||
'browser-navigate-to': 'Browser Navigate',
|
||||
|
@ -1280,6 +1281,8 @@ export function getToolComponent(toolName: string): string {
|
|||
// Web operations
|
||||
case 'web-search':
|
||||
return 'WebSearchToolView';
|
||||
case 'image-search':
|
||||
return 'WebSearchToolView';
|
||||
case 'crawl-webpage':
|
||||
return 'WebCrawlToolView';
|
||||
case 'scrape-webpage':
|
||||
|
|
|
@ -14,7 +14,6 @@ import {
|
|||
import { ToolViewProps } from '../types';
|
||||
import { cleanUrl, formatTimestamp, getToolTitle } from '../utils';
|
||||
import { truncateString } from '@/lib/utils';
|
||||
import { useTheme } from 'next-themes';
|
||||
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
|
||||
import { Badge } from '@/components/ui/badge';
|
||||
import { Button } from '@/components/ui/button';
|
||||
|
@ -31,7 +30,6 @@ export function WebSearchToolView({
|
|||
isSuccess = true,
|
||||
isStreaming = false,
|
||||
}: ToolViewProps) {
|
||||
const { resolvedTheme } = useTheme();
|
||||
const [expandedResults, setExpandedResults] = useState<Record<number, boolean>>({});
|
||||
|
||||
const {
|
||||
|
@ -127,36 +125,40 @@ export function WebSearchToolView({
|
|||
<div className="mb-6">
|
||||
<h3 className="text-sm font-medium text-zinc-700 dark:text-zinc-300 mb-3 flex items-center">
|
||||
<ImageIcon className="h-4 w-4 mr-2 opacity-70" />
|
||||
Images
|
||||
Images {name === 'image-search' && `(${images.length})`}
|
||||
</h3>
|
||||
<div className="grid grid-cols-2 sm:grid-cols-3 gap-3 mb-1">
|
||||
{images.slice(0, 6).map((image, idx) => (
|
||||
<a
|
||||
key={idx}
|
||||
href={image}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="group relative overflow-hidden rounded-lg border border-zinc-200 dark:border-zinc-800 bg-zinc-100 dark:bg-zinc-900 hover:border-blue-300 dark:hover:border-blue-700 transition-colors shadow-sm hover:shadow-md"
|
||||
>
|
||||
<img
|
||||
src={image}
|
||||
alt={`Search result ${idx + 1}`}
|
||||
className="object-cover w-full h-32 group-hover:opacity-90 transition-opacity"
|
||||
onError={(e) => {
|
||||
const target = e.target as HTMLImageElement;
|
||||
target.src = "data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='24' height='24' viewBox='0 0 24 24' fill='none' stroke='%23888' stroke-width='2' stroke-linecap='round' stroke-linejoin='round'%3E%3Crect x='3' y='3' width='18' height='18' rx='2' ry='2'%3E%3C/rect%3E%3Ccircle cx='8.5' cy='8.5' r='1.5'%3E%3C/circle%3E%3Cpolyline points='21 15 16 10 5 21'%3E%3C/polyline%3E%3C/svg%3E";
|
||||
target.classList.add("p-4");
|
||||
}}
|
||||
/>
|
||||
<div className="absolute top-0 right-0 p-1">
|
||||
<Badge variant="secondary" className="bg-black/60 hover:bg-black/70 text-white border-none shadow-md">
|
||||
<ExternalLink className="h-3 w-3" />
|
||||
</Badge>
|
||||
</div>
|
||||
</a>
|
||||
))}
|
||||
<div className={`grid gap-3 mb-1 ${name === 'image-search' ? 'grid-cols-2 sm:grid-cols-3 md:grid-cols-4' : 'grid-cols-2 sm:grid-cols-3'}`}>
|
||||
{(name === 'image-search' ? images : images.slice(0, 6)).map((image, idx) => {
|
||||
const imageUrl = typeof image === 'string' ? image : (image as any).imageUrl;
|
||||
|
||||
return (
|
||||
<a
|
||||
key={idx}
|
||||
href={imageUrl}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="group relative overflow-hidden rounded-lg border border-zinc-200 dark:border-zinc-800 bg-zinc-100 dark:bg-zinc-900 hover:border-blue-300 dark:hover:border-blue-700 transition-colors shadow-sm hover:shadow-md"
|
||||
>
|
||||
<img
|
||||
src={imageUrl}
|
||||
alt={`Search result ${idx + 1}`}
|
||||
className="object-cover w-full h-32 group-hover:opacity-90 transition-opacity"
|
||||
onError={(e) => {
|
||||
const target = e.target as HTMLImageElement;
|
||||
target.src = "data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='24' height='24' viewBox='0 0 24 24' fill='none' stroke='%23888' stroke-width='2' stroke-linecap='round' stroke-linejoin='round'%3E%3Crect x='3' y='3' width='18' height='18' rx='2' ry='2'%3E%3C/rect%3E%3Ccircle cx='8.5' cy='8.5' r='1.5'%3E%3C/circle%3E%3Cpolyline points='21 15 16 10 5 21'%3E%3C/polyline%3E%3C/svg%3E";
|
||||
target.classList.add("p-4");
|
||||
}}
|
||||
/>
|
||||
<div className="absolute top-0 right-0 p-1">
|
||||
<Badge variant="secondary" className="bg-black/60 hover:bg-black/70 text-white border-none shadow-md">
|
||||
<ExternalLink className="h-3 w-3" />
|
||||
</Badge>
|
||||
</div>
|
||||
</a>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
{images.length > 6 && (
|
||||
{name !== 'image-search' && images.length > 6 && (
|
||||
<Button variant="outline" size="sm" className="mt-2 text-xs">
|
||||
View {images.length - 6} more images
|
||||
</Button>
|
||||
|
@ -164,7 +166,7 @@ export function WebSearchToolView({
|
|||
</div>
|
||||
)}
|
||||
|
||||
{searchResults.length > 0 && (
|
||||
{searchResults.length > 0 && name !== 'image-search' && (
|
||||
<div className="text-sm font-medium text-zinc-800 dark:text-zinc-200 mb-4 flex items-center justify-between">
|
||||
<span>Search Results ({searchResults.length})</span>
|
||||
<Badge variant="outline" className="text-xs font-normal">
|
||||
|
@ -174,8 +176,9 @@ export function WebSearchToolView({
|
|||
</div>
|
||||
)}
|
||||
|
||||
<div className="space-y-4">
|
||||
{searchResults.map((result, idx) => {
|
||||
{name !== 'image-search' && (
|
||||
<div className="space-y-4">
|
||||
{searchResults.map((result, idx) => {
|
||||
const { icon: ResultTypeIcon, label: resultTypeLabel } = getResultType(result);
|
||||
const isExpanded = expandedResults[idx] || false;
|
||||
const favicon = getFavicon(result.url);
|
||||
|
@ -292,9 +295,10 @@ export function WebSearchToolView({
|
|||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
) : (
|
||||
|
@ -319,11 +323,21 @@ export function WebSearchToolView({
|
|||
|
||||
<div className="px-4 py-2 h-10 bg-gradient-to-r from-zinc-50/90 to-zinc-100/90 dark:from-zinc-900/90 dark:to-zinc-800/90 backdrop-blur-sm border-t border-zinc-200 dark:border-zinc-800 flex justify-between items-center gap-4">
|
||||
<div className="h-full flex items-center gap-2 text-sm text-zinc-500 dark:text-zinc-400">
|
||||
{!isStreaming && searchResults.length > 0 && (
|
||||
<Badge variant="outline" className="h-6 py-0.5">
|
||||
<Globe className="h-3 w-3" />
|
||||
{searchResults.length} results
|
||||
</Badge>
|
||||
{!isStreaming && (
|
||||
<>
|
||||
{name === 'image-search' && images.length > 0 && (
|
||||
<Badge variant="outline" className="h-6 py-0.5">
|
||||
<ImageIcon className="h-3 w-3" />
|
||||
{images.length} images
|
||||
</Badge>
|
||||
)}
|
||||
{name !== 'image-search' && searchResults.length > 0 && (
|
||||
<Badge variant="outline" className="h-6 py-0.5">
|
||||
<Globe className="h-3 w-3" />
|
||||
{searchResults.length} results
|
||||
</Badge>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
|
|
|
@ -40,15 +40,35 @@ const extractFromNewFormat = (content: any): WebSearchData => {
|
|||
}
|
||||
parsedOutput = parsedOutput || {};
|
||||
|
||||
// Handle both single and batch image search responses
|
||||
let images: string[] = [];
|
||||
let query = args.query || parsedOutput?.query || null;
|
||||
|
||||
if (parsedOutput?.batch_results && Array.isArray(parsedOutput.batch_results)) {
|
||||
// Batch response: flatten all images from all queries
|
||||
images = parsedOutput.batch_results.reduce((acc: string[], result: any) => {
|
||||
return acc.concat(result.images || []);
|
||||
}, []);
|
||||
|
||||
// Create combined query string for display
|
||||
const queries = parsedOutput.batch_results.map((r: any) => r.query).filter(Boolean);
|
||||
if (queries.length > 0) {
|
||||
query = queries.length > 1 ? `${queries.length} queries: ${queries.join(', ')}` : queries[0];
|
||||
}
|
||||
} else {
|
||||
// Single response
|
||||
images = parsedOutput?.images || [];
|
||||
}
|
||||
|
||||
const extractedData = {
|
||||
query: args.query || parsedOutput?.query || null,
|
||||
query,
|
||||
results: parsedOutput?.results?.map((result: any) => ({
|
||||
title: result.title || '',
|
||||
url: result.url || '',
|
||||
snippet: result.content || result.snippet || ''
|
||||
})) || [],
|
||||
answer: parsedOutput?.answer || null,
|
||||
images: parsedOutput?.images || [],
|
||||
images,
|
||||
success: toolExecution.result?.success,
|
||||
timestamp: toolExecution.execution_details?.timestamp
|
||||
};
|
||||
|
@ -157,7 +177,14 @@ export function extractWebSearchData(
|
|||
if (parsedContent.answer && typeof parsedContent.answer === 'string') {
|
||||
answer = parsedContent.answer;
|
||||
}
|
||||
if (parsedContent.images && Array.isArray(parsedContent.images)) {
|
||||
|
||||
// Handle both single and batch image responses in legacy format
|
||||
if (parsedContent.batch_results && Array.isArray(parsedContent.batch_results)) {
|
||||
// Batch response: flatten all images from all queries
|
||||
images = parsedContent.batch_results.reduce((acc: string[], result: any) => {
|
||||
return acc.concat(result.images || []);
|
||||
}, []);
|
||||
} else if (parsedContent.images && Array.isArray(parsedContent.images)) {
|
||||
images = parsedContent.images;
|
||||
}
|
||||
} catch (e) {
|
||||
|
|
|
@ -79,6 +79,7 @@ const defaultRegistry: ToolViewRegistryType = {
|
|||
'web-search': WebSearchToolView,
|
||||
'crawl-webpage': WebCrawlToolView,
|
||||
'scrape-webpage': WebScrapeToolView,
|
||||
'image-search': WebSearchToolView,
|
||||
|
||||
'execute-data-provider-call': ExecuteDataProviderCallToolView,
|
||||
'get-data-provider-endpoints': DataProviderEndpointsToolView,
|
||||
|
|
|
@ -265,6 +265,7 @@ export const extractPrimaryParam = (
|
|||
|
||||
// Web search
|
||||
case 'web-search':
|
||||
case 'image-search':
|
||||
match = content.match(/query=(?:"|')([^"|']+)(?:"|')/);
|
||||
return match
|
||||
? match[1].length > 30
|
||||
|
@ -343,6 +344,7 @@ const TOOL_DISPLAY_NAMES = new Map([
|
|||
['present-presentation', 'Presenting'],
|
||||
['clear-images-from-context', 'Clearing Images from context'],
|
||||
['load-image', 'Loading Image'],
|
||||
['image-search', 'Searching Image'],
|
||||
|
||||
['create-sheet', 'Creating Sheet'],
|
||||
['update-sheet', 'Updating Sheet'],
|
||||
|
|
Loading…
Reference in New Issue