suna/backend/agent/tools/data_providers_tool.py

175 lines
6.9 KiB
Python
Raw Normal View History

2025-04-16 10:00:22 +08:00
import json
2025-06-01 04:53:53 +08:00
from typing import Union, Dict, Any
2025-04-16 10:00:22 +08:00
from agentpress.tool import Tool, ToolResult, openapi_schema, usage_example
2025-04-17 01:23:28 +08:00
from agent.tools.data_providers.LinkedinProvider import LinkedinProvider
from agent.tools.data_providers.YahooFinanceProvider import YahooFinanceProvider
from agent.tools.data_providers.AmazonProvider import AmazonProvider
2025-04-17 01:53:28 +08:00
from agent.tools.data_providers.ZillowProvider import ZillowProvider
2025-04-17 02:10:13 +08:00
from agent.tools.data_providers.TwitterProvider import TwitterProvider
2025-04-16 10:00:22 +08:00
2025-04-17 01:23:28 +08:00
class DataProvidersTool(Tool):
"""Tool for making requests to various data providers."""
2025-04-16 10:00:22 +08:00
def __init__(self):
super().__init__()
2025-04-17 01:23:28 +08:00
self.register_data_providers = {
"linkedin": LinkedinProvider(),
"yahoo_finance": YahooFinanceProvider(),
2025-04-17 01:53:28 +08:00
"amazon": AmazonProvider(),
2025-04-17 02:10:13 +08:00
"zillow": ZillowProvider(),
"twitter": TwitterProvider()
2025-04-16 10:00:22 +08:00
}
@openapi_schema({
"type": "function",
"function": {
2025-04-17 01:23:28 +08:00
"name": "get_data_provider_endpoints",
"description": "Get available endpoints for a specific data provider",
2025-04-16 10:00:22 +08:00
"parameters": {
"type": "object",
"properties": {
"service_name": {
"type": "string",
2025-04-17 02:10:13 +08:00
"description": "The name of the data provider (e.g., 'linkedin', 'twitter', 'zillow', 'amazon', 'yahoo_finance')"
2025-04-16 10:00:22 +08:00
}
},
"required": ["service_name"]
}
}
})
@usage_example('''
2025-04-16 10:00:22 +08:00
<!--
2025-04-17 01:23:28 +08:00
The get-data-provider-endpoints tool returns available endpoints for a specific data provider.
2025-04-16 10:00:22 +08:00
Use this tool when you need to discover what endpoints are available.
-->
<!-- Example to get LinkedIn API endpoints -->
2025-05-28 20:07:54 +08:00
<function_calls>
<invoke name="get_data_provider_endpoints">
<parameter name="service_name">linkedin</parameter>
</invoke>
</function_calls>
''')
2025-04-17 01:23:28 +08:00
async def get_data_provider_endpoints(
2025-04-16 10:00:22 +08:00
self,
service_name: str
) -> ToolResult:
"""
2025-04-17 01:23:28 +08:00
Get available endpoints for a specific data provider.
2025-04-16 10:00:22 +08:00
Parameters:
2025-04-17 01:23:28 +08:00
- service_name: The name of the data provider (e.g., 'linkedin')
2025-04-16 10:00:22 +08:00
"""
try:
if not service_name:
2025-04-17 01:23:28 +08:00
return self.fail_response("Data provider name is required.")
2025-04-16 10:00:22 +08:00
2025-04-17 01:23:28 +08:00
if service_name not in self.register_data_providers:
return self.fail_response(f"Data provider '{service_name}' not found. Available data providers: {list(self.register_data_providers.keys())}")
2025-04-16 10:00:22 +08:00
2025-04-17 01:23:28 +08:00
endpoints = self.register_data_providers[service_name].get_endpoints()
2025-04-16 10:00:22 +08:00
return self.success_response(endpoints)
except Exception as e:
error_message = str(e)
2025-04-17 01:23:28 +08:00
simplified_message = f"Error getting data provider endpoints: {error_message[:200]}"
2025-04-16 10:00:22 +08:00
if len(error_message) > 200:
simplified_message += "..."
return self.fail_response(simplified_message)
@openapi_schema({
"type": "function",
"function": {
2025-04-17 01:23:28 +08:00
"name": "execute_data_provider_call",
"description": "Execute a call to a specific data provider endpoint",
2025-04-16 10:00:22 +08:00
"parameters": {
"type": "object",
"properties": {
"service_name": {
"type": "string",
"description": "The name of the API service (e.g., 'linkedin')"
},
"route": {
"type": "string",
"description": "The key of the endpoint to call"
},
"payload": {
"type": "object",
"description": "The payload to send with the API call"
}
},
"required": ["service_name", "route"]
}
}
})
@usage_example('''
2025-04-16 10:00:22 +08:00
<!--
2025-04-17 01:23:28 +08:00
The execute-data-provider-call tool makes a request to a specific data provider endpoint.
Use this tool when you need to call an data provider endpoint with specific parameters.
The route must be a valid endpoint key obtained from get-data-provider-endpoints tool!!
2025-04-16 10:00:22 +08:00
-->
<!-- Example to call linkedIn service with the specific route person -->
2025-05-28 20:07:54 +08:00
<function_calls>
<invoke name="execute_data_provider_call">
<parameter name="service_name">linkedin</parameter>
<parameter name="route">person</parameter>
<parameter name="payload">{"link": "https://www.linkedin.com/in/johndoe/"}</parameter>
</invoke>
</function_calls>
''')
2025-04-17 01:23:28 +08:00
async def execute_data_provider_call(
2025-04-16 10:00:22 +08:00
self,
service_name: str,
route: str,
2025-06-01 04:53:53 +08:00
payload: Union[Dict[str, Any], str, None] = None
2025-04-16 10:00:22 +08:00
) -> ToolResult:
"""
2025-04-17 01:23:28 +08:00
Execute a call to a specific data provider endpoint.
2025-04-16 10:00:22 +08:00
Parameters:
2025-04-17 01:23:28 +08:00
- service_name: The name of the data provider (e.g., 'linkedin')
2025-04-16 10:00:22 +08:00
- route: The key of the endpoint to call
2025-06-01 04:53:53 +08:00
- payload: The payload to send with the data provider call (dict or JSON string)
2025-04-16 10:00:22 +08:00
"""
try:
2025-06-01 04:53:53 +08:00
# Handle payload - it can be either a dict or a JSON string
if isinstance(payload, str):
try:
payload = json.loads(payload)
except json.JSONDecodeError as e:
return self.fail_response(f"Invalid JSON in payload: {str(e)}")
elif payload is None:
payload = {}
# If payload is already a dict, use it as-is
2025-04-16 10:00:22 +08:00
if not service_name:
return self.fail_response("service_name is required.")
if not route:
return self.fail_response("route is required.")
2025-04-17 01:23:28 +08:00
if service_name not in self.register_data_providers:
return self.fail_response(f"API '{service_name}' not found. Available APIs: {list(self.register_data_providers.keys())}")
2025-04-16 10:00:22 +08:00
2025-04-17 01:23:28 +08:00
data_provider = self.register_data_providers[service_name]
2025-04-16 10:00:22 +08:00
if route == service_name:
return self.fail_response(f"route '{route}' is the same as service_name '{service_name}'. YOU FUCKING IDIOT!")
2025-04-17 01:23:28 +08:00
if route not in data_provider.get_endpoints().keys():
return self.fail_response(f"Endpoint '{route}' not found in {service_name} data provider.")
2025-04-16 10:00:22 +08:00
2025-04-17 01:23:28 +08:00
result = data_provider.call_endpoint(route, payload)
2025-04-16 10:00:22 +08:00
return self.success_response(result)
except Exception as e:
error_message = str(e)
print(error_message)
2025-04-17 01:23:28 +08:00
simplified_message = f"Error executing data provider call: {error_message[:200]}"
2025-04-16 10:00:22 +08:00
if len(error_message) > 200:
simplified_message += "..."
return self.fail_response(simplified_message)