mirror of https://github.com/kortix-ai/suna.git
432 lines
18 KiB
Python
432 lines
18 KiB
Python
from tavily import AsyncTavilyClient
|
|
import httpx
|
|
from dotenv import load_dotenv
|
|
from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema
|
|
from utils.config import config
|
|
from sandbox.tool_base import SandboxToolsBase
|
|
from agentpress.thread_manager import ThreadManager
|
|
import json
|
|
import os
|
|
import datetime
|
|
import asyncio
|
|
import logging
|
|
|
|
# TODO: add subpages, etc... in filters as sometimes its necessary
|
|
|
|
class SandboxWebSearchTool(SandboxToolsBase):
|
|
"""Tool for performing web searches using Tavily API and web scraping using Firecrawl."""
|
|
|
|
def __init__(self, project_id: str, thread_manager: ThreadManager):
|
|
super().__init__(project_id, thread_manager)
|
|
# Load environment variables
|
|
load_dotenv()
|
|
# Use API keys from config
|
|
self.tavily_api_key = config.TAVILY_API_KEY
|
|
self.firecrawl_api_key = config.FIRECRAWL_API_KEY
|
|
self.firecrawl_url = config.FIRECRAWL_URL
|
|
|
|
if not self.tavily_api_key:
|
|
raise ValueError("TAVILY_API_KEY not found in configuration")
|
|
if not self.firecrawl_api_key:
|
|
raise ValueError("FIRECRAWL_API_KEY not found in configuration")
|
|
|
|
# Tavily asynchronous search client
|
|
self.tavily_client = AsyncTavilyClient(api_key=self.tavily_api_key)
|
|
|
|
@openapi_schema({
|
|
"type": "function",
|
|
"function": {
|
|
"name": "web_search",
|
|
"description": "Search the web for up-to-date information on a specific topic using the Tavily API. This tool allows you to gather real-time information from the internet to answer user queries, research topics, validate facts, and find recent developments. Results include titles, URLs, and publication dates. Use this tool for discovering relevant web pages before potentially crawling them for complete content.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "The search query to find relevant web pages. Be specific and include key terms to improve search accuracy. For best results, use natural language questions or keyword combinations that precisely describe what you're looking for."
|
|
},
|
|
"num_results": {
|
|
"type": "integer",
|
|
"description": "The number of search results to return. Increase for more comprehensive research or decrease for focused, high-relevance results.",
|
|
"default": 20
|
|
}
|
|
},
|
|
"required": ["query"]
|
|
}
|
|
}
|
|
})
|
|
@xml_schema(
|
|
tag_name="web-search",
|
|
mappings=[
|
|
{"param_name": "query", "node_type": "attribute", "path": "."},
|
|
{"param_name": "num_results", "node_type": "attribute", "path": "."}
|
|
],
|
|
example='''
|
|
<!--
|
|
The web-search tool allows you to search the internet for real-time information.
|
|
Use this tool when you need to find current information, research topics, or verify facts.
|
|
|
|
THE TOOL NOW RETURNS:
|
|
- Direct answer to your query from search results
|
|
- Relevant images when available
|
|
- Detailed search results including titles, URLs, and snippets
|
|
|
|
WORKFLOW RECOMMENDATION:
|
|
1. Use web-search first with a specific question to get direct answers
|
|
2. Only use scrape-webpage if you need more detailed information from specific pages
|
|
-->
|
|
|
|
<!-- Simple search example -->
|
|
<web-search
|
|
query="what is Kortix AI and what are they building?"
|
|
num_results="20">
|
|
</web-search>
|
|
|
|
<!-- Another search example -->
|
|
<web-search
|
|
query="latest AI research on transformer models"
|
|
num_results="20">
|
|
</web-search>
|
|
'''
|
|
)
|
|
async def web_search(
|
|
self,
|
|
query: str,
|
|
num_results: int = 20
|
|
) -> ToolResult:
|
|
"""
|
|
Search the web using the Tavily API to find relevant and up-to-date information.
|
|
"""
|
|
try:
|
|
# Ensure we have a valid query
|
|
if not query or not isinstance(query, str):
|
|
return self.fail_response("A valid search query is required.")
|
|
|
|
# Normalize num_results
|
|
if num_results is None:
|
|
num_results = 20
|
|
elif isinstance(num_results, int):
|
|
num_results = max(1, min(num_results, 50))
|
|
elif isinstance(num_results, str):
|
|
try:
|
|
num_results = max(1, min(int(num_results), 50))
|
|
except ValueError:
|
|
num_results = 20
|
|
else:
|
|
num_results = 20
|
|
|
|
# Execute the search with Tavily
|
|
logging.info(f"Executing web search for query: '{query}' with {num_results} results")
|
|
search_response = await self.tavily_client.search(
|
|
query=query,
|
|
max_results=num_results,
|
|
include_images=True,
|
|
include_answer="advanced",
|
|
search_depth="advanced",
|
|
)
|
|
|
|
# Return the complete Tavily response
|
|
# This includes the query, answer, results, images and more
|
|
logging.info(f"Retrieved search results for query: '{query}' with answer and {len(search_response.get('results', []))} results")
|
|
|
|
return ToolResult(
|
|
success=True,
|
|
output=json.dumps(search_response, ensure_ascii=False)
|
|
)
|
|
|
|
except Exception as e:
|
|
error_message = str(e)
|
|
logging.error(f"Error performing web search for '{query}': {error_message}")
|
|
simplified_message = f"Error performing web search: {error_message[:200]}"
|
|
if len(error_message) > 200:
|
|
simplified_message += "..."
|
|
return self.fail_response(simplified_message)
|
|
|
|
@openapi_schema({
|
|
"type": "function",
|
|
"function": {
|
|
"name": "scrape_webpage",
|
|
"description": "Extract full text content from multiple webpages in a single operation. IMPORTANT: You should ALWAYS collect multiple relevant URLs from web-search results and scrape them all in a single call for efficiency. This tool saves time by processing multiple pages simultaneously rather than one at a time. The extracted text includes the main content of each page without HTML markup.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"urls": {
|
|
"type": "string",
|
|
"description": "Multiple URLs to scrape, separated by commas. You should ALWAYS include several URLs when possible for efficiency. Example: 'https://example.com/page1,https://example.com/page2,https://example.com/page3'"
|
|
}
|
|
},
|
|
"required": ["urls"]
|
|
}
|
|
}
|
|
})
|
|
@xml_schema(
|
|
tag_name="scrape-webpage",
|
|
mappings=[
|
|
{"param_name": "urls", "node_type": "attribute", "path": "."}
|
|
],
|
|
example='''
|
|
<!--
|
|
IMPORTANT: The scrape-webpage tool should ONLY be used when you absolutely need
|
|
the full content of specific web pages that can't be answered by web-search alone.
|
|
|
|
WORKFLOW PRIORITY:
|
|
1. ALWAYS use web-search first - it now provides direct answers to questions
|
|
2. Only use scrape-webpage when you need specific details not found in the search results
|
|
3. Remember that web-search now returns:
|
|
- Direct answers to your query
|
|
- Relevant images
|
|
- Detailed search result snippets
|
|
|
|
When to use scrape-webpage:
|
|
- When you need complete article text beyond what search snippets provide
|
|
- For extracting structured data from specific pages
|
|
- When analyzing lengthy documentation or guides
|
|
- For comparing detailed content across multiple sources
|
|
|
|
When NOT to use scrape-webpage:
|
|
- When web-search already answers the query
|
|
- For simple fact-checking or basic information
|
|
- When only a high-level overview is needed
|
|
-->
|
|
|
|
<!-- Example workflow: -->
|
|
<!-- 1. First search for relevant content with a specific question -->
|
|
<web-search
|
|
query="what is Kortix AI and what are they building?"
|
|
num_results="20">
|
|
</web-search>
|
|
|
|
<!-- 2. Only if you need specific details not in the search results, then scrape -->
|
|
<scrape-webpage
|
|
urls="https://www.kortix.ai/,https://github.com/kortix-ai/suna">
|
|
</scrape-webpage>
|
|
|
|
<!-- 3. Only if scrape fails or interaction needed, use browser tools -->
|
|
<!-- Example of when to use browser tools:
|
|
- Dynamic content loading
|
|
- JavaScript-heavy sites
|
|
- Pages requiring login
|
|
- Interactive elements
|
|
- Infinite scroll pages
|
|
-->
|
|
'''
|
|
)
|
|
async def scrape_webpage(
|
|
self,
|
|
urls: str
|
|
) -> ToolResult:
|
|
"""
|
|
Retrieve the complete text content of multiple webpages in a single efficient operation.
|
|
|
|
ALWAYS collect multiple relevant URLs from search results and scrape them all at once
|
|
rather than making separate calls for each URL. This is much more efficient.
|
|
|
|
Parameters:
|
|
- urls: Multiple URLs to scrape, separated by commas
|
|
"""
|
|
try:
|
|
logging.info(f"Starting to scrape webpages: {urls}")
|
|
|
|
# Ensure sandbox is initialized
|
|
await self._ensure_sandbox()
|
|
|
|
# Parse the URLs parameter
|
|
if not urls:
|
|
logging.warning("Scrape attempt with empty URLs")
|
|
return self.fail_response("Valid URLs are required.")
|
|
|
|
# Split the URLs string into a list
|
|
url_list = [url.strip() for url in urls.split(',') if url.strip()]
|
|
|
|
if not url_list:
|
|
logging.warning("No valid URLs found in the input")
|
|
return self.fail_response("No valid URLs provided.")
|
|
|
|
if len(url_list) == 1:
|
|
logging.warning("Only a single URL provided - for efficiency you should scrape multiple URLs at once")
|
|
|
|
logging.info(f"Processing {len(url_list)} URLs: {url_list}")
|
|
|
|
# Process each URL and collect results
|
|
results = []
|
|
for url in url_list:
|
|
try:
|
|
# Add protocol if missing
|
|
if not (url.startswith('http://') or url.startswith('https://')):
|
|
url = 'https://' + url
|
|
logging.info(f"Added https:// protocol to URL: {url}")
|
|
|
|
# Scrape this URL
|
|
result = await self._scrape_single_url(url)
|
|
results.append(result)
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error processing URL {url}: {str(e)}")
|
|
results.append({
|
|
"url": url,
|
|
"success": False,
|
|
"error": str(e)
|
|
})
|
|
|
|
# Summarize results
|
|
successful = sum(1 for r in results if r.get("success", False))
|
|
failed = len(results) - successful
|
|
|
|
# Create success/failure message
|
|
if successful == len(results):
|
|
message = f"Successfully scraped all {len(results)} URLs. Results saved to:"
|
|
for r in results:
|
|
if r.get("file_path"):
|
|
message += f"\n- {r.get('file_path')}"
|
|
elif successful > 0:
|
|
message = f"Scraped {successful} URLs successfully and {failed} failed. Results saved to:"
|
|
for r in results:
|
|
if r.get("success", False) and r.get("file_path"):
|
|
message += f"\n- {r.get('file_path')}"
|
|
message += "\n\nFailed URLs:"
|
|
for r in results:
|
|
if not r.get("success", False):
|
|
message += f"\n- {r.get('url')}: {r.get('error', 'Unknown error')}"
|
|
else:
|
|
error_details = "; ".join([f"{r.get('url')}: {r.get('error', 'Unknown error')}" for r in results])
|
|
return self.fail_response(f"Failed to scrape all {len(results)} URLs. Errors: {error_details}")
|
|
|
|
return ToolResult(
|
|
success=True,
|
|
output=message
|
|
)
|
|
|
|
except Exception as e:
|
|
error_message = str(e)
|
|
logging.error(f"Error in scrape_webpage: {error_message}")
|
|
return self.fail_response(f"Error processing scrape request: {error_message[:200]}")
|
|
|
|
async def _scrape_single_url(self, url: str) -> dict:
|
|
"""
|
|
Helper function to scrape a single URL and return the result information.
|
|
"""
|
|
logging.info(f"Scraping single URL: {url}")
|
|
|
|
try:
|
|
# ---------- Firecrawl scrape endpoint ----------
|
|
logging.info(f"Sending request to Firecrawl for URL: {url}")
|
|
async with httpx.AsyncClient() as client:
|
|
headers = {
|
|
"Authorization": f"Bearer {self.firecrawl_api_key}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
payload = {
|
|
"url": url,
|
|
"formats": ["markdown"]
|
|
}
|
|
|
|
# Use longer timeout and retry logic for more reliability
|
|
max_retries = 3
|
|
timeout_seconds = 120
|
|
retry_count = 0
|
|
|
|
while retry_count < max_retries:
|
|
try:
|
|
logging.info(f"Sending request to Firecrawl (attempt {retry_count + 1}/{max_retries})")
|
|
response = await client.post(
|
|
f"{self.firecrawl_url}/v1/scrape",
|
|
json=payload,
|
|
headers=headers,
|
|
timeout=timeout_seconds,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
logging.info(f"Successfully received response from Firecrawl for {url}")
|
|
break
|
|
except (httpx.ReadTimeout, httpx.ConnectTimeout, httpx.ReadError) as timeout_err:
|
|
retry_count += 1
|
|
logging.warning(f"Request timed out (attempt {retry_count}/{max_retries}): {str(timeout_err)}")
|
|
if retry_count >= max_retries:
|
|
raise Exception(f"Request timed out after {max_retries} attempts with {timeout_seconds}s timeout")
|
|
# Exponential backoff
|
|
logging.info(f"Waiting {2 ** retry_count}s before retry")
|
|
await asyncio.sleep(2 ** retry_count)
|
|
except Exception as e:
|
|
# Don't retry on non-timeout errors
|
|
logging.error(f"Error during scraping: {str(e)}")
|
|
raise e
|
|
|
|
# Format the response
|
|
title = data.get("data", {}).get("metadata", {}).get("title", "")
|
|
markdown_content = data.get("data", {}).get("markdown", "")
|
|
logging.info(f"Extracted content from {url}: title='{title}', content length={len(markdown_content)}")
|
|
|
|
formatted_result = {
|
|
"title": title,
|
|
"url": url,
|
|
"text": markdown_content
|
|
}
|
|
|
|
# Add metadata if available
|
|
if "metadata" in data.get("data", {}):
|
|
formatted_result["metadata"] = data["data"]["metadata"]
|
|
logging.info(f"Added metadata: {data['data']['metadata'].keys()}")
|
|
|
|
# Create a simple filename from the URL domain and date
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
|
# Extract domain from URL for the filename
|
|
from urllib.parse import urlparse
|
|
parsed_url = urlparse(url)
|
|
domain = parsed_url.netloc.replace("www.", "")
|
|
|
|
# Clean up domain for filename
|
|
domain = "".join([c if c.isalnum() else "_" for c in domain])
|
|
safe_filename = f"{timestamp}_{domain}.json"
|
|
|
|
logging.info(f"Generated filename: {safe_filename}")
|
|
|
|
# Save results to a file in the /workspace/scrape directory
|
|
scrape_dir = f"{self.workspace_path}/scrape"
|
|
self.sandbox.fs.create_folder(scrape_dir, "755")
|
|
|
|
results_file_path = f"{scrape_dir}/{safe_filename}"
|
|
json_content = json.dumps(formatted_result, ensure_ascii=False, indent=2)
|
|
logging.info(f"Saving content to file: {results_file_path}, size: {len(json_content)} bytes")
|
|
|
|
self.sandbox.fs.upload_file(
|
|
results_file_path,
|
|
json_content.encode()
|
|
)
|
|
|
|
return {
|
|
"url": url,
|
|
"success": True,
|
|
"title": title,
|
|
"file_path": results_file_path,
|
|
"content_length": len(markdown_content)
|
|
}
|
|
|
|
except Exception as e:
|
|
error_message = str(e)
|
|
logging.error(f"Error scraping URL '{url}': {error_message}")
|
|
|
|
# Create an error result
|
|
return {
|
|
"url": url,
|
|
"success": False,
|
|
"error": error_message
|
|
}
|
|
|
|
if __name__ == "__main__":
|
|
async def test_web_search():
|
|
"""Test function for the web search tool"""
|
|
# This test function is not compatible with the sandbox version
|
|
print("Test function needs to be updated for sandbox version")
|
|
|
|
async def test_scrape_webpage():
|
|
"""Test function for the webpage scrape tool"""
|
|
# This test function is not compatible with the sandbox version
|
|
print("Test function needs to be updated for sandbox version")
|
|
|
|
async def run_tests():
|
|
"""Run all test functions"""
|
|
await test_web_search()
|
|
await test_scrape_webpage()
|
|
|
|
asyncio.run(run_tests()) |