mirror of https://github.com/kortix-ai/suna.git
327 lines
12 KiB
Python
327 lines
12 KiB
Python
from typing import Dict, List, Optional, Set
|
|
from .models import Model, ModelProvider, ModelCapability, ModelPricing
|
|
|
|
DEFAULT_FREE_MODEL = "Kimi K2"
|
|
DEFAULT_PREMIUM_MODEL = "Claude Sonnet 4"
|
|
|
|
class ModelRegistry:
|
|
def __init__(self):
|
|
self._models: Dict[str, Model] = {}
|
|
self._aliases: Dict[str, str] = {}
|
|
self._initialize_models()
|
|
|
|
def _initialize_models(self):
|
|
self.register(Model(
|
|
id="anthropic/claude-sonnet-4-20250514",
|
|
name="Claude Sonnet 4",
|
|
provider=ModelProvider.ANTHROPIC,
|
|
aliases=["claude-sonnet-4", "anthropic/claude-sonnet-4", "Claude Sonnet 4", "claude-sonnet-4-20250514"],
|
|
context_window=200_000,
|
|
capabilities=[
|
|
ModelCapability.CHAT,
|
|
ModelCapability.FUNCTION_CALLING,
|
|
ModelCapability.VISION,
|
|
ModelCapability.THINKING,
|
|
],
|
|
pricing=ModelPricing(
|
|
input_cost_per_million_tokens=3.00,
|
|
output_cost_per_million_tokens=15.00
|
|
),
|
|
tier_availability=["paid"],
|
|
priority=100,
|
|
recommended=True,
|
|
enabled=True
|
|
))
|
|
|
|
self.register(Model(
|
|
id="anthropic/claude-3-7-sonnet-latest",
|
|
name="Claude 3.7 Sonnet",
|
|
provider=ModelProvider.ANTHROPIC,
|
|
aliases=["sonnet-3.7", "claude-3.7", "Claude 3.7 Sonnet", "claude-3-7-sonnet-latest"],
|
|
context_window=200_000,
|
|
capabilities=[
|
|
ModelCapability.CHAT,
|
|
ModelCapability.FUNCTION_CALLING,
|
|
ModelCapability.VISION,
|
|
],
|
|
pricing=ModelPricing(
|
|
input_cost_per_million_tokens=3.00,
|
|
output_cost_per_million_tokens=15.00
|
|
),
|
|
tier_availability=["paid"],
|
|
priority=93,
|
|
enabled=True
|
|
))
|
|
|
|
self.register(Model(
|
|
id="anthropic/claude-3-5-sonnet-latest",
|
|
name="Claude 3.5 Sonnet",
|
|
provider=ModelProvider.ANTHROPIC,
|
|
aliases=["sonnet-3.5", "claude-3.5", "Claude 3.5 Sonnet", "claude-3-5-sonnet-latest"],
|
|
context_window=200_000,
|
|
capabilities=[
|
|
ModelCapability.CHAT,
|
|
ModelCapability.FUNCTION_CALLING,
|
|
ModelCapability.VISION,
|
|
],
|
|
pricing=ModelPricing(
|
|
input_cost_per_million_tokens=3.00,
|
|
output_cost_per_million_tokens=15.00
|
|
),
|
|
tier_availability=["paid"],
|
|
priority=90,
|
|
enabled=True
|
|
))
|
|
|
|
self.register(Model(
|
|
id="openai/gpt-5",
|
|
name="GPT-5",
|
|
provider=ModelProvider.OPENAI,
|
|
aliases=["gpt-5", "GPT-5"],
|
|
context_window=400_000,
|
|
capabilities=[
|
|
ModelCapability.CHAT,
|
|
ModelCapability.FUNCTION_CALLING,
|
|
ModelCapability.VISION,
|
|
ModelCapability.STRUCTURED_OUTPUT,
|
|
],
|
|
pricing=ModelPricing(
|
|
input_cost_per_million_tokens=1.25,
|
|
output_cost_per_million_tokens=10.00
|
|
),
|
|
tier_availability=["paid"],
|
|
priority=99,
|
|
enabled=True
|
|
))
|
|
|
|
self.register(Model(
|
|
id="openai/gpt-5-mini",
|
|
name="GPT-5 Mini",
|
|
provider=ModelProvider.OPENAI,
|
|
aliases=["gpt-5-mini", "GPT-5 Mini"],
|
|
context_window=400_000,
|
|
capabilities=[
|
|
ModelCapability.CHAT,
|
|
ModelCapability.FUNCTION_CALLING,
|
|
ModelCapability.STRUCTURED_OUTPUT,
|
|
],
|
|
pricing=ModelPricing(
|
|
input_cost_per_million_tokens=0.25,
|
|
output_cost_per_million_tokens=2.00
|
|
),
|
|
tier_availability=["free", "paid"],
|
|
priority=85,
|
|
enabled=True
|
|
))
|
|
|
|
self.register(Model(
|
|
id="gemini/gemini-2.5-pro",
|
|
name="Gemini 2.5 Pro",
|
|
provider=ModelProvider.GEMINI,
|
|
aliases=["google/gemini-2.5-pro", "gemini-2.5-pro", "Gemini 2.5 Pro"],
|
|
context_window=2_000_000,
|
|
capabilities=[
|
|
ModelCapability.CHAT,
|
|
ModelCapability.FUNCTION_CALLING,
|
|
ModelCapability.VISION,
|
|
ModelCapability.STRUCTURED_OUTPUT,
|
|
],
|
|
pricing=ModelPricing(
|
|
input_cost_per_million_tokens=1.25,
|
|
output_cost_per_million_tokens=10.00
|
|
),
|
|
tier_availability=["paid"],
|
|
priority=96,
|
|
enabled=True
|
|
))
|
|
|
|
self.register(Model(
|
|
id="xai/grok-4",
|
|
name="Grok 4",
|
|
provider=ModelProvider.XAI,
|
|
aliases=["grok-4", "x-ai/grok-4", "openrouter/x-ai/grok-4", "Grok 4"],
|
|
context_window=128_000,
|
|
capabilities=[
|
|
ModelCapability.CHAT,
|
|
ModelCapability.FUNCTION_CALLING,
|
|
],
|
|
pricing=ModelPricing(
|
|
input_cost_per_million_tokens=5.00,
|
|
output_cost_per_million_tokens=15.00
|
|
),
|
|
tier_availability=["paid"],
|
|
priority=94,
|
|
enabled=True
|
|
))
|
|
|
|
self.register(Model(
|
|
id="openrouter/moonshotai/kimi-k2",
|
|
name="Kimi K2",
|
|
provider=ModelProvider.MOONSHOTAI,
|
|
aliases=["moonshotai/kimi-k2", "kimi-k2", "Kimi K2"],
|
|
context_window=200_000,
|
|
capabilities=[
|
|
ModelCapability.CHAT,
|
|
ModelCapability.FUNCTION_CALLING,
|
|
],
|
|
pricing=ModelPricing(
|
|
input_cost_per_million_tokens=1.00,
|
|
output_cost_per_million_tokens=3.00
|
|
),
|
|
tier_availability=["free", "paid"],
|
|
priority=100,
|
|
enabled=True
|
|
))
|
|
|
|
"""
|
|
# DeepSeek Models
|
|
self.register(Model(
|
|
id="openrouter/deepseek/deepseek-chat",
|
|
name="DeepSeek Chat",
|
|
provider=ModelProvider.OPENROUTER,
|
|
aliases=["deepseek", "deepseek-chat"],
|
|
context_window=128_000,
|
|
capabilities=[
|
|
ModelCapability.CHAT,
|
|
ModelCapability.FUNCTION_CALLING
|
|
],
|
|
pricing=ModelPricing(
|
|
input_cost_per_million_tokens=0.38,
|
|
output_cost_per_million_tokens=0.89
|
|
),
|
|
tier_availability=["free", "paid"],
|
|
priority=95,
|
|
enabled=False # Currently disabled
|
|
))
|
|
|
|
# Qwen Models
|
|
self.register(Model(
|
|
id="openrouter/qwen/qwen3-235b-a22b",
|
|
name="Qwen3 235B",
|
|
provider=ModelProvider.OPENROUTER,
|
|
aliases=["qwen3", "qwen-3"],
|
|
context_window=128_000,
|
|
capabilities=[
|
|
ModelCapability.CHAT,
|
|
ModelCapability.FUNCTION_CALLING
|
|
],
|
|
pricing=ModelPricing(
|
|
input_cost_per_million_tokens=0.13,
|
|
output_cost_per_million_tokens=0.60
|
|
),
|
|
tier_availability=["free", "paid"],
|
|
priority=90,
|
|
enabled=False # Currently disabled
|
|
))
|
|
"""
|
|
|
|
def register(self, model: Model) -> None:
|
|
self._models[model.id] = model
|
|
for alias in model.aliases:
|
|
self._aliases[alias] = model.id
|
|
|
|
def get(self, model_id: str) -> Optional[Model]:
|
|
if model_id in self._models:
|
|
return self._models[model_id]
|
|
|
|
if model_id in self._aliases:
|
|
actual_id = self._aliases[model_id]
|
|
return self._models.get(actual_id)
|
|
|
|
return None
|
|
|
|
def get_all(self, enabled_only: bool = True) -> List[Model]:
|
|
models = list(self._models.values())
|
|
if enabled_only:
|
|
models = [m for m in models if m.enabled]
|
|
return models
|
|
|
|
def get_by_tier(self, tier: str, enabled_only: bool = True) -> List[Model]:
|
|
models = self.get_all(enabled_only)
|
|
return [m for m in models if tier in m.tier_availability]
|
|
|
|
def get_by_provider(self, provider: ModelProvider, enabled_only: bool = True) -> List[Model]:
|
|
models = self.get_all(enabled_only)
|
|
return [m for m in models if m.provider == provider]
|
|
|
|
def get_by_capability(self, capability: ModelCapability, enabled_only: bool = True) -> List[Model]:
|
|
models = self.get_all(enabled_only)
|
|
return [m for m in models if capability in m.capabilities]
|
|
|
|
def resolve_model_id(self, model_id: str) -> Optional[str]:
|
|
model = self.get(model_id)
|
|
return model.id if model else None
|
|
|
|
def get_aliases(self, model_id: str) -> List[str]:
|
|
model = self.get(model_id)
|
|
return model.aliases if model else []
|
|
|
|
def enable_model(self, model_id: str) -> bool:
|
|
model = self.get(model_id)
|
|
if model:
|
|
model.enabled = True
|
|
return True
|
|
return False
|
|
|
|
def disable_model(self, model_id: str) -> bool:
|
|
model = self.get(model_id)
|
|
if model:
|
|
model.enabled = False
|
|
return True
|
|
return False
|
|
|
|
def get_context_window(self, model_id: str, default: int = 31_000) -> int:
|
|
model = self.get(model_id)
|
|
return model.context_window if model else default
|
|
|
|
def get_pricing(self, model_id: str) -> Optional[ModelPricing]:
|
|
model = self.get(model_id)
|
|
return model.pricing if model else None
|
|
|
|
def to_legacy_format(self) -> Dict:
|
|
models_dict = {}
|
|
aliases_dict = {}
|
|
pricing_dict = {}
|
|
context_windows_dict = {}
|
|
|
|
for model in self.get_all(enabled_only=True):
|
|
models_dict[model.id] = {
|
|
"aliases": model.aliases,
|
|
"pricing": {
|
|
"input_cost_per_million_tokens": model.pricing.input_cost_per_million_tokens,
|
|
"output_cost_per_million_tokens": model.pricing.output_cost_per_million_tokens,
|
|
} if model.pricing else None,
|
|
"context_window": model.context_window,
|
|
"tier_availability": model.tier_availability,
|
|
}
|
|
|
|
for alias in model.aliases:
|
|
aliases_dict[alias] = model.id
|
|
|
|
if model.pricing:
|
|
pricing_dict[model.id] = {
|
|
"input_cost_per_million_tokens": model.pricing.input_cost_per_million_tokens,
|
|
"output_cost_per_million_tokens": model.pricing.output_cost_per_million_tokens,
|
|
}
|
|
|
|
context_windows_dict[model.id] = model.context_window
|
|
|
|
free_models = [m.id for m in self.get_by_tier("free")]
|
|
paid_models = [m.id for m in self.get_by_tier("paid")]
|
|
|
|
# Debug logging
|
|
from utils.logger import logger
|
|
logger.debug(f"Legacy format generation: {len(free_models)} free models, {len(paid_models)} paid models")
|
|
logger.debug(f"Free models: {free_models}")
|
|
logger.debug(f"Paid models: {paid_models}")
|
|
|
|
return {
|
|
"MODELS": models_dict,
|
|
"MODEL_NAME_ALIASES": aliases_dict,
|
|
"HARDCODED_MODEL_PRICES": pricing_dict,
|
|
"MODEL_CONTEXT_WINDOWS": context_windows_dict,
|
|
"FREE_TIER_MODELS": free_models,
|
|
"PAID_TIER_MODELS": paid_models,
|
|
}
|
|
|
|
registry = ModelRegistry() |