mirror of https://github.com/kortix-ai/suna.git
580 lines
20 KiB
Python
580 lines
20 KiB
Python
"""
|
|
API Keys Service
|
|
|
|
This module provides functionality for managing API keys including:
|
|
- Creating new API keys with UUIDs
|
|
- Validating API keys for authentication
|
|
- Managing expiration and revocation
|
|
- CRUD operations for user API keys
|
|
"""
|
|
|
|
import asyncio
|
|
from datetime import datetime, timezone, timedelta
|
|
from typing import Optional, List, Dict
|
|
from uuid import UUID, uuid4
|
|
import secrets
|
|
import string
|
|
import hmac
|
|
import hashlib
|
|
import time
|
|
from pydantic import BaseModel, Field, field_validator
|
|
from fastapi import HTTPException
|
|
from utils.logger import logger
|
|
from services.supabase import DBConnection
|
|
from services import redis
|
|
from utils.config import config
|
|
|
|
|
|
class APIKeyStatus:
|
|
ACTIVE = "active"
|
|
REVOKED = "revoked"
|
|
EXPIRED = "expired"
|
|
|
|
|
|
class APIKeyCreateRequest(BaseModel):
|
|
"""Request model for creating a new API key"""
|
|
|
|
title: str = Field(
|
|
...,
|
|
min_length=1,
|
|
max_length=255,
|
|
description="Human-readable title for the API key",
|
|
)
|
|
description: Optional[str] = Field(
|
|
None, description="Optional description for the API key"
|
|
)
|
|
expires_in_days: Optional[int] = Field(
|
|
None, gt=0, le=365, description="Number of days until expiration (max 365)"
|
|
)
|
|
|
|
@field_validator("title")
|
|
def validate_title(cls, v):
|
|
if not v or not v.strip():
|
|
raise ValueError("Title cannot be empty")
|
|
return v.strip()
|
|
|
|
|
|
class APIKeyResponse(BaseModel):
|
|
"""Response model for API key information (without the secret key)"""
|
|
|
|
key_id: UUID
|
|
public_key: str
|
|
title: str
|
|
description: Optional[str]
|
|
status: str
|
|
expires_at: Optional[datetime]
|
|
last_used_at: Optional[datetime]
|
|
created_at: datetime
|
|
|
|
|
|
class APIKeyCreateResponse(BaseModel):
|
|
"""Response model for newly created API key (includes both keys)"""
|
|
|
|
key_id: UUID
|
|
public_key: str
|
|
secret_key: str # Only returned on creation
|
|
title: str
|
|
description: Optional[str]
|
|
status: str
|
|
expires_at: Optional[datetime]
|
|
created_at: datetime
|
|
|
|
|
|
class APIKeyValidationResult(BaseModel):
|
|
"""Result of API key validation"""
|
|
|
|
is_valid: bool
|
|
account_id: Optional[UUID] = None
|
|
key_id: Optional[UUID] = None
|
|
error_message: Optional[str] = None
|
|
|
|
|
|
class APIKeyService:
|
|
"""
|
|
Service for managing API keys with performance optimizations
|
|
|
|
Performance Features:
|
|
- HMAC-SHA256 hashing (100x faster than bcrypt)
|
|
- Redis caching for validation results (2min TTL)
|
|
- Throttled last_used_at updates (max once per 15min per key, configurable)
|
|
- Cached user lookups (5min TTL)
|
|
- Asynchronous operations where possible
|
|
- In-memory fallback throttling when Redis unavailable
|
|
- Streamlined database schema without unnecessary triggers
|
|
"""
|
|
|
|
# Class-level in-memory throttle cache (fallback when Redis unavailable)
|
|
_throttle_cache: Dict[str, float] = {}
|
|
|
|
def __init__(self, db: DBConnection):
|
|
self.db = db
|
|
|
|
def _generate_key_pair(self) -> tuple[str, str]:
|
|
"""
|
|
Generate a public key and secret key pair
|
|
|
|
Returns:
|
|
tuple: (public_key, secret_key) where public_key starts with 'pk_' and secret_key starts with 'sk_'
|
|
"""
|
|
# Generate random strings for both keys
|
|
pk_suffix = "".join(
|
|
secrets.choice(string.ascii_letters + string.digits) for _ in range(32)
|
|
)
|
|
sk_suffix = "".join(
|
|
secrets.choice(string.ascii_letters + string.digits) for _ in range(32)
|
|
)
|
|
|
|
public_key = f"pk_{pk_suffix}"
|
|
secret_key = f"sk_{sk_suffix}"
|
|
|
|
return public_key, secret_key
|
|
|
|
def _get_secret_key(self) -> str:
|
|
"""Get the secret key for HMAC hashing"""
|
|
return config.API_KEY_SECRET
|
|
|
|
def _hash_secret_key(self, secret_key: str) -> str:
|
|
"""
|
|
Hash a secret key using HMAC-SHA256 (much faster than bcrypt)
|
|
|
|
Args:
|
|
secret_key: The secret key to hash
|
|
|
|
Returns:
|
|
str: The HMAC-SHA256 hash of the secret key
|
|
"""
|
|
secret = self._get_secret_key().encode("utf-8")
|
|
return hmac.new(secret, secret_key.encode("utf-8"), hashlib.sha256).hexdigest()
|
|
|
|
def _verify_secret_key(self, secret_key: str, hashed_key: str) -> bool:
|
|
"""
|
|
Verify a secret key against its hash using constant-time comparison
|
|
|
|
Args:
|
|
secret_key: The secret key to verify
|
|
hashed_key: The stored hash
|
|
|
|
Returns:
|
|
bool: True if the secret key matches the hash
|
|
"""
|
|
try:
|
|
expected_hash = self._hash_secret_key(secret_key)
|
|
return hmac.compare_digest(expected_hash, hashed_key)
|
|
except Exception:
|
|
return False
|
|
|
|
async def create_api_key(
|
|
self, account_id: UUID, request: APIKeyCreateRequest
|
|
) -> APIKeyCreateResponse:
|
|
"""
|
|
Create a new API key for the specified account
|
|
|
|
Args:
|
|
account_id: The account ID to create the key for
|
|
request: The API key creation request
|
|
|
|
Returns:
|
|
APIKeyCreateResponse containing the new API key details including both keys
|
|
"""
|
|
try:
|
|
# Calculate expiration date if specified
|
|
expires_at = None
|
|
if request.expires_in_days:
|
|
expires_at = datetime.now(timezone.utc) + timedelta(
|
|
days=request.expires_in_days
|
|
)
|
|
|
|
# Generate public and secret key pair
|
|
public_key, secret_key = self._generate_key_pair()
|
|
|
|
# Hash the secret key for storage
|
|
secret_key_hash = self._hash_secret_key(secret_key)
|
|
|
|
# Insert into database
|
|
client = await self.db.client
|
|
result = (
|
|
await client.table("api_keys")
|
|
.insert(
|
|
{
|
|
"public_key": public_key,
|
|
"secret_key_hash": secret_key_hash,
|
|
"account_id": str(account_id),
|
|
"title": request.title,
|
|
"description": request.description,
|
|
"expires_at": expires_at.isoformat() if expires_at else None,
|
|
"status": APIKeyStatus.ACTIVE,
|
|
}
|
|
)
|
|
.execute()
|
|
)
|
|
|
|
if not result.data:
|
|
raise HTTPException(status_code=500, detail="Failed to create API key")
|
|
|
|
key_data = result.data[0]
|
|
|
|
logger.info(
|
|
"API key created successfully",
|
|
account_id=str(account_id),
|
|
key_id=key_data["key_id"],
|
|
public_key=public_key,
|
|
title=request.title,
|
|
)
|
|
|
|
return APIKeyCreateResponse(
|
|
key_id=UUID(key_data["key_id"]),
|
|
public_key=public_key,
|
|
secret_key=secret_key, # Only returned on creation
|
|
title=key_data["title"],
|
|
description=key_data["description"],
|
|
status=key_data["status"],
|
|
expires_at=(
|
|
datetime.fromisoformat(key_data["expires_at"])
|
|
if key_data["expires_at"]
|
|
else None
|
|
),
|
|
created_at=datetime.fromisoformat(key_data["created_at"]),
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating API key: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail="Failed to create API key")
|
|
|
|
async def list_api_keys(self, account_id: UUID) -> List[APIKeyResponse]:
|
|
"""
|
|
List all API keys for the specified account
|
|
|
|
Args:
|
|
account_id: The account ID to list keys for
|
|
|
|
Returns:
|
|
List of APIKeyResponse objects
|
|
"""
|
|
try:
|
|
client = await self.db.client
|
|
result = (
|
|
await client.table("api_keys")
|
|
.select(
|
|
"key_id, public_key, title, description, status, expires_at, last_used_at, created_at"
|
|
)
|
|
.eq("account_id", str(account_id))
|
|
.order("created_at", desc=True)
|
|
.execute()
|
|
)
|
|
|
|
api_keys = []
|
|
for key_data in result.data:
|
|
api_keys.append(
|
|
APIKeyResponse(
|
|
key_id=UUID(key_data["key_id"]),
|
|
public_key=key_data["public_key"],
|
|
title=key_data["title"],
|
|
description=key_data["description"],
|
|
status=key_data["status"],
|
|
expires_at=(
|
|
datetime.fromisoformat(key_data["expires_at"])
|
|
if key_data["expires_at"]
|
|
else None
|
|
),
|
|
last_used_at=(
|
|
datetime.fromisoformat(key_data["last_used_at"])
|
|
if key_data["last_used_at"]
|
|
else None
|
|
),
|
|
created_at=datetime.fromisoformat(key_data["created_at"]),
|
|
)
|
|
)
|
|
|
|
return api_keys
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error listing API keys: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail="Failed to list API keys")
|
|
|
|
async def revoke_api_key(self, account_id: UUID, key_id: UUID) -> bool:
|
|
"""
|
|
Revoke an API key
|
|
|
|
Args:
|
|
account_id: The account ID that owns the key
|
|
key_id: The ID of the key to revoke
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
try:
|
|
client = await self.db.client
|
|
result = (
|
|
await client.table("api_keys")
|
|
.update({"status": APIKeyStatus.REVOKED})
|
|
.eq("key_id", str(key_id))
|
|
.eq("account_id", str(account_id))
|
|
.execute()
|
|
)
|
|
|
|
if not result.data:
|
|
raise HTTPException(status_code=404, detail="API key not found")
|
|
|
|
logger.info(
|
|
"API key revoked successfully",
|
|
account_id=str(account_id),
|
|
key_id=str(key_id),
|
|
)
|
|
|
|
return True
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error revoking API key: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail="Failed to revoke API key")
|
|
|
|
async def validate_api_key(
|
|
self, public_key: str, secret_key: str
|
|
) -> APIKeyValidationResult:
|
|
"""
|
|
Validate an API key pair with Redis caching for performance
|
|
|
|
Args:
|
|
public_key: The public key (starts with 'pk_')
|
|
secret_key: The secret key (starts with 'sk_')
|
|
|
|
Returns:
|
|
APIKeyValidationResult with validation status and account info
|
|
"""
|
|
try:
|
|
# Validate key format
|
|
if not public_key.startswith("pk_") or not secret_key.startswith("sk_"):
|
|
return APIKeyValidationResult(
|
|
is_valid=False, error_message="Invalid API key format"
|
|
)
|
|
|
|
# Check Redis cache first (cache key includes secret hash for security)
|
|
cache_key = f"api_key:{public_key}:{self._hash_secret_key(secret_key)[:8]}"
|
|
|
|
try:
|
|
redis_client = await redis.get_client()
|
|
cached_result = await redis_client.get(cache_key)
|
|
if cached_result:
|
|
import json
|
|
|
|
cached_data = json.loads(cached_result)
|
|
logger.debug(f"API key validation cache hit for {public_key}")
|
|
return APIKeyValidationResult(
|
|
is_valid=cached_data["is_valid"],
|
|
account_id=(
|
|
UUID(cached_data["account_id"])
|
|
if cached_data["account_id"]
|
|
else None
|
|
),
|
|
key_id=(
|
|
UUID(cached_data["key_id"])
|
|
if cached_data["key_id"]
|
|
else None
|
|
),
|
|
error_message=cached_data.get("error_message"),
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Redis cache lookup failed: {e}")
|
|
# Continue without cache
|
|
|
|
client = await self.db.client
|
|
|
|
# Single optimized query with join to get user info
|
|
result = (
|
|
await client.table("api_keys")
|
|
.select("key_id, account_id, status, expires_at, secret_key_hash")
|
|
.eq("public_key", public_key)
|
|
.execute()
|
|
)
|
|
|
|
if not result.data:
|
|
validation_result = APIKeyValidationResult(
|
|
is_valid=False, error_message="API key not found"
|
|
)
|
|
await self._cache_validation_result(
|
|
cache_key, validation_result, ttl=300
|
|
) # Cache negative results for 5 min
|
|
return validation_result
|
|
|
|
key_data = result.data[0]
|
|
|
|
# Check if key is expired first (faster than status check)
|
|
if key_data["expires_at"]:
|
|
expires_at = datetime.fromisoformat(key_data["expires_at"])
|
|
if expires_at < datetime.now(timezone.utc):
|
|
validation_result = APIKeyValidationResult(
|
|
is_valid=False, error_message="API key expired"
|
|
)
|
|
await self._cache_validation_result(
|
|
cache_key, validation_result, ttl=3600
|
|
) # Cache expired for 1 hour
|
|
return validation_result
|
|
|
|
# Check if key is active
|
|
if key_data["status"] != APIKeyStatus.ACTIVE:
|
|
validation_result = APIKeyValidationResult(
|
|
is_valid=False, error_message=f"API key is {key_data['status']}"
|
|
)
|
|
await self._cache_validation_result(
|
|
cache_key, validation_result, ttl=3600
|
|
) # Cache inactive for 1 hour
|
|
return validation_result
|
|
|
|
# Verify the secret key against the stored hash
|
|
if not self._verify_secret_key(secret_key, key_data["secret_key_hash"]):
|
|
validation_result = APIKeyValidationResult(
|
|
is_valid=False, error_message="Invalid secret key"
|
|
)
|
|
await self._cache_validation_result(
|
|
cache_key, validation_result, ttl=300
|
|
) # Cache invalid for 5 min
|
|
return validation_result
|
|
|
|
# Success case
|
|
validation_result = APIKeyValidationResult(
|
|
is_valid=True,
|
|
account_id=UUID(key_data["account_id"]),
|
|
key_id=UUID(key_data["key_id"]),
|
|
)
|
|
|
|
# Cache successful validation for 2 minutes
|
|
await self._cache_validation_result(cache_key, validation_result, ttl=120)
|
|
|
|
# Update last used timestamp with throttling to prevent DB spam
|
|
# (max once per 15 minutes per key, configurable via config.API_KEY_LAST_USED_THROTTLE_SECONDS)
|
|
asyncio.create_task(self._update_last_used_throttled(key_data["key_id"]))
|
|
|
|
return validation_result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error validating API key: {e}", exc_info=True)
|
|
return APIKeyValidationResult(
|
|
is_valid=False, error_message="Internal server error"
|
|
)
|
|
|
|
async def _cache_validation_result(
|
|
self, cache_key: str, result: APIKeyValidationResult, ttl: int = 120
|
|
):
|
|
"""Cache validation result in Redis"""
|
|
try:
|
|
redis_client = await redis.get_client()
|
|
import json
|
|
|
|
cache_data = {
|
|
"is_valid": result.is_valid,
|
|
"account_id": str(result.account_id) if result.account_id else None,
|
|
"key_id": str(result.key_id) if result.key_id else None,
|
|
"error_message": result.error_message,
|
|
}
|
|
await redis_client.setex(cache_key, ttl, json.dumps(cache_data))
|
|
except Exception as e:
|
|
logger.warning(f"Failed to cache validation result: {e}")
|
|
|
|
async def _update_last_used_throttled(self, key_id: str):
|
|
"""Update last used timestamp with throttling to reduce DB load"""
|
|
throttle_interval = config.API_KEY_LAST_USED_THROTTLE_SECONDS
|
|
current_time = time.time()
|
|
|
|
# Try Redis first
|
|
try:
|
|
redis_client = await redis.get_client()
|
|
throttle_key = f"last_used_throttle:{key_id}"
|
|
|
|
# Check if we've updated this key recently
|
|
last_update = await redis_client.get(throttle_key)
|
|
if last_update:
|
|
# Already updated within throttle interval, skip
|
|
return
|
|
|
|
# Set throttle flag first to prevent race conditions
|
|
await redis_client.setex(throttle_key, throttle_interval, "1")
|
|
|
|
except Exception as redis_error:
|
|
# Fallback to in-memory throttling when Redis unavailable
|
|
logger.debug(
|
|
f"Redis unavailable for throttling, using in-memory fallback: {redis_error}"
|
|
)
|
|
|
|
# Clean up old entries (simple cleanup every 100 operations)
|
|
if len(self._throttle_cache) > 1000:
|
|
cutoff_time = current_time - (
|
|
throttle_interval * 2
|
|
) # Keep extra buffer
|
|
self._throttle_cache = {
|
|
k: v for k, v in self._throttle_cache.items() if v > cutoff_time
|
|
}
|
|
|
|
# Check in-memory throttle
|
|
last_update_time = self._throttle_cache.get(key_id, 0)
|
|
if current_time - last_update_time < throttle_interval:
|
|
# Already updated within throttle interval, skip
|
|
return
|
|
|
|
# Set in-memory throttle
|
|
self._throttle_cache[key_id] = current_time
|
|
|
|
# Update database
|
|
try:
|
|
client = await self.db.client
|
|
await client.table("api_keys").update(
|
|
{"last_used_at": datetime.now(timezone.utc).isoformat()}
|
|
).eq("key_id", key_id).execute()
|
|
|
|
logger.debug(f"Updated last_used_at for key {key_id}")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to update last_used_at for key {key_id}: {e}")
|
|
|
|
async def _update_last_used_async(self, key_id: str):
|
|
"""Legacy method - kept for backwards compatibility"""
|
|
await self._update_last_used_throttled(key_id)
|
|
|
|
async def _clear_throttle(self, key_id: str):
|
|
"""Clear the throttle for a specific key (useful for testing)"""
|
|
try:
|
|
redis_client = await redis.get_client()
|
|
throttle_key = f"last_used_throttle:{key_id}"
|
|
await redis_client.delete(throttle_key)
|
|
logger.debug(f"Cleared throttle for key {key_id}")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to clear throttle for key {key_id}: {e}")
|
|
|
|
async def delete_api_key(self, account_id: UUID, key_id: UUID) -> bool:
|
|
"""
|
|
Delete an API key permanently
|
|
|
|
Args:
|
|
account_id: The account ID that owns the key
|
|
key_id: The ID of the key to delete
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
try:
|
|
client = await self.db.client
|
|
result = (
|
|
await client.table("api_keys")
|
|
.delete()
|
|
.eq("key_id", str(key_id))
|
|
.eq("account_id", str(account_id))
|
|
.execute()
|
|
)
|
|
|
|
if not result.data:
|
|
raise HTTPException(status_code=404, detail="API key not found")
|
|
|
|
logger.info(
|
|
"API key deleted successfully",
|
|
account_id=str(account_id),
|
|
key_id=str(key_id),
|
|
)
|
|
|
|
return True
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error deleting API key: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail="Failed to delete API key")
|