suna/backend/services/llm.py

449 lines
19 KiB
Python

"""
LLM API interface for making calls to various language models.
This module provides a unified interface for making API calls to different LLM providers
(OpenAI, Anthropic, Groq, xAI, etc.) using LiteLLM. It includes support for:
- Streaming responses
- Tool calls and function calling
- Retry logic with exponential backoff
- Model-specific configurations
- Comprehensive error handling and logging
"""
from typing import Union, Dict, Any, Optional, AsyncGenerator, List
import os
import json
import asyncio
from openai import OpenAIError
import litellm
from litellm.files.main import ModelResponse
from utils.logger import logger
from utils.config import config
# litellm.set_verbose=True
litellm.modify_params=True
# Constants
MAX_RETRIES = 2
RATE_LIMIT_DELAY = 30
RETRY_DELAY = 0.1
class LLMError(Exception):
"""Base exception for LLM-related errors."""
pass
class LLMRetryError(LLMError):
"""Exception raised when retries are exhausted."""
pass
def setup_api_keys() -> None:
"""Set up API keys from environment variables."""
providers = ['OPENAI', 'ANTHROPIC', 'GROQ', 'OPENROUTER', 'XAI']
for provider in providers:
key = getattr(config, f'{provider}_API_KEY')
if key:
logger.debug(f"API key set for provider: {provider}")
else:
logger.warning(f"No API key found for provider: {provider}")
# Set up OpenRouter API base if not already set
if config.OPENROUTER_API_KEY and config.OPENROUTER_API_BASE:
os.environ['OPENROUTER_API_BASE'] = config.OPENROUTER_API_BASE
logger.debug(f"Set OPENROUTER_API_BASE to {config.OPENROUTER_API_BASE}")
# Set up AWS Bedrock credentials
aws_access_key = config.AWS_ACCESS_KEY_ID
aws_secret_key = config.AWS_SECRET_ACCESS_KEY
aws_region = config.AWS_REGION_NAME
if aws_access_key and aws_secret_key and aws_region:
logger.debug(f"AWS credentials set for Bedrock in region: {aws_region}")
# Configure LiteLLM to use AWS credentials
os.environ['AWS_ACCESS_KEY_ID'] = aws_access_key
os.environ['AWS_SECRET_ACCESS_KEY'] = aws_secret_key
os.environ['AWS_REGION_NAME'] = aws_region
else:
logger.warning(f"Missing AWS credentials for Bedrock integration - access_key: {bool(aws_access_key)}, secret_key: {bool(aws_secret_key)}, region: {aws_region}")
def get_openrouter_fallback(model_name: str) -> Optional[str]:
"""Get OpenRouter fallback model for a given model name."""
# Skip if already using OpenRouter
if model_name.startswith("openrouter/"):
return None
# Map models to their OpenRouter equivalents
fallback_mapping = {
"anthropic/claude-3-7-sonnet-latest": "openrouter/anthropic/claude-3.7-sonnet",
"anthropic/claude-sonnet-4-20250514": "openrouter/anthropic/claude-sonnet-4",
"xai/grok-4": "openrouter/x-ai/grok-4",
}
# Check for exact match first
if model_name in fallback_mapping:
return fallback_mapping[model_name]
# Check for partial matches (e.g., bedrock models)
for key, value in fallback_mapping.items():
if key in model_name:
return value
# Default fallbacks by provider
if "claude" in model_name.lower() or "anthropic" in model_name.lower():
return "openrouter/anthropic/claude-sonnet-4"
elif "xai" in model_name.lower() or "grok" in model_name.lower():
return "openrouter/x-ai/grok-4"
return None
async def handle_error(error: Exception, attempt: int, max_attempts: int) -> None:
"""Handle API errors with appropriate delays and logging."""
delay = RATE_LIMIT_DELAY if isinstance(error, litellm.exceptions.RateLimitError) else RETRY_DELAY
logger.warning(f"Error on attempt {attempt + 1}/{max_attempts}: {str(error)}")
logger.debug(f"Waiting {delay} seconds before retry...")
await asyncio.sleep(delay)
def detect_error_and_suggest_fallback(error: Exception, current_model: str) -> tuple[bool, str, str]:
"""
Detect specific error types and suggest appropriate fallback strategies.
Args:
error: The exception that occurred
current_model: The current model being used
Returns:
tuple[bool, str, str]: (should_fallback, fallback_model, error_type)
- should_fallback: Whether to attempt a fallback
- fallback_model: The suggested fallback model
- error_type: The type of error detected
"""
error_str = str(error).lower()
# Anthropic-specific errors
if "anthropicexception - overloaded" in error_str:
if not current_model.startswith("openrouter/"):
fallback_model = f"openrouter/{current_model}"
return True, fallback_model, "anthropic_overloaded"
return False, "", "anthropic_overloaded"
# OpenRouter-specific errors
if "openrouter" in current_model.lower():
if "connection" in error_str or "timeout" in error_str:
# Try a different OpenRouter model
if "claude" in current_model.lower():
return True, "openrouter/anthropic/claude-sonnet-4", "openrouter_connection"
elif "gpt" in current_model.lower():
return True, "openrouter/openai/gpt-4o", "openrouter_connection"
elif "grok" in current_model.lower() or "xai" in current_model.lower():
return True, "openrouter/x-ai/grok-4", "openrouter_connection"
elif "rate limit" in error_str or "quota" in error_str:
# Try a different OpenRouter model for rate limiting
if "claude" in current_model.lower():
return True, "openrouter/anthropic/claude-3-5-sonnet", "openrouter_rate_limit"
elif "gpt" in current_model.lower():
return True, "openrouter/openai/gpt-4-turbo", "openrouter_rate_limit"
# OpenAI-specific errors
if "openai" in current_model.lower() or "gpt" in current_model.lower():
if "rate limit" in error_str or "quota" in error_str:
return True, "openrouter/openai/gpt-4o", "openai_rate_limit"
elif "connection" in error_str or "timeout" in error_str:
return True, "openrouter/openai/gpt-4o", "openai_connection"
elif "service unavailable" in error_str or "internal server error" in error_str:
return True, "openrouter/openai/gpt-4o", "openai_service_unavailable"
# xAI-specific errors
if "xai" in current_model.lower() or "grok" in current_model.lower():
if "rate limit" in error_str or "quota" in error_str:
return True, "openrouter/x-ai/grok-4", "xai_rate_limit"
elif "connection" in error_str or "timeout" in error_str:
return True, "openrouter/x-ai/grok-4", "xai_connection"
# Generic connection/timeout errors
if "connection" in error_str or "timeout" in error_str:
if not current_model.startswith("openrouter/"):
# Try OpenRouter as a fallback for connection issues
if "claude" in current_model.lower():
return True, "openrouter/anthropic/claude-sonnet-4", "connection_timeout"
elif "gpt" in current_model.lower():
return True, "openrouter/openai/gpt-4o", "connection_timeout"
elif "grok" in current_model.lower() or "xai" in current_model.lower():
return True, "openrouter/x-ai/grok-4", "connection_timeout"
# Generic rate limiting
if "rate limit" in error_str or "quota" in error_str:
if not current_model.startswith("openrouter/"):
# Try OpenRouter as a fallback for rate limiting
if "claude" in current_model.lower():
return True, "openrouter/anthropic/claude-sonnet-4", "rate_limit"
elif "gpt" in current_model.lower():
return True, "openrouter/openai/gpt-4o", "rate_limit"
elif "grok" in current_model.lower() or "xai" in current_model.lower():
return True, "openrouter/x-ai/grok-4", "rate_limit"
# Service unavailable errors
if "service unavailable" in error_str or "internal server error" in error_str or "bad gateway" in error_str:
if not current_model.startswith("openrouter/"):
# Try OpenRouter as a fallback for service issues
if "claude" in current_model.lower():
return True, "openrouter/anthropic/claude-sonnet-4", "service_unavailable"
elif "gpt" in current_model.lower():
return True, "openrouter/openai/gpt-4o", "service_unavailable"
elif "grok" in current_model.lower() or "xai" in current_model.lower():
return True, "openrouter/x-ai/grok-4", "service_unavailable"
return False, "", "unknown"
def prepare_params(
messages: List[Dict[str, Any]],
model_name: str,
temperature: float = 0,
max_tokens: Optional[int] = None,
response_format: Optional[Any] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: str = "auto",
api_key: Optional[str] = None,
api_base: Optional[str] = None,
stream: bool = False,
top_p: Optional[float] = None,
model_id: Optional[str] = None,
enable_thinking: Optional[bool] = False,
reasoning_effort: Optional[str] = 'low'
) -> Dict[str, Any]:
"""Prepare parameters for the API call."""
params = {
"model": model_name,
"messages": messages,
"temperature": temperature,
"response_format": response_format,
"top_p": top_p,
"stream": stream,
}
if api_key:
params["api_key"] = api_key
if api_base:
params["api_base"] = api_base
if model_id:
params["model_id"] = model_id
# Handle token limits
if max_tokens is not None:
# For Claude 3.7 in Bedrock, do not set max_tokens or max_tokens_to_sample
# as it causes errors with inference profiles
if model_name.startswith("bedrock/") and "claude-3-7" in model_name:
logger.debug(f"Skipping max_tokens for Claude 3.7 model: {model_name}")
# Do not add any max_tokens parameter for Claude 3.7
else:
param_name = "max_completion_tokens" if 'o1' in model_name else "max_tokens"
params[param_name] = max_tokens
# Add tools if provided
if tools:
params.update({
"tools": tools,
"tool_choice": tool_choice
})
logger.debug(f"Added {len(tools)} tools to API parameters")
# # Add Claude-specific headers
if "claude" in model_name.lower() or "anthropic" in model_name.lower():
params["extra_headers"] = {
# "anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"
"anthropic-beta": "output-128k-2025-02-19"
}
# params["mock_testing_fallback"] = True
logger.debug("Added Claude-specific headers")
# Add OpenRouter-specific parameters
if model_name.startswith("openrouter/"):
logger.debug(f"Preparing OpenRouter parameters for model: {model_name}")
# Add optional site URL and app name from config
site_url = config.OR_SITE_URL
app_name = config.OR_APP_NAME
if site_url or app_name:
extra_headers = params.get("extra_headers", {})
if site_url:
extra_headers["HTTP-Referer"] = site_url
if app_name:
extra_headers["X-Title"] = app_name
params["extra_headers"] = extra_headers
logger.debug(f"Added OpenRouter site URL and app name to headers")
# Add Bedrock-specific parameters
if model_name.startswith("bedrock/"):
logger.debug(f"Preparing AWS Bedrock parameters for model: {model_name}")
if not model_id and "anthropic.claude-3-7-sonnet" in model_name:
params["model_id"] = "arn:aws:bedrock:us-west-2:935064898258:inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0"
logger.debug(f"Auto-set model_id for Claude 3.7 Sonnet: {params['model_id']}")
fallback_model = get_openrouter_fallback(model_name)
if fallback_model:
params["fallbacks"] = [{
"model": fallback_model,
"messages": messages,
}]
logger.debug(f"Added OpenRouter fallback for model: {model_name} to {fallback_model}")
# Apply Anthropic prompt caching (minimal implementation)
# Check model name *after* potential modifications (like adding bedrock/ prefix)
effective_model_name = params.get("model", model_name) # Use model from params if set, else original
if "claude" in effective_model_name.lower() or "anthropic" in effective_model_name.lower():
messages = params["messages"] # Direct reference, modification affects params
# Ensure messages is a list
if not isinstance(messages, list):
return params # Return early if messages format is unexpected
# Apply cache control to the first 4 text blocks across all messages
cache_control_count = 0
max_cache_control_blocks = 3
for message in messages:
if cache_control_count >= max_cache_control_blocks:
break
content = message.get("content")
if isinstance(content, str):
message["content"] = [
{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
]
cache_control_count += 1
elif isinstance(content, list):
for item in content:
if cache_control_count >= max_cache_control_blocks:
break
if isinstance(item, dict) and item.get("type") == "text" and "cache_control" not in item:
item["cache_control"] = {"type": "ephemeral"}
cache_control_count += 1
# Add reasoning_effort for Anthropic models if enabled
use_thinking = enable_thinking if enable_thinking is not None else False
is_anthropic = "anthropic" in effective_model_name.lower() or "claude" in effective_model_name.lower()
is_xai = "xai" in effective_model_name.lower() or model_name.startswith("xai/")
is_kimi_k2 = "kimi-k2" in effective_model_name.lower() or model_name.startswith("moonshotai/kimi-k2")
if is_kimi_k2:
params["provider"] = {
"order": ["baseten/fp8", "together/fp8", "novita/fp8", "moonshotai", "groq"]
}
if is_anthropic and use_thinking:
effort_level = reasoning_effort if reasoning_effort else 'low'
params["reasoning_effort"] = effort_level
params["temperature"] = 1.0 # Required by Anthropic when reasoning_effort is used
logger.info(f"Anthropic thinking enabled with reasoning_effort='{effort_level}'")
# Add reasoning_effort for xAI models if enabled
if is_xai and use_thinking:
effort_level = reasoning_effort if reasoning_effort else 'low'
params["reasoning_effort"] = effort_level
logger.info(f"xAI thinking enabled with reasoning_effort='{effort_level}'")
# Add xAI-specific parameters
if model_name.startswith("xai/"):
logger.debug(f"Preparing xAI parameters for model: {model_name}")
# xAI models support standard parameters, no special handling needed beyond reasoning_effort
return params
async def make_llm_api_call(
messages: List[Dict[str, Any]],
model_name: str,
response_format: Optional[Any] = None,
temperature: float = 0,
max_tokens: Optional[int] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: str = "auto",
api_key: Optional[str] = None,
api_base: Optional[str] = None,
stream: bool = False,
top_p: Optional[float] = None,
model_id: Optional[str] = None,
enable_thinking: Optional[bool] = False,
reasoning_effort: Optional[str] = 'low'
) -> Union[Dict[str, Any], AsyncGenerator, ModelResponse]:
"""
Make an API call to a language model using LiteLLM.
Args:
messages: List of message dictionaries for the conversation
model_name: Name of the model to use (e.g., "gpt-4", "claude-3", "openrouter/openai/gpt-4", "bedrock/anthropic.claude-3-sonnet-20240229-v1:0")
response_format: Desired format for the response
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens in the response
tools: List of tool definitions for function calling
tool_choice: How to select tools ("auto" or "none")
api_key: Override default API key
api_base: Override default API base URL
stream: Whether to stream the response
top_p: Top-p sampling parameter
model_id: Optional ARN for Bedrock inference profiles
enable_thinking: Whether to enable thinking
reasoning_effort: Level of reasoning effort
Returns:
Union[Dict[str, Any], AsyncGenerator]: API response or stream
Raises:
LLMRetryError: If API call fails after retries
LLMError: For other API-related errors
"""
# debug <timestamp>.json messages
logger.info(f"Making LLM API call to model: {model_name} (Thinking: {enable_thinking}, Effort: {reasoning_effort})")
logger.info(f"📡 API Call: Using model {model_name}")
params = prepare_params(
messages=messages,
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
response_format=response_format,
tools=tools,
tool_choice=tool_choice,
api_key=api_key,
api_base=api_base,
stream=stream,
top_p=top_p,
model_id=model_id,
enable_thinking=enable_thinking,
reasoning_effort=reasoning_effort
)
last_error = None
for attempt in range(MAX_RETRIES):
try:
logger.debug(f"Attempt {attempt + 1}/{MAX_RETRIES}")
# logger.debug(f"API request parameters: {json.dumps(params, indent=2)}")
response = await litellm.acompletion(**params)
logger.debug(f"Successfully received API response from {model_name}")
# logger.debug(f"Response: {response}")
return response
except (litellm.exceptions.RateLimitError, OpenAIError, json.JSONDecodeError) as e:
last_error = e
await handle_error(e, attempt, MAX_RETRIES)
except Exception as e:
# Check if this is a fallback-eligible error
should_fallback, fallback_model, error_type = detect_error_and_suggest_fallback(e, model_name)
if should_fallback and attempt == MAX_RETRIES - 1: # Only on last attempt
logger.warning(f"{error_type} detected on final attempt, suggesting fallback to {fallback_model}: {str(e)}")
# Don't retry, let the caller handle the fallback
raise e
last_error = e
logger.error(f"Unexpected error during API call: {str(e)}", exc_info=True)
await handle_error(e, attempt, MAX_RETRIES)
error_msg = f"Failed to make API call after {MAX_RETRIES} attempts"
if last_error:
error_msg += f". Last error: {str(last_error)}"
logger.error(error_msg, exc_info=True)
raise LLMRetryError(error_msg)
# Initialize API keys on module import
setup_api_keys()