suna/backend/models/manager.py

224 lines
8.2 KiB
Python

from typing import Optional, List, Dict, Any, Tuple
from .registry import registry
from .models import Model, ModelCapability
from utils.logger import logger
from .registry import DEFAULT_PREMIUM_MODEL, DEFAULT_FREE_MODEL
class ModelManager:
def __init__(self):
self.registry = registry
def get_model(self, model_id: str) -> Optional[Model]:
return self.registry.get(model_id)
def resolve_model_id(self, model_id: str) -> str:
logger.debug(f"resolve_model_id called with: '{model_id}' (type: {type(model_id)})")
resolved = self.registry.resolve_model_id(model_id)
if resolved:
logger.debug(f"Resolved model '{model_id}' to '{resolved}'")
return resolved
all_aliases = list(self.registry._aliases.keys())
logger.warning(f"Could not resolve model ID: '{model_id}'. Available aliases: {all_aliases[:10]}...")
return model_id
def validate_model(self, model_id: str) -> Tuple[bool, str]:
model = self.get_model(model_id)
if not model:
return False, f"Model '{model_id}' not found"
if not model.enabled:
return False, f"Model '{model.name}' is currently disabled"
return True, ""
def calculate_cost(
self,
model_id: str,
input_tokens: int,
output_tokens: int
) -> Optional[float]:
model = self.get_model(model_id)
if not model or not model.pricing:
logger.warning(f"No pricing available for model: {model_id}")
return None
input_cost = input_tokens * model.pricing.input_cost_per_token
output_cost = output_tokens * model.pricing.output_cost_per_token
total_cost = input_cost + output_cost
logger.debug(
f"Cost calculation for {model.name}: "
f"{input_tokens} input tokens (${input_cost:.6f}) + "
f"{output_tokens} output tokens (${output_cost:.6f}) = "
f"${total_cost:.6f}"
)
return total_cost
def get_models_for_tier(self, tier: str) -> List[Model]:
return self.registry.get_by_tier(tier, enabled_only=True)
def get_models_with_capability(self, capability: ModelCapability) -> List[Model]:
return self.registry.get_by_capability(capability, enabled_only=True)
def select_best_model(
self,
tier: str,
required_capabilities: Optional[List[ModelCapability]] = None,
min_context_window: Optional[int] = None,
prefer_cheaper: bool = False
) -> Optional[Model]:
models = self.get_models_for_tier(tier)
if required_capabilities:
models = [
m for m in models
if all(cap in m.capabilities for cap in required_capabilities)
]
if min_context_window:
models = [m for m in models if m.context_window >= min_context_window]
if not models:
return None
if prefer_cheaper and any(m.pricing for m in models):
models_with_pricing = [m for m in models if m.pricing]
if models_with_pricing:
models = sorted(
models_with_pricing,
key=lambda m: m.pricing.input_cost_per_million_tokens
)
else:
models = sorted(
models,
key=lambda m: (-m.priority, not m.recommended)
)
return models[0] if models else None
def get_default_model(self, tier: str = "free") -> Optional[Model]:
models = self.get_models_for_tier(tier)
recommended = [m for m in models if m.recommended]
if recommended:
recommended = sorted(recommended, key=lambda m: -m.priority)
return recommended[0]
if models:
models = sorted(models, key=lambda m: -m.priority)
return models[0]
return None
def get_context_window(self, model_id: str, default: int = 31_000) -> int:
return self.registry.get_context_window(model_id, default)
def check_token_limit(
self,
model_id: str,
token_count: int,
is_input: bool = True
) -> Tuple[bool, int]:
model = self.get_model(model_id)
if not model:
return False, 0
if is_input:
max_allowed = model.context_window
else:
max_allowed = model.max_output_tokens or model.context_window
return token_count <= max_allowed, max_allowed
def format_model_info(self, model_id: str) -> Dict[str, Any]:
model = self.get_model(model_id)
if not model:
return {"error": f"Model '{model_id}' not found"}
return {
"id": model.id,
"name": model.name,
"provider": model.provider.value,
"context_window": model.context_window,
"max_output_tokens": model.max_output_tokens,
"capabilities": [cap.value for cap in model.capabilities],
"pricing": {
"input_per_million": model.pricing.input_cost_per_million_tokens,
"output_per_million": model.pricing.output_cost_per_million_tokens,
} if model.pricing else None,
"enabled": model.enabled,
"beta": model.beta,
"tier_availability": model.tier_availability,
"priority": model.priority,
"recommended": model.recommended,
}
def list_available_models(
self,
tier: Optional[str] = None,
include_disabled: bool = False
) -> List[Dict[str, Any]]:
logger.debug(f"list_available_models called with tier='{tier}', include_disabled={include_disabled}")
if tier:
models = self.registry.get_by_tier(tier, enabled_only=not include_disabled)
logger.debug(f"Found {len(models)} models for tier '{tier}'")
else:
models = self.registry.get_all(enabled_only=not include_disabled)
logger.debug(f"Found {len(models)} total models")
if models:
model_names = [m.name for m in models]
logger.debug(f"Models: {model_names}")
else:
logger.warning(f"No models found for tier '{tier}' - this might indicate a configuration issue")
models = sorted(
models,
key=lambda m: (not m.is_free_tier, -m.priority, m.name)
)
return [self.format_model_info(m.id) for m in models]
def get_legacy_constants(self) -> Dict:
return self.registry.to_legacy_format()
async def get_default_model_for_user(self, client, user_id: str) -> str:
try:
from utils.config import config, EnvMode
if config.ENV_MODE == EnvMode.LOCAL:
return DEFAULT_PREMIUM_MODEL
from services.billing import get_user_subscription, SUBSCRIPTION_TIERS
subscription = await get_user_subscription(user_id)
is_paid_tier = False
if subscription:
price_id = None
if subscription.get('items') and subscription['items'].get('data') and len(subscription['items']['data']) > 0:
price_id = subscription['items']['data'][0]['price']['id']
else:
price_id = subscription.get('price_id')
tier_info = SUBSCRIPTION_TIERS.get(price_id)
if tier_info and tier_info['name'] != 'free':
is_paid_tier = True
if is_paid_tier:
logger.debug(f"Setting Claude Sonnet 4 as default for paid user {user_id}")
return DEFAULT_PREMIUM_MODEL
else:
logger.debug(f"Setting Kimi K2 as default for free user {user_id}")
return DEFAULT_FREE_MODEL
except Exception as e:
logger.warning(f"Failed to determine user tier for {user_id}: {e}")
return DEFAULT_FREE_MODEL
model_manager = ModelManager()