suna/backend/services/llm.py

349 lines
14 KiB
Python
Raw Normal View History

2025-03-30 14:48:57 +08:00
"""
LLM API interface for making calls to various language models.
This module provides a unified interface for making API calls to different LLM providers
2025-07-10 18:58:10 +08:00
(OpenAI, Anthropic, Groq, xAI, etc.) using LiteLLM. It includes support for:
2025-03-30 14:48:57 +08:00
- Streaming responses
- Tool calls and function calling
- Retry logic with exponential backoff
- Model-specific configurations
2025-04-01 09:26:52 +08:00
- Comprehensive error handling and logging
2025-03-30 14:48:57 +08:00
"""
2025-04-01 09:26:52 +08:00
from typing import Union, Dict, Any, Optional, AsyncGenerator, List
2025-03-30 14:48:57 +08:00
import os
import json
import asyncio
from openai import OpenAIError
import litellm
2025-07-10 18:58:10 +08:00
from litellm.files.main import ModelResponse
2025-04-02 02:49:35 +08:00
from utils.logger import logger
2025-04-24 08:45:58 +08:00
from utils.config import config
2025-03-30 14:48:57 +08:00
# litellm.set_verbose=True
2025-04-08 01:35:40 +08:00
litellm.modify_params=True
2025-04-01 09:26:52 +08:00
# Constants
MAX_RETRIES = 2
2025-04-01 09:26:52 +08:00
RATE_LIMIT_DELAY = 30
RETRY_DELAY = 0.1
2025-03-30 14:48:57 +08:00
class LLMError(Exception):
"""Base exception for LLM-related errors."""
pass
class LLMRetryError(LLMError):
"""Exception raised when retries are exhausted."""
pass
2025-04-01 09:26:52 +08:00
def setup_api_keys() -> None:
"""Set up API keys from environment variables."""
2025-07-10 18:58:10 +08:00
providers = ['OPENAI', 'ANTHROPIC', 'GROQ', 'OPENROUTER', 'XAI']
2025-04-01 09:26:52 +08:00
for provider in providers:
2025-04-24 08:45:58 +08:00
key = getattr(config, f'{provider}_API_KEY')
2025-04-01 09:26:52 +08:00
if key:
logger.debug(f"API key set for provider: {provider}")
else:
logger.warning(f"No API key found for provider: {provider}")
# Set up OpenRouter API base if not already set
2025-04-24 08:45:58 +08:00
if config.OPENROUTER_API_KEY and config.OPENROUTER_API_BASE:
os.environ['OPENROUTER_API_BASE'] = config.OPENROUTER_API_BASE
logger.debug(f"Set OPENROUTER_API_BASE to {config.OPENROUTER_API_BASE}")
# Set up AWS Bedrock credentials
2025-04-24 08:45:58 +08:00
aws_access_key = config.AWS_ACCESS_KEY_ID
aws_secret_key = config.AWS_SECRET_ACCESS_KEY
aws_region = config.AWS_REGION_NAME
if aws_access_key and aws_secret_key and aws_region:
logger.debug(f"AWS credentials set for Bedrock in region: {aws_region}")
# Configure LiteLLM to use AWS credentials
os.environ['AWS_ACCESS_KEY_ID'] = aws_access_key
os.environ['AWS_SECRET_ACCESS_KEY'] = aws_secret_key
os.environ['AWS_REGION_NAME'] = aws_region
else:
logger.warning(f"Missing AWS credentials for Bedrock integration - access_key: {bool(aws_access_key)}, secret_key: {bool(aws_secret_key)}, region: {aws_region}")
2025-04-01 09:26:52 +08:00
2025-07-10 18:58:10 +08:00
def get_openrouter_fallback(model_name: str) -> Optional[str]:
"""Get OpenRouter fallback model for a given model name."""
# Skip if already using OpenRouter
if model_name.startswith("openrouter/"):
return None
# Map models to their OpenRouter equivalents
fallback_mapping = {
"anthropic/claude-3-7-sonnet-latest": "openrouter/anthropic/claude-3.7-sonnet",
"anthropic/claude-sonnet-4-20250514": "openrouter/anthropic/claude-sonnet-4",
"xai/grok-4": "openrouter/x-ai/grok-4",
}
# Check for exact match first
if model_name in fallback_mapping:
return fallback_mapping[model_name]
# Check for partial matches (e.g., bedrock models)
for key, value in fallback_mapping.items():
if key in model_name:
return value
# Default fallbacks by provider
if "claude" in model_name.lower() or "anthropic" in model_name.lower():
return "openrouter/anthropic/claude-sonnet-4"
elif "xai" in model_name.lower() or "grok" in model_name.lower():
return "openrouter/x-ai/grok-4"
return None
2025-04-01 09:26:52 +08:00
async def handle_error(error: Exception, attempt: int, max_attempts: int) -> None:
"""Handle API errors with appropriate delays and logging."""
delay = RATE_LIMIT_DELAY if isinstance(error, litellm.exceptions.RateLimitError) else RETRY_DELAY
logger.warning(f"Error on attempt {attempt + 1}/{max_attempts}: {str(error)}")
logger.debug(f"Waiting {delay} seconds before retry...")
await asyncio.sleep(delay)
def prepare_params(
messages: List[Dict[str, Any]],
model_name: str,
temperature: float = 0,
max_tokens: Optional[int] = None,
response_format: Optional[Any] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: str = "auto",
api_key: Optional[str] = None,
api_base: Optional[str] = None,
stream: bool = False,
top_p: Optional[float] = None,
2025-04-18 12:49:41 +08:00
model_id: Optional[str] = None,
enable_thinking: Optional[bool] = False,
reasoning_effort: Optional[str] = 'low'
2025-04-01 09:26:52 +08:00
) -> Dict[str, Any]:
"""Prepare parameters for the API call."""
params = {
"model": model_name,
"messages": messages,
"temperature": temperature,
"response_format": response_format,
"top_p": top_p,
"stream": stream,
}
if api_key:
params["api_key"] = api_key
if api_base:
params["api_base"] = api_base
if model_id:
params["model_id"] = model_id
2025-04-01 09:26:52 +08:00
# Handle token limits
if max_tokens is not None:
# For Claude 3.7 in Bedrock, do not set max_tokens or max_tokens_to_sample
# as it causes errors with inference profiles
if model_name.startswith("bedrock/") and "claude-3-7" in model_name:
logger.debug(f"Skipping max_tokens for Claude 3.7 model: {model_name}")
# Do not add any max_tokens parameter for Claude 3.7
else:
param_name = "max_completion_tokens" if 'o1' in model_name else "max_tokens"
params[param_name] = max_tokens
2025-04-01 09:26:52 +08:00
# Add tools if provided
if tools:
params.update({
"tools": tools,
"tool_choice": tool_choice
})
logger.debug(f"Added {len(tools)} tools to API parameters")
# # Add Claude-specific headers
2025-04-16 02:07:31 +08:00
if "claude" in model_name.lower() or "anthropic" in model_name.lower():
params["extra_headers"] = {
# "anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"
"anthropic-beta": "output-128k-2025-02-19"
}
# params["mock_testing_fallback"] = True
2025-04-16 02:07:31 +08:00
logger.debug("Added Claude-specific headers")
# Add OpenRouter-specific parameters
if model_name.startswith("openrouter/"):
logger.debug(f"Preparing OpenRouter parameters for model: {model_name}")
2025-04-24 08:45:58 +08:00
# Add optional site URL and app name from config
site_url = config.OR_SITE_URL
app_name = config.OR_APP_NAME
if site_url or app_name:
extra_headers = params.get("extra_headers", {})
if site_url:
extra_headers["HTTP-Referer"] = site_url
if app_name:
extra_headers["X-Title"] = app_name
params["extra_headers"] = extra_headers
logger.debug(f"Added OpenRouter site URL and app name to headers")
# Add Bedrock-specific parameters
if model_name.startswith("bedrock/"):
logger.debug(f"Preparing AWS Bedrock parameters for model: {model_name}")
if not model_id and "anthropic.claude-3-7-sonnet" in model_name:
params["model_id"] = "arn:aws:bedrock:us-west-2:935064898258:inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0"
logger.debug(f"Auto-set model_id for Claude 3.7 Sonnet: {params['model_id']}")
2025-04-01 09:26:52 +08:00
fallback_model = get_openrouter_fallback(model_name)
if fallback_model:
params["fallbacks"] = [{
"model": fallback_model,
"messages": messages,
}]
logger.debug(f"Added OpenRouter fallback for model: {model_name} to {fallback_model}")
2025-04-18 11:49:27 +08:00
# Apply Anthropic prompt caching (minimal implementation)
# Check model name *after* potential modifications (like adding bedrock/ prefix)
effective_model_name = params.get("model", model_name) # Use model from params if set, else original
if "claude" in effective_model_name.lower() or "anthropic" in effective_model_name.lower():
messages = params["messages"] # Direct reference, modification affects params
# Ensure messages is a list
if not isinstance(messages, list):
return params # Return early if messages format is unexpected
# Apply cache control to the first 4 text blocks across all messages
cache_control_count = 0
max_cache_control_blocks = 3
2025-04-18 11:49:27 +08:00
for message in messages:
if cache_control_count >= max_cache_control_blocks:
break
2025-04-24 00:06:40 +08:00
content = message.get("content")
2025-04-18 11:49:27 +08:00
if isinstance(content, str):
2025-04-24 00:06:40 +08:00
message["content"] = [
2025-04-18 11:49:27 +08:00
{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
]
cache_control_count += 1
2025-04-18 11:49:27 +08:00
elif isinstance(content, list):
for item in content:
if cache_control_count >= max_cache_control_blocks:
break
if isinstance(item, dict) and item.get("type") == "text" and "cache_control" not in item:
item["cache_control"] = {"type": "ephemeral"}
cache_control_count += 1
2025-04-18 11:49:27 +08:00
2025-04-18 12:49:41 +08:00
# Add reasoning_effort for Anthropic models if enabled
use_thinking = enable_thinking if enable_thinking is not None else False
is_anthropic = "anthropic" in effective_model_name.lower() or "claude" in effective_model_name.lower()
2025-07-10 18:58:10 +08:00
is_xai = "xai" in effective_model_name.lower() or model_name.startswith("xai/")
is_kimi_k2 = "kimi-k2" in effective_model_name.lower() or model_name.startswith("moonshotai/kimi-k2")
if is_kimi_k2:
params["provider"] = {
"order": ["groq", "together/fp8"]
}
2025-04-18 12:49:41 +08:00
if is_anthropic and use_thinking:
effort_level = reasoning_effort if reasoning_effort else 'low'
params["reasoning_effort"] = effort_level
params["temperature"] = 1.0 # Required by Anthropic when reasoning_effort is used
logger.info(f"Anthropic thinking enabled with reasoning_effort='{effort_level}'")
2025-07-10 18:58:10 +08:00
# Add reasoning_effort for xAI models if enabled
if is_xai and use_thinking:
effort_level = reasoning_effort if reasoning_effort else 'low'
params["reasoning_effort"] = effort_level
logger.info(f"xAI thinking enabled with reasoning_effort='{effort_level}'")
# Add xAI-specific parameters
if model_name.startswith("xai/"):
logger.debug(f"Preparing xAI parameters for model: {model_name}")
# xAI models support standard parameters, no special handling needed beyond reasoning_effort
2025-04-01 09:26:52 +08:00
return params
2025-03-30 14:48:57 +08:00
async def make_llm_api_call(
2025-04-01 09:26:52 +08:00
messages: List[Dict[str, Any]],
2025-03-30 14:48:57 +08:00
model_name: str,
response_format: Optional[Any] = None,
temperature: float = 0,
max_tokens: Optional[int] = None,
2025-04-01 09:26:52 +08:00
tools: Optional[List[Dict[str, Any]]] = None,
2025-03-30 14:48:57 +08:00
tool_choice: str = "auto",
api_key: Optional[str] = None,
api_base: Optional[str] = None,
stream: bool = False,
top_p: Optional[float] = None,
2025-04-18 12:49:41 +08:00
model_id: Optional[str] = None,
enable_thinking: Optional[bool] = False,
reasoning_effort: Optional[str] = 'low'
2025-07-10 18:58:10 +08:00
) -> Union[Dict[str, Any], AsyncGenerator, ModelResponse]:
2025-03-30 14:48:57 +08:00
"""
Make an API call to a language model using LiteLLM.
2025-03-30 14:48:57 +08:00
Args:
messages: List of message dictionaries for the conversation
model_name: Name of the model to use (e.g., "gpt-4", "claude-3", "openrouter/openai/gpt-4", "bedrock/anthropic.claude-3-sonnet-20240229-v1:0")
2025-03-30 14:48:57 +08:00
response_format: Desired format for the response
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens in the response
tools: List of tool definitions for function calling
tool_choice: How to select tools ("auto" or "none")
api_key: Override default API key
api_base: Override default API base URL
stream: Whether to stream the response
top_p: Top-p sampling parameter
model_id: Optional ARN for Bedrock inference profiles
2025-04-18 12:49:41 +08:00
enable_thinking: Whether to enable thinking
reasoning_effort: Level of reasoning effort
2025-03-30 14:48:57 +08:00
Returns:
Union[Dict[str, Any], AsyncGenerator]: API response or stream
2025-03-30 14:48:57 +08:00
Raises:
LLMRetryError: If API call fails after retries
LLMError: For other API-related errors
"""
# debug <timestamp>.json messages
logger.info(f"Making LLM API call to model: {model_name} (Thinking: {enable_thinking}, Effort: {reasoning_effort})")
logger.info(f"📡 API Call: Using model {model_name}")
2025-04-01 09:26:52 +08:00
params = prepare_params(
messages=messages,
2025-03-30 14:48:57 +08:00
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
response_format=response_format,
tools=tools,
tool_choice=tool_choice,
api_key=api_key,
api_base=api_base,
stream=stream,
top_p=top_p,
2025-04-18 12:49:41 +08:00
model_id=model_id,
enable_thinking=enable_thinking,
reasoning_effort=reasoning_effort
2025-03-30 14:48:57 +08:00
)
2025-04-01 09:26:52 +08:00
last_error = None
for attempt in range(MAX_RETRIES):
try:
logger.debug(f"Attempt {attempt + 1}/{MAX_RETRIES}")
2025-04-07 00:45:02 +08:00
# logger.debug(f"API request parameters: {json.dumps(params, indent=2)}")
2025-04-01 09:26:52 +08:00
response = await litellm.acompletion(**params)
2025-04-11 00:02:21 +08:00
logger.debug(f"Successfully received API response from {model_name}")
# logger.debug(f"Response: {response}")
2025-04-01 09:26:52 +08:00
return response
2025-04-01 09:26:52 +08:00
except (litellm.exceptions.RateLimitError, OpenAIError, json.JSONDecodeError) as e:
last_error = e
await handle_error(e, attempt, MAX_RETRIES)
2025-04-01 09:26:52 +08:00
except Exception as e:
logger.error(f"Unexpected error during API call: {str(e)}", exc_info=True)
raise LLMError(f"API call failed: {str(e)}")
2025-04-01 09:26:52 +08:00
error_msg = f"Failed to make API call after {MAX_RETRIES} attempts"
if last_error:
error_msg += f". Last error: {str(last_error)}"
logger.error(error_msg, exc_info=True)
raise LLMRetryError(error_msg)
# Initialize API keys on module import
setup_api_keys()