mirror of https://github.com/kortix-ai/suna.git
374 lines
14 KiB
Python
374 lines
14 KiB
Python
"""
|
|
LLM API interface for making calls to various language models.
|
|
|
|
This module provides a unified interface for making API calls to different LLM providers
|
|
(OpenAI, Anthropic, Groq, xAI, etc.) using LiteLLM. It includes support for:
|
|
- Streaming responses
|
|
- Tool calls and function calling
|
|
- Retry logic with exponential backoff
|
|
- Model-specific configurations
|
|
- Comprehensive error handling and logging
|
|
"""
|
|
|
|
from typing import Union, Dict, Any, Optional, AsyncGenerator, List
|
|
import os
|
|
import litellm
|
|
from litellm.files.main import ModelResponse
|
|
from utils.logger import logger
|
|
from utils.config import config
|
|
|
|
# litellm.set_verbose=True
|
|
# Let LiteLLM auto-adjust params and drop unsupported ones (e.g., GPT-5 temperature!=1)
|
|
litellm.modify_params = True
|
|
litellm.drop_params = True
|
|
|
|
# Constants
|
|
MAX_RETRIES = 3
|
|
class LLMError(Exception):
|
|
"""Base exception for LLM-related errors."""
|
|
pass
|
|
|
|
def setup_api_keys() -> None:
|
|
"""Set up API keys from environment variables."""
|
|
providers = ['OPENAI', 'ANTHROPIC', 'GROQ', 'OPENROUTER', 'XAI', 'MORPH', 'GEMINI']
|
|
for provider in providers:
|
|
key = getattr(config, f'{provider}_API_KEY')
|
|
if key:
|
|
logger.debug(f"API key set for provider: {provider}")
|
|
else:
|
|
logger.warning(f"No API key found for provider: {provider}")
|
|
|
|
# Set up OpenRouter API base if not already set
|
|
if config.OPENROUTER_API_KEY and config.OPENROUTER_API_BASE:
|
|
os.environ['OPENROUTER_API_BASE'] = config.OPENROUTER_API_BASE
|
|
logger.debug(f"Set OPENROUTER_API_BASE to {config.OPENROUTER_API_BASE}")
|
|
|
|
# Set up AWS Bedrock credentials
|
|
aws_access_key = config.AWS_ACCESS_KEY_ID
|
|
aws_secret_key = config.AWS_SECRET_ACCESS_KEY
|
|
aws_region = config.AWS_REGION_NAME
|
|
|
|
if aws_access_key and aws_secret_key and aws_region:
|
|
logger.debug(f"AWS credentials set for Bedrock in region: {aws_region}")
|
|
# Configure LiteLLM to use AWS credentials
|
|
os.environ['AWS_ACCESS_KEY_ID'] = aws_access_key
|
|
os.environ['AWS_SECRET_ACCESS_KEY'] = aws_secret_key
|
|
os.environ['AWS_REGION_NAME'] = aws_region
|
|
else:
|
|
logger.warning(f"Missing AWS credentials for Bedrock integration - access_key: {bool(aws_access_key)}, secret_key: {bool(aws_secret_key)}, region: {aws_region}")
|
|
|
|
def get_openrouter_fallback(model_name: str) -> Optional[str]:
|
|
"""Get OpenRouter fallback model for a given model name."""
|
|
# Skip if already using OpenRouter
|
|
if model_name.startswith("openrouter/"):
|
|
return None
|
|
|
|
# Map models to their OpenRouter equivalents
|
|
fallback_mapping = {
|
|
"anthropic/claude-3-7-sonnet-latest": "openrouter/anthropic/claude-3.7-sonnet",
|
|
"anthropic/claude-sonnet-4-20250514": "openrouter/anthropic/claude-sonnet-4",
|
|
"xai/grok-4": "openrouter/x-ai/grok-4",
|
|
"gemini/gemini-2.5-pro": "openrouter/google/gemini-2.5-pro",
|
|
}
|
|
|
|
# Check for exact match first
|
|
if model_name in fallback_mapping:
|
|
return fallback_mapping[model_name]
|
|
|
|
# Check for partial matches (e.g., bedrock models)
|
|
for key, value in fallback_mapping.items():
|
|
if key in model_name:
|
|
return value
|
|
|
|
# Default fallbacks by provider
|
|
if "claude" in model_name.lower() or "anthropic" in model_name.lower():
|
|
return "openrouter/anthropic/claude-sonnet-4"
|
|
elif "xai" in model_name.lower() or "grok" in model_name.lower():
|
|
return "openrouter/x-ai/grok-4"
|
|
|
|
return None
|
|
|
|
def _configure_token_limits(params: Dict[str, Any], model_name: str, max_tokens: Optional[int]) -> None:
|
|
"""Configure token limits based on model type."""
|
|
if max_tokens is None:
|
|
return
|
|
|
|
if model_name.startswith("bedrock/") and "claude-3-7" in model_name:
|
|
# For Claude 3.7 in Bedrock, do not set max_tokens or max_tokens_to_sample
|
|
# as it causes errors with inference profiles
|
|
logger.debug(f"Skipping max_tokens for Claude 3.7 model: {model_name}")
|
|
return
|
|
|
|
is_openai_o_series = 'o1' in model_name
|
|
is_openai_gpt5 = 'gpt-5' in model_name
|
|
param_name = "max_completion_tokens" if (is_openai_o_series or is_openai_gpt5) else "max_tokens"
|
|
params[param_name] = max_tokens
|
|
|
|
def _apply_anthropic_caching(messages: List[Dict[str, Any]]) -> None:
|
|
"""Apply Anthropic caching to the messages."""
|
|
|
|
# Apply cache control to the first 4 text blocks across all messages
|
|
cache_control_count = 0
|
|
max_cache_control_blocks = 3
|
|
|
|
for message in messages:
|
|
if cache_control_count >= max_cache_control_blocks:
|
|
break
|
|
|
|
content = message.get("content")
|
|
|
|
if isinstance(content, str):
|
|
message["content"] = [
|
|
{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
|
|
]
|
|
cache_control_count += 1
|
|
elif isinstance(content, list):
|
|
for item in content:
|
|
if cache_control_count >= max_cache_control_blocks:
|
|
break
|
|
if isinstance(item, dict) and item.get("type") == "text" and "cache_control" not in item:
|
|
item["cache_control"] = {"type": "ephemeral"}
|
|
cache_control_count += 1
|
|
|
|
def _configure_anthopic(params: Dict[str, Any], model_name: str, messages: List[Dict[str, Any]]) -> None:
|
|
"""Configure Anthropic-specific parameters."""
|
|
if not ("claude" in model_name.lower() or "anthropic" in model_name.lower()):
|
|
return
|
|
|
|
params["extra_headers"] = {
|
|
"anthropic-beta": "output-128k-2025-02-19"
|
|
}
|
|
logger.debug("Added Anthropic-specific headers")
|
|
_apply_anthropic_caching(messages)
|
|
|
|
def _configure_openrouter(params: Dict[str, Any], model_name: str) -> None:
|
|
"""Configure OpenRouter-specific parameters."""
|
|
if not model_name.startswith("openrouter/"):
|
|
return
|
|
|
|
logger.debug(f"Preparing OpenRouter parameters for model: {model_name}")
|
|
|
|
# Add optional site URL and app name from config
|
|
site_url = config.OR_SITE_URL
|
|
app_name = config.OR_APP_NAME
|
|
if site_url or app_name:
|
|
extra_headers = params.get("extra_headers", {})
|
|
if site_url:
|
|
extra_headers["HTTP-Referer"] = site_url
|
|
if app_name:
|
|
extra_headers["X-Title"] = app_name
|
|
params["extra_headers"] = extra_headers
|
|
logger.debug(f"Added OpenRouter site URL and app name to headers")
|
|
|
|
def _configure_bedrock(params: Dict[str, Any], model_name: str, model_id: Optional[str]) -> None:
|
|
"""Configure Bedrock-specific parameters."""
|
|
if not model_name.startswith("bedrock/"):
|
|
return
|
|
|
|
logger.debug(f"Preparing AWS Bedrock parameters for model: {model_name}")
|
|
|
|
# Auto-set model_id for Claude 3.7 Sonnet if not provided
|
|
if not model_id and "anthropic.claude-3-7-sonnet" in model_name:
|
|
params["model_id"] = "arn:aws:bedrock:us-west-2:935064898258:inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0"
|
|
logger.debug(f"Auto-set model_id for Claude 3.7 Sonnet: {params['model_id']}")
|
|
|
|
def _configure_openai_gpt5(params: Dict[str, Any], model_name: str) -> None:
|
|
"""Configure OpenAI GPT-5 specific parameters."""
|
|
if "gpt-5" not in model_name:
|
|
return
|
|
|
|
|
|
# Drop unsupported temperature param (only default 1 allowed)
|
|
if "temperature" in params and params["temperature"] != 1:
|
|
params.pop("temperature", None)
|
|
|
|
# Request priority service tier when calling OpenAI directly
|
|
|
|
# Pass via both top-level and extra_body for LiteLLM compatibility
|
|
if not model_name.startswith("openrouter/"):
|
|
params["service_tier"] = "priority"
|
|
extra_body = params.get("extra_body", {})
|
|
if "service_tier" not in extra_body:
|
|
extra_body["service_tier"] = "priority"
|
|
params["extra_body"] = extra_body
|
|
|
|
def _configure_kimi_k2(params: Dict[str, Any], model_name: str) -> None:
|
|
"""Configure Kimi K2-specific parameters."""
|
|
is_kimi_k2 = "kimi-k2" in model_name.lower() or model_name.startswith("moonshotai/kimi-k2")
|
|
if not is_kimi_k2:
|
|
return
|
|
|
|
params["provider"] = {
|
|
"order": ["groq", "moonshotai"] #, "groq", "together/fp8", "novita/fp8", "baseten/fp8",
|
|
}
|
|
|
|
def _configure_thinking(params: Dict[str, Any], model_name: str, enable_thinking: Optional[bool], reasoning_effort: Optional[str]) -> None:
|
|
"""Configure reasoning/thinking parameters for supported models."""
|
|
if not enable_thinking:
|
|
return
|
|
|
|
|
|
effort_level = reasoning_effort or 'low'
|
|
is_anthropic = "anthropic" in model_name.lower() or "claude" in model_name.lower()
|
|
is_xai = "xai" in model_name.lower() or model_name.startswith("xai/")
|
|
|
|
if is_anthropic:
|
|
params["reasoning_effort"] = effort_level
|
|
params["temperature"] = 1.0 # Required by Anthropic when reasoning_effort is used
|
|
logger.info(f"Anthropic thinking enabled with reasoning_effort='{effort_level}'")
|
|
elif is_xai:
|
|
params["reasoning_effort"] = effort_level
|
|
logger.info(f"xAI thinking enabled with reasoning_effort='{effort_level}'")
|
|
|
|
def _add_fallback_model(params: Dict[str, Any], model_name: str, messages: List[Dict[str, Any]]) -> None:
|
|
"""Add fallback model to the parameters."""
|
|
fallback_model = get_openrouter_fallback(model_name)
|
|
if fallback_model:
|
|
params["fallbacks"] = [{
|
|
"model": fallback_model,
|
|
"messages": messages,
|
|
}]
|
|
logger.debug(f"Added OpenRouter fallback for model: {model_name} to {fallback_model}")
|
|
|
|
def _add_tools_config(params: Dict[str, Any], tools: Optional[List[Dict[str, Any]]], tool_choice: str) -> None:
|
|
"""Add tools configuration to parameters."""
|
|
if tools is None:
|
|
return
|
|
|
|
params.update({
|
|
"tools": tools,
|
|
"tool_choice": tool_choice
|
|
})
|
|
logger.debug(f"Added {len(tools)} tools to API parameters")
|
|
|
|
def prepare_params(
|
|
messages: List[Dict[str, Any]],
|
|
model_name: str,
|
|
temperature: float = 0,
|
|
max_tokens: Optional[int] = None,
|
|
response_format: Optional[Any] = None,
|
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
tool_choice: str = "auto",
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
stream: bool = False,
|
|
top_p: Optional[float] = None,
|
|
model_id: Optional[str] = None,
|
|
enable_thinking: Optional[bool] = False,
|
|
reasoning_effort: Optional[str] = 'low'
|
|
) -> Dict[str, Any]:
|
|
from models import model_manager
|
|
resolved_model_name = model_manager.resolve_model_id(model_name)
|
|
logger.debug(f"Model resolution: '{model_name}' -> '{resolved_model_name}'")
|
|
|
|
params = {
|
|
"model": resolved_model_name,
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
"response_format": response_format,
|
|
"top_p": top_p,
|
|
"stream": stream,
|
|
"num_retries": MAX_RETRIES,
|
|
}
|
|
|
|
if api_key:
|
|
params["api_key"] = api_key
|
|
if api_base:
|
|
params["api_base"] = api_base
|
|
if model_id:
|
|
params["model_id"] = model_id
|
|
|
|
# Handle token limits
|
|
_configure_token_limits(params, resolved_model_name, max_tokens)
|
|
# Add tools if provided
|
|
_add_tools_config(params, tools, tool_choice)
|
|
# Add Anthropic-specific parameters
|
|
_configure_anthopic(params, resolved_model_name, params["messages"])
|
|
# Add OpenRouter-specific parameters
|
|
_configure_openrouter(params, resolved_model_name)
|
|
# Add Bedrock-specific parameters
|
|
_configure_bedrock(params, resolved_model_name, model_id)
|
|
|
|
_add_fallback_model(params, resolved_model_name, messages)
|
|
# Add OpenAI GPT-5 specific parameters
|
|
_configure_openai_gpt5(params, resolved_model_name)
|
|
# Add Kimi K2-specific parameters
|
|
_configure_kimi_k2(params, resolved_model_name)
|
|
_configure_thinking(params, resolved_model_name, enable_thinking, reasoning_effort)
|
|
|
|
return params
|
|
|
|
async def make_llm_api_call(
|
|
messages: List[Dict[str, Any]],
|
|
model_name: str,
|
|
response_format: Optional[Any] = None,
|
|
temperature: float = 0,
|
|
max_tokens: Optional[int] = None,
|
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
tool_choice: str = "auto",
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
stream: bool = False,
|
|
top_p: Optional[float] = None,
|
|
model_id: Optional[str] = None,
|
|
enable_thinking: Optional[bool] = False,
|
|
reasoning_effort: Optional[str] = 'low'
|
|
) -> Union[Dict[str, Any], AsyncGenerator, ModelResponse]:
|
|
"""
|
|
Make an API call to a language model using LiteLLM.
|
|
|
|
Args:
|
|
messages: List of message dictionaries for the conversation
|
|
model_name: Name of the model to use (e.g., "gpt-4", "claude-3", "openrouter/openai/gpt-4", "bedrock/anthropic.claude-3-sonnet-20240229-v1:0")
|
|
response_format: Desired format for the response
|
|
temperature: Sampling temperature (0-1)
|
|
max_tokens: Maximum tokens in the response
|
|
tools: List of tool definitions for function calling
|
|
tool_choice: How to select tools ("auto" or "none")
|
|
api_key: Override default API key
|
|
api_base: Override default API base URL
|
|
stream: Whether to stream the response
|
|
top_p: Top-p sampling parameter
|
|
model_id: Optional ARN for Bedrock inference profiles
|
|
enable_thinking: Whether to enable thinking
|
|
reasoning_effort: Level of reasoning effort
|
|
|
|
Returns:
|
|
Union[Dict[str, Any], AsyncGenerator]: API response or stream
|
|
|
|
Raises:
|
|
LLMRetryError: If API call fails after retries
|
|
LLMError: For other API-related errors
|
|
"""
|
|
# debug <timestamp>.json messages
|
|
logger.debug(f"Making LLM API call to model: {model_name} (Thinking: {enable_thinking}, Effort: {reasoning_effort})")
|
|
logger.debug(f"📡 API Call: Using model {model_name}")
|
|
params = prepare_params(
|
|
messages=messages,
|
|
model_name=model_name,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
response_format=response_format,
|
|
tools=tools,
|
|
tool_choice=tool_choice,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
stream=stream,
|
|
top_p=top_p,
|
|
model_id=model_id,
|
|
enable_thinking=enable_thinking,
|
|
reasoning_effort=reasoning_effort
|
|
)
|
|
try:
|
|
response = await litellm.acompletion(**params)
|
|
logger.debug(f"Successfully received API response from {model_name}")
|
|
# logger.debug(f"Response: {response}")
|
|
return response
|
|
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error during API call: {str(e)}", exc_info=True)
|
|
raise LLMError(f"API call failed: {str(e)}")
|
|
|
|
# Initialize API keys on module import
|
|
setup_api_keys()
|