suna/backend/services/llm.py

180 lines
5.9 KiB
Python
Raw Normal View History

2025-03-30 14:48:57 +08:00
"""
LLM API interface for making calls to various language models.
This module provides a unified interface for making API calls to different LLM providers
(OpenAI, Anthropic, Groq, etc.) using LiteLLM. It includes support for:
- Streaming responses
- Tool calls and function calling
- Retry logic with exponential backoff
- Model-specific configurations
2025-04-01 09:26:52 +08:00
- Comprehensive error handling and logging
2025-03-30 14:48:57 +08:00
"""
2025-04-01 09:26:52 +08:00
from typing import Union, Dict, Any, Optional, AsyncGenerator, List
2025-03-30 14:48:57 +08:00
import os
import json
import asyncio
from openai import OpenAIError
import litellm
2025-04-01 10:36:26 +08:00
from backend.utils.logger import logger
2025-03-30 14:48:57 +08:00
2025-04-01 09:26:52 +08:00
# Constants
MAX_RETRIES = 3
RATE_LIMIT_DELAY = 30
RETRY_DELAY = 5
2025-03-30 14:48:57 +08:00
class LLMError(Exception):
"""Base exception for LLM-related errors."""
pass
class LLMRetryError(LLMError):
"""Exception raised when retries are exhausted."""
pass
2025-04-01 09:26:52 +08:00
def setup_api_keys() -> None:
"""Set up API keys from environment variables."""
providers = ['OPENAI', 'ANTHROPIC', 'GROQ']
for provider in providers:
key = os.environ.get(f'{provider}_API_KEY')
if key:
logger.debug(f"API key set for provider: {provider}")
else:
logger.warning(f"No API key found for provider: {provider}")
async def handle_error(error: Exception, attempt: int, max_attempts: int) -> None:
"""Handle API errors with appropriate delays and logging."""
delay = RATE_LIMIT_DELAY if isinstance(error, litellm.exceptions.RateLimitError) else RETRY_DELAY
logger.warning(f"Error on attempt {attempt + 1}/{max_attempts}: {str(error)}")
logger.debug(f"Waiting {delay} seconds before retry...")
await asyncio.sleep(delay)
def prepare_params(
messages: List[Dict[str, Any]],
model_name: str,
temperature: float = 0,
max_tokens: Optional[int] = None,
response_format: Optional[Any] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: str = "auto",
api_key: Optional[str] = None,
api_base: Optional[str] = None,
stream: bool = False,
top_p: Optional[float] = None
) -> Dict[str, Any]:
"""Prepare parameters for the API call."""
params = {
"model": model_name,
"messages": messages,
"temperature": temperature,
"response_format": response_format,
"top_p": top_p,
"stream": stream,
}
if api_key:
params["api_key"] = api_key
if api_base:
params["api_base"] = api_base
# Handle token limits
if max_tokens is not None:
param_name = "max_completion_tokens" if 'o1' in model_name else "max_tokens"
params[param_name] = max_tokens
# Add tools if provided
if tools:
params.update({
"tools": tools,
"tool_choice": tool_choice
})
logger.debug(f"Added {len(tools)} tools to API parameters")
# Add Claude-specific headers
if "claude" in model_name.lower() or "anthropic" in model_name.lower():
params["extra_headers"] = {
"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"
}
logger.debug("Added Claude-specific headers")
return params
2025-03-30 14:48:57 +08:00
async def make_llm_api_call(
2025-04-01 09:26:52 +08:00
messages: List[Dict[str, Any]],
2025-03-30 14:48:57 +08:00
model_name: str,
response_format: Optional[Any] = None,
temperature: float = 0,
max_tokens: Optional[int] = None,
2025-04-01 09:26:52 +08:00
tools: Optional[List[Dict[str, Any]]] = None,
2025-03-30 14:48:57 +08:00
tool_choice: str = "auto",
api_key: Optional[str] = None,
api_base: Optional[str] = None,
stream: bool = False,
top_p: Optional[float] = None
) -> Union[Dict[str, Any], AsyncGenerator]:
"""
Make an API call to a language model using LiteLLM.
Args:
messages: List of message dictionaries for the conversation
model_name: Name of the model to use (e.g., "gpt-4", "claude-3")
response_format: Desired format for the response
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens in the response
tools: List of tool definitions for function calling
tool_choice: How to select tools ("auto" or "none")
api_key: Override default API key
api_base: Override default API base URL
stream: Whether to stream the response
top_p: Top-p sampling parameter
Returns:
Union[Dict[str, Any], AsyncGenerator]: API response or stream
Raises:
LLMRetryError: If API call fails after retries
LLMError: For other API-related errors
"""
2025-04-01 09:26:52 +08:00
logger.info(f"Making LLM API call to model: {model_name}")
params = prepare_params(
messages=messages,
2025-03-30 14:48:57 +08:00
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
response_format=response_format,
tools=tools,
tool_choice=tool_choice,
api_key=api_key,
api_base=api_base,
stream=stream,
top_p=top_p
)
2025-04-01 09:26:52 +08:00
last_error = None
for attempt in range(MAX_RETRIES):
try:
logger.debug(f"Attempt {attempt + 1}/{MAX_RETRIES}")
logger.debug(f"API request parameters: {json.dumps(params, indent=2)}")
response = await litellm.acompletion(**params)
logger.info(f"Successfully received API response from {model_name}")
logger.debug(f"Response: {response}")
return response
except (litellm.exceptions.RateLimitError, OpenAIError, json.JSONDecodeError) as e:
last_error = e
await handle_error(e, attempt, MAX_RETRIES)
except Exception as e:
logger.error(f"Unexpected error during API call: {str(e)}", exc_info=True)
raise LLMError(f"API call failed: {str(e)}")
error_msg = f"Failed to make API call after {MAX_RETRIES} attempts"
if last_error:
error_msg += f". Last error: {str(last_error)}"
logger.error(error_msg, exc_info=True)
raise LLMRetryError(error_msg)
# Initialize API keys on module import
setup_api_keys()
2025-03-30 14:48:57 +08:00