suna/backend/core/ai_models/ai_models.py

178 lines
6.1 KiB
Python

from dataclasses import dataclass, field
from typing import List, Optional, Dict, Any, Union
from enum import Enum
class ModelProvider(Enum):
OPENAI = "openai"
ANTHROPIC = "anthropic"
BEDROCK = "bedrock"
OPENROUTER = "openrouter"
GOOGLE = "google"
XAI = "xai"
MOONSHOTAI = "moonshotai"
class ModelCapability(Enum):
CHAT = "chat"
FUNCTION_CALLING = "function_calling"
VISION = "vision"
CODE_INTERPRETER = "code_interpreter"
WEB_SEARCH = "web_search"
THINKING = "thinking"
STRUCTURED_OUTPUT = "structured_output"
@dataclass
class ModelPricing:
input_cost_per_million_tokens: float
output_cost_per_million_tokens: float
@property
def input_cost_per_token(self) -> float:
return self.input_cost_per_million_tokens / 1_000_000
@property
def output_cost_per_token(self) -> float:
return self.output_cost_per_million_tokens / 1_000_000
@dataclass
class ModelConfig:
"""Essential model configuration - provider settings and API configuration only."""
# === Provider & API Configuration ===
api_base: Optional[str] = None
api_version: Optional[str] = None
base_url: Optional[str] = None # Alternative to api_base
deployment_id: Optional[str] = None # Azure
timeout: Optional[Union[float, int]] = None
num_retries: Optional[int] = None
# === Headers (Provider-Specific) ===
headers: Optional[Dict[str, str]] = None
extra_headers: Optional[Dict[str, str]] = None
def __post_init__(self):
"""Set intelligent defaults and validate configuration."""
# Merge headers if both are provided
if self.headers and self.extra_headers:
merged_headers = self.headers.copy()
merged_headers.update(self.extra_headers)
self.extra_headers = merged_headers
self.headers = None # Use extra_headers as the single source
elif self.headers and not self.extra_headers:
self.extra_headers = self.headers
self.headers = None
@dataclass
class Model:
id: str
name: str
provider: ModelProvider
aliases: List[str] = field(default_factory=list)
context_window: int = 128_000
max_output_tokens: Optional[int] = None
capabilities: List[ModelCapability] = field(default_factory=list)
pricing: Optional[ModelPricing] = None
enabled: bool = True
beta: bool = False
tier_availability: List[str] = field(default_factory=lambda: ["paid"])
metadata: Dict[str, Any] = field(default_factory=dict)
priority: int = 0
recommended: bool = False
# NEW: Centralized model configuration
config: Optional[ModelConfig] = None
def __post_init__(self):
if ModelCapability.CHAT not in self.capabilities:
self.capabilities.insert(0, ModelCapability.CHAT)
@property
def full_id(self) -> str:
if "/" in self.id:
return self.id
return f"{self.provider.value}/{self.id}"
@property
def supports_thinking(self) -> bool:
return ModelCapability.THINKING in self.capabilities
@property
def supports_functions(self) -> bool:
return ModelCapability.FUNCTION_CALLING in self.capabilities
@property
def supports_vision(self) -> bool:
return ModelCapability.VISION in self.capabilities
@property
def is_free_tier(self) -> bool:
return "free" in self.tier_availability
def get_litellm_params(self, **override_params) -> Dict[str, Any]:
"""Get complete LiteLLM parameters for this model, including all configuration."""
# Start with intelligent defaults
params = {
"model": self.id,
"num_retries": 3,
"stream_options": {"include_usage": True}, # Default for all models
}
# Apply model-specific configuration if available
if self.config:
# Provider & API configuration parameters
api_params = [
'api_base', 'api_version', 'base_url', 'deployment_id',
'timeout', 'num_retries'
]
# Apply configured parameters
for param_name in api_params:
param_value = getattr(self.config, param_name, None)
if param_value is not None:
params[param_name] = param_value
# Handle headers specially
if self.config.extra_headers:
params["extra_headers"] = self.config.extra_headers.copy()
elif self.config.headers:
params["extra_headers"] = self.config.headers.copy()
# Apply any runtime overrides
for key, value in override_params.items():
if value is not None:
# Handle extra_headers merging
if key == "extra_headers" and "extra_headers" in params:
if isinstance(params["extra_headers"], dict) and isinstance(value, dict):
params["extra_headers"].update(value)
else:
params[key] = value
else:
params[key] = value
return params
def to_dict(self) -> Dict[str, Any]:
return {
"id": self.id,
"name": self.name,
"provider": self.provider.value,
"aliases": self.aliases,
"context_window": self.context_window,
"max_output_tokens": self.max_output_tokens,
"capabilities": [cap.value for cap in self.capabilities],
"pricing": {
"input_cost_per_million_tokens": self.pricing.input_cost_per_million_tokens,
"output_cost_per_million_tokens": self.pricing.output_cost_per_million_tokens,
} if self.pricing else None,
"enabled": self.enabled,
"beta": self.beta,
"tier_availability": self.tier_availability,
"metadata": self.metadata,
"priority": self.priority,
"recommended": self.recommended,
}