mirror of https://github.com/kortix-ai/suna.git
180 lines
5.9 KiB
Python
180 lines
5.9 KiB
Python
"""
|
|
LLM API interface for making calls to various language models.
|
|
|
|
This module provides a unified interface for making API calls to different LLM providers
|
|
(OpenAI, Anthropic, Groq, etc.) using LiteLLM. It includes support for:
|
|
- Streaming responses
|
|
- Tool calls and function calling
|
|
- Retry logic with exponential backoff
|
|
- Model-specific configurations
|
|
- Comprehensive error handling and logging
|
|
"""
|
|
|
|
from typing import Union, Dict, Any, Optional, AsyncGenerator, List
|
|
import os
|
|
import json
|
|
import asyncio
|
|
from openai import OpenAIError
|
|
import litellm
|
|
from backend.utils.logger import logger
|
|
|
|
# Constants
|
|
MAX_RETRIES = 3
|
|
RATE_LIMIT_DELAY = 30
|
|
RETRY_DELAY = 5
|
|
|
|
class LLMError(Exception):
|
|
"""Base exception for LLM-related errors."""
|
|
pass
|
|
|
|
class LLMRetryError(LLMError):
|
|
"""Exception raised when retries are exhausted."""
|
|
pass
|
|
|
|
def setup_api_keys() -> None:
|
|
"""Set up API keys from environment variables."""
|
|
providers = ['OPENAI', 'ANTHROPIC', 'GROQ']
|
|
for provider in providers:
|
|
key = os.environ.get(f'{provider}_API_KEY')
|
|
if key:
|
|
logger.debug(f"API key set for provider: {provider}")
|
|
else:
|
|
logger.warning(f"No API key found for provider: {provider}")
|
|
|
|
async def handle_error(error: Exception, attempt: int, max_attempts: int) -> None:
|
|
"""Handle API errors with appropriate delays and logging."""
|
|
delay = RATE_LIMIT_DELAY if isinstance(error, litellm.exceptions.RateLimitError) else RETRY_DELAY
|
|
logger.warning(f"Error on attempt {attempt + 1}/{max_attempts}: {str(error)}")
|
|
logger.debug(f"Waiting {delay} seconds before retry...")
|
|
await asyncio.sleep(delay)
|
|
|
|
def prepare_params(
|
|
messages: List[Dict[str, Any]],
|
|
model_name: str,
|
|
temperature: float = 0,
|
|
max_tokens: Optional[int] = None,
|
|
response_format: Optional[Any] = None,
|
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
tool_choice: str = "auto",
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
stream: bool = False,
|
|
top_p: Optional[float] = None
|
|
) -> Dict[str, Any]:
|
|
"""Prepare parameters for the API call."""
|
|
params = {
|
|
"model": model_name,
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
"response_format": response_format,
|
|
"top_p": top_p,
|
|
"stream": stream,
|
|
}
|
|
|
|
if api_key:
|
|
params["api_key"] = api_key
|
|
if api_base:
|
|
params["api_base"] = api_base
|
|
|
|
# Handle token limits
|
|
if max_tokens is not None:
|
|
param_name = "max_completion_tokens" if 'o1' in model_name else "max_tokens"
|
|
params[param_name] = max_tokens
|
|
|
|
# Add tools if provided
|
|
if tools:
|
|
params.update({
|
|
"tools": tools,
|
|
"tool_choice": tool_choice
|
|
})
|
|
logger.debug(f"Added {len(tools)} tools to API parameters")
|
|
|
|
# Add Claude-specific headers
|
|
if "claude" in model_name.lower() or "anthropic" in model_name.lower():
|
|
params["extra_headers"] = {
|
|
"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"
|
|
}
|
|
logger.debug("Added Claude-specific headers")
|
|
|
|
return params
|
|
|
|
async def make_llm_api_call(
|
|
messages: List[Dict[str, Any]],
|
|
model_name: str,
|
|
response_format: Optional[Any] = None,
|
|
temperature: float = 0,
|
|
max_tokens: Optional[int] = None,
|
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
tool_choice: str = "auto",
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
stream: bool = False,
|
|
top_p: Optional[float] = None
|
|
) -> Union[Dict[str, Any], AsyncGenerator]:
|
|
"""
|
|
Make an API call to a language model using LiteLLM.
|
|
|
|
Args:
|
|
messages: List of message dictionaries for the conversation
|
|
model_name: Name of the model to use (e.g., "gpt-4", "claude-3")
|
|
response_format: Desired format for the response
|
|
temperature: Sampling temperature (0-1)
|
|
max_tokens: Maximum tokens in the response
|
|
tools: List of tool definitions for function calling
|
|
tool_choice: How to select tools ("auto" or "none")
|
|
api_key: Override default API key
|
|
api_base: Override default API base URL
|
|
stream: Whether to stream the response
|
|
top_p: Top-p sampling parameter
|
|
|
|
Returns:
|
|
Union[Dict[str, Any], AsyncGenerator]: API response or stream
|
|
|
|
Raises:
|
|
LLMRetryError: If API call fails after retries
|
|
LLMError: For other API-related errors
|
|
"""
|
|
logger.info(f"Making LLM API call to model: {model_name}")
|
|
params = prepare_params(
|
|
messages=messages,
|
|
model_name=model_name,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
response_format=response_format,
|
|
tools=tools,
|
|
tool_choice=tool_choice,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
stream=stream,
|
|
top_p=top_p
|
|
)
|
|
|
|
last_error = None
|
|
for attempt in range(MAX_RETRIES):
|
|
try:
|
|
logger.debug(f"Attempt {attempt + 1}/{MAX_RETRIES}")
|
|
logger.debug(f"API request parameters: {json.dumps(params, indent=2)}")
|
|
|
|
response = await litellm.acompletion(**params)
|
|
logger.info(f"Successfully received API response from {model_name}")
|
|
logger.debug(f"Response: {response}")
|
|
return response
|
|
|
|
except (litellm.exceptions.RateLimitError, OpenAIError, json.JSONDecodeError) as e:
|
|
last_error = e
|
|
await handle_error(e, attempt, MAX_RETRIES)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error during API call: {str(e)}", exc_info=True)
|
|
raise LLMError(f"API call failed: {str(e)}")
|
|
|
|
error_msg = f"Failed to make API call after {MAX_RETRIES} attempts"
|
|
if last_error:
|
|
error_msg += f". Last error: {str(last_error)}"
|
|
logger.error(error_msg, exc_info=True)
|
|
raise LLMRetryError(error_msg)
|
|
|
|
# Initialize API keys on module import
|
|
setup_api_keys()
|
|
|