suna/backend/triggers/oauth/base.py

268 lines
9.8 KiB
Python

import abc
import secrets
import json
import base64
from typing import Dict, Any, Optional, List
from urllib.parse import urlencode
from dataclasses import dataclass
from enum import Enum
import httpx
from utils.logger import logger
class OAuthProvider(str, Enum):
"""Supported OAuth providers."""
SLACK = "slack"
DISCORD = "discord"
TEAMS = "teams"
GITHUB = "github"
GOOGLE = "google"
NOTION = "notion"
@dataclass
class OAuthConfig:
"""OAuth configuration for a provider."""
provider: OAuthProvider
client_id: str
client_secret: str
authorization_url: str
token_url: str
scopes: List[str]
redirect_uri: str
user_info_url: Optional[str] = None
additional_params: Dict[str, str] = None
@dataclass
class OAuthTokenResponse:
"""Standardized OAuth token response."""
access_token: str
refresh_token: Optional[str] = None
expires_in: Optional[int] = None
scope: Optional[str] = None
token_type: str = "Bearer"
additional_data: Dict[str, Any] = None
@dataclass
class IntegrationResult:
"""Result of OAuth integration setup."""
success: bool
trigger_id: Optional[str] = None
provider_name: str = ""
workspace_name: Optional[str] = None
bot_name: Optional[str] = None
webhook_url: Optional[str] = None
error: Optional[str] = None
metadata: Dict[str, Any] = None
class BaseOAuthProvider(abc.ABC):
"""Base class for OAuth providers."""
def __init__(self, config: OAuthConfig, db_connection):
self.config = config
self.db = db_connection
def generate_authorization_url(self, agent_id: str, user_id: str) -> str:
"""Generate OAuth authorization URL."""
state = self._create_state_token(agent_id, user_id)
params = {
"client_id": self.config.client_id,
"redirect_uri": self.config.redirect_uri,
"scope": " ".join(self.config.scopes),
"state": state,
"response_type": "code"
}
# Add provider-specific parameters
if self.config.additional_params:
params.update(self.config.additional_params)
return f"{self.config.authorization_url}?{urlencode(params)}"
async def handle_callback(self, code: str, state: str) -> IntegrationResult:
"""Handle OAuth callback and set up integration."""
try:
# Verify state
state_data = await self._verify_state_token(state)
if not state_data:
return IntegrationResult(
success=False,
error="Invalid state parameter",
provider_name=self.config.provider.value
)
# Exchange code for token
token_response = await self._exchange_code_for_token(code)
if not token_response:
return IntegrationResult(
success=False,
error="Failed to exchange code for token",
provider_name=self.config.provider.value
)
# Get provider-specific data
provider_data = await self._get_provider_data(token_response)
# Create trigger
trigger_result = await self._create_trigger(
state_data["agent_id"],
state_data["user_id"],
token_response,
provider_data
)
return IntegrationResult(
success=True,
trigger_id=trigger_result["trigger_id"],
provider_name=self.config.provider.value,
workspace_name=provider_data.get("workspace_name"),
bot_name=provider_data.get("bot_name"),
webhook_url=trigger_result["webhook_url"],
metadata=provider_data
)
except Exception as e:
logger.error(f"Error handling {self.config.provider.value} OAuth callback: {e}")
return IntegrationResult(
success=False,
error=str(e),
provider_name=self.config.provider.value
)
async def _exchange_code_for_token(self, code: str) -> Optional[OAuthTokenResponse]:
"""Exchange authorization code for access token."""
try:
async with httpx.AsyncClient() as client:
response = await client.post(
self.config.token_url,
data={
"client_id": self.config.client_id,
"client_secret": self.config.client_secret,
"code": code,
"redirect_uri": self.config.redirect_uri,
"grant_type": "authorization_code"
}
)
data = response.json()
if response.status_code == 200 and self._is_token_response_valid(data):
return OAuthTokenResponse(
access_token=data["access_token"],
refresh_token=data.get("refresh_token"),
expires_in=data.get("expires_in"),
scope=data.get("scope"),
token_type=data.get("token_type", "Bearer"),
additional_data=data
)
else:
logger.error(f"Token exchange failed: {data}")
return None
except Exception as e:
logger.error(f"Error exchanging code for token: {e}")
return None
@abc.abstractmethod
async def _get_provider_data(self, token_response: OAuthTokenResponse) -> Dict[str, Any]:
"""Get provider-specific data (workspace info, bot info, etc.)."""
pass
@abc.abstractmethod
def _is_token_response_valid(self, data: Dict[str, Any]) -> bool:
"""Check if token response is valid for this provider."""
pass
@abc.abstractmethod
def _get_trigger_config(self, token_response: OAuthTokenResponse, provider_data: Dict[str, Any]) -> Dict[str, Any]:
"""Generate trigger configuration for this provider."""
pass
async def _create_trigger(
self,
agent_id: str,
user_id: str,
token_response: OAuthTokenResponse,
provider_data: Dict[str, Any]
) -> Dict[str, Any]:
"""Create trigger with provider-specific configuration."""
from ..core import TriggerManager
trigger_manager = TriggerManager(self.db)
await trigger_manager.load_provider_definitions()
config = self._get_trigger_config(token_response, provider_data)
config["oauth_installed"] = True
config["provider"] = self.config.provider.value
trigger_config = await trigger_manager.create_trigger(
agent_id=agent_id,
provider_id=self.config.provider.value,
name=f"{self.config.provider.value.title()} - {provider_data.get('workspace_name', 'Integration')}",
description=f"Auto-configured {self.config.provider.value.title()} integration",
config=config
)
# Store OAuth data
# COMMENTED OUT: OAuth installations table deprecated
# await self._store_oauth_data(trigger_config.trigger_id, token_response, provider_data)
import os
base_url = os.getenv("WEBHOOK_BASE_URL", "http://localhost:8000")
# For Slack, use the universal webhook URL (no trigger_id parameter)
# Slack requires a single Event Request URL per app
if self.config.provider == OAuthProvider.SLACK:
webhook_url = f"{base_url}/api/triggers/slack/webhook"
else:
webhook_url = f"{base_url}/api/triggers/{trigger_config.trigger_id}/webhook"
return {
"trigger_id": trigger_config.trigger_id,
"webhook_url": webhook_url
}
def _create_state_token(self, agent_id: str, user_id: str) -> str:
"""Create secure state token."""
state_data = {
"agent_id": agent_id,
"user_id": user_id,
"provider": self.config.provider.value,
"nonce": secrets.token_urlsafe(16)
}
state_json = json.dumps(state_data)
return base64.b64encode(state_json.encode()).decode()
async def _verify_state_token(self, state: str) -> Optional[Dict[str, Any]]:
"""Verify and decode state token."""
try:
state_json = base64.b64decode(state.encode()).decode()
state_data = json.loads(state_json)
# Verify provider matches
if state_data.get("provider") != self.config.provider.value:
return None
return state_data
except Exception as e:
logger.error(f"Error verifying state token: {e}")
return None
# COMMENTED OUT: OAuth installations functionality deprecated
# async def _store_oauth_data(
# self,
# trigger_id: str,
# token_response: OAuthTokenResponse,
# provider_data: Dict[str, Any]
# ):
# """Store OAuth installation data."""
# client = await self.db.client
# await client.table('oauth_installations').insert({
# 'trigger_id': trigger_id,
# 'provider': self.config.provider.value,
# 'access_token': token_response.access_token,
# 'refresh_token': token_response.refresh_token,
# 'expires_in': token_response.expires_in,
# 'scope': token_response.scope,
# 'provider_data': provider_data,
# 'installed_at': 'now()'
# }).execute()