mirror of https://github.com/kortix-ai/suna.git
224 lines
8.2 KiB
Python
224 lines
8.2 KiB
Python
from typing import Optional, List, Dict, Any, Tuple
|
|
from .registry import registry
|
|
from .models import Model, ModelCapability
|
|
from utils.logger import logger
|
|
from .registry import DEFAULT_PREMIUM_MODEL, DEFAULT_FREE_MODEL
|
|
|
|
class ModelManager:
|
|
def __init__(self):
|
|
self.registry = registry
|
|
|
|
def get_model(self, model_id: str) -> Optional[Model]:
|
|
return self.registry.get(model_id)
|
|
|
|
def resolve_model_id(self, model_id: str) -> str:
|
|
logger.debug(f"resolve_model_id called with: '{model_id}' (type: {type(model_id)})")
|
|
|
|
resolved = self.registry.resolve_model_id(model_id)
|
|
if resolved:
|
|
logger.debug(f"Resolved model '{model_id}' to '{resolved}'")
|
|
return resolved
|
|
|
|
all_aliases = list(self.registry._aliases.keys())
|
|
logger.warning(f"Could not resolve model ID: '{model_id}'. Available aliases: {all_aliases[:10]}...")
|
|
return model_id
|
|
|
|
def validate_model(self, model_id: str) -> Tuple[bool, str]:
|
|
model = self.get_model(model_id)
|
|
|
|
if not model:
|
|
return False, f"Model '{model_id}' not found"
|
|
|
|
if not model.enabled:
|
|
return False, f"Model '{model.name}' is currently disabled"
|
|
|
|
return True, ""
|
|
|
|
def calculate_cost(
|
|
self,
|
|
model_id: str,
|
|
input_tokens: int,
|
|
output_tokens: int
|
|
) -> Optional[float]:
|
|
model = self.get_model(model_id)
|
|
if not model or not model.pricing:
|
|
logger.warning(f"No pricing available for model: {model_id}")
|
|
return None
|
|
|
|
input_cost = input_tokens * model.pricing.input_cost_per_token
|
|
output_cost = output_tokens * model.pricing.output_cost_per_token
|
|
total_cost = input_cost + output_cost
|
|
|
|
logger.debug(
|
|
f"Cost calculation for {model.name}: "
|
|
f"{input_tokens} input tokens (${input_cost:.6f}) + "
|
|
f"{output_tokens} output tokens (${output_cost:.6f}) = "
|
|
f"${total_cost:.6f}"
|
|
)
|
|
|
|
return total_cost
|
|
|
|
def get_models_for_tier(self, tier: str) -> List[Model]:
|
|
return self.registry.get_by_tier(tier, enabled_only=True)
|
|
|
|
def get_models_with_capability(self, capability: ModelCapability) -> List[Model]:
|
|
return self.registry.get_by_capability(capability, enabled_only=True)
|
|
|
|
def select_best_model(
|
|
self,
|
|
tier: str,
|
|
required_capabilities: Optional[List[ModelCapability]] = None,
|
|
min_context_window: Optional[int] = None,
|
|
prefer_cheaper: bool = False
|
|
) -> Optional[Model]:
|
|
models = self.get_models_for_tier(tier)
|
|
|
|
if required_capabilities:
|
|
models = [
|
|
m for m in models
|
|
if all(cap in m.capabilities for cap in required_capabilities)
|
|
]
|
|
|
|
if min_context_window:
|
|
models = [m for m in models if m.context_window >= min_context_window]
|
|
|
|
if not models:
|
|
return None
|
|
|
|
if prefer_cheaper and any(m.pricing for m in models):
|
|
models_with_pricing = [m for m in models if m.pricing]
|
|
if models_with_pricing:
|
|
models = sorted(
|
|
models_with_pricing,
|
|
key=lambda m: m.pricing.input_cost_per_million_tokens
|
|
)
|
|
else:
|
|
models = sorted(
|
|
models,
|
|
key=lambda m: (-m.priority, not m.recommended)
|
|
)
|
|
|
|
return models[0] if models else None
|
|
|
|
def get_default_model(self, tier: str = "free") -> Optional[Model]:
|
|
models = self.get_models_for_tier(tier)
|
|
|
|
recommended = [m for m in models if m.recommended]
|
|
if recommended:
|
|
recommended = sorted(recommended, key=lambda m: -m.priority)
|
|
return recommended[0]
|
|
|
|
if models:
|
|
models = sorted(models, key=lambda m: -m.priority)
|
|
return models[0]
|
|
|
|
return None
|
|
|
|
def get_context_window(self, model_id: str, default: int = 31_000) -> int:
|
|
return self.registry.get_context_window(model_id, default)
|
|
|
|
def check_token_limit(
|
|
self,
|
|
model_id: str,
|
|
token_count: int,
|
|
is_input: bool = True
|
|
) -> Tuple[bool, int]:
|
|
model = self.get_model(model_id)
|
|
if not model:
|
|
return False, 0
|
|
|
|
if is_input:
|
|
max_allowed = model.context_window
|
|
else:
|
|
max_allowed = model.max_output_tokens or model.context_window
|
|
|
|
return token_count <= max_allowed, max_allowed
|
|
|
|
def format_model_info(self, model_id: str) -> Dict[str, Any]:
|
|
model = self.get_model(model_id)
|
|
if not model:
|
|
return {"error": f"Model '{model_id}' not found"}
|
|
|
|
return {
|
|
"id": model.id,
|
|
"name": model.name,
|
|
"provider": model.provider.value,
|
|
"context_window": model.context_window,
|
|
"max_output_tokens": model.max_output_tokens,
|
|
"capabilities": [cap.value for cap in model.capabilities],
|
|
"pricing": {
|
|
"input_per_million": model.pricing.input_cost_per_million_tokens,
|
|
"output_per_million": model.pricing.output_cost_per_million_tokens,
|
|
} if model.pricing else None,
|
|
"enabled": model.enabled,
|
|
"beta": model.beta,
|
|
"tier_availability": model.tier_availability,
|
|
"priority": model.priority,
|
|
"recommended": model.recommended,
|
|
}
|
|
|
|
def list_available_models(
|
|
self,
|
|
tier: Optional[str] = None,
|
|
include_disabled: bool = False
|
|
) -> List[Dict[str, Any]]:
|
|
logger.debug(f"list_available_models called with tier='{tier}', include_disabled={include_disabled}")
|
|
|
|
if tier:
|
|
models = self.registry.get_by_tier(tier, enabled_only=not include_disabled)
|
|
logger.debug(f"Found {len(models)} models for tier '{tier}'")
|
|
else:
|
|
models = self.registry.get_all(enabled_only=not include_disabled)
|
|
logger.debug(f"Found {len(models)} total models")
|
|
|
|
if models:
|
|
model_names = [m.name for m in models]
|
|
logger.debug(f"Models: {model_names}")
|
|
else:
|
|
logger.warning(f"No models found for tier '{tier}' - this might indicate a configuration issue")
|
|
|
|
models = sorted(
|
|
models,
|
|
key=lambda m: (not m.is_free_tier, -m.priority, m.name)
|
|
)
|
|
|
|
return [self.format_model_info(m.id) for m in models]
|
|
|
|
def get_legacy_constants(self) -> Dict:
|
|
return self.registry.to_legacy_format()
|
|
|
|
async def get_default_model_for_user(self, client, user_id: str) -> str:
|
|
try:
|
|
from utils.config import config, EnvMode
|
|
if config.ENV_MODE == EnvMode.LOCAL:
|
|
return DEFAULT_PREMIUM_MODEL
|
|
|
|
from services.billing import get_user_subscription, SUBSCRIPTION_TIERS
|
|
|
|
subscription = await get_user_subscription(user_id)
|
|
|
|
is_paid_tier = False
|
|
if subscription:
|
|
price_id = None
|
|
if subscription.get('items') and subscription['items'].get('data') and len(subscription['items']['data']) > 0:
|
|
price_id = subscription['items']['data'][0]['price']['id']
|
|
else:
|
|
price_id = subscription.get('price_id')
|
|
|
|
tier_info = SUBSCRIPTION_TIERS.get(price_id)
|
|
if tier_info and tier_info['name'] != 'free':
|
|
is_paid_tier = True
|
|
|
|
if is_paid_tier:
|
|
logger.debug(f"Setting Claude Sonnet 4 as default for paid user {user_id}")
|
|
return DEFAULT_PREMIUM_MODEL
|
|
else:
|
|
logger.debug(f"Setting Kimi K2 as default for free user {user_id}")
|
|
return DEFAULT_FREE_MODEL
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to determine user tier for {user_id}: {e}")
|
|
return DEFAULT_FREE_MODEL
|
|
|
|
|
|
model_manager = ModelManager() |