mirror of https://github.com/kortix-ai/suna.git
Add Sonnet 4.5 and centralize model configuration system
- Add Claude Sonnet 4.5 (global.anthropic.claude-sonnet-4-5-20250929-v1:0) to registry - Update all Anthropic models to use bedrock/converse/ endpoint with full ARNs - Create comprehensive ModelConfig class for centralized provider settings - Add alias system with raw ARNs for proper LiteLLM response resolution - Refactor response processor to preserve exact LiteLLM response objects - Simplify LLM service by merging prepare_params into make_llm_api_call - Set stream_options include_usage as universal default for all models - Remove scattered configuration functions in favor of registry-driven approach - Fix pricing lookup by mapping provider model IDs back to registry entries
This commit is contained in:
parent
2b5b8cc0bc
commit
d8100cb7a0
|
@ -1,5 +1,5 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Optional, Dict, Any
|
from typing import List, Optional, Dict, Any, Union
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,6 +36,36 @@ class ModelPricing:
|
||||||
return self.output_cost_per_million_tokens / 1_000_000
|
return self.output_cost_per_million_tokens / 1_000_000
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelConfig:
|
||||||
|
"""Essential model configuration - provider settings and API configuration only."""
|
||||||
|
|
||||||
|
# === Provider & API Configuration ===
|
||||||
|
api_base: Optional[str] = None
|
||||||
|
api_version: Optional[str] = None
|
||||||
|
base_url: Optional[str] = None # Alternative to api_base
|
||||||
|
deployment_id: Optional[str] = None # Azure
|
||||||
|
timeout: Optional[Union[float, int]] = None
|
||||||
|
num_retries: Optional[int] = None
|
||||||
|
|
||||||
|
# === Headers (Provider-Specific) ===
|
||||||
|
headers: Optional[Dict[str, str]] = None
|
||||||
|
extra_headers: Optional[Dict[str, str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Set intelligent defaults and validate configuration."""
|
||||||
|
# Merge headers if both are provided
|
||||||
|
if self.headers and self.extra_headers:
|
||||||
|
merged_headers = self.headers.copy()
|
||||||
|
merged_headers.update(self.extra_headers)
|
||||||
|
self.extra_headers = merged_headers
|
||||||
|
self.headers = None # Use extra_headers as the single source
|
||||||
|
elif self.headers and not self.extra_headers:
|
||||||
|
self.extra_headers = self.headers
|
||||||
|
self.headers = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Model:
|
class Model:
|
||||||
id: str
|
id: str
|
||||||
|
@ -53,6 +83,9 @@ class Model:
|
||||||
priority: int = 0
|
priority: int = 0
|
||||||
recommended: bool = False
|
recommended: bool = False
|
||||||
|
|
||||||
|
# NEW: Centralized model configuration
|
||||||
|
config: Optional[ModelConfig] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if ModelCapability.CHAT not in self.capabilities:
|
if ModelCapability.CHAT not in self.capabilities:
|
||||||
self.capabilities.insert(0, ModelCapability.CHAT)
|
self.capabilities.insert(0, ModelCapability.CHAT)
|
||||||
|
@ -79,6 +112,50 @@ class Model:
|
||||||
def is_free_tier(self) -> bool:
|
def is_free_tier(self) -> bool:
|
||||||
return "free" in self.tier_availability
|
return "free" in self.tier_availability
|
||||||
|
|
||||||
|
def get_litellm_params(self, **override_params) -> Dict[str, Any]:
|
||||||
|
"""Get complete LiteLLM parameters for this model, including all configuration."""
|
||||||
|
# Start with intelligent defaults
|
||||||
|
params = {
|
||||||
|
"model": self.id,
|
||||||
|
"num_retries": 3,
|
||||||
|
"stream_options": {"include_usage": True}, # Default for all models
|
||||||
|
}
|
||||||
|
|
||||||
|
# Apply model-specific configuration if available
|
||||||
|
if self.config:
|
||||||
|
# Provider & API configuration parameters
|
||||||
|
api_params = [
|
||||||
|
'api_base', 'api_version', 'base_url', 'deployment_id',
|
||||||
|
'timeout', 'num_retries'
|
||||||
|
]
|
||||||
|
|
||||||
|
# Apply configured parameters
|
||||||
|
for param_name in api_params:
|
||||||
|
param_value = getattr(self.config, param_name, None)
|
||||||
|
if param_value is not None:
|
||||||
|
params[param_name] = param_value
|
||||||
|
|
||||||
|
# Handle headers specially
|
||||||
|
if self.config.extra_headers:
|
||||||
|
params["extra_headers"] = self.config.extra_headers.copy()
|
||||||
|
elif self.config.headers:
|
||||||
|
params["extra_headers"] = self.config.headers.copy()
|
||||||
|
|
||||||
|
|
||||||
|
# Apply any runtime overrides
|
||||||
|
for key, value in override_params.items():
|
||||||
|
if value is not None:
|
||||||
|
# Handle extra_headers merging
|
||||||
|
if key == "extra_headers" and "extra_headers" in params:
|
||||||
|
if isinstance(params["extra_headers"], dict) and isinstance(value, dict):
|
||||||
|
params["extra_headers"].update(value)
|
||||||
|
else:
|
||||||
|
params[key] = value
|
||||||
|
else:
|
||||||
|
params[key] = value
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"id": self.id,
|
"id": self.id,
|
||||||
|
|
|
@ -58,6 +58,23 @@ class ModelManager:
|
||||||
def get_models_for_tier(self, tier: str) -> List[Model]:
|
def get_models_for_tier(self, tier: str) -> List[Model]:
|
||||||
return self.registry.get_by_tier(tier, enabled_only=True)
|
return self.registry.get_by_tier(tier, enabled_only=True)
|
||||||
|
|
||||||
|
def get_litellm_params(self, model_id: str, **override_params) -> Dict[str, Any]:
|
||||||
|
"""Get complete LiteLLM parameters for a model from the registry."""
|
||||||
|
model = self.get_model(model_id)
|
||||||
|
if not model:
|
||||||
|
logger.warning(f"Model '{model_id}' not found in registry, using basic params")
|
||||||
|
return {
|
||||||
|
"model": model_id,
|
||||||
|
"num_retries": 3,
|
||||||
|
**override_params
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get the complete configuration from the model
|
||||||
|
params = model.get_litellm_params(**override_params)
|
||||||
|
# logger.debug(f"Generated LiteLLM params for {model.name}: {list(params.keys())}")
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
def get_models_with_capability(self, capability: ModelCapability) -> List[Model]:
|
def get_models_with_capability(self, capability: ModelCapability) -> List[Model]:
|
||||||
return self.registry.get_by_capability(capability, enabled_only=True)
|
return self.registry.get_by_capability(capability, enabled_only=True)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from typing import Dict, List, Optional, Set
|
from typing import Dict, List, Optional, Set
|
||||||
from .ai_models import Model, ModelProvider, ModelCapability, ModelPricing
|
from .ai_models import Model, ModelProvider, ModelCapability, ModelPricing, ModelConfig
|
||||||
from core.utils.config import config, EnvMode
|
from core.utils.config import config, EnvMode
|
||||||
|
|
||||||
FREE_MODEL_ID = "moonshotai/kimi-k2"
|
FREE_MODEL_ID = "moonshotai/kimi-k2"
|
||||||
|
@ -39,7 +39,12 @@ class ModelRegistry:
|
||||||
tier_availability=["paid"],
|
tier_availability=["paid"],
|
||||||
priority=101,
|
priority=101,
|
||||||
recommended=True,
|
recommended=True,
|
||||||
enabled=True
|
enabled=True,
|
||||||
|
config=ModelConfig(
|
||||||
|
extra_headers={
|
||||||
|
"anthropic-beta": "context-1m-2025-08-07"
|
||||||
|
},
|
||||||
|
)
|
||||||
))
|
))
|
||||||
|
|
||||||
self.register(Model(
|
self.register(Model(
|
||||||
|
@ -61,7 +66,12 @@ class ModelRegistry:
|
||||||
tier_availability=["paid"],
|
tier_availability=["paid"],
|
||||||
priority=100,
|
priority=100,
|
||||||
recommended=True,
|
recommended=True,
|
||||||
enabled=True
|
enabled=True,
|
||||||
|
config=ModelConfig(
|
||||||
|
extra_headers={
|
||||||
|
"anthropic-beta": "context-1m-2025-08-07"
|
||||||
|
},
|
||||||
|
)
|
||||||
))
|
))
|
||||||
|
|
||||||
self.register(Model(
|
self.register(Model(
|
||||||
|
@ -81,7 +91,12 @@ class ModelRegistry:
|
||||||
),
|
),
|
||||||
tier_availability=["paid"],
|
tier_availability=["paid"],
|
||||||
priority=99,
|
priority=99,
|
||||||
enabled=True
|
enabled=True,
|
||||||
|
config=ModelConfig(
|
||||||
|
extra_headers={
|
||||||
|
"anthropic-beta": "prompt-caching-2024-07-31"
|
||||||
|
},
|
||||||
|
)
|
||||||
))
|
))
|
||||||
|
|
||||||
self.register(Model(
|
self.register(Model(
|
||||||
|
@ -186,25 +201,30 @@ class ModelRegistry:
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|
||||||
self.register(Model(
|
# self.register(Model(
|
||||||
id="openrouter/moonshotai/kimi-k2",
|
# id="openrouter/moonshotai/kimi-k2",
|
||||||
name="Kimi K2",
|
# name="Kimi K2",
|
||||||
provider=ModelProvider.MOONSHOTAI,
|
# provider=ModelProvider.MOONSHOTAI,
|
||||||
aliases=["kimi-k2", "Kimi K2", "moonshotai/kimi-k2"],
|
# aliases=["kimi-k2", "Kimi K2", "moonshotai/kimi-k2"],
|
||||||
context_window=200_000,
|
# context_window=200_000,
|
||||||
capabilities=[
|
# capabilities=[
|
||||||
ModelCapability.CHAT,
|
# ModelCapability.CHAT,
|
||||||
ModelCapability.FUNCTION_CALLING,
|
# ModelCapability.FUNCTION_CALLING,
|
||||||
],
|
# ],
|
||||||
pricing=ModelPricing(
|
# pricing=ModelPricing(
|
||||||
input_cost_per_million_tokens=1.00,
|
# input_cost_per_million_tokens=1.00,
|
||||||
output_cost_per_million_tokens=3.00
|
# output_cost_per_million_tokens=3.00
|
||||||
),
|
# ),
|
||||||
tier_availability=["free", "paid"],
|
# tier_availability=["free", "paid"],
|
||||||
priority=94,
|
# priority=94,
|
||||||
enabled=True
|
# enabled=True,
|
||||||
))
|
# config=ModelConfig(
|
||||||
|
# extra_headers={
|
||||||
|
# "HTTP-Referer": config.OR_SITE_URL if hasattr(config, 'OR_SITE_URL') and config.OR_SITE_URL else "",
|
||||||
|
# "X-Title": config.OR_APP_NAME if hasattr(config, 'OR_APP_NAME') and config.OR_APP_NAME else ""
|
||||||
|
# }
|
||||||
|
# )
|
||||||
|
# ))
|
||||||
|
|
||||||
# # DeepSeek Models
|
# # DeepSeek Models
|
||||||
# self.register(Model(
|
# self.register(Model(
|
||||||
|
|
|
@ -83,49 +83,6 @@ def setup_provider_router(openai_compatible_api_key: str = None, openai_compatib
|
||||||
]
|
]
|
||||||
provider_router = Router(model_list=model_list)
|
provider_router = Router(model_list=model_list)
|
||||||
|
|
||||||
def _configure_token_limits(params: Dict[str, Any], model_name: str, max_tokens: Optional[int]) -> None:
|
|
||||||
"""Configure token limits based on model type."""
|
|
||||||
# Only set max_tokens if explicitly provided - let providers use their defaults otherwise
|
|
||||||
if max_tokens is None:
|
|
||||||
# logger.debug(f"No max_tokens specified, using provider defaults for model: {model_name}")
|
|
||||||
return
|
|
||||||
|
|
||||||
is_openai_o_series = 'o1' in model_name
|
|
||||||
is_openai_gpt5 = 'gpt-5' in model_name
|
|
||||||
param_name = "max_completion_tokens" if (is_openai_o_series or is_openai_gpt5) else "max_tokens"
|
|
||||||
params[param_name] = max_tokens
|
|
||||||
# logger.debug(f"Set {param_name}={max_tokens} for model: {model_name}")
|
|
||||||
|
|
||||||
def _configure_anthropic(params: Dict[str, Any], model_name: str) -> None:
|
|
||||||
"""Configure Anthropic-specific parameters."""
|
|
||||||
if not ("claude" in model_name.lower() or "anthropic" in model_name.lower()):
|
|
||||||
return
|
|
||||||
|
|
||||||
# Include prompt caching and context-1m beta features
|
|
||||||
params["extra_headers"] = {
|
|
||||||
"anthropic-beta": "prompt-caching-2024-07-31,context-1m-2025-08-07"
|
|
||||||
}
|
|
||||||
logger.debug(f"Added Anthropic-specific headers for prompt caching and 1M context window")
|
|
||||||
|
|
||||||
def _configure_openrouter(params: Dict[str, Any], model_name: str) -> None:
|
|
||||||
"""Configure OpenRouter-specific parameters."""
|
|
||||||
if not model_name.startswith("openrouter/"):
|
|
||||||
return
|
|
||||||
|
|
||||||
# logger.debug(f"Preparing OpenRouter parameters for model: {model_name}")
|
|
||||||
|
|
||||||
# Add optional site URL and app name from config
|
|
||||||
site_url = config.OR_SITE_URL
|
|
||||||
app_name = config.OR_APP_NAME
|
|
||||||
if site_url or app_name:
|
|
||||||
extra_headers = params.get("extra_headers", {})
|
|
||||||
if site_url:
|
|
||||||
extra_headers["HTTP-Referer"] = site_url
|
|
||||||
if app_name:
|
|
||||||
extra_headers["X-Title"] = app_name
|
|
||||||
params["extra_headers"] = extra_headers
|
|
||||||
# logger.debug(f"Added OpenRouter site URL and app name to headers")
|
|
||||||
|
|
||||||
def _configure_openai_compatible(params: Dict[str, Any], model_name: str, api_key: Optional[str], api_base: Optional[str]) -> None:
|
def _configure_openai_compatible(params: Dict[str, Any], model_name: str, api_key: Optional[str], api_base: Optional[str]) -> None:
|
||||||
"""Configure OpenAI-compatible provider setup."""
|
"""Configure OpenAI-compatible provider setup."""
|
||||||
if not model_name.startswith("openai-compatible/"):
|
if not model_name.startswith("openai-compatible/"):
|
||||||
|
@ -142,29 +99,6 @@ def _configure_openai_compatible(params: Dict[str, Any], model_name: str, api_ke
|
||||||
setup_provider_router(api_key, api_base)
|
setup_provider_router(api_key, api_base)
|
||||||
logger.debug(f"Configured OpenAI-compatible provider with custom API base")
|
logger.debug(f"Configured OpenAI-compatible provider with custom API base")
|
||||||
|
|
||||||
def _configure_thinking(params: Dict[str, Any], model_name: str) -> None:
|
|
||||||
"""Configure reasoning/thinking parameters automatically based on model capabilities."""
|
|
||||||
# Check if model supports thinking/reasoning
|
|
||||||
is_anthropic = "anthropic" in model_name.lower() or "claude" in model_name.lower()
|
|
||||||
is_xai = "xai" in model_name.lower() or model_name.startswith("xai/")
|
|
||||||
is_bedrock_anthropic = "bedrock" in model_name.lower() and "anthropic" in model_name.lower()
|
|
||||||
|
|
||||||
# Enable thinking for supported models
|
|
||||||
if is_anthropic or is_xai or is_bedrock_anthropic:
|
|
||||||
# Use higher effort for premium models
|
|
||||||
if "sonnet-4" in model_name.lower() or "claude-4" in model_name.lower():
|
|
||||||
effort_level = "medium"
|
|
||||||
else:
|
|
||||||
effort_level = "low"
|
|
||||||
|
|
||||||
if is_anthropic or is_bedrock_anthropic:
|
|
||||||
params["reasoning_effort"] = effort_level
|
|
||||||
params["temperature"] = 1.0 # Required by Anthropic when reasoning_effort is used
|
|
||||||
logger.info(f"Anthropic thinking auto-enabled with reasoning_effort='{effort_level}' for model: {model_name}")
|
|
||||||
elif is_xai:
|
|
||||||
params["reasoning_effort"] = effort_level
|
|
||||||
logger.info(f"xAI thinking auto-enabled with reasoning_effort='{effort_level}' for model: {model_name}")
|
|
||||||
|
|
||||||
def _add_tools_config(params: Dict[str, Any], tools: Optional[List[Dict[str, Any]]], tool_choice: str) -> None:
|
def _add_tools_config(params: Dict[str, Any], tools: Optional[List[Dict[str, Any]]], tool_choice: str) -> None:
|
||||||
"""Add tools configuration to parameters."""
|
"""Add tools configuration to parameters."""
|
||||||
if tools is None:
|
if tools is None:
|
||||||
|
@ -176,52 +110,6 @@ def _add_tools_config(params: Dict[str, Any], tools: Optional[List[Dict[str, Any
|
||||||
})
|
})
|
||||||
# logger.debug(f"Added {len(tools)} tools to API parameters")
|
# logger.debug(f"Added {len(tools)} tools to API parameters")
|
||||||
|
|
||||||
def prepare_params(
|
|
||||||
messages: List[Dict[str, Any]],
|
|
||||||
model_name: str,
|
|
||||||
temperature: float = 0,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
response_format: Optional[Any] = None,
|
|
||||||
tools: Optional[List[Dict[str, Any]]] = None,
|
|
||||||
tool_choice: str = "auto",
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
api_base: Optional[str] = None,
|
|
||||||
stream: bool = True, # Always stream for better UX
|
|
||||||
top_p: Optional[float] = None,
|
|
||||||
model_id: Optional[str] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
from core.ai_models import model_manager
|
|
||||||
resolved_model_name = model_manager.resolve_model_id(model_name)
|
|
||||||
# logger.debug(f"Model resolution: '{model_name}' -> '{resolved_model_name}'")
|
|
||||||
|
|
||||||
params = {
|
|
||||||
"model": resolved_model_name,
|
|
||||||
"messages": messages,
|
|
||||||
"temperature": temperature,
|
|
||||||
"response_format": response_format,
|
|
||||||
"top_p": top_p,
|
|
||||||
"stream": stream,
|
|
||||||
"num_retries": MAX_RETRIES,
|
|
||||||
}
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
params["stream_options"] = {"include_usage": True}
|
|
||||||
if api_key:
|
|
||||||
params["api_key"] = api_key
|
|
||||||
if api_base:
|
|
||||||
params["api_base"] = api_base
|
|
||||||
if model_id:
|
|
||||||
params["model_id"] = model_id
|
|
||||||
|
|
||||||
_configure_openai_compatible(params, model_name, api_key, api_base)
|
|
||||||
_configure_token_limits(params, resolved_model_name, max_tokens)
|
|
||||||
_add_tools_config(params, tools, tool_choice)
|
|
||||||
_configure_anthropic(params, resolved_model_name)
|
|
||||||
_configure_openrouter(params, resolved_model_name)
|
|
||||||
# _configure_thinking(params, resolved_model_name)
|
|
||||||
|
|
||||||
return params
|
|
||||||
|
|
||||||
async def make_llm_api_call(
|
async def make_llm_api_call(
|
||||||
messages: List[Dict[str, Any]],
|
messages: List[Dict[str, Any]],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
|
@ -239,46 +127,33 @@ async def make_llm_api_call(
|
||||||
"""Make an API call to a language model using LiteLLM."""
|
"""Make an API call to a language model using LiteLLM."""
|
||||||
logger.info(f"Making LLM API call to model: {model_name} with {len(messages)} messages")
|
logger.info(f"Making LLM API call to model: {model_name} with {len(messages)} messages")
|
||||||
|
|
||||||
# DEBUG: Log if any messages have cache_control
|
# Prepare parameters using centralized model configuration
|
||||||
# cache_messages = [i for i, msg in enumerate(messages) if
|
from core.ai_models import model_manager
|
||||||
# isinstance(msg.get('content'), list) and
|
resolved_model_name = model_manager.resolve_model_id(model_name)
|
||||||
# msg['content'] and
|
# logger.debug(f"Model resolution: '{model_name}' -> '{resolved_model_name}'")
|
||||||
# isinstance(msg['content'][0], dict) and
|
|
||||||
# 'cache_control' in msg['content'][0]]
|
|
||||||
# if cache_messages:
|
|
||||||
# logger.info(f"🔥 CACHE CONTROL: Found cache_control in messages at positions: {cache_messages}")
|
|
||||||
# else:
|
|
||||||
# logger.info(f"❌ NO CACHE CONTROL: No cache_control found in any messages")
|
|
||||||
|
|
||||||
# Check token count for context window issues
|
# Get centralized model configuration from registry
|
||||||
# try:
|
params = model_manager.get_litellm_params(
|
||||||
# from litellm import token_counter
|
resolved_model_name,
|
||||||
# total_tokens = token_counter(model=model_name, messages=messages)
|
|
||||||
# logger.debug(f"Estimated input tokens: {total_tokens}")
|
|
||||||
|
|
||||||
# if total_tokens > 200000:
|
|
||||||
# logger.warning(f"High token count detected: {total_tokens}")
|
|
||||||
# except Exception:
|
|
||||||
# pass # Token counting is optional
|
|
||||||
|
|
||||||
# Prepare parameters
|
|
||||||
params = prepare_params(
|
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model_name=model_name,
|
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
tools=tools,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
stream=stream,
|
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
model_id=model_id,
|
stream=stream,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add model_id separately if provided (to avoid duplicate argument error)
|
||||||
|
if model_id:
|
||||||
|
params["model_id"] = model_id
|
||||||
|
|
||||||
|
# Apply additional configurations that aren't in the model config yet
|
||||||
|
_configure_openai_compatible(params, model_name, api_key, api_base)
|
||||||
|
_add_tools_config(params, tools, tool_choice)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# logger.debug(f"Calling LiteLLM acompletion for {model_name}")
|
# logger.debug(f"Calling LiteLLM acompletion for {resolved_model_name}")
|
||||||
response = await provider_router.acompletion(**params)
|
response = await provider_router.acompletion(**params)
|
||||||
|
|
||||||
# For streaming responses, we need to handle errors that occur during iteration
|
# For streaming responses, we need to handle errors that occur during iteration
|
||||||
|
|
Loading…
Reference in New Issue