suna/backend/composio_integration/connected_account_service.py

248 lines
10 KiB
Python
Raw Normal View History

2025-08-03 04:10:11 +08:00
from typing import Optional, List, Dict, Any
2025-08-02 16:22:17 +08:00
from pydantic import BaseModel
from utils.logger import logger
2025-08-03 04:10:11 +08:00
from enum import Enum
2025-08-02 16:22:17 +08:00
from .client import ComposioClient
class ConnectionState(BaseModel):
2025-08-03 04:10:11 +08:00
auth_scheme: str = "OAUTH2"
val: Dict[str, Any] = {}
2025-08-02 16:22:17 +08:00
class ConnectedAccount(BaseModel):
id: str
status: str
redirect_url: Optional[str] = None
redirect_uri: Optional[str] = None
connection_data: ConnectionState
auth_config_id: str
user_id: str
2025-08-03 04:10:11 +08:00
deprecated: Optional[bool] = None
2025-08-02 16:22:17 +08:00
class ConnectedAccountService:
def __init__(self, api_key: Optional[str] = None):
self.client = ComposioClient.get_client(api_key)
2025-08-03 04:10:11 +08:00
def _extract_deprecated_value(self, deprecated_obj) -> Optional[bool]:
"""Extract boolean value from Composio SDK's Deprecated object"""
if deprecated_obj is None:
return None
# If it's already a boolean, return it
if isinstance(deprecated_obj, bool):
return deprecated_obj
# If it's a Composio Deprecated object, try to extract meaningful info
if hasattr(deprecated_obj, '__dict__'):
# Check if it has any deprecation info - if so, consider it deprecated
deprecated_dict = deprecated_obj.__dict__
if deprecated_dict:
return True
# Default to not deprecated
return False
def _extract_val_dict(self, val_obj) -> Dict[str, Any]:
"""Extract dictionary from Composio SDK's val object"""
if val_obj is None:
return {}
# If it's already a dict, return it
if isinstance(val_obj, dict):
return val_obj
# If it's a Pydantic model, convert it to dict
if hasattr(val_obj, 'model_dump'):
return val_obj.model_dump()
elif hasattr(val_obj, 'dict'):
return val_obj.dict()
elif hasattr(val_obj, '__dict__'):
return val_obj.__dict__
# Fallback to empty dict
return {}
2025-08-02 16:22:17 +08:00
async def create_connected_account(
self,
auth_config_id: str,
user_id: str = "default"
) -> ConnectedAccount:
try:
logger.info(f"Creating connected account for auth_config: {auth_config_id}, user: {user_id}")
response = self.client.connected_accounts.create(
auth_config={"id": auth_config_id},
connection={
"user_id": user_id,
"state": {
"authScheme": "OAUTH2",
"val": {
"status": "INITIALIZING"
}
}
}
)
2025-08-03 04:10:11 +08:00
# Access Pydantic model attributes directly
connection_data_obj = getattr(response, 'connection_data', None)
if not connection_data_obj:
# Try alternative attribute names
connection_data_obj = getattr(response, 'connectionData', None)
if connection_data_obj and hasattr(connection_data_obj, '__dict__'):
connection_data_dict = connection_data_obj.__dict__
# Extract val field properly - it might be a Pydantic object
val_obj = connection_data_dict.get('val', {})
val_dict = self._extract_val_dict(val_obj)
connection_data = ConnectionState(
auth_scheme=connection_data_dict.get('auth_scheme', 'OAUTH2'),
val=val_dict
)
else:
connection_data = ConnectionState()
# Handle the deprecated field properly
deprecated_obj = getattr(response, 'deprecated', None)
deprecated_value = self._extract_deprecated_value(deprecated_obj)
2025-08-02 16:22:17 +08:00
connected_account = ConnectedAccount(
2025-08-03 04:10:11 +08:00
id=response.id,
status=response.status,
redirect_url=getattr(response, 'redirect_url', None),
redirect_uri=getattr(response, 'redirect_uri', None),
2025-08-02 16:22:17 +08:00
connection_data=connection_data,
auth_config_id=auth_config_id,
user_id=user_id,
2025-08-03 04:10:11 +08:00
deprecated=deprecated_value
2025-08-02 16:22:17 +08:00
)
logger.info(f"Successfully created connected account: {connected_account.id}")
return connected_account
except Exception as e:
logger.error(f"Failed to create connected account: {e}", exc_info=True)
raise
2025-08-03 04:10:11 +08:00
async def get_connected_account(self, connected_account_id: str) -> Optional[ConnectedAccount]:
2025-08-02 16:22:17 +08:00
try:
2025-08-03 04:10:11 +08:00
logger.info(f"Fetching connected account: {connected_account_id}")
2025-08-02 16:22:17 +08:00
2025-08-03 04:10:11 +08:00
response = self.client.connected_accounts.get(connected_account_id)
2025-08-02 16:22:17 +08:00
2025-08-03 04:10:11 +08:00
if not response:
return None
2025-08-02 16:22:17 +08:00
2025-08-03 04:10:11 +08:00
# Access Pydantic model attributes directly
connection_data_obj = getattr(response, 'connection_data', None)
if not connection_data_obj:
connection_data_obj = getattr(response, 'connectionData', None)
if connection_data_obj and hasattr(connection_data_obj, '__dict__'):
connection_data_dict = connection_data_obj.__dict__
# Extract val field properly - it might be a Pydantic object
val_obj = connection_data_dict.get('val', {})
val_dict = self._extract_val_dict(val_obj)
connection_data = ConnectionState(
auth_scheme=connection_data_dict.get('auth_scheme', 'OAUTH2'),
val=val_dict
)
else:
connection_data = ConnectionState()
# Handle the deprecated field properly
deprecated_obj = getattr(response, 'deprecated', None)
deprecated_value = self._extract_deprecated_value(deprecated_obj)
return ConnectedAccount(
id=response.id,
status=response.status,
redirect_url=getattr(response, 'redirect_url', None),
redirect_uri=getattr(response, 'redirect_uri', None),
connection_data=connection_data,
auth_config_id=getattr(response, 'auth_config_id', ''),
user_id=getattr(response, 'user_id', ''),
deprecated=deprecated_value
)
2025-08-02 16:22:17 +08:00
except Exception as e:
2025-08-03 04:10:11 +08:00
logger.error(f"Failed to get connected account {connected_account_id}: {e}", exc_info=True)
2025-08-02 16:22:17 +08:00
raise
2025-08-03 04:10:11 +08:00
async def get_auth_status(self, connected_account_id: str) -> Dict[str, Any]:
2025-08-02 16:22:17 +08:00
try:
2025-08-03 04:10:11 +08:00
logger.info(f"Getting auth status for connected account: {connected_account_id}")
2025-08-02 16:22:17 +08:00
2025-08-03 04:10:11 +08:00
connected_account = await self.get_connected_account(connected_account_id)
if not connected_account:
return {"status": "not_found", "message": "Connected account not found"}
return {
"status": connected_account.status,
"redirect_url": connected_account.redirect_url,
"connection_data": connected_account.connection_data.dict()
}
2025-08-02 16:22:17 +08:00
except Exception as e:
2025-08-03 04:10:11 +08:00
logger.error(f"Failed to get auth status: {e}", exc_info=True)
2025-08-02 16:22:17 +08:00
raise
2025-08-03 04:10:11 +08:00
async def list_connected_accounts(self, auth_config_id: Optional[str] = None) -> List[ConnectedAccount]:
2025-08-02 16:22:17 +08:00
try:
2025-08-03 04:10:11 +08:00
logger.info(f"Listing connected accounts for auth_config: {auth_config_id}")
if auth_config_id:
response = self.client.connected_accounts.list(auth_config_id=auth_config_id)
else:
response = self.client.connected_accounts.list()
2025-08-02 16:22:17 +08:00
2025-08-03 04:10:11 +08:00
connected_accounts = []
items = getattr(response, 'items', [])
for item in items:
connection_data_obj = getattr(item, 'connection_data', None)
if not connection_data_obj:
connection_data_obj = getattr(item, 'connectionData', None)
if connection_data_obj and hasattr(connection_data_obj, '__dict__'):
connection_data_dict = connection_data_obj.__dict__
# Extract val field properly - it might be a Pydantic object
val_obj = connection_data_dict.get('val', {})
val_dict = self._extract_val_dict(val_obj)
connection_data = ConnectionState(
auth_scheme=connection_data_dict.get('auth_scheme', 'OAUTH2'),
val=val_dict
)
else:
connection_data = ConnectionState()
# Handle the deprecated field properly
deprecated_obj = getattr(item, 'deprecated', None)
deprecated_value = self._extract_deprecated_value(deprecated_obj)
connected_account = ConnectedAccount(
id=item.id,
status=item.status,
redirect_url=getattr(item, 'redirect_url', None),
redirect_uri=getattr(item, 'redirect_uri', None),
connection_data=connection_data,
auth_config_id=getattr(item, 'auth_config_id', auth_config_id or ''),
user_id=getattr(item, 'user_id', ''),
deprecated=deprecated_value
)
connected_accounts.append(connected_account)
logger.info(f"Successfully listed {len(connected_accounts)} connected accounts")
return connected_accounts
2025-08-02 16:22:17 +08:00
except Exception as e:
2025-08-03 04:10:11 +08:00
logger.error(f"Failed to list connected accounts: {e}", exc_info=True)
2025-08-02 16:22:17 +08:00
raise