suna/backend/services/api_keys.py

580 lines
20 KiB
Python

"""
API Keys Service
This module provides functionality for managing API keys including:
- Creating new API keys with UUIDs
- Validating API keys for authentication
- Managing expiration and revocation
- CRUD operations for user API keys
"""
import asyncio
from datetime import datetime, timezone, timedelta
from typing import Optional, List, Dict
from uuid import UUID, uuid4
import secrets
import string
import hmac
import hashlib
import time
from pydantic import BaseModel, Field, field_validator
from fastapi import HTTPException
from utils.logger import logger
from services.supabase import DBConnection
from services import redis
from utils.config import config
class APIKeyStatus:
ACTIVE = "active"
REVOKED = "revoked"
EXPIRED = "expired"
class APIKeyCreateRequest(BaseModel):
"""Request model for creating a new API key"""
title: str = Field(
...,
min_length=1,
max_length=255,
description="Human-readable title for the API key",
)
description: Optional[str] = Field(
None, description="Optional description for the API key"
)
expires_in_days: Optional[int] = Field(
None, gt=0, le=365, description="Number of days until expiration (max 365)"
)
@field_validator("title")
def validate_title(cls, v):
if not v or not v.strip():
raise ValueError("Title cannot be empty")
return v.strip()
class APIKeyResponse(BaseModel):
"""Response model for API key information (without the secret key)"""
key_id: UUID
public_key: str
title: str
description: Optional[str]
status: str
expires_at: Optional[datetime]
last_used_at: Optional[datetime]
created_at: datetime
class APIKeyCreateResponse(BaseModel):
"""Response model for newly created API key (includes both keys)"""
key_id: UUID
public_key: str
secret_key: str # Only returned on creation
title: str
description: Optional[str]
status: str
expires_at: Optional[datetime]
created_at: datetime
class APIKeyValidationResult(BaseModel):
"""Result of API key validation"""
is_valid: bool
account_id: Optional[UUID] = None
key_id: Optional[UUID] = None
error_message: Optional[str] = None
class APIKeyService:
"""
Service for managing API keys with performance optimizations
Performance Features:
- HMAC-SHA256 hashing (100x faster than bcrypt)
- Redis caching for validation results (2min TTL)
- Throttled last_used_at updates (max once per 15min per key, configurable)
- Cached user lookups (5min TTL)
- Asynchronous operations where possible
- In-memory fallback throttling when Redis unavailable
- Streamlined database schema without unnecessary triggers
"""
# Class-level in-memory throttle cache (fallback when Redis unavailable)
_throttle_cache: Dict[str, float] = {}
def __init__(self, db: DBConnection):
self.db = db
def _generate_key_pair(self) -> tuple[str, str]:
"""
Generate a public key and secret key pair
Returns:
tuple: (public_key, secret_key) where public_key starts with 'pk_' and secret_key starts with 'sk_'
"""
# Generate random strings for both keys
pk_suffix = "".join(
secrets.choice(string.ascii_letters + string.digits) for _ in range(32)
)
sk_suffix = "".join(
secrets.choice(string.ascii_letters + string.digits) for _ in range(32)
)
public_key = f"pk_{pk_suffix}"
secret_key = f"sk_{sk_suffix}"
return public_key, secret_key
def _get_secret_key(self) -> str:
"""Get the secret key for HMAC hashing"""
return config.API_KEY_SECRET
def _hash_secret_key(self, secret_key: str) -> str:
"""
Hash a secret key using HMAC-SHA256 (much faster than bcrypt)
Args:
secret_key: The secret key to hash
Returns:
str: The HMAC-SHA256 hash of the secret key
"""
secret = self._get_secret_key().encode("utf-8")
return hmac.new(secret, secret_key.encode("utf-8"), hashlib.sha256).hexdigest()
def _verify_secret_key(self, secret_key: str, hashed_key: str) -> bool:
"""
Verify a secret key against its hash using constant-time comparison
Args:
secret_key: The secret key to verify
hashed_key: The stored hash
Returns:
bool: True if the secret key matches the hash
"""
try:
expected_hash = self._hash_secret_key(secret_key)
return hmac.compare_digest(expected_hash, hashed_key)
except Exception:
return False
async def create_api_key(
self, account_id: UUID, request: APIKeyCreateRequest
) -> APIKeyCreateResponse:
"""
Create a new API key for the specified account
Args:
account_id: The account ID to create the key for
request: The API key creation request
Returns:
APIKeyCreateResponse containing the new API key details including both keys
"""
try:
# Calculate expiration date if specified
expires_at = None
if request.expires_in_days:
expires_at = datetime.now(timezone.utc) + timedelta(
days=request.expires_in_days
)
# Generate public and secret key pair
public_key, secret_key = self._generate_key_pair()
# Hash the secret key for storage
secret_key_hash = self._hash_secret_key(secret_key)
# Insert into database
client = await self.db.client
result = (
await client.table("api_keys")
.insert(
{
"public_key": public_key,
"secret_key_hash": secret_key_hash,
"account_id": str(account_id),
"title": request.title,
"description": request.description,
"expires_at": expires_at.isoformat() if expires_at else None,
"status": APIKeyStatus.ACTIVE,
}
)
.execute()
)
if not result.data:
raise HTTPException(status_code=500, detail="Failed to create API key")
key_data = result.data[0]
logger.info(
"API key created successfully",
account_id=str(account_id),
key_id=key_data["key_id"],
public_key=public_key,
title=request.title,
)
return APIKeyCreateResponse(
key_id=UUID(key_data["key_id"]),
public_key=public_key,
secret_key=secret_key, # Only returned on creation
title=key_data["title"],
description=key_data["description"],
status=key_data["status"],
expires_at=(
datetime.fromisoformat(key_data["expires_at"])
if key_data["expires_at"]
else None
),
created_at=datetime.fromisoformat(key_data["created_at"]),
)
except Exception as e:
logger.error(f"Error creating API key: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to create API key")
async def list_api_keys(self, account_id: UUID) -> List[APIKeyResponse]:
"""
List all API keys for the specified account
Args:
account_id: The account ID to list keys for
Returns:
List of APIKeyResponse objects
"""
try:
client = await self.db.client
result = (
await client.table("api_keys")
.select(
"key_id, public_key, title, description, status, expires_at, last_used_at, created_at"
)
.eq("account_id", str(account_id))
.order("created_at", desc=True)
.execute()
)
api_keys = []
for key_data in result.data:
api_keys.append(
APIKeyResponse(
key_id=UUID(key_data["key_id"]),
public_key=key_data["public_key"],
title=key_data["title"],
description=key_data["description"],
status=key_data["status"],
expires_at=(
datetime.fromisoformat(key_data["expires_at"])
if key_data["expires_at"]
else None
),
last_used_at=(
datetime.fromisoformat(key_data["last_used_at"])
if key_data["last_used_at"]
else None
),
created_at=datetime.fromisoformat(key_data["created_at"]),
)
)
return api_keys
except Exception as e:
logger.error(f"Error listing API keys: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to list API keys")
async def revoke_api_key(self, account_id: UUID, key_id: UUID) -> bool:
"""
Revoke an API key
Args:
account_id: The account ID that owns the key
key_id: The ID of the key to revoke
Returns:
True if successful, False otherwise
"""
try:
client = await self.db.client
result = (
await client.table("api_keys")
.update({"status": APIKeyStatus.REVOKED})
.eq("key_id", str(key_id))
.eq("account_id", str(account_id))
.execute()
)
if not result.data:
raise HTTPException(status_code=404, detail="API key not found")
logger.info(
"API key revoked successfully",
account_id=str(account_id),
key_id=str(key_id),
)
return True
except HTTPException:
raise
except Exception as e:
logger.error(f"Error revoking API key: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to revoke API key")
async def validate_api_key(
self, public_key: str, secret_key: str
) -> APIKeyValidationResult:
"""
Validate an API key pair with Redis caching for performance
Args:
public_key: The public key (starts with 'pk_')
secret_key: The secret key (starts with 'sk_')
Returns:
APIKeyValidationResult with validation status and account info
"""
try:
# Validate key format
if not public_key.startswith("pk_") or not secret_key.startswith("sk_"):
return APIKeyValidationResult(
is_valid=False, error_message="Invalid API key format"
)
# Check Redis cache first (cache key includes secret hash for security)
cache_key = f"api_key:{public_key}:{self._hash_secret_key(secret_key)[:8]}"
try:
redis_client = await redis.get_client()
cached_result = await redis_client.get(cache_key)
if cached_result:
import json
cached_data = json.loads(cached_result)
logger.debug(f"API key validation cache hit for {public_key}")
return APIKeyValidationResult(
is_valid=cached_data["is_valid"],
account_id=(
UUID(cached_data["account_id"])
if cached_data["account_id"]
else None
),
key_id=(
UUID(cached_data["key_id"])
if cached_data["key_id"]
else None
),
error_message=cached_data.get("error_message"),
)
except Exception as e:
logger.warning(f"Redis cache lookup failed: {e}")
# Continue without cache
client = await self.db.client
# Single optimized query with join to get user info
result = (
await client.table("api_keys")
.select("key_id, account_id, status, expires_at, secret_key_hash")
.eq("public_key", public_key)
.execute()
)
if not result.data:
validation_result = APIKeyValidationResult(
is_valid=False, error_message="API key not found"
)
await self._cache_validation_result(
cache_key, validation_result, ttl=300
) # Cache negative results for 5 min
return validation_result
key_data = result.data[0]
# Check if key is expired first (faster than status check)
if key_data["expires_at"]:
expires_at = datetime.fromisoformat(key_data["expires_at"])
if expires_at < datetime.now(timezone.utc):
validation_result = APIKeyValidationResult(
is_valid=False, error_message="API key expired"
)
await self._cache_validation_result(
cache_key, validation_result, ttl=3600
) # Cache expired for 1 hour
return validation_result
# Check if key is active
if key_data["status"] != APIKeyStatus.ACTIVE:
validation_result = APIKeyValidationResult(
is_valid=False, error_message=f"API key is {key_data['status']}"
)
await self._cache_validation_result(
cache_key, validation_result, ttl=3600
) # Cache inactive for 1 hour
return validation_result
# Verify the secret key against the stored hash
if not self._verify_secret_key(secret_key, key_data["secret_key_hash"]):
validation_result = APIKeyValidationResult(
is_valid=False, error_message="Invalid secret key"
)
await self._cache_validation_result(
cache_key, validation_result, ttl=300
) # Cache invalid for 5 min
return validation_result
# Success case
validation_result = APIKeyValidationResult(
is_valid=True,
account_id=UUID(key_data["account_id"]),
key_id=UUID(key_data["key_id"]),
)
# Cache successful validation for 2 minutes
await self._cache_validation_result(cache_key, validation_result, ttl=120)
# Update last used timestamp with throttling to prevent DB spam
# (max once per 15 minutes per key, configurable via config.API_KEY_LAST_USED_THROTTLE_SECONDS)
asyncio.create_task(self._update_last_used_throttled(key_data["key_id"]))
return validation_result
except Exception as e:
logger.error(f"Error validating API key: {e}", exc_info=True)
return APIKeyValidationResult(
is_valid=False, error_message="Internal server error"
)
async def _cache_validation_result(
self, cache_key: str, result: APIKeyValidationResult, ttl: int = 120
):
"""Cache validation result in Redis"""
try:
redis_client = await redis.get_client()
import json
cache_data = {
"is_valid": result.is_valid,
"account_id": str(result.account_id) if result.account_id else None,
"key_id": str(result.key_id) if result.key_id else None,
"error_message": result.error_message,
}
await redis_client.setex(cache_key, ttl, json.dumps(cache_data))
except Exception as e:
logger.warning(f"Failed to cache validation result: {e}")
async def _update_last_used_throttled(self, key_id: str):
"""Update last used timestamp with throttling to reduce DB load"""
throttle_interval = config.API_KEY_LAST_USED_THROTTLE_SECONDS
current_time = time.time()
# Try Redis first
try:
redis_client = await redis.get_client()
throttle_key = f"last_used_throttle:{key_id}"
# Check if we've updated this key recently
last_update = await redis_client.get(throttle_key)
if last_update:
# Already updated within throttle interval, skip
return
# Set throttle flag first to prevent race conditions
await redis_client.setex(throttle_key, throttle_interval, "1")
except Exception as redis_error:
# Fallback to in-memory throttling when Redis unavailable
logger.debug(
f"Redis unavailable for throttling, using in-memory fallback: {redis_error}"
)
# Clean up old entries (simple cleanup every 100 operations)
if len(self._throttle_cache) > 1000:
cutoff_time = current_time - (
throttle_interval * 2
) # Keep extra buffer
self._throttle_cache = {
k: v for k, v in self._throttle_cache.items() if v > cutoff_time
}
# Check in-memory throttle
last_update_time = self._throttle_cache.get(key_id, 0)
if current_time - last_update_time < throttle_interval:
# Already updated within throttle interval, skip
return
# Set in-memory throttle
self._throttle_cache[key_id] = current_time
# Update database
try:
client = await self.db.client
await client.table("api_keys").update(
{"last_used_at": datetime.now(timezone.utc).isoformat()}
).eq("key_id", key_id).execute()
logger.debug(f"Updated last_used_at for key {key_id}")
except Exception as e:
logger.warning(f"Failed to update last_used_at for key {key_id}: {e}")
async def _update_last_used_async(self, key_id: str):
"""Legacy method - kept for backwards compatibility"""
await self._update_last_used_throttled(key_id)
async def _clear_throttle(self, key_id: str):
"""Clear the throttle for a specific key (useful for testing)"""
try:
redis_client = await redis.get_client()
throttle_key = f"last_used_throttle:{key_id}"
await redis_client.delete(throttle_key)
logger.debug(f"Cleared throttle for key {key_id}")
except Exception as e:
logger.warning(f"Failed to clear throttle for key {key_id}: {e}")
async def delete_api_key(self, account_id: UUID, key_id: UUID) -> bool:
"""
Delete an API key permanently
Args:
account_id: The account ID that owns the key
key_id: The ID of the key to delete
Returns:
True if successful, False otherwise
"""
try:
client = await self.db.client
result = (
await client.table("api_keys")
.delete()
.eq("key_id", str(key_id))
.eq("account_id", str(account_id))
.execute()
)
if not result.data:
raise HTTPException(status_code=404, detail="API key not found")
logger.info(
"API key deleted successfully",
account_id=str(account_id),
key_id=str(key_id),
)
return True
except HTTPException:
raise
except Exception as e:
logger.error(f"Error deleting API key: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to delete API key")