Exa -> Tavily working, tested

This commit is contained in:
LE Quoc Dat 2025-04-18 08:11:53 +01:00
parent 8669e40312
commit 833e4fbad8
3 changed files with 100 additions and 79 deletions

View File

@ -53,8 +53,10 @@ async def run_agent(
thread_manager.add_tool(SandboxDeployTool, sandbox=sandbox) thread_manager.add_tool(SandboxDeployTool, sandbox=sandbox)
thread_manager.add_tool(MessageTool) # we are just doing this via prompt as there is no need to call it as a tool thread_manager.add_tool(MessageTool) # we are just doing this via prompt as there is no need to call it as a tool
if os.getenv("EXA_API_KEY"): if os.getenv("TAVILY_API_KEY"):
thread_manager.add_tool(WebSearchTool) thread_manager.add_tool(WebSearchTool)
else:
print("TAVILY_API_KEY not found, WebSearchTool will not be available.")
if os.getenv("RAPID_API_KEY"): if os.getenv("RAPID_API_KEY"):
thread_manager.add_tool(DataProvidersTool) thread_manager.add_tool(DataProvidersTool)

View File

@ -1,4 +1,5 @@
from exa_py import Exa from tavily import AsyncTavilyClient
import httpx
from typing import List, Optional from typing import List, Optional
from datetime import datetime from datetime import datetime
import os import os
@ -15,10 +16,12 @@ class WebSearchTool(Tool):
# Load environment variables # Load environment variables
load_dotenv() load_dotenv()
# Use the provided API key or get it from environment variables # Use the provided API key or get it from environment variables
self.api_key = api_key or os.getenv("EXA_API_KEY") self.api_key = api_key or os.getenv("TAVILY_API_KEY")
if not self.api_key: if not self.api_key:
raise ValueError("EXA_API_KEY not found in environment variables") raise ValueError("TAVILY_API_KEY not found in environment variables")
self.exa = Exa(api_key=self.api_key)
# Tavily asynchronous search client
self.tavily_client = AsyncTavilyClient(api_key=self.api_key)
@openapi_schema({ @openapi_schema({
"type": "function", "type": "function",
@ -111,57 +114,49 @@ class WebSearchTool(Tool):
if not query or not isinstance(query, str): if not query or not isinstance(query, str):
return self.fail_response("A valid search query is required.") return self.fail_response("A valid search query is required.")
# Basic parameters - use only the minimum required to avoid API errors # ---------- Tavily search parameters ----------
params = { # num_results normalisation (150)
"query": query,
"type": "auto",
"livecrawl": "auto"
}
# Handle summary parameter (boolean conversion)
if summary is None:
params["summary"] = True
elif isinstance(summary, bool):
params["summary"] = summary
elif isinstance(summary, str):
params["summary"] = summary.lower() == "true"
else:
params["summary"] = True
# Handle num_results parameter (integer conversion)
if num_results is None: if num_results is None:
params["num_results"] = 20 num_results = 20
elif isinstance(num_results, int): elif isinstance(num_results, int):
params["num_results"] = max(1, min(num_results, 50)) num_results = max(1, min(num_results, 50))
elif isinstance(num_results, str): elif isinstance(num_results, str):
try: try:
params["num_results"] = max(1, min(int(num_results), 50)) num_results = max(1, min(int(num_results), 50))
except ValueError: except ValueError:
params["num_results"] = 20 num_results = 20
else: else:
params["num_results"] = 20 num_results = 20
# Execute the search with minimal parameters # Execute the search with Tavily
search_response = self.exa.search_and_contents(**params) search_response = await self.tavily_client.search(
query=query,
# Format the results max_results=num_results,
include_answer=False,
include_images=False,
)
# `tavily` may return a dict with `results` or a bare list
raw_results = (
search_response.get("results")
if isinstance(search_response, dict)
else search_response
)
formatted_results = [] formatted_results = []
for result in search_response.results: for result in raw_results:
formatted_result = { formatted_result = {
"Title": result.title, "Title": result.get("title"),
"URL": result.url "URL": result.get("url"),
} }
# Add optional fields if they exist if summary:
if hasattr(result, 'summary') and result.summary: # Prefer full content; fall back to description
formatted_result["Summary"] = result.summary if result.get("content"):
formatted_result["Summary"] = result["content"]
if hasattr(result, 'published_date') and result.published_date: elif result.get("description"):
formatted_result["Published Date"] = result.published_date formatted_result["Summary"] = result["description"]
if hasattr(result, 'score'):
formatted_result["Score"] = result.score
formatted_results.append(formatted_result) formatted_results.append(formatted_result)
return self.success_response(formatted_results) return self.success_response(formatted_results)
@ -243,26 +238,50 @@ class WebSearchTool(Tool):
else: else:
return self.fail_response("URL must be a string.") return self.fail_response("URL must be a string.")
# Execute the crawl with the parsed URL # ---------- Tavily extract endpoint ----------
result = self.exa.get_contents( async with httpx.AsyncClient() as client:
[url], headers = {
text=True, "Authorization": f"Bearer {self.api_key}",
livecrawl="auto" "Content-Type": "application/json",
)
# Format the results to include all available fields
formatted_results = []
for content in result.results:
formatted_result = {
"Title": content.title,
"URL": content.url,
"Text": content.text
} }
payload = {
# Add optional fields if they exist "urls": url,
if hasattr(content, 'published_date') and content.published_date: "include_images": False,
formatted_result["Published Date"] = content.published_date "extract_depth": "basic",
}
response = await client.post(
"https://api.tavily.com/extract",
json=payload,
headers=headers,
timeout=60,
)
response.raise_for_status()
data = response.json()
print(f"--- Raw Tavily Response ---")
print(data)
print(f"--------------------------")
# Normalise Tavily extract output to a list of dicts
extracted = []
if isinstance(data, list):
extracted = data
elif isinstance(data, dict):
if "results" in data and isinstance(data["results"], list):
extracted = data["results"]
elif "urls" in data and isinstance(data["urls"], dict):
extracted = list(data["urls"].values())
else:
extracted = [data]
formatted_results = []
for item in extracted:
formatted_result = {
"Title": item.get("title"),
"URL": item.get("url") or url,
"Text": item.get("content") or item.get("text") or "",
}
if item.get("published_date"):
formatted_result["Published Date"] = item["published_date"]
formatted_results.append(formatted_result) formatted_results.append(formatted_result)
return self.success_response(formatted_results) return self.success_response(formatted_results)
@ -279,27 +298,27 @@ class WebSearchTool(Tool):
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio
# async def test_web_search(): async def test_web_search():
# """Test function for the web search tool""" """Test function for the web search tool"""
# search_tool = WebSearchTool() search_tool = WebSearchTool()
# result = await search_tool.web_search( result = await search_tool.web_search(
# query="rubber gym mats best prices comparison", query="rubber gym mats best prices comparison",
# summary=True, summary=True,
# num_results=20 num_results=20
# ) )
# print(result) print(result)
async def test_crawl_webpage(): async def test_crawl_webpage():
"""Test function for the webpage crawl tool""" """Test function for the webpage crawl tool"""
search_tool = WebSearchTool() search_tool = WebSearchTool()
result = await search_tool.crawl_webpage( result = await search_tool.crawl_webpage(
url="https://example.com" url="https://google.com"
) )
print(result) print(result)
async def run_tests(): async def run_tests():
"""Run all test functions""" """Run all test functions"""
# await test_web_search() await test_web_search()
await test_crawl_webpage() await test_crawl_webpage()
asyncio.run(run_tests()) asyncio.run(run_tests())

View File

@ -22,5 +22,5 @@ certifi==2024.2.2
python-ripgrep==0.0.6 python-ripgrep==0.0.6
daytona_sdk>=0.12.0 daytona_sdk>=0.12.0
boto3>=1.34.0 boto3>=1.34.0
exa-py>=1.9.1
pydantic pydantic
tavily-python>=0.5.4