mirror of https://github.com/kortix-ai/suna.git
374 lines
14 KiB
Python
374 lines
14 KiB
Python
from typing import Dict, List, Optional, Set
|
|
from .ai_models import Model, ModelProvider, ModelCapability, ModelPricing, ModelConfig
|
|
from core.utils.config import config, EnvMode
|
|
|
|
FREE_MODEL_ID = "moonshotai/kimi-k2"
|
|
|
|
# Set premium model ID based on environment
|
|
if config.ENV_MODE == EnvMode.LOCAL:
|
|
PREMIUM_MODEL_ID = "anthropic/claude-sonnet-4-20250514"
|
|
else: # STAGING or PRODUCTION
|
|
PREMIUM_MODEL_ID = "bedrock/anthropic.claude-sonnet-4-20250514-v1:0"
|
|
|
|
is_local = config.ENV_MODE == EnvMode.LOCAL
|
|
|
|
class ModelRegistry:
|
|
def __init__(self):
|
|
self._models: Dict[str, Model] = {}
|
|
self._aliases: Dict[str, str] = {}
|
|
self._initialize_models()
|
|
|
|
def _initialize_models(self):
|
|
|
|
self.register(Model(
|
|
id="anthropic/claude-sonnet-4-5-20250929" if is_local else "bedrock/converse/arn:aws:bedrock:us-west-2:935064898258:inference-profile/global.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
|
name="Sonnet 4.5",
|
|
provider=ModelProvider.ANTHROPIC,
|
|
aliases=["claude-sonnet-4.5", "anthropic/claude-sonnet-4.5", "Claude Sonnet 4.5", "claude-sonnet-4-5-20250929", "global.anthropic.claude-sonnet-4-5-20250929-v1:0", "arn:aws:bedrock:us-west-2:935064898258:inference-profile/global.anthropic.claude-sonnet-4-5-20250929-v1:0", "bedrock/anthropic.claude-sonnet-4-5-20250929-v1:0"],
|
|
context_window=1_000_000,
|
|
capabilities=[
|
|
ModelCapability.CHAT,
|
|
ModelCapability.FUNCTION_CALLING,
|
|
ModelCapability.VISION,
|
|
ModelCapability.THINKING,
|
|
],
|
|
pricing=ModelPricing(
|
|
input_cost_per_million_tokens=3.00,
|
|
output_cost_per_million_tokens=15.00
|
|
),
|
|
tier_availability=["paid"],
|
|
priority=101,
|
|
recommended=True,
|
|
enabled=True,
|
|
config=ModelConfig(
|
|
extra_headers={
|
|
"anthropic-beta": "context-1m-2025-08-07"
|
|
},
|
|
)
|
|
))
|
|
|
|
self.register(Model(
|
|
id="anthropic/claude-sonnet-4-20250514" if is_local else "bedrock/converse/arn:aws:bedrock:us-west-2:935064898258:inference-profile/us.anthropic.claude-sonnet-4-20250514-v1:0",
|
|
name="Sonnet 4",
|
|
provider=ModelProvider.ANTHROPIC,
|
|
aliases=["claude-sonnet-4", "Claude Sonnet 4", "claude-sonnet-4-20250514", "arn:aws:bedrock:us-west-2:935064898258:inference-profile/us.anthropic.claude-sonnet-4-20250514-v1:0", "bedrock/anthropic.claude-sonnet-4-20250514-v1:0"],
|
|
context_window=1_000_000,
|
|
capabilities=[
|
|
ModelCapability.CHAT,
|
|
ModelCapability.FUNCTION_CALLING,
|
|
ModelCapability.VISION,
|
|
ModelCapability.THINKING,
|
|
],
|
|
pricing=ModelPricing(
|
|
input_cost_per_million_tokens=3.00,
|
|
output_cost_per_million_tokens=15.00
|
|
),
|
|
tier_availability=["paid"],
|
|
priority=100,
|
|
recommended=True,
|
|
enabled=True,
|
|
config=ModelConfig(
|
|
extra_headers={
|
|
"anthropic-beta": "context-1m-2025-08-07"
|
|
},
|
|
)
|
|
))
|
|
|
|
self.register(Model(
|
|
id="anthropic/claude-3-7-sonnet-latest" if is_local else "bedrock/converse/arn:aws:bedrock:us-west-2:935064898258:inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
|
name="Sonnet 3.7",
|
|
provider=ModelProvider.ANTHROPIC,
|
|
aliases=["claude-3.7", "Claude 3.7 Sonnet", "claude-3-7-sonnet-latest", "arn:aws:bedrock:us-west-2:935064898258:inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0", "bedrock/anthropic.claude-3-7-sonnet-20250219-v1:0"],
|
|
context_window=200_000,
|
|
capabilities=[
|
|
ModelCapability.CHAT,
|
|
ModelCapability.FUNCTION_CALLING,
|
|
ModelCapability.VISION,
|
|
],
|
|
pricing=ModelPricing(
|
|
input_cost_per_million_tokens=3.00,
|
|
output_cost_per_million_tokens=15.00
|
|
),
|
|
tier_availability=["paid"],
|
|
priority=99,
|
|
enabled=True,
|
|
config=ModelConfig(
|
|
extra_headers={
|
|
"anthropic-beta": "prompt-caching-2024-07-31"
|
|
},
|
|
)
|
|
))
|
|
|
|
self.register(Model(
|
|
id="xai/grok-4-fast-non-reasoning",
|
|
name="Grok 4 Fast",
|
|
provider=ModelProvider.XAI,
|
|
aliases=["grok-4-fast-non-reasoning", "Grok 4 Fast"],
|
|
context_window=2_000_000,
|
|
capabilities=[
|
|
ModelCapability.CHAT,
|
|
ModelCapability.FUNCTION_CALLING,
|
|
],
|
|
pricing=ModelPricing(
|
|
input_cost_per_million_tokens=0.20,
|
|
output_cost_per_million_tokens=0.50
|
|
),
|
|
tier_availability=["paid"],
|
|
priority=98,
|
|
enabled=True
|
|
))
|
|
|
|
# self.register(Model(
|
|
# id="anthropic/claude-3-5-sonnet-latest",
|
|
# name="Claude 3.5 Sonnet",
|
|
# provider=ModelProvider.ANTHROPIC,
|
|
# aliases=["sonnet-3.5", "claude-3.5", "Claude 3.5 Sonnet", "claude-3-5-sonnet-latest"],
|
|
# context_window=200_000,
|
|
# capabilities=[
|
|
# ModelCapability.CHAT,
|
|
# ModelCapability.FUNCTION_CALLING,
|
|
# ModelCapability.VISION,
|
|
# ],
|
|
# pricing=ModelPricing(
|
|
# input_cost_per_million_tokens=3.00,
|
|
# output_cost_per_million_tokens=15.00
|
|
# ),
|
|
# tier_availability=["paid"],
|
|
# priority=90,
|
|
# enabled=True
|
|
# ))
|
|
|
|
self.register(Model(
|
|
id="openai/gpt-5",
|
|
name="GPT-5",
|
|
provider=ModelProvider.OPENAI,
|
|
aliases=["gpt-5", "GPT-5"],
|
|
context_window=400_000,
|
|
capabilities=[
|
|
ModelCapability.CHAT,
|
|
ModelCapability.FUNCTION_CALLING,
|
|
ModelCapability.VISION,
|
|
ModelCapability.STRUCTURED_OUTPUT,
|
|
],
|
|
pricing=ModelPricing(
|
|
input_cost_per_million_tokens=1.25,
|
|
output_cost_per_million_tokens=10.00
|
|
),
|
|
tier_availability=["paid"],
|
|
priority=97,
|
|
enabled=True
|
|
))
|
|
|
|
self.register(Model(
|
|
id="openai/gpt-5-mini",
|
|
name="GPT-5 Mini",
|
|
provider=ModelProvider.OPENAI,
|
|
aliases=["gpt-5-mini", "GPT-5 Mini"],
|
|
context_window=400_000,
|
|
capabilities=[
|
|
ModelCapability.CHAT,
|
|
ModelCapability.FUNCTION_CALLING,
|
|
ModelCapability.STRUCTURED_OUTPUT,
|
|
],
|
|
pricing=ModelPricing(
|
|
input_cost_per_million_tokens=0.25,
|
|
output_cost_per_million_tokens=2.00
|
|
),
|
|
tier_availability=["free", "paid"],
|
|
priority=96,
|
|
enabled=True
|
|
))
|
|
|
|
self.register(Model(
|
|
id="gemini/gemini-2.5-pro",
|
|
name="Gemini 2.5 Pro",
|
|
provider=ModelProvider.GOOGLE,
|
|
aliases=["gemini-2.5-pro", "Gemini 2.5 Pro"],
|
|
context_window=2_000_000,
|
|
capabilities=[
|
|
ModelCapability.CHAT,
|
|
ModelCapability.FUNCTION_CALLING,
|
|
ModelCapability.VISION,
|
|
ModelCapability.STRUCTURED_OUTPUT,
|
|
],
|
|
pricing=ModelPricing(
|
|
input_cost_per_million_tokens=1.25,
|
|
output_cost_per_million_tokens=10.00
|
|
),
|
|
tier_availability=["paid"],
|
|
priority=95,
|
|
enabled=True
|
|
))
|
|
|
|
|
|
# self.register(Model(
|
|
# id="openrouter/moonshotai/kimi-k2",
|
|
# name="Kimi K2",
|
|
# provider=ModelProvider.MOONSHOTAI,
|
|
# aliases=["kimi-k2", "Kimi K2", "moonshotai/kimi-k2"],
|
|
# context_window=200_000,
|
|
# capabilities=[
|
|
# ModelCapability.CHAT,
|
|
# ModelCapability.FUNCTION_CALLING,
|
|
# ],
|
|
# pricing=ModelPricing(
|
|
# input_cost_per_million_tokens=1.00,
|
|
# output_cost_per_million_tokens=3.00
|
|
# ),
|
|
# tier_availability=["free", "paid"],
|
|
# priority=94,
|
|
# enabled=True,
|
|
# config=ModelConfig(
|
|
# extra_headers={
|
|
# "HTTP-Referer": config.OR_SITE_URL if hasattr(config, 'OR_SITE_URL') and config.OR_SITE_URL else "",
|
|
# "X-Title": config.OR_APP_NAME if hasattr(config, 'OR_APP_NAME') and config.OR_APP_NAME else ""
|
|
# }
|
|
# )
|
|
# ))
|
|
|
|
# # DeepSeek Models
|
|
# self.register(Model(
|
|
# id="openrouter/deepseek/deepseek-chat",
|
|
# name="DeepSeek Chat",
|
|
# provider=ModelProvider.OPENROUTER,
|
|
# aliases=["deepseek", "deepseek-chat"],
|
|
# context_window=128_000,
|
|
# capabilities=[
|
|
# ModelCapability.CHAT,
|
|
# ModelCapability.FUNCTION_CALLING
|
|
# ],
|
|
# pricing=ModelPricing(
|
|
# input_cost_per_million_tokens=0.38,
|
|
# output_cost_per_million_tokens=0.89
|
|
# ),
|
|
# tier_availability=["free", "paid"],
|
|
# priority=95,
|
|
# enabled=False # Currently disabled
|
|
# ))
|
|
|
|
# # Qwen Models
|
|
# self.register(Model(
|
|
# id="openrouter/qwen/qwen3-235b-a22b",
|
|
# name="Qwen3 235B",
|
|
# provider=ModelProvider.OPENROUTER,
|
|
# aliases=["qwen3", "qwen-3"],
|
|
# context_window=128_000,
|
|
# capabilities=[
|
|
# ModelCapability.CHAT,
|
|
# ModelCapability.FUNCTION_CALLING
|
|
# ],
|
|
# pricing=ModelPricing(
|
|
# input_cost_per_million_tokens=0.13,
|
|
# output_cost_per_million_tokens=0.60
|
|
# ),
|
|
# tier_availability=["free", "paid"],
|
|
# priority=90,
|
|
# enabled=False # Currently disabled
|
|
# ))
|
|
|
|
|
|
def register(self, model: Model) -> None:
|
|
self._models[model.id] = model
|
|
for alias in model.aliases:
|
|
self._aliases[alias] = model.id
|
|
|
|
def get(self, model_id: str) -> Optional[Model]:
|
|
if model_id in self._models:
|
|
return self._models[model_id]
|
|
|
|
if model_id in self._aliases:
|
|
actual_id = self._aliases[model_id]
|
|
return self._models.get(actual_id)
|
|
|
|
return None
|
|
|
|
def get_all(self, enabled_only: bool = True) -> List[Model]:
|
|
models = list(self._models.values())
|
|
if enabled_only:
|
|
models = [m for m in models if m.enabled]
|
|
return models
|
|
|
|
def get_by_tier(self, tier: str, enabled_only: bool = True) -> List[Model]:
|
|
models = self.get_all(enabled_only)
|
|
return [m for m in models if tier in m.tier_availability]
|
|
|
|
def get_by_provider(self, provider: ModelProvider, enabled_only: bool = True) -> List[Model]:
|
|
models = self.get_all(enabled_only)
|
|
return [m for m in models if m.provider == provider]
|
|
|
|
def get_by_capability(self, capability: ModelCapability, enabled_only: bool = True) -> List[Model]:
|
|
models = self.get_all(enabled_only)
|
|
return [m for m in models if capability in m.capabilities]
|
|
|
|
def resolve_model_id(self, model_id: str) -> Optional[str]:
|
|
model = self.get(model_id)
|
|
return model.id if model else None
|
|
|
|
|
|
def get_aliases(self, model_id: str) -> List[str]:
|
|
model = self.get(model_id)
|
|
return model.aliases if model else []
|
|
|
|
def enable_model(self, model_id: str) -> bool:
|
|
model = self.get(model_id)
|
|
if model:
|
|
model.enabled = True
|
|
return True
|
|
return False
|
|
|
|
def disable_model(self, model_id: str) -> bool:
|
|
model = self.get(model_id)
|
|
if model:
|
|
model.enabled = False
|
|
return True
|
|
return False
|
|
|
|
def get_context_window(self, model_id: str, default: int = 31_000) -> int:
|
|
model = self.get(model_id)
|
|
return model.context_window if model else default
|
|
|
|
def get_pricing(self, model_id: str) -> Optional[ModelPricing]:
|
|
model = self.get(model_id)
|
|
return model.pricing if model else None
|
|
|
|
def to_legacy_format(self) -> Dict:
|
|
models_dict = {}
|
|
pricing_dict = {}
|
|
context_windows_dict = {}
|
|
|
|
for model in self.get_all(enabled_only=True):
|
|
models_dict[model.id] = {
|
|
"pricing": {
|
|
"input_cost_per_million_tokens": model.pricing.input_cost_per_million_tokens,
|
|
"output_cost_per_million_tokens": model.pricing.output_cost_per_million_tokens,
|
|
} if model.pricing else None,
|
|
"context_window": model.context_window,
|
|
"tier_availability": model.tier_availability,
|
|
}
|
|
|
|
if model.pricing:
|
|
pricing_dict[model.id] = {
|
|
"input_cost_per_million_tokens": model.pricing.input_cost_per_million_tokens,
|
|
"output_cost_per_million_tokens": model.pricing.output_cost_per_million_tokens,
|
|
}
|
|
|
|
context_windows_dict[model.id] = model.context_window
|
|
|
|
free_models = [m.id for m in self.get_by_tier("free")]
|
|
paid_models = [m.id for m in self.get_by_tier("paid")]
|
|
|
|
# Debug logging
|
|
from core.utils.logger import logger
|
|
logger.debug(f"Legacy format generation: {len(free_models)} free models, {len(paid_models)} paid models")
|
|
logger.debug(f"Free models: {free_models}")
|
|
logger.debug(f"Paid models: {paid_models}")
|
|
|
|
return {
|
|
"MODELS": models_dict,
|
|
"HARDCODED_MODEL_PRICES": pricing_dict,
|
|
"MODEL_CONTEXT_WINDOWS": context_windows_dict,
|
|
"FREE_TIER_MODELS": free_models,
|
|
"PAID_TIER_MODELS": paid_models,
|
|
}
|
|
|
|
registry = ModelRegistry() |