suna/backend/mcp_local/credential_manager.py

427 lines
16 KiB
Python

"""
Secure MCP Credential Manager
This module handles:
1. Encrypting and storing MCP credentials securely
2. Retrieving and decrypting credentials for runtime use
3. Managing credential lifecycle and validation
4. Auditing credential usage
"""
import os
import json
import hashlib
import base64
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass
from datetime import datetime, timezone
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from utils.logger import logger
from services.supabase import DBConnection
db = DBConnection()
@dataclass
class MCPCredential:
"""Represents an MCP credential"""
credential_id: str
account_id: str
mcp_qualified_name: str
display_name: str
config: Dict[str, Any]
is_active: bool
last_used_at: Optional[datetime] = None
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
@dataclass
class MCPRequirement:
"""Represents an MCP requirement from a template"""
qualified_name: str
display_name: str
enabled_tools: List[str]
required_config: List[str]
custom_type: Optional[str] = None # 'sse' or 'http' for custom MCP servers
class CredentialManager:
"""Manages secure storage and retrieval of MCP credentials"""
def __init__(self):
self.encryption_key = self._get_or_create_encryption_key()
self.cipher = Fernet(self.encryption_key)
def _get_or_create_encryption_key(self) -> bytes:
"""Get or create encryption key for credentials"""
# Use environment variable or fallback to a default key for development
key_env = os.getenv("MCP_CREDENTIAL_ENCRYPTION_KEY", "Rgv3pwjsXetY0KAesiMk-OKbC8C2roTx2S8qC0OYBb0=")
try:
# If the key is already in the correct format (base64 encoded), it should work directly
# Fernet expects the key as bytes, so we need to encode the string
if isinstance(key_env, str):
# The key should be a base64-encoded string, encode it to bytes for Fernet
return key_env.encode('utf-8')
else:
return key_env
except Exception as e:
logger.error(f"Invalid encryption key: {e}")
# Generate a new key as fallback
logger.warning("Generating new encryption key for this session")
key = Fernet.generate_key()
logger.info(f"Generated new encryption key. Set this in your environment:")
logger.info(f"MCP_CREDENTIAL_ENCRYPTION_KEY={key.decode()}")
return key
def _encrypt_config(self, config: Dict[str, Any]) -> Tuple[bytes, str]:
"""Encrypt configuration and return encrypted data + hash"""
config_json = json.dumps(config, sort_keys=True)
config_bytes = config_json.encode('utf-8')
# Create hash for integrity checking
config_hash = hashlib.sha256(config_bytes).hexdigest()
# Encrypt the config
encrypted_config = self.cipher.encrypt(config_bytes)
return encrypted_config, config_hash
def _decrypt_config(self, encrypted_config: bytes, expected_hash: str) -> Dict[str, Any]:
"""Decrypt configuration and verify integrity"""
try:
decrypted_bytes = self.cipher.decrypt(encrypted_config)
# Verify hash
actual_hash = hashlib.sha256(decrypted_bytes).hexdigest()
if actual_hash != expected_hash:
raise ValueError("Credential integrity check failed")
config_json = decrypted_bytes.decode('utf-8')
return json.loads(config_json)
except Exception as e:
logger.error(f"Failed to decrypt credential: {e}")
raise ValueError("Failed to decrypt credential")
async def store_credential(
self,
account_id: str,
mcp_qualified_name: str,
display_name: str,
config: Dict[str, Any]
) -> str:
"""
Store encrypted MCP credentials for a user
Args:
account_id: User's account ID
mcp_qualified_name: MCP server qualified name (e.g., "exa", "@smithery-ai/github")
display_name: Human-readable name for the MCP
config: Configuration dictionary with API keys and settings
Returns:
credential_id: UUID of the stored credential
"""
logger.info(f"Storing credential for {mcp_qualified_name} for user {account_id}")
try:
# Encrypt the configuration
encrypted_config, config_hash = self._encrypt_config(config)
client = await db.client
# Upsert the credential (encode bytes as base64 for database storage)
encoded_config = base64.b64encode(encrypted_config).decode('utf-8')
logger.debug(f"Encoded config length: {len(encoded_config)}, content preview: {encoded_config[:50]}...")
result = await client.table('user_mcp_credentials').upsert({
'account_id': account_id,
'mcp_qualified_name': mcp_qualified_name,
'display_name': display_name,
'encrypted_config': encoded_config,
'config_hash': config_hash,
'is_active': True,
'updated_at': datetime.now(timezone.utc).isoformat()
}, on_conflict='account_id,mcp_qualified_name').execute()
if not result.data:
raise ValueError("Failed to store credential")
credential_id = result.data[0]['credential_id']
logger.info(f"Successfully stored credential {credential_id} for {mcp_qualified_name}")
return credential_id
except Exception as e:
logger.error(f"Error storing credential for {mcp_qualified_name}: {str(e)}")
raise
async def get_credential(self, account_id: str, mcp_qualified_name: str) -> Optional[MCPCredential]:
"""
Retrieve and decrypt MCP credentials for a user
Args:
account_id: User's account ID
mcp_qualified_name: MCP server qualified name
Returns:
MCPCredential object or None if not found
"""
try:
client = await db.client
result = await client.table('user_mcp_credentials').select('*')\
.eq('account_id', account_id)\
.eq('mcp_qualified_name', mcp_qualified_name)\
.eq('is_active', True)\
.execute()
if not result.data:
return None
cred_data = result.data[0]
# Decrypt the configuration (handle both old and new formats)
encrypted_config = cred_data['encrypted_config']
logger.debug(f"Retrieved config type: {type(encrypted_config)}, length: {len(encrypted_config) if encrypted_config else 0}")
if isinstance(encrypted_config, str):
# New format: base64 encoded string
logger.debug(f"Decoding base64 string of length {len(encrypted_config)}")
try:
encrypted_config_bytes = base64.b64decode(encrypted_config.encode('utf-8'))
except Exception as e:
logger.error(f"Failed to decode base64 credential: {e}, string: {encrypted_config[:50]}...")
return None
else:
# Old format: raw bytes (backward compatibility)
encrypted_config_bytes = encrypted_config
config = self._decrypt_config(
encrypted_config_bytes,
cred_data['config_hash']
)
# Update last used timestamp
await client.table('user_mcp_credentials')\
.update({'last_used_at': datetime.now(timezone.utc).isoformat()})\
.eq('credential_id', cred_data['credential_id'])\
.execute()
return MCPCredential(
credential_id=cred_data['credential_id'],
account_id=cred_data['account_id'],
mcp_qualified_name=cred_data['mcp_qualified_name'],
display_name=cred_data['display_name'],
config=config,
is_active=cred_data['is_active'],
last_used_at=cred_data.get('last_used_at'),
created_at=cred_data.get('created_at'),
updated_at=cred_data.get('updated_at')
)
except Exception as e:
logger.error(f"Error retrieving credential for {mcp_qualified_name}: {str(e)}")
return None
async def get_user_credentials(self, account_id: str) -> List[MCPCredential]:
"""Get all active credentials for a user"""
try:
client = await db.client
result = await client.table('user_mcp_credentials').select('*')\
.eq('account_id', account_id)\
.eq('is_active', True)\
.order('created_at', desc=True)\
.execute()
credentials = []
for cred_data in result.data:
try:
# Decrypt the configuration (handle both old and new formats)
encrypted_config = cred_data['encrypted_config']
if isinstance(encrypted_config, str):
# New format: base64 encoded string
try:
encrypted_config_bytes = base64.b64decode(encrypted_config.encode('utf-8'))
except Exception as e:
logger.error(f"Failed to decode base64 credential {cred_data['credential_id']}: {e}")
continue
else:
# Old format: raw bytes (backward compatibility)
encrypted_config_bytes = encrypted_config
config = self._decrypt_config(
encrypted_config_bytes,
cred_data['config_hash']
)
credentials.append(MCPCredential(
credential_id=cred_data['credential_id'],
account_id=cred_data['account_id'],
mcp_qualified_name=cred_data['mcp_qualified_name'],
display_name=cred_data['display_name'],
config=config,
is_active=cred_data['is_active'],
last_used_at=cred_data.get('last_used_at'),
created_at=cred_data.get('created_at'),
updated_at=cred_data.get('updated_at')
))
except Exception as e:
logger.error(f"Failed to decrypt credential {cred_data['credential_id']}: {e}")
continue
return credentials
except Exception as e:
logger.error(f"Error retrieving user credentials: {str(e)}")
return []
async def delete_credential(self, account_id: str, mcp_qualified_name: str) -> bool:
"""Delete (deactivate) a credential"""
try:
client = await db.client
result = await client.table('user_mcp_credentials')\
.update({'is_active': False})\
.eq('account_id', account_id)\
.eq('mcp_qualified_name', mcp_qualified_name)\
.execute()
return len(result.data) > 0
except Exception as e:
logger.error(f"Error deleting credential for {mcp_qualified_name}: {str(e)}")
return False
async def test_credential(self, account_id: str, mcp_qualified_name: str) -> bool:
"""Test if a credential is valid by attempting to connect"""
try:
credential = await self.get_credential(account_id, mcp_qualified_name)
if not credential:
return False
# Import here to avoid circular imports
from .client import MCPManager
# Create a test MCP configuration
test_config = {
"name": credential.display_name,
"qualifiedName": credential.mcp_qualified_name,
"config": credential.config,
"enabledTools": [] # Empty for testing
}
# Try to connect
mcp_manager = MCPManager()
try:
await mcp_manager.connect_server(test_config)
await self._log_credential_usage(
credential.credential_id,
None,
"test_connection",
True
)
return True
except Exception as e:
await self._log_credential_usage(
credential.credential_id,
None,
"test_connection",
False,
str(e)
)
return False
finally:
await mcp_manager.disconnect_all()
except Exception as e:
logger.error(f"Error testing credential for {mcp_qualified_name}: {str(e)}")
return False
async def _log_credential_usage(
self,
credential_id: str,
instance_id: Optional[str],
action: str,
success: bool,
error_message: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
):
"""Log credential usage for auditing"""
try:
client = await db.client
await client.table('credential_usage_log').insert({
'credential_id': credential_id,
'instance_id': instance_id,
'action': action,
'success': success,
'error_message': error_message,
'metadata': metadata or {}
}).execute()
except Exception as e:
logger.error(f"Failed to log credential usage: {e}")
async def get_missing_credentials_for_requirements(
self,
account_id: str,
requirements: List[MCPRequirement]
) -> List[MCPRequirement]:
"""Get list of missing credentials for given requirements"""
user_credentials = await self.get_user_credentials(account_id)
user_mcp_names = {cred.mcp_qualified_name for cred in user_credentials}
missing = []
for req in requirements:
if req.custom_type:
custom_pattern = f"custom_{req.custom_type}_"
found = any(
cred_name.startswith(custom_pattern) and
req.display_name.lower().replace(' ', '_') in cred_name
for cred_name in user_mcp_names
)
if not found:
missing.append(req)
else:
if req.qualified_name not in user_mcp_names:
missing.append(req)
return missing
async def build_credential_mappings(
self,
account_id: str,
requirements: List[MCPRequirement]
) -> Dict[str, str]:
"""Build credential mappings for agent instance"""
mappings = {}
for req in requirements:
if req.custom_type:
user_credentials = await self.get_user_credentials(account_id)
custom_pattern = f"custom_{req.custom_type}_"
for cred in user_credentials:
if (cred.mcp_qualified_name.startswith(custom_pattern) and
req.display_name.lower().replace(' ', '_') in cred.mcp_qualified_name):
mappings[req.qualified_name] = cred.credential_id
break
else:
credential = await self.get_credential(account_id, req.qualified_name)
if credential:
mappings[req.qualified_name] = credential.credential_id
return mappings
# Global credential manager instance
credential_manager = CredentialManager()