mirror of https://github.com/kortix-ai/suna.git
Add Sonnet 4.5 and centralize model configuration system
- Add Claude Sonnet 4.5 (global.anthropic.claude-sonnet-4-5-20250929-v1:0) to registry - Update all Anthropic models to use bedrock/converse/ endpoint with full ARNs - Create comprehensive ModelConfig class for centralized provider settings - Add alias system with raw ARNs for proper LiteLLM response resolution - Refactor response processor to preserve exact LiteLLM response objects - Simplify LLM service by merging prepare_params into make_llm_api_call - Set stream_options include_usage as universal default for all models - Remove scattered configuration functions in favor of registry-driven approach - Fix pricing lookup by mapping provider model IDs back to registry entries
This commit is contained in:
parent
2b5b8cc0bc
commit
d8100cb7a0
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Dict, Any
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from enum import Enum
|
||||
|
||||
|
||||
|
@ -36,6 +36,36 @@ class ModelPricing:
|
|||
return self.output_cost_per_million_tokens / 1_000_000
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""Essential model configuration - provider settings and API configuration only."""
|
||||
|
||||
# === Provider & API Configuration ===
|
||||
api_base: Optional[str] = None
|
||||
api_version: Optional[str] = None
|
||||
base_url: Optional[str] = None # Alternative to api_base
|
||||
deployment_id: Optional[str] = None # Azure
|
||||
timeout: Optional[Union[float, int]] = None
|
||||
num_retries: Optional[int] = None
|
||||
|
||||
# === Headers (Provider-Specific) ===
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
extra_headers: Optional[Dict[str, str]] = None
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
"""Set intelligent defaults and validate configuration."""
|
||||
# Merge headers if both are provided
|
||||
if self.headers and self.extra_headers:
|
||||
merged_headers = self.headers.copy()
|
||||
merged_headers.update(self.extra_headers)
|
||||
self.extra_headers = merged_headers
|
||||
self.headers = None # Use extra_headers as the single source
|
||||
elif self.headers and not self.extra_headers:
|
||||
self.extra_headers = self.headers
|
||||
self.headers = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Model:
|
||||
id: str
|
||||
|
@ -53,6 +83,9 @@ class Model:
|
|||
priority: int = 0
|
||||
recommended: bool = False
|
||||
|
||||
# NEW: Centralized model configuration
|
||||
config: Optional[ModelConfig] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if ModelCapability.CHAT not in self.capabilities:
|
||||
self.capabilities.insert(0, ModelCapability.CHAT)
|
||||
|
@ -79,6 +112,50 @@ class Model:
|
|||
def is_free_tier(self) -> bool:
|
||||
return "free" in self.tier_availability
|
||||
|
||||
def get_litellm_params(self, **override_params) -> Dict[str, Any]:
|
||||
"""Get complete LiteLLM parameters for this model, including all configuration."""
|
||||
# Start with intelligent defaults
|
||||
params = {
|
||||
"model": self.id,
|
||||
"num_retries": 3,
|
||||
"stream_options": {"include_usage": True}, # Default for all models
|
||||
}
|
||||
|
||||
# Apply model-specific configuration if available
|
||||
if self.config:
|
||||
# Provider & API configuration parameters
|
||||
api_params = [
|
||||
'api_base', 'api_version', 'base_url', 'deployment_id',
|
||||
'timeout', 'num_retries'
|
||||
]
|
||||
|
||||
# Apply configured parameters
|
||||
for param_name in api_params:
|
||||
param_value = getattr(self.config, param_name, None)
|
||||
if param_value is not None:
|
||||
params[param_name] = param_value
|
||||
|
||||
# Handle headers specially
|
||||
if self.config.extra_headers:
|
||||
params["extra_headers"] = self.config.extra_headers.copy()
|
||||
elif self.config.headers:
|
||||
params["extra_headers"] = self.config.headers.copy()
|
||||
|
||||
|
||||
# Apply any runtime overrides
|
||||
for key, value in override_params.items():
|
||||
if value is not None:
|
||||
# Handle extra_headers merging
|
||||
if key == "extra_headers" and "extra_headers" in params:
|
||||
if isinstance(params["extra_headers"], dict) and isinstance(value, dict):
|
||||
params["extra_headers"].update(value)
|
||||
else:
|
||||
params[key] = value
|
||||
else:
|
||||
params[key] = value
|
||||
|
||||
return params
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
|
|
|
@ -58,6 +58,23 @@ class ModelManager:
|
|||
def get_models_for_tier(self, tier: str) -> List[Model]:
|
||||
return self.registry.get_by_tier(tier, enabled_only=True)
|
||||
|
||||
def get_litellm_params(self, model_id: str, **override_params) -> Dict[str, Any]:
|
||||
"""Get complete LiteLLM parameters for a model from the registry."""
|
||||
model = self.get_model(model_id)
|
||||
if not model:
|
||||
logger.warning(f"Model '{model_id}' not found in registry, using basic params")
|
||||
return {
|
||||
"model": model_id,
|
||||
"num_retries": 3,
|
||||
**override_params
|
||||
}
|
||||
|
||||
# Get the complete configuration from the model
|
||||
params = model.get_litellm_params(**override_params)
|
||||
# logger.debug(f"Generated LiteLLM params for {model.name}: {list(params.keys())}")
|
||||
|
||||
return params
|
||||
|
||||
def get_models_with_capability(self, capability: ModelCapability) -> List[Model]:
|
||||
return self.registry.get_by_capability(capability, enabled_only=True)
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Dict, List, Optional, Set
|
||||
from .ai_models import Model, ModelProvider, ModelCapability, ModelPricing
|
||||
from .ai_models import Model, ModelProvider, ModelCapability, ModelPricing, ModelConfig
|
||||
from core.utils.config import config, EnvMode
|
||||
|
||||
FREE_MODEL_ID = "moonshotai/kimi-k2"
|
||||
|
@ -39,7 +39,12 @@ class ModelRegistry:
|
|||
tier_availability=["paid"],
|
||||
priority=101,
|
||||
recommended=True,
|
||||
enabled=True
|
||||
enabled=True,
|
||||
config=ModelConfig(
|
||||
extra_headers={
|
||||
"anthropic-beta": "context-1m-2025-08-07"
|
||||
},
|
||||
)
|
||||
))
|
||||
|
||||
self.register(Model(
|
||||
|
@ -61,7 +66,12 @@ class ModelRegistry:
|
|||
tier_availability=["paid"],
|
||||
priority=100,
|
||||
recommended=True,
|
||||
enabled=True
|
||||
enabled=True,
|
||||
config=ModelConfig(
|
||||
extra_headers={
|
||||
"anthropic-beta": "context-1m-2025-08-07"
|
||||
},
|
||||
)
|
||||
))
|
||||
|
||||
self.register(Model(
|
||||
|
@ -81,7 +91,12 @@ class ModelRegistry:
|
|||
),
|
||||
tier_availability=["paid"],
|
||||
priority=99,
|
||||
enabled=True
|
||||
enabled=True,
|
||||
config=ModelConfig(
|
||||
extra_headers={
|
||||
"anthropic-beta": "prompt-caching-2024-07-31"
|
||||
},
|
||||
)
|
||||
))
|
||||
|
||||
self.register(Model(
|
||||
|
@ -186,25 +201,30 @@ class ModelRegistry:
|
|||
))
|
||||
|
||||
|
||||
self.register(Model(
|
||||
id="openrouter/moonshotai/kimi-k2",
|
||||
name="Kimi K2",
|
||||
provider=ModelProvider.MOONSHOTAI,
|
||||
aliases=["kimi-k2", "Kimi K2", "moonshotai/kimi-k2"],
|
||||
context_window=200_000,
|
||||
capabilities=[
|
||||
ModelCapability.CHAT,
|
||||
ModelCapability.FUNCTION_CALLING,
|
||||
],
|
||||
pricing=ModelPricing(
|
||||
input_cost_per_million_tokens=1.00,
|
||||
output_cost_per_million_tokens=3.00
|
||||
),
|
||||
tier_availability=["free", "paid"],
|
||||
priority=94,
|
||||
enabled=True
|
||||
))
|
||||
|
||||
# self.register(Model(
|
||||
# id="openrouter/moonshotai/kimi-k2",
|
||||
# name="Kimi K2",
|
||||
# provider=ModelProvider.MOONSHOTAI,
|
||||
# aliases=["kimi-k2", "Kimi K2", "moonshotai/kimi-k2"],
|
||||
# context_window=200_000,
|
||||
# capabilities=[
|
||||
# ModelCapability.CHAT,
|
||||
# ModelCapability.FUNCTION_CALLING,
|
||||
# ],
|
||||
# pricing=ModelPricing(
|
||||
# input_cost_per_million_tokens=1.00,
|
||||
# output_cost_per_million_tokens=3.00
|
||||
# ),
|
||||
# tier_availability=["free", "paid"],
|
||||
# priority=94,
|
||||
# enabled=True,
|
||||
# config=ModelConfig(
|
||||
# extra_headers={
|
||||
# "HTTP-Referer": config.OR_SITE_URL if hasattr(config, 'OR_SITE_URL') and config.OR_SITE_URL else "",
|
||||
# "X-Title": config.OR_APP_NAME if hasattr(config, 'OR_APP_NAME') and config.OR_APP_NAME else ""
|
||||
# }
|
||||
# )
|
||||
# ))
|
||||
|
||||
# # DeepSeek Models
|
||||
# self.register(Model(
|
||||
|
|
|
@ -83,49 +83,6 @@ def setup_provider_router(openai_compatible_api_key: str = None, openai_compatib
|
|||
]
|
||||
provider_router = Router(model_list=model_list)
|
||||
|
||||
def _configure_token_limits(params: Dict[str, Any], model_name: str, max_tokens: Optional[int]) -> None:
|
||||
"""Configure token limits based on model type."""
|
||||
# Only set max_tokens if explicitly provided - let providers use their defaults otherwise
|
||||
if max_tokens is None:
|
||||
# logger.debug(f"No max_tokens specified, using provider defaults for model: {model_name}")
|
||||
return
|
||||
|
||||
is_openai_o_series = 'o1' in model_name
|
||||
is_openai_gpt5 = 'gpt-5' in model_name
|
||||
param_name = "max_completion_tokens" if (is_openai_o_series or is_openai_gpt5) else "max_tokens"
|
||||
params[param_name] = max_tokens
|
||||
# logger.debug(f"Set {param_name}={max_tokens} for model: {model_name}")
|
||||
|
||||
def _configure_anthropic(params: Dict[str, Any], model_name: str) -> None:
|
||||
"""Configure Anthropic-specific parameters."""
|
||||
if not ("claude" in model_name.lower() or "anthropic" in model_name.lower()):
|
||||
return
|
||||
|
||||
# Include prompt caching and context-1m beta features
|
||||
params["extra_headers"] = {
|
||||
"anthropic-beta": "prompt-caching-2024-07-31,context-1m-2025-08-07"
|
||||
}
|
||||
logger.debug(f"Added Anthropic-specific headers for prompt caching and 1M context window")
|
||||
|
||||
def _configure_openrouter(params: Dict[str, Any], model_name: str) -> None:
|
||||
"""Configure OpenRouter-specific parameters."""
|
||||
if not model_name.startswith("openrouter/"):
|
||||
return
|
||||
|
||||
# logger.debug(f"Preparing OpenRouter parameters for model: {model_name}")
|
||||
|
||||
# Add optional site URL and app name from config
|
||||
site_url = config.OR_SITE_URL
|
||||
app_name = config.OR_APP_NAME
|
||||
if site_url or app_name:
|
||||
extra_headers = params.get("extra_headers", {})
|
||||
if site_url:
|
||||
extra_headers["HTTP-Referer"] = site_url
|
||||
if app_name:
|
||||
extra_headers["X-Title"] = app_name
|
||||
params["extra_headers"] = extra_headers
|
||||
# logger.debug(f"Added OpenRouter site URL and app name to headers")
|
||||
|
||||
def _configure_openai_compatible(params: Dict[str, Any], model_name: str, api_key: Optional[str], api_base: Optional[str]) -> None:
|
||||
"""Configure OpenAI-compatible provider setup."""
|
||||
if not model_name.startswith("openai-compatible/"):
|
||||
|
@ -142,29 +99,6 @@ def _configure_openai_compatible(params: Dict[str, Any], model_name: str, api_ke
|
|||
setup_provider_router(api_key, api_base)
|
||||
logger.debug(f"Configured OpenAI-compatible provider with custom API base")
|
||||
|
||||
def _configure_thinking(params: Dict[str, Any], model_name: str) -> None:
|
||||
"""Configure reasoning/thinking parameters automatically based on model capabilities."""
|
||||
# Check if model supports thinking/reasoning
|
||||
is_anthropic = "anthropic" in model_name.lower() or "claude" in model_name.lower()
|
||||
is_xai = "xai" in model_name.lower() or model_name.startswith("xai/")
|
||||
is_bedrock_anthropic = "bedrock" in model_name.lower() and "anthropic" in model_name.lower()
|
||||
|
||||
# Enable thinking for supported models
|
||||
if is_anthropic or is_xai or is_bedrock_anthropic:
|
||||
# Use higher effort for premium models
|
||||
if "sonnet-4" in model_name.lower() or "claude-4" in model_name.lower():
|
||||
effort_level = "medium"
|
||||
else:
|
||||
effort_level = "low"
|
||||
|
||||
if is_anthropic or is_bedrock_anthropic:
|
||||
params["reasoning_effort"] = effort_level
|
||||
params["temperature"] = 1.0 # Required by Anthropic when reasoning_effort is used
|
||||
logger.info(f"Anthropic thinking auto-enabled with reasoning_effort='{effort_level}' for model: {model_name}")
|
||||
elif is_xai:
|
||||
params["reasoning_effort"] = effort_level
|
||||
logger.info(f"xAI thinking auto-enabled with reasoning_effort='{effort_level}' for model: {model_name}")
|
||||
|
||||
def _add_tools_config(params: Dict[str, Any], tools: Optional[List[Dict[str, Any]]], tool_choice: str) -> None:
|
||||
"""Add tools configuration to parameters."""
|
||||
if tools is None:
|
||||
|
@ -176,52 +110,6 @@ def _add_tools_config(params: Dict[str, Any], tools: Optional[List[Dict[str, Any
|
|||
})
|
||||
# logger.debug(f"Added {len(tools)} tools to API parameters")
|
||||
|
||||
def prepare_params(
|
||||
messages: List[Dict[str, Any]],
|
||||
model_name: str,
|
||||
temperature: float = 0,
|
||||
max_tokens: Optional[int] = None,
|
||||
response_format: Optional[Any] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_choice: str = "auto",
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
stream: bool = True, # Always stream for better UX
|
||||
top_p: Optional[float] = None,
|
||||
model_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
from core.ai_models import model_manager
|
||||
resolved_model_name = model_manager.resolve_model_id(model_name)
|
||||
# logger.debug(f"Model resolution: '{model_name}' -> '{resolved_model_name}'")
|
||||
|
||||
params = {
|
||||
"model": resolved_model_name,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"response_format": response_format,
|
||||
"top_p": top_p,
|
||||
"stream": stream,
|
||||
"num_retries": MAX_RETRIES,
|
||||
}
|
||||
|
||||
if stream:
|
||||
params["stream_options"] = {"include_usage": True}
|
||||
if api_key:
|
||||
params["api_key"] = api_key
|
||||
if api_base:
|
||||
params["api_base"] = api_base
|
||||
if model_id:
|
||||
params["model_id"] = model_id
|
||||
|
||||
_configure_openai_compatible(params, model_name, api_key, api_base)
|
||||
_configure_token_limits(params, resolved_model_name, max_tokens)
|
||||
_add_tools_config(params, tools, tool_choice)
|
||||
_configure_anthropic(params, resolved_model_name)
|
||||
_configure_openrouter(params, resolved_model_name)
|
||||
# _configure_thinking(params, resolved_model_name)
|
||||
|
||||
return params
|
||||
|
||||
async def make_llm_api_call(
|
||||
messages: List[Dict[str, Any]],
|
||||
model_name: str,
|
||||
|
@ -239,46 +127,33 @@ async def make_llm_api_call(
|
|||
"""Make an API call to a language model using LiteLLM."""
|
||||
logger.info(f"Making LLM API call to model: {model_name} with {len(messages)} messages")
|
||||
|
||||
# DEBUG: Log if any messages have cache_control
|
||||
# cache_messages = [i for i, msg in enumerate(messages) if
|
||||
# isinstance(msg.get('content'), list) and
|
||||
# msg['content'] and
|
||||
# isinstance(msg['content'][0], dict) and
|
||||
# 'cache_control' in msg['content'][0]]
|
||||
# if cache_messages:
|
||||
# logger.info(f"🔥 CACHE CONTROL: Found cache_control in messages at positions: {cache_messages}")
|
||||
# else:
|
||||
# logger.info(f"❌ NO CACHE CONTROL: No cache_control found in any messages")
|
||||
# Prepare parameters using centralized model configuration
|
||||
from core.ai_models import model_manager
|
||||
resolved_model_name = model_manager.resolve_model_id(model_name)
|
||||
# logger.debug(f"Model resolution: '{model_name}' -> '{resolved_model_name}'")
|
||||
|
||||
# Check token count for context window issues
|
||||
# try:
|
||||
# from litellm import token_counter
|
||||
# total_tokens = token_counter(model=model_name, messages=messages)
|
||||
# logger.debug(f"Estimated input tokens: {total_tokens}")
|
||||
|
||||
# if total_tokens > 200000:
|
||||
# logger.warning(f"High token count detected: {total_tokens}")
|
||||
# except Exception:
|
||||
# pass # Token counting is optional
|
||||
|
||||
# Prepare parameters
|
||||
params = prepare_params(
|
||||
# Get centralized model configuration from registry
|
||||
params = model_manager.get_litellm_params(
|
||||
resolved_model_name,
|
||||
messages=messages,
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
stream=stream,
|
||||
top_p=top_p,
|
||||
model_id=model_id,
|
||||
stream=stream,
|
||||
api_key=api_key,
|
||||
api_base=api_base
|
||||
)
|
||||
|
||||
# Add model_id separately if provided (to avoid duplicate argument error)
|
||||
if model_id:
|
||||
params["model_id"] = model_id
|
||||
|
||||
# Apply additional configurations that aren't in the model config yet
|
||||
_configure_openai_compatible(params, model_name, api_key, api_base)
|
||||
_add_tools_config(params, tools, tool_choice)
|
||||
|
||||
try:
|
||||
# logger.debug(f"Calling LiteLLM acompletion for {model_name}")
|
||||
# logger.debug(f"Calling LiteLLM acompletion for {resolved_model_name}")
|
||||
response = await provider_router.acompletion(**params)
|
||||
|
||||
# For streaming responses, we need to handle errors that occur during iteration
|
||||
|
|
Loading…
Reference in New Issue