from typing import Optional, List, Dict, Any from pydantic import BaseModel from utils.logger import logger from enum import Enum from .client import ComposioClient class ConnectionState(BaseModel): auth_scheme: str = "OAUTH2" val: Dict[str, Any] = {} class ConnectedAccount(BaseModel): id: str status: str redirect_url: Optional[str] = None redirect_uri: Optional[str] = None connection_data: ConnectionState auth_config_id: str user_id: str deprecated: Optional[bool] = None class ConnectedAccountService: def __init__(self, api_key: Optional[str] = None): self.client = ComposioClient.get_client(api_key) def _extract_deprecated_value(self, deprecated_obj) -> Optional[bool]: """Extract boolean value from Composio SDK's Deprecated object""" if deprecated_obj is None: return None # If it's already a boolean, return it if isinstance(deprecated_obj, bool): return deprecated_obj # If it's a Composio Deprecated object, try to extract meaningful info if hasattr(deprecated_obj, '__dict__'): # Check if it has any deprecation info - if so, consider it deprecated deprecated_dict = deprecated_obj.__dict__ if deprecated_dict: return True # Default to not deprecated return False def _extract_val_dict(self, val_obj) -> Dict[str, Any]: """Extract dictionary from Composio SDK's val object""" if val_obj is None: return {} # If it's already a dict, return it if isinstance(val_obj, dict): return val_obj # If it's a Pydantic model, convert it to dict if hasattr(val_obj, 'model_dump'): return val_obj.model_dump() elif hasattr(val_obj, 'dict'): return val_obj.dict() elif hasattr(val_obj, '__dict__'): return val_obj.__dict__ # Fallback to empty dict return {} async def create_connected_account( self, auth_config_id: str, user_id: str = "default" ) -> ConnectedAccount: try: logger.info(f"Creating connected account for auth_config: {auth_config_id}, user: {user_id}") response = self.client.connected_accounts.create( auth_config={"id": auth_config_id}, connection={ "user_id": user_id, "state": { "authScheme": "OAUTH2", "val": { "status": "INITIALIZING" } } } ) # Access Pydantic model attributes directly connection_data_obj = getattr(response, 'connection_data', None) if not connection_data_obj: # Try alternative attribute names connection_data_obj = getattr(response, 'connectionData', None) if connection_data_obj and hasattr(connection_data_obj, '__dict__'): connection_data_dict = connection_data_obj.__dict__ # Extract val field properly - it might be a Pydantic object val_obj = connection_data_dict.get('val', {}) val_dict = self._extract_val_dict(val_obj) connection_data = ConnectionState( auth_scheme=connection_data_dict.get('auth_scheme', 'OAUTH2'), val=val_dict ) else: connection_data = ConnectionState() # Handle the deprecated field properly deprecated_obj = getattr(response, 'deprecated', None) deprecated_value = self._extract_deprecated_value(deprecated_obj) connected_account = ConnectedAccount( id=response.id, status=response.status, redirect_url=getattr(response, 'redirect_url', None), redirect_uri=getattr(response, 'redirect_uri', None), connection_data=connection_data, auth_config_id=auth_config_id, user_id=user_id, deprecated=deprecated_value ) logger.info(f"Successfully created connected account: {connected_account.id}") return connected_account except Exception as e: logger.error(f"Failed to create connected account: {e}", exc_info=True) raise async def get_connected_account(self, connected_account_id: str) -> Optional[ConnectedAccount]: try: logger.info(f"Fetching connected account: {connected_account_id}") response = self.client.connected_accounts.get(connected_account_id) if not response: return None # Access Pydantic model attributes directly connection_data_obj = getattr(response, 'connection_data', None) if not connection_data_obj: connection_data_obj = getattr(response, 'connectionData', None) if connection_data_obj and hasattr(connection_data_obj, '__dict__'): connection_data_dict = connection_data_obj.__dict__ # Extract val field properly - it might be a Pydantic object val_obj = connection_data_dict.get('val', {}) val_dict = self._extract_val_dict(val_obj) connection_data = ConnectionState( auth_scheme=connection_data_dict.get('auth_scheme', 'OAUTH2'), val=val_dict ) else: connection_data = ConnectionState() # Handle the deprecated field properly deprecated_obj = getattr(response, 'deprecated', None) deprecated_value = self._extract_deprecated_value(deprecated_obj) return ConnectedAccount( id=response.id, status=response.status, redirect_url=getattr(response, 'redirect_url', None), redirect_uri=getattr(response, 'redirect_uri', None), connection_data=connection_data, auth_config_id=getattr(response, 'auth_config_id', ''), user_id=getattr(response, 'user_id', ''), deprecated=deprecated_value ) except Exception as e: logger.error(f"Failed to get connected account {connected_account_id}: {e}", exc_info=True) raise async def get_auth_status(self, connected_account_id: str) -> Dict[str, Any]: try: logger.info(f"Getting auth status for connected account: {connected_account_id}") connected_account = await self.get_connected_account(connected_account_id) if not connected_account: return {"status": "not_found", "message": "Connected account not found"} return { "status": connected_account.status, "redirect_url": connected_account.redirect_url, "connection_data": connected_account.connection_data.dict() } except Exception as e: logger.error(f"Failed to get auth status: {e}", exc_info=True) raise async def list_connected_accounts(self, auth_config_id: Optional[str] = None) -> List[ConnectedAccount]: try: logger.info(f"Listing connected accounts for auth_config: {auth_config_id}") if auth_config_id: response = self.client.connected_accounts.list(auth_config_id=auth_config_id) else: response = self.client.connected_accounts.list() connected_accounts = [] items = getattr(response, 'items', []) for item in items: connection_data_obj = getattr(item, 'connection_data', None) if not connection_data_obj: connection_data_obj = getattr(item, 'connectionData', None) if connection_data_obj and hasattr(connection_data_obj, '__dict__'): connection_data_dict = connection_data_obj.__dict__ # Extract val field properly - it might be a Pydantic object val_obj = connection_data_dict.get('val', {}) val_dict = self._extract_val_dict(val_obj) connection_data = ConnectionState( auth_scheme=connection_data_dict.get('auth_scheme', 'OAUTH2'), val=val_dict ) else: connection_data = ConnectionState() # Handle the deprecated field properly deprecated_obj = getattr(item, 'deprecated', None) deprecated_value = self._extract_deprecated_value(deprecated_obj) connected_account = ConnectedAccount( id=item.id, status=item.status, redirect_url=getattr(item, 'redirect_url', None), redirect_uri=getattr(item, 'redirect_uri', None), connection_data=connection_data, auth_config_id=getattr(item, 'auth_config_id', auth_config_id or ''), user_id=getattr(item, 'user_id', ''), deprecated=deprecated_value ) connected_accounts.append(connected_account) logger.info(f"Successfully listed {len(connected_accounts)} connected accounts") return connected_accounts except Exception as e: logger.error(f"Failed to list connected accounts: {e}", exc_info=True) raise