mirror of https://github.com/kortix-ai/suna.git
Exa -> Tavily working, tested
This commit is contained in:
parent
8669e40312
commit
833e4fbad8
|
@ -53,8 +53,10 @@ async def run_agent(
|
|||
thread_manager.add_tool(SandboxDeployTool, sandbox=sandbox)
|
||||
thread_manager.add_tool(MessageTool) # we are just doing this via prompt as there is no need to call it as a tool
|
||||
|
||||
if os.getenv("EXA_API_KEY"):
|
||||
if os.getenv("TAVILY_API_KEY"):
|
||||
thread_manager.add_tool(WebSearchTool)
|
||||
else:
|
||||
print("TAVILY_API_KEY not found, WebSearchTool will not be available.")
|
||||
|
||||
if os.getenv("RAPID_API_KEY"):
|
||||
thread_manager.add_tool(DataProvidersTool)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from exa_py import Exa
|
||||
from tavily import AsyncTavilyClient
|
||||
import httpx
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
import os
|
||||
|
@ -15,10 +16,12 @@ class WebSearchTool(Tool):
|
|||
# Load environment variables
|
||||
load_dotenv()
|
||||
# Use the provided API key or get it from environment variables
|
||||
self.api_key = api_key or os.getenv("EXA_API_KEY")
|
||||
self.api_key = api_key or os.getenv("TAVILY_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("EXA_API_KEY not found in environment variables")
|
||||
self.exa = Exa(api_key=self.api_key)
|
||||
raise ValueError("TAVILY_API_KEY not found in environment variables")
|
||||
|
||||
# Tavily asynchronous search client
|
||||
self.tavily_client = AsyncTavilyClient(api_key=self.api_key)
|
||||
|
||||
@openapi_schema({
|
||||
"type": "function",
|
||||
|
@ -111,57 +114,49 @@ class WebSearchTool(Tool):
|
|||
if not query or not isinstance(query, str):
|
||||
return self.fail_response("A valid search query is required.")
|
||||
|
||||
# Basic parameters - use only the minimum required to avoid API errors
|
||||
params = {
|
||||
"query": query,
|
||||
"type": "auto",
|
||||
"livecrawl": "auto"
|
||||
}
|
||||
|
||||
# Handle summary parameter (boolean conversion)
|
||||
if summary is None:
|
||||
params["summary"] = True
|
||||
elif isinstance(summary, bool):
|
||||
params["summary"] = summary
|
||||
elif isinstance(summary, str):
|
||||
params["summary"] = summary.lower() == "true"
|
||||
else:
|
||||
params["summary"] = True
|
||||
|
||||
# Handle num_results parameter (integer conversion)
|
||||
# ---------- Tavily search parameters ----------
|
||||
# num_results normalisation (1‑50)
|
||||
if num_results is None:
|
||||
params["num_results"] = 20
|
||||
num_results = 20
|
||||
elif isinstance(num_results, int):
|
||||
params["num_results"] = max(1, min(num_results, 50))
|
||||
num_results = max(1, min(num_results, 50))
|
||||
elif isinstance(num_results, str):
|
||||
try:
|
||||
params["num_results"] = max(1, min(int(num_results), 50))
|
||||
num_results = max(1, min(int(num_results), 50))
|
||||
except ValueError:
|
||||
params["num_results"] = 20
|
||||
num_results = 20
|
||||
else:
|
||||
params["num_results"] = 20
|
||||
|
||||
# Execute the search with minimal parameters
|
||||
search_response = self.exa.search_and_contents(**params)
|
||||
|
||||
# Format the results
|
||||
num_results = 20
|
||||
|
||||
# Execute the search with Tavily
|
||||
search_response = await self.tavily_client.search(
|
||||
query=query,
|
||||
max_results=num_results,
|
||||
include_answer=False,
|
||||
include_images=False,
|
||||
)
|
||||
|
||||
# `tavily` may return a dict with `results` or a bare list
|
||||
raw_results = (
|
||||
search_response.get("results")
|
||||
if isinstance(search_response, dict)
|
||||
else search_response
|
||||
)
|
||||
|
||||
formatted_results = []
|
||||
for result in search_response.results:
|
||||
for result in raw_results:
|
||||
formatted_result = {
|
||||
"Title": result.title,
|
||||
"URL": result.url
|
||||
"Title": result.get("title"),
|
||||
"URL": result.get("url"),
|
||||
}
|
||||
|
||||
# Add optional fields if they exist
|
||||
if hasattr(result, 'summary') and result.summary:
|
||||
formatted_result["Summary"] = result.summary
|
||||
|
||||
if hasattr(result, 'published_date') and result.published_date:
|
||||
formatted_result["Published Date"] = result.published_date
|
||||
|
||||
if hasattr(result, 'score'):
|
||||
formatted_result["Score"] = result.score
|
||||
|
||||
|
||||
if summary:
|
||||
# Prefer full content; fall back to description
|
||||
if result.get("content"):
|
||||
formatted_result["Summary"] = result["content"]
|
||||
elif result.get("description"):
|
||||
formatted_result["Summary"] = result["description"]
|
||||
|
||||
formatted_results.append(formatted_result)
|
||||
|
||||
return self.success_response(formatted_results)
|
||||
|
@ -243,26 +238,50 @@ class WebSearchTool(Tool):
|
|||
else:
|
||||
return self.fail_response("URL must be a string.")
|
||||
|
||||
# Execute the crawl with the parsed URL
|
||||
result = self.exa.get_contents(
|
||||
[url],
|
||||
text=True,
|
||||
livecrawl="auto"
|
||||
)
|
||||
|
||||
# Format the results to include all available fields
|
||||
formatted_results = []
|
||||
for content in result.results:
|
||||
formatted_result = {
|
||||
"Title": content.title,
|
||||
"URL": content.url,
|
||||
"Text": content.text
|
||||
# ---------- Tavily extract endpoint ----------
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Add optional fields if they exist
|
||||
if hasattr(content, 'published_date') and content.published_date:
|
||||
formatted_result["Published Date"] = content.published_date
|
||||
|
||||
payload = {
|
||||
"urls": url,
|
||||
"include_images": False,
|
||||
"extract_depth": "basic",
|
||||
}
|
||||
response = await client.post(
|
||||
"https://api.tavily.com/extract",
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
print(f"--- Raw Tavily Response ---")
|
||||
print(data)
|
||||
print(f"--------------------------")
|
||||
|
||||
# Normalise Tavily extract output to a list of dicts
|
||||
extracted = []
|
||||
if isinstance(data, list):
|
||||
extracted = data
|
||||
elif isinstance(data, dict):
|
||||
if "results" in data and isinstance(data["results"], list):
|
||||
extracted = data["results"]
|
||||
elif "urls" in data and isinstance(data["urls"], dict):
|
||||
extracted = list(data["urls"].values())
|
||||
else:
|
||||
extracted = [data]
|
||||
|
||||
formatted_results = []
|
||||
for item in extracted:
|
||||
formatted_result = {
|
||||
"Title": item.get("title"),
|
||||
"URL": item.get("url") or url,
|
||||
"Text": item.get("content") or item.get("text") or "",
|
||||
}
|
||||
if item.get("published_date"):
|
||||
formatted_result["Published Date"] = item["published_date"]
|
||||
formatted_results.append(formatted_result)
|
||||
|
||||
return self.success_response(formatted_results)
|
||||
|
@ -279,27 +298,27 @@ class WebSearchTool(Tool):
|
|||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# async def test_web_search():
|
||||
# """Test function for the web search tool"""
|
||||
# search_tool = WebSearchTool()
|
||||
# result = await search_tool.web_search(
|
||||
# query="rubber gym mats best prices comparison",
|
||||
# summary=True,
|
||||
# num_results=20
|
||||
# )
|
||||
# print(result)
|
||||
async def test_web_search():
|
||||
"""Test function for the web search tool"""
|
||||
search_tool = WebSearchTool()
|
||||
result = await search_tool.web_search(
|
||||
query="rubber gym mats best prices comparison",
|
||||
summary=True,
|
||||
num_results=20
|
||||
)
|
||||
print(result)
|
||||
|
||||
async def test_crawl_webpage():
|
||||
"""Test function for the webpage crawl tool"""
|
||||
search_tool = WebSearchTool()
|
||||
result = await search_tool.crawl_webpage(
|
||||
url="https://example.com"
|
||||
url="https://google.com"
|
||||
)
|
||||
print(result)
|
||||
|
||||
async def run_tests():
|
||||
"""Run all test functions"""
|
||||
# await test_web_search()
|
||||
await test_web_search()
|
||||
await test_crawl_webpage()
|
||||
|
||||
asyncio.run(run_tests())
|
||||
asyncio.run(run_tests())
|
|
@ -22,5 +22,5 @@ certifi==2024.2.2
|
|||
python-ripgrep==0.0.6
|
||||
daytona_sdk>=0.12.0
|
||||
boto3>=1.34.0
|
||||
exa-py>=1.9.1
|
||||
pydantic
|
||||
tavily-python>=0.5.4
|
Loading…
Reference in New Issue