mirror of https://github.com/kortix-ai/suna.git
268 lines
9.8 KiB
Python
268 lines
9.8 KiB
Python
import abc
|
|
import secrets
|
|
import json
|
|
import base64
|
|
from typing import Dict, Any, Optional, List
|
|
from urllib.parse import urlencode
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
import httpx
|
|
from utils.logger import logger
|
|
|
|
class OAuthProvider(str, Enum):
|
|
"""Supported OAuth providers."""
|
|
SLACK = "slack"
|
|
DISCORD = "discord"
|
|
TEAMS = "teams"
|
|
GITHUB = "github"
|
|
GOOGLE = "google"
|
|
NOTION = "notion"
|
|
|
|
@dataclass
|
|
class OAuthConfig:
|
|
"""OAuth configuration for a provider."""
|
|
provider: OAuthProvider
|
|
client_id: str
|
|
client_secret: str
|
|
authorization_url: str
|
|
token_url: str
|
|
scopes: List[str]
|
|
redirect_uri: str
|
|
user_info_url: Optional[str] = None
|
|
additional_params: Dict[str, str] = None
|
|
|
|
@dataclass
|
|
class OAuthTokenResponse:
|
|
"""Standardized OAuth token response."""
|
|
access_token: str
|
|
refresh_token: Optional[str] = None
|
|
expires_in: Optional[int] = None
|
|
scope: Optional[str] = None
|
|
token_type: str = "Bearer"
|
|
additional_data: Dict[str, Any] = None
|
|
|
|
@dataclass
|
|
class IntegrationResult:
|
|
"""Result of OAuth integration setup."""
|
|
success: bool
|
|
trigger_id: Optional[str] = None
|
|
provider_name: str = ""
|
|
workspace_name: Optional[str] = None
|
|
bot_name: Optional[str] = None
|
|
webhook_url: Optional[str] = None
|
|
error: Optional[str] = None
|
|
metadata: Dict[str, Any] = None
|
|
|
|
class BaseOAuthProvider(abc.ABC):
|
|
"""Base class for OAuth providers."""
|
|
|
|
def __init__(self, config: OAuthConfig, db_connection):
|
|
self.config = config
|
|
self.db = db_connection
|
|
|
|
def generate_authorization_url(self, agent_id: str, user_id: str) -> str:
|
|
"""Generate OAuth authorization URL."""
|
|
state = self._create_state_token(agent_id, user_id)
|
|
|
|
params = {
|
|
"client_id": self.config.client_id,
|
|
"redirect_uri": self.config.redirect_uri,
|
|
"scope": " ".join(self.config.scopes),
|
|
"state": state,
|
|
"response_type": "code"
|
|
}
|
|
|
|
# Add provider-specific parameters
|
|
if self.config.additional_params:
|
|
params.update(self.config.additional_params)
|
|
|
|
return f"{self.config.authorization_url}?{urlencode(params)}"
|
|
|
|
async def handle_callback(self, code: str, state: str) -> IntegrationResult:
|
|
"""Handle OAuth callback and set up integration."""
|
|
try:
|
|
# Verify state
|
|
state_data = await self._verify_state_token(state)
|
|
if not state_data:
|
|
return IntegrationResult(
|
|
success=False,
|
|
error="Invalid state parameter",
|
|
provider_name=self.config.provider.value
|
|
)
|
|
|
|
# Exchange code for token
|
|
token_response = await self._exchange_code_for_token(code)
|
|
if not token_response:
|
|
return IntegrationResult(
|
|
success=False,
|
|
error="Failed to exchange code for token",
|
|
provider_name=self.config.provider.value
|
|
)
|
|
|
|
# Get provider-specific data
|
|
provider_data = await self._get_provider_data(token_response)
|
|
|
|
# Create trigger
|
|
trigger_result = await self._create_trigger(
|
|
state_data["agent_id"],
|
|
state_data["user_id"],
|
|
token_response,
|
|
provider_data
|
|
)
|
|
|
|
return IntegrationResult(
|
|
success=True,
|
|
trigger_id=trigger_result["trigger_id"],
|
|
provider_name=self.config.provider.value,
|
|
workspace_name=provider_data.get("workspace_name"),
|
|
bot_name=provider_data.get("bot_name"),
|
|
webhook_url=trigger_result["webhook_url"],
|
|
metadata=provider_data
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error handling {self.config.provider.value} OAuth callback: {e}")
|
|
return IntegrationResult(
|
|
success=False,
|
|
error=str(e),
|
|
provider_name=self.config.provider.value
|
|
)
|
|
|
|
async def _exchange_code_for_token(self, code: str) -> Optional[OAuthTokenResponse]:
|
|
"""Exchange authorization code for access token."""
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
self.config.token_url,
|
|
data={
|
|
"client_id": self.config.client_id,
|
|
"client_secret": self.config.client_secret,
|
|
"code": code,
|
|
"redirect_uri": self.config.redirect_uri,
|
|
"grant_type": "authorization_code"
|
|
}
|
|
)
|
|
|
|
data = response.json()
|
|
|
|
if response.status_code == 200 and self._is_token_response_valid(data):
|
|
return OAuthTokenResponse(
|
|
access_token=data["access_token"],
|
|
refresh_token=data.get("refresh_token"),
|
|
expires_in=data.get("expires_in"),
|
|
scope=data.get("scope"),
|
|
token_type=data.get("token_type", "Bearer"),
|
|
additional_data=data
|
|
)
|
|
else:
|
|
logger.error(f"Token exchange failed: {data}")
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error exchanging code for token: {e}")
|
|
return None
|
|
|
|
@abc.abstractmethod
|
|
async def _get_provider_data(self, token_response: OAuthTokenResponse) -> Dict[str, Any]:
|
|
"""Get provider-specific data (workspace info, bot info, etc.)."""
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def _is_token_response_valid(self, data: Dict[str, Any]) -> bool:
|
|
"""Check if token response is valid for this provider."""
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def _get_trigger_config(self, token_response: OAuthTokenResponse, provider_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Generate trigger configuration for this provider."""
|
|
pass
|
|
|
|
async def _create_trigger(
|
|
self,
|
|
agent_id: str,
|
|
user_id: str,
|
|
token_response: OAuthTokenResponse,
|
|
provider_data: Dict[str, Any]
|
|
) -> Dict[str, Any]:
|
|
"""Create trigger with provider-specific configuration."""
|
|
from ..core import TriggerManager
|
|
|
|
trigger_manager = TriggerManager(self.db)
|
|
await trigger_manager.load_provider_definitions()
|
|
|
|
config = self._get_trigger_config(token_response, provider_data)
|
|
config["oauth_installed"] = True
|
|
config["provider"] = self.config.provider.value
|
|
|
|
trigger_config = await trigger_manager.create_trigger(
|
|
agent_id=agent_id,
|
|
provider_id=self.config.provider.value,
|
|
name=f"{self.config.provider.value.title()} - {provider_data.get('workspace_name', 'Integration')}",
|
|
description=f"Auto-configured {self.config.provider.value.title()} integration",
|
|
config=config
|
|
)
|
|
|
|
# Store OAuth data
|
|
# COMMENTED OUT: OAuth installations table deprecated
|
|
# await self._store_oauth_data(trigger_config.trigger_id, token_response, provider_data)
|
|
|
|
import os
|
|
base_url = os.getenv("WEBHOOK_BASE_URL", "http://localhost:8000")
|
|
# For Slack, use the universal webhook URL (no trigger_id parameter)
|
|
# Slack requires a single Event Request URL per app
|
|
if self.config.provider == OAuthProvider.SLACK:
|
|
webhook_url = f"{base_url}/api/triggers/slack/webhook"
|
|
else:
|
|
webhook_url = f"{base_url}/api/triggers/{trigger_config.trigger_id}/webhook"
|
|
|
|
return {
|
|
"trigger_id": trigger_config.trigger_id,
|
|
"webhook_url": webhook_url
|
|
}
|
|
|
|
def _create_state_token(self, agent_id: str, user_id: str) -> str:
|
|
"""Create secure state token."""
|
|
state_data = {
|
|
"agent_id": agent_id,
|
|
"user_id": user_id,
|
|
"provider": self.config.provider.value,
|
|
"nonce": secrets.token_urlsafe(16)
|
|
}
|
|
|
|
state_json = json.dumps(state_data)
|
|
return base64.b64encode(state_json.encode()).decode()
|
|
|
|
async def _verify_state_token(self, state: str) -> Optional[Dict[str, Any]]:
|
|
"""Verify and decode state token."""
|
|
try:
|
|
state_json = base64.b64decode(state.encode()).decode()
|
|
state_data = json.loads(state_json)
|
|
|
|
# Verify provider matches
|
|
if state_data.get("provider") != self.config.provider.value:
|
|
return None
|
|
|
|
return state_data
|
|
except Exception as e:
|
|
logger.error(f"Error verifying state token: {e}")
|
|
return None
|
|
|
|
# COMMENTED OUT: OAuth installations functionality deprecated
|
|
# async def _store_oauth_data(
|
|
# self,
|
|
# trigger_id: str,
|
|
# token_response: OAuthTokenResponse,
|
|
# provider_data: Dict[str, Any]
|
|
# ):
|
|
# """Store OAuth installation data."""
|
|
# client = await self.db.client
|
|
# await client.table('oauth_installations').insert({
|
|
# 'trigger_id': trigger_id,
|
|
# 'provider': self.config.provider.value,
|
|
# 'access_token': token_response.access_token,
|
|
# 'refresh_token': token_response.refresh_token,
|
|
# 'expires_in': token_response.expires_in,
|
|
# 'scope': token_response.scope,
|
|
# 'provider_data': provider_data,
|
|
# 'installed_at': 'now()'
|
|
# }).execute() |