suna/backend/triggers/unified_oauth_api.py

243 lines
8.8 KiB
Python
Raw Normal View History

2025-06-30 18:57:34 +08:00
"""
Unified OAuth API for all providers.
This single API handles OAuth flows for Slack, Discord, Teams, and any future providers
with a consistent interface.
"""
from fastapi import APIRouter, HTTPException, Depends, Query
from fastapi.responses import RedirectResponse
from typing import Optional, List
from pydantic import BaseModel
from enum import Enum
2025-07-01 02:03:46 +08:00
import os
2025-06-30 18:57:34 +08:00
from .oauth.base import OAuthProvider
from .oauth.providers import get_oauth_provider
from services.supabase import DBConnection
from utils.auth_utils import get_current_user_id_from_jwt
from utils.logger import logger
router = APIRouter(prefix="/api/integrations", tags=["oauth-integrations"])
db = None
def initialize(database: DBConnection):
"""Initialize the unified OAuth API with database connection."""
global db
db = database
class IntegrationInstallRequest(BaseModel):
agent_id: str
provider: OAuthProvider
class IntegrationInstallResponse(BaseModel):
install_url: str
provider: str
class IntegrationStatusResponse(BaseModel):
agent_id: str
integrations: List[dict]
@router.get("/available")
async def get_available_integrations():
"""Get list of available OAuth integrations."""
return {
"providers": [
{
"id": "slack",
"name": "Slack",
"description": "Connect to Slack workspaces",
"icon": "slack",
"color": "#4A154B"
},
{
"id": "discord",
"name": "Discord",
"description": "Connect to Discord servers",
"icon": "discord",
"color": "#5865F2"
},
{
"id": "teams",
"name": "Microsoft Teams",
"description": "Connect to Teams organizations",
"icon": "teams",
"color": "#6264A7"
}
]
}
@router.post("/install", response_model=IntegrationInstallResponse)
async def initiate_integration_install(
request: IntegrationInstallRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
):
try:
await verify_agent_access(request.agent_id, user_id)
oauth_provider = get_oauth_provider(request.provider, db)
install_url = oauth_provider.generate_authorization_url(request.agent_id, user_id)
return IntegrationInstallResponse(
install_url=install_url,
provider=request.provider.value
)
2025-07-01 02:03:46 +08:00
except ValueError as e:
error_msg = str(e)
if "environment variable is required" in error_msg:
logger.error(f"Missing OAuth configuration for {request.provider.value}: {e}")
raise HTTPException(
status_code=400,
detail=f"OAuth integration for {request.provider.value.title()} is not configured. Please contact your administrator to set up the required environment variables."
)
raise HTTPException(status_code=400, detail=str(e))
2025-06-30 18:57:34 +08:00
except Exception as e:
logger.error(f"Error initiating {request.provider.value} install: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{provider}/callback")
async def handle_oauth_callback(
provider: OAuthProvider,
code: str = Query(..., description="OAuth authorization code"),
state: str = Query(..., description="State parameter"),
error: Optional[str] = Query(None, description="OAuth error")
):
"""Handle OAuth callback for any provider."""
if error:
logger.error(f"{provider.value} OAuth error: {error}")
2025-07-01 02:03:46 +08:00
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
2025-06-30 18:57:34 +08:00
return RedirectResponse(
2025-07-01 02:03:46 +08:00
url=f"{frontend_url}/agents?{provider.value}_error={error}",
2025-06-30 18:57:34 +08:00
status_code=302
)
try:
oauth_provider = get_oauth_provider(provider, db)
result = await oauth_provider.handle_callback(code, state)
if result.success:
2025-07-01 02:03:46 +08:00
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
2025-06-30 18:57:34 +08:00
redirect_url = (
2025-07-01 02:03:46 +08:00
f"{frontend_url}/agents?"
2025-06-30 18:57:34 +08:00
f"{provider.value}_success=true&"
f"trigger_id={result.trigger_id}&"
f"workspace={result.workspace_name or ''}&"
f"bot_name={result.bot_name or ''}"
)
return RedirectResponse(url=redirect_url, status_code=302)
else:
2025-07-01 02:03:46 +08:00
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
2025-06-30 18:57:34 +08:00
return RedirectResponse(
2025-07-01 02:03:46 +08:00
url=f"{frontend_url}/agents?{provider.value}_error={result.error}",
2025-06-30 18:57:34 +08:00
status_code=302
)
except Exception as e:
logger.error(f"Error handling {provider.value} callback: {e}")
2025-07-01 02:03:46 +08:00
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
2025-06-30 18:57:34 +08:00
return RedirectResponse(
2025-07-01 02:03:46 +08:00
url=f"{frontend_url}/agents?{provider.value}_error=callback_failed",
2025-06-30 18:57:34 +08:00
status_code=302
)
@router.get("/status/{agent_id}", response_model=IntegrationStatusResponse)
async def get_integration_status(
agent_id: str,
user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Get integration status for an agent across all providers."""
try:
await verify_agent_access(agent_id, user_id)
client = await db.client
result = await client.table('agent_triggers')\
.select('trigger_id, trigger_type, name, is_active, created_at')\
.eq('agent_id', agent_id)\
.in_('trigger_type', ['slack', 'discord', 'teams'])\
.execute()
integrations = []
for trigger in result.data:
oauth_result = await client.table('oauth_installations')\
.select('provider, provider_data, installed_at')\
.eq('trigger_id', trigger['trigger_id'])\
.execute()
if oauth_result.data:
oauth_data = oauth_result.data[0]
provider_data = oauth_data.get('provider_data', {})
integrations.append({
"trigger_id": trigger["trigger_id"],
"provider": oauth_data["provider"],
"name": trigger["name"],
"is_active": trigger["is_active"],
"workspace_name": provider_data.get("workspace_name"),
"bot_name": provider_data.get("bot_name"),
"installed_at": oauth_data["installed_at"],
"created_at": trigger["created_at"]
})
return IntegrationStatusResponse(
agent_id=agent_id,
integrations=integrations
)
except Exception as e:
logger.error(f"Error getting integration status: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/uninstall/{trigger_id}")
async def uninstall_integration(
trigger_id: str,
user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Uninstall any OAuth integration."""
try:
client = await db.client
trigger_result = await client.table('agent_triggers')\
.select('agent_id, trigger_type')\
.eq('trigger_id', trigger_id)\
.execute()
if not trigger_result.data:
raise HTTPException(status_code=404, detail="Integration not found")
agent_id = trigger_result.data[0]['agent_id']
provider_type = trigger_result.data[0]['trigger_type']
await verify_agent_access(agent_id, user_id)
from .core import TriggerManager
trigger_manager = TriggerManager(db)
success = await trigger_manager.delete_trigger(trigger_id)
if success:
await client.table('oauth_installations')\
.delete()\
.eq('trigger_id', trigger_id)\
.execute()
return {
"success": True,
"message": f"{provider_type.title()} integration uninstalled successfully"
}
else:
raise HTTPException(status_code=500, detail="Failed to uninstall integration")
except Exception as e:
logger.error(f"Error uninstalling integration: {e}")
raise HTTPException(status_code=500, detail=str(e))
async def verify_agent_access(agent_id: str, user_id: str):
"""Verify that the user has access to the agent."""
client = await db.client
result = await client.table('agents').select('account_id').eq('agent_id', agent_id).execute()
if not result.data:
raise HTTPException(status_code=404, detail="Agent not found")
agent = result.data[0]
if agent['account_id'] != user_id:
raise HTTPException(status_code=403, detail="Access denied")