mirror of https://github.com/kortix-ai/suna.git
427 lines
16 KiB
Python
427 lines
16 KiB
Python
"""
|
|
Secure MCP Credential Manager
|
|
|
|
This module handles:
|
|
1. Encrypting and storing MCP credentials securely
|
|
2. Retrieving and decrypting credentials for runtime use
|
|
3. Managing credential lifecycle and validation
|
|
4. Auditing credential usage
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import hashlib
|
|
import base64
|
|
from typing import Dict, List, Any, Optional, Tuple
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timezone
|
|
|
|
from cryptography.fernet import Fernet
|
|
from cryptography.hazmat.primitives import hashes
|
|
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
|
|
|
from utils.logger import logger
|
|
from services.supabase import DBConnection
|
|
|
|
db = DBConnection()
|
|
|
|
|
|
@dataclass
|
|
class MCPCredential:
|
|
"""Represents an MCP credential"""
|
|
credential_id: str
|
|
account_id: str
|
|
mcp_qualified_name: str
|
|
display_name: str
|
|
config: Dict[str, Any]
|
|
is_active: bool
|
|
last_used_at: Optional[datetime] = None
|
|
created_at: Optional[datetime] = None
|
|
updated_at: Optional[datetime] = None
|
|
|
|
|
|
@dataclass
|
|
class MCPRequirement:
|
|
"""Represents an MCP requirement from a template"""
|
|
qualified_name: str
|
|
display_name: str
|
|
enabled_tools: List[str]
|
|
required_config: List[str]
|
|
custom_type: Optional[str] = None # 'sse' or 'http' for custom MCP servers
|
|
|
|
|
|
class CredentialManager:
|
|
"""Manages secure storage and retrieval of MCP credentials"""
|
|
|
|
def __init__(self):
|
|
self.encryption_key = self._get_or_create_encryption_key()
|
|
self.cipher = Fernet(self.encryption_key)
|
|
|
|
def _get_or_create_encryption_key(self) -> bytes:
|
|
"""Get or create encryption key for credentials"""
|
|
# Use environment variable or fallback to a default key for development
|
|
key_env = os.getenv("MCP_CREDENTIAL_ENCRYPTION_KEY", "Rgv3pwjsXetY0KAesiMk-OKbC8C2roTx2S8qC0OYBb0=")
|
|
|
|
try:
|
|
# If the key is already in the correct format (base64 encoded), it should work directly
|
|
# Fernet expects the key as bytes, so we need to encode the string
|
|
if isinstance(key_env, str):
|
|
# The key should be a base64-encoded string, encode it to bytes for Fernet
|
|
return key_env.encode('utf-8')
|
|
else:
|
|
return key_env
|
|
|
|
except Exception as e:
|
|
logger.error(f"Invalid encryption key: {e}")
|
|
# Generate a new key as fallback
|
|
logger.warning("Generating new encryption key for this session")
|
|
key = Fernet.generate_key()
|
|
logger.info(f"Generated new encryption key. Set this in your environment:")
|
|
logger.info(f"MCP_CREDENTIAL_ENCRYPTION_KEY={key.decode()}")
|
|
return key
|
|
|
|
def _encrypt_config(self, config: Dict[str, Any]) -> Tuple[bytes, str]:
|
|
"""Encrypt configuration and return encrypted data + hash"""
|
|
config_json = json.dumps(config, sort_keys=True)
|
|
config_bytes = config_json.encode('utf-8')
|
|
|
|
# Create hash for integrity checking
|
|
config_hash = hashlib.sha256(config_bytes).hexdigest()
|
|
|
|
# Encrypt the config
|
|
encrypted_config = self.cipher.encrypt(config_bytes)
|
|
|
|
return encrypted_config, config_hash
|
|
|
|
def _decrypt_config(self, encrypted_config: bytes, expected_hash: str) -> Dict[str, Any]:
|
|
"""Decrypt configuration and verify integrity"""
|
|
try:
|
|
decrypted_bytes = self.cipher.decrypt(encrypted_config)
|
|
|
|
# Verify hash
|
|
actual_hash = hashlib.sha256(decrypted_bytes).hexdigest()
|
|
if actual_hash != expected_hash:
|
|
raise ValueError("Credential integrity check failed")
|
|
|
|
config_json = decrypted_bytes.decode('utf-8')
|
|
return json.loads(config_json)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to decrypt credential: {e}")
|
|
raise ValueError("Failed to decrypt credential")
|
|
|
|
async def store_credential(
|
|
self,
|
|
account_id: str,
|
|
mcp_qualified_name: str,
|
|
display_name: str,
|
|
config: Dict[str, Any]
|
|
) -> str:
|
|
"""
|
|
Store encrypted MCP credentials for a user
|
|
|
|
Args:
|
|
account_id: User's account ID
|
|
mcp_qualified_name: MCP server qualified name (e.g., "exa", "@smithery-ai/github")
|
|
display_name: Human-readable name for the MCP
|
|
config: Configuration dictionary with API keys and settings
|
|
|
|
Returns:
|
|
credential_id: UUID of the stored credential
|
|
"""
|
|
logger.info(f"Storing credential for {mcp_qualified_name} for user {account_id}")
|
|
|
|
try:
|
|
# Encrypt the configuration
|
|
encrypted_config, config_hash = self._encrypt_config(config)
|
|
|
|
client = await db.client
|
|
|
|
# Upsert the credential (encode bytes as base64 for database storage)
|
|
encoded_config = base64.b64encode(encrypted_config).decode('utf-8')
|
|
logger.debug(f"Encoded config length: {len(encoded_config)}, content preview: {encoded_config[:50]}...")
|
|
|
|
result = await client.table('user_mcp_credentials').upsert({
|
|
'account_id': account_id,
|
|
'mcp_qualified_name': mcp_qualified_name,
|
|
'display_name': display_name,
|
|
'encrypted_config': encoded_config,
|
|
'config_hash': config_hash,
|
|
'is_active': True,
|
|
'updated_at': datetime.now(timezone.utc).isoformat()
|
|
}, on_conflict='account_id,mcp_qualified_name').execute()
|
|
|
|
if not result.data:
|
|
raise ValueError("Failed to store credential")
|
|
|
|
credential_id = result.data[0]['credential_id']
|
|
logger.info(f"Successfully stored credential {credential_id} for {mcp_qualified_name}")
|
|
|
|
return credential_id
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error storing credential for {mcp_qualified_name}: {str(e)}")
|
|
raise
|
|
|
|
async def get_credential(self, account_id: str, mcp_qualified_name: str) -> Optional[MCPCredential]:
|
|
"""
|
|
Retrieve and decrypt MCP credentials for a user
|
|
|
|
Args:
|
|
account_id: User's account ID
|
|
mcp_qualified_name: MCP server qualified name
|
|
|
|
Returns:
|
|
MCPCredential object or None if not found
|
|
"""
|
|
try:
|
|
client = await db.client
|
|
|
|
result = await client.table('user_mcp_credentials').select('*')\
|
|
.eq('account_id', account_id)\
|
|
.eq('mcp_qualified_name', mcp_qualified_name)\
|
|
.eq('is_active', True)\
|
|
.execute()
|
|
|
|
if not result.data:
|
|
return None
|
|
|
|
cred_data = result.data[0]
|
|
|
|
# Decrypt the configuration (handle both old and new formats)
|
|
encrypted_config = cred_data['encrypted_config']
|
|
logger.debug(f"Retrieved config type: {type(encrypted_config)}, length: {len(encrypted_config) if encrypted_config else 0}")
|
|
|
|
if isinstance(encrypted_config, str):
|
|
# New format: base64 encoded string
|
|
logger.debug(f"Decoding base64 string of length {len(encrypted_config)}")
|
|
try:
|
|
encrypted_config_bytes = base64.b64decode(encrypted_config.encode('utf-8'))
|
|
except Exception as e:
|
|
logger.error(f"Failed to decode base64 credential: {e}, string: {encrypted_config[:50]}...")
|
|
return None
|
|
else:
|
|
# Old format: raw bytes (backward compatibility)
|
|
encrypted_config_bytes = encrypted_config
|
|
|
|
config = self._decrypt_config(
|
|
encrypted_config_bytes,
|
|
cred_data['config_hash']
|
|
)
|
|
|
|
# Update last used timestamp
|
|
await client.table('user_mcp_credentials')\
|
|
.update({'last_used_at': datetime.now(timezone.utc).isoformat()})\
|
|
.eq('credential_id', cred_data['credential_id'])\
|
|
.execute()
|
|
|
|
return MCPCredential(
|
|
credential_id=cred_data['credential_id'],
|
|
account_id=cred_data['account_id'],
|
|
mcp_qualified_name=cred_data['mcp_qualified_name'],
|
|
display_name=cred_data['display_name'],
|
|
config=config,
|
|
is_active=cred_data['is_active'],
|
|
last_used_at=cred_data.get('last_used_at'),
|
|
created_at=cred_data.get('created_at'),
|
|
updated_at=cred_data.get('updated_at')
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error retrieving credential for {mcp_qualified_name}: {str(e)}")
|
|
return None
|
|
|
|
async def get_user_credentials(self, account_id: str) -> List[MCPCredential]:
|
|
"""Get all active credentials for a user"""
|
|
try:
|
|
client = await db.client
|
|
|
|
result = await client.table('user_mcp_credentials').select('*')\
|
|
.eq('account_id', account_id)\
|
|
.eq('is_active', True)\
|
|
.order('created_at', desc=True)\
|
|
.execute()
|
|
|
|
credentials = []
|
|
for cred_data in result.data:
|
|
try:
|
|
# Decrypt the configuration (handle both old and new formats)
|
|
encrypted_config = cred_data['encrypted_config']
|
|
if isinstance(encrypted_config, str):
|
|
# New format: base64 encoded string
|
|
try:
|
|
encrypted_config_bytes = base64.b64decode(encrypted_config.encode('utf-8'))
|
|
except Exception as e:
|
|
logger.error(f"Failed to decode base64 credential {cred_data['credential_id']}: {e}")
|
|
continue
|
|
else:
|
|
# Old format: raw bytes (backward compatibility)
|
|
encrypted_config_bytes = encrypted_config
|
|
|
|
config = self._decrypt_config(
|
|
encrypted_config_bytes,
|
|
cred_data['config_hash']
|
|
)
|
|
|
|
credentials.append(MCPCredential(
|
|
credential_id=cred_data['credential_id'],
|
|
account_id=cred_data['account_id'],
|
|
mcp_qualified_name=cred_data['mcp_qualified_name'],
|
|
display_name=cred_data['display_name'],
|
|
config=config,
|
|
is_active=cred_data['is_active'],
|
|
last_used_at=cred_data.get('last_used_at'),
|
|
created_at=cred_data.get('created_at'),
|
|
updated_at=cred_data.get('updated_at')
|
|
))
|
|
except Exception as e:
|
|
logger.error(f"Failed to decrypt credential {cred_data['credential_id']}: {e}")
|
|
continue
|
|
|
|
return credentials
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error retrieving user credentials: {str(e)}")
|
|
return []
|
|
|
|
async def delete_credential(self, account_id: str, mcp_qualified_name: str) -> bool:
|
|
"""Delete (deactivate) a credential"""
|
|
try:
|
|
client = await db.client
|
|
|
|
result = await client.table('user_mcp_credentials')\
|
|
.update({'is_active': False})\
|
|
.eq('account_id', account_id)\
|
|
.eq('mcp_qualified_name', mcp_qualified_name)\
|
|
.execute()
|
|
|
|
return len(result.data) > 0
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error deleting credential for {mcp_qualified_name}: {str(e)}")
|
|
return False
|
|
|
|
async def test_credential(self, account_id: str, mcp_qualified_name: str) -> bool:
|
|
"""Test if a credential is valid by attempting to connect"""
|
|
try:
|
|
credential = await self.get_credential(account_id, mcp_qualified_name)
|
|
if not credential:
|
|
return False
|
|
|
|
# Import here to avoid circular imports
|
|
from .client import MCPManager
|
|
|
|
# Create a test MCP configuration
|
|
test_config = {
|
|
"name": credential.display_name,
|
|
"qualifiedName": credential.mcp_qualified_name,
|
|
"config": credential.config,
|
|
"enabledTools": [] # Empty for testing
|
|
}
|
|
|
|
# Try to connect
|
|
mcp_manager = MCPManager()
|
|
try:
|
|
await mcp_manager.connect_server(test_config)
|
|
await self._log_credential_usage(
|
|
credential.credential_id,
|
|
None,
|
|
"test_connection",
|
|
True
|
|
)
|
|
return True
|
|
except Exception as e:
|
|
await self._log_credential_usage(
|
|
credential.credential_id,
|
|
None,
|
|
"test_connection",
|
|
False,
|
|
str(e)
|
|
)
|
|
return False
|
|
finally:
|
|
await mcp_manager.disconnect_all()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error testing credential for {mcp_qualified_name}: {str(e)}")
|
|
return False
|
|
|
|
async def _log_credential_usage(
|
|
self,
|
|
credential_id: str,
|
|
instance_id: Optional[str],
|
|
action: str,
|
|
success: bool,
|
|
error_message: Optional[str] = None,
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
):
|
|
"""Log credential usage for auditing"""
|
|
try:
|
|
client = await db.client
|
|
|
|
await client.table('credential_usage_log').insert({
|
|
'credential_id': credential_id,
|
|
'instance_id': instance_id,
|
|
'action': action,
|
|
'success': success,
|
|
'error_message': error_message,
|
|
'metadata': metadata or {}
|
|
}).execute()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to log credential usage: {e}")
|
|
|
|
async def get_missing_credentials_for_requirements(
|
|
self,
|
|
account_id: str,
|
|
requirements: List[MCPRequirement]
|
|
) -> List[MCPRequirement]:
|
|
"""Get list of missing credentials for given requirements"""
|
|
user_credentials = await self.get_user_credentials(account_id)
|
|
user_mcp_names = {cred.mcp_qualified_name for cred in user_credentials}
|
|
|
|
missing = []
|
|
for req in requirements:
|
|
if req.custom_type:
|
|
custom_pattern = f"custom_{req.custom_type}_"
|
|
found = any(
|
|
cred_name.startswith(custom_pattern) and
|
|
req.display_name.lower().replace(' ', '_') in cred_name
|
|
for cred_name in user_mcp_names
|
|
)
|
|
if not found:
|
|
missing.append(req)
|
|
else:
|
|
if req.qualified_name not in user_mcp_names:
|
|
missing.append(req)
|
|
|
|
return missing
|
|
|
|
async def build_credential_mappings(
|
|
self,
|
|
account_id: str,
|
|
requirements: List[MCPRequirement]
|
|
) -> Dict[str, str]:
|
|
"""Build credential mappings for agent instance"""
|
|
mappings = {}
|
|
|
|
for req in requirements:
|
|
if req.custom_type:
|
|
user_credentials = await self.get_user_credentials(account_id)
|
|
custom_pattern = f"custom_{req.custom_type}_"
|
|
|
|
for cred in user_credentials:
|
|
if (cred.mcp_qualified_name.startswith(custom_pattern) and
|
|
req.display_name.lower().replace(' ', '_') in cred.mcp_qualified_name):
|
|
mappings[req.qualified_name] = cred.credential_id
|
|
break
|
|
else:
|
|
credential = await self.get_credential(account_id, req.qualified_name)
|
|
if credential:
|
|
mappings[req.qualified_name] = credential.credential_id
|
|
|
|
return mappings
|
|
|
|
|
|
# Global credential manager instance
|
|
credential_manager = CredentialManager() |