mirror of https://github.com/kortix-ai/suna.git
421 lines
17 KiB
Python
421 lines
17 KiB
Python
"""
|
|
LLM API interface for making calls to various language models.
|
|
|
|
This module provides a unified interface for making API calls to different LLM providers
|
|
(OpenAI, Anthropic, Groq, etc.) using LiteLLM. It includes support for:
|
|
- Streaming responses
|
|
- Tool calls and function calling
|
|
- Retry logic with exponential backoff
|
|
- Model-specific configurations
|
|
- Comprehensive error handling and logging
|
|
"""
|
|
|
|
from typing import Union, Dict, Any, Optional, AsyncGenerator, List
|
|
import os
|
|
import json
|
|
import asyncio
|
|
import time # Added for timestamp
|
|
from openai import OpenAIError
|
|
import litellm
|
|
from utils.logger import logger
|
|
|
|
# litellm.set_verbose=True
|
|
litellm.modify_params=True
|
|
|
|
# Constants
|
|
MAX_RETRIES = 3
|
|
RATE_LIMIT_DELAY = 30
|
|
RETRY_DELAY = 5
|
|
|
|
# Define debug log directory relative to this file's location
|
|
DEBUG_LOG_DIR = os.path.join(os.path.dirname(__file__), '..', 'debug_logs') # Assumes backend/debug_logs
|
|
|
|
class LLMError(Exception):
|
|
"""Base exception for LLM-related errors."""
|
|
pass
|
|
|
|
class LLMRetryError(LLMError):
|
|
"""Exception raised when retries are exhausted."""
|
|
pass
|
|
|
|
def setup_api_keys() -> None:
|
|
"""Set up API keys from environment variables."""
|
|
providers = ['OPENAI', 'ANTHROPIC', 'GROQ', 'OPENROUTER']
|
|
for provider in providers:
|
|
key = os.environ.get(f'{provider}_API_KEY')
|
|
if key:
|
|
logger.debug(f"API key set for provider: {provider}")
|
|
else:
|
|
logger.warning(f"No API key found for provider: {provider}")
|
|
|
|
# Set up OpenRouter API base if not already set
|
|
if os.environ.get('OPENROUTER_API_KEY') and not os.environ.get('OPENROUTER_API_BASE'):
|
|
os.environ['OPENROUTER_API_BASE'] = 'https://openrouter.ai/api/v1'
|
|
logger.debug("Set default OPENROUTER_API_BASE to https://openrouter.ai/api/v1")
|
|
|
|
# Set up AWS Bedrock credentials
|
|
aws_access_key = os.environ.get('AWS_ACCESS_KEY_ID')
|
|
aws_secret_key = os.environ.get('AWS_SECRET_ACCESS_KEY')
|
|
aws_region = os.environ.get('AWS_REGION_NAME')
|
|
|
|
if aws_access_key and aws_secret_key and aws_region:
|
|
logger.debug(f"AWS credentials set for Bedrock in region: {aws_region}")
|
|
# Configure LiteLLM to use AWS credentials
|
|
os.environ['AWS_ACCESS_KEY_ID'] = aws_access_key
|
|
os.environ['AWS_SECRET_ACCESS_KEY'] = aws_secret_key
|
|
os.environ['AWS_REGION_NAME'] = aws_region
|
|
else:
|
|
logger.warning(f"Missing AWS credentials for Bedrock integration - access_key: {bool(aws_access_key)}, secret_key: {bool(aws_secret_key)}, region: {aws_region}")
|
|
|
|
async def handle_error(error: Exception, attempt: int, max_attempts: int) -> None:
|
|
"""Handle API errors with appropriate delays and logging."""
|
|
delay = RATE_LIMIT_DELAY if isinstance(error, litellm.exceptions.RateLimitError) else RETRY_DELAY
|
|
logger.warning(f"Error on attempt {attempt + 1}/{max_attempts}: {str(error)}")
|
|
logger.debug(f"Waiting {delay} seconds before retry...")
|
|
await asyncio.sleep(delay)
|
|
|
|
def prepare_params(
|
|
messages: List[Dict[str, Any]],
|
|
model_name: str,
|
|
temperature: float = 0,
|
|
max_tokens: Optional[int] = None,
|
|
response_format: Optional[Any] = None,
|
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
tool_choice: str = "auto",
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
stream: bool = False,
|
|
top_p: Optional[float] = None,
|
|
model_id: Optional[str] = None
|
|
) -> Dict[str, Any]:
|
|
"""Prepare parameters for the API call."""
|
|
params = {
|
|
"model": model_name,
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
"response_format": response_format,
|
|
"top_p": top_p,
|
|
"stream": stream,
|
|
}
|
|
|
|
if api_key:
|
|
params["api_key"] = api_key
|
|
if api_base:
|
|
params["api_base"] = api_base
|
|
if model_id:
|
|
params["model_id"] = model_id
|
|
|
|
# Handle token limits
|
|
if max_tokens is not None:
|
|
# For Claude 3.7 in Bedrock, do not set max_tokens or max_tokens_to_sample
|
|
# as it causes errors with inference profiles
|
|
if model_name.startswith("bedrock/") and "claude-3-7" in model_name:
|
|
logger.debug(f"Skipping max_tokens for Claude 3.7 model: {model_name}")
|
|
# Do not add any max_tokens parameter for Claude 3.7
|
|
else:
|
|
param_name = "max_completion_tokens" if 'o1' in model_name else "max_tokens"
|
|
params[param_name] = max_tokens
|
|
|
|
# Add tools if provided
|
|
if tools:
|
|
params.update({
|
|
"tools": tools,
|
|
"tool_choice": tool_choice
|
|
})
|
|
logger.debug(f"Added {len(tools)} tools to API parameters")
|
|
|
|
# # Add Claude-specific headers
|
|
if "claude" in model_name.lower() or "anthropic" in model_name.lower():
|
|
params["extra_headers"] = {
|
|
# "anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"
|
|
"anthropic-beta": "output-128k-2025-02-19"
|
|
}
|
|
logger.debug("Added Claude-specific headers")
|
|
|
|
# Add OpenRouter-specific parameters
|
|
if model_name.startswith("openrouter/"):
|
|
logger.debug(f"Preparing OpenRouter parameters for model: {model_name}")
|
|
|
|
# Add optional site URL and app name if set in environment
|
|
site_url = os.environ.get("OR_SITE_URL")
|
|
app_name = os.environ.get("OR_APP_NAME")
|
|
if site_url or app_name:
|
|
extra_headers = params.get("extra_headers", {})
|
|
if site_url:
|
|
extra_headers["HTTP-Referer"] = site_url
|
|
if app_name:
|
|
extra_headers["X-Title"] = app_name
|
|
params["extra_headers"] = extra_headers
|
|
logger.debug(f"Added OpenRouter site URL and app name to headers")
|
|
|
|
# Add Bedrock-specific parameters
|
|
if model_name.startswith("bedrock/"):
|
|
logger.debug(f"Preparing AWS Bedrock parameters for model: {model_name}")
|
|
|
|
if not model_id and "anthropic.claude-3-7-sonnet" in model_name:
|
|
params["model_id"] = "arn:aws:bedrock:us-west-2:935064898258:inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0"
|
|
logger.debug(f"Auto-set model_id for Claude 3.7 Sonnet: {params['model_id']}")
|
|
|
|
return params
|
|
|
|
async def make_llm_api_call(
|
|
messages: List[Dict[str, Any]],
|
|
model_name: str,
|
|
response_format: Optional[Any] = None,
|
|
temperature: float = 0,
|
|
max_tokens: Optional[int] = None,
|
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
tool_choice: str = "auto",
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
stream: bool = False,
|
|
top_p: Optional[float] = None,
|
|
model_id: Optional[str] = None
|
|
) -> Union[Dict[str, Any], AsyncGenerator]:
|
|
"""
|
|
Make an API call to a language model using LiteLLM.
|
|
|
|
Args:
|
|
messages: List of message dictionaries for the conversation
|
|
model_name: Name of the model to use (e.g., "gpt-4", "claude-3", "openrouter/openai/gpt-4", "bedrock/anthropic.claude-3-sonnet-20240229-v1:0")
|
|
response_format: Desired format for the response
|
|
temperature: Sampling temperature (0-1)
|
|
max_tokens: Maximum tokens in the response
|
|
tools: List of tool definitions for function calling
|
|
tool_choice: How to select tools ("auto" or "none")
|
|
api_key: Override default API key
|
|
api_base: Override default API base URL
|
|
stream: Whether to stream the response
|
|
top_p: Top-p sampling parameter
|
|
model_id: Optional ARN for Bedrock inference profiles
|
|
|
|
Returns:
|
|
Union[Dict[str, Any], AsyncGenerator]: API response or stream
|
|
|
|
Raises:
|
|
LLMRetryError: If API call fails after retries
|
|
LLMError: For other API-related errors
|
|
"""
|
|
logger.debug(f"Making LLM API call to model: {model_name}")
|
|
params = prepare_params(
|
|
messages=messages,
|
|
model_name=model_name,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
response_format=response_format,
|
|
tools=tools,
|
|
tool_choice=tool_choice,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
stream=stream,
|
|
top_p=top_p,
|
|
model_id=model_id
|
|
)
|
|
|
|
# Apply Anthropic prompt caching (minimal implementation)
|
|
if params["model"].startswith("anthropic/"):
|
|
logger.debug("Applying minimal Anthropic prompt caching.")
|
|
messages = params["messages"] # Direct reference
|
|
|
|
# 1. Process the first message if it's a system prompt with string content
|
|
if messages and messages[0].get("role") == "system":
|
|
content = messages[0].get("content")
|
|
if isinstance(content, str):
|
|
messages[0]["content"] = [
|
|
{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
|
|
]
|
|
logger.debug("Applied cache_control to system message.")
|
|
modified = True
|
|
elif not isinstance(content, list):
|
|
logger.warning("System message content is not a string or list, skipping cache_control.")
|
|
# else: content is already a list, do nothing
|
|
|
|
# 2. Find and process the last user message
|
|
last_user_idx = -1
|
|
for i in range(len(messages) - 1, -1, -1):
|
|
if messages[i].get("role") == "user":
|
|
last_user_idx = i
|
|
break
|
|
|
|
if last_user_idx != -1:
|
|
last_user_message = messages[last_user_idx]
|
|
content = last_user_message.get("content")
|
|
applied_to_user = False
|
|
|
|
if isinstance(content, str):
|
|
last_user_message["content"] = [
|
|
{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
|
|
]
|
|
logger.debug(f"Applied cache_control to last user message (string content, index {last_user_idx}).")
|
|
applied_to_user = True
|
|
elif isinstance(content, list):
|
|
# Modify text blocks within the list directly
|
|
found_text_block = False
|
|
for item in content:
|
|
if isinstance(item, dict) and item.get("type") == "text":
|
|
# Add cache_control if not already present (avoids adding it multiple times)
|
|
if "cache_control" not in item:
|
|
item["cache_control"] = {"type": "ephemeral"}
|
|
found_text_block = True # Mark modification only if added
|
|
|
|
if found_text_block:
|
|
logger.debug(f"Applied cache_control to text part(s) of last user message (list content, index {last_user_idx}).")
|
|
applied_to_user = True
|
|
# else: No text block found or cache_control already present, do nothing
|
|
else:
|
|
logger.warning(f"Last user message (index {last_user_idx}) content is not a string or list ({type(content)}), skipping cache_control.")
|
|
|
|
if applied_to_user:
|
|
modified = True
|
|
|
|
# --- Debug Logging Setup ---
|
|
# Initialize log path to None, it will be set only if logging is enabled
|
|
response_log_path = None
|
|
enable_debug_logging = os.environ.get('ENABLE_LLM_DEBUG_LOGGING', 'false').lower() == 'true'
|
|
|
|
if enable_debug_logging:
|
|
try:
|
|
os.makedirs(DEBUG_LOG_DIR, exist_ok=True)
|
|
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
|
# Use a unique ID or counter if calls can happen in the same second
|
|
# For simplicity, using timestamp only for now
|
|
request_log_path = os.path.join(DEBUG_LOG_DIR, f"llm_request_{timestamp}.json")
|
|
response_log_path = os.path.join(DEBUG_LOG_DIR, f"llm_response_{timestamp}.json") # Set here if enabled
|
|
|
|
# Log the request parameters just before the attempt loop
|
|
logger.debug(f"Logging LLM request parameters to {request_log_path}")
|
|
with open(request_log_path, 'w') as f:
|
|
# Use default=str for potentially non-serializable items in params if needed
|
|
json.dump(params, f, indent=2, default=str)
|
|
|
|
except Exception as log_err:
|
|
logger.error(f"Failed to set up or write LLM debug request log: {log_err}", exc_info=True)
|
|
# Reset response path to None if setup failed, even if logging was enabled
|
|
response_log_path = None
|
|
else:
|
|
logger.debug("LLM debug logging is disabled via environment variable.")
|
|
# --- End Debug Logging Setup ---
|
|
|
|
last_error = None
|
|
for attempt in range(MAX_RETRIES):
|
|
try:
|
|
logger.debug(f"Attempt {attempt + 1}/{MAX_RETRIES}")
|
|
|
|
response = await litellm.acompletion(**params)
|
|
logger.debug(f"Successfully received API response from {model_name}")
|
|
|
|
# --- Debug Logging Response ---
|
|
if response_log_path: # Only log if request logging setup succeeded
|
|
try:
|
|
logger.debug(f"Logging LLM response object to {response_log_path}")
|
|
# Check if it's a streaming response (AsyncGenerator)
|
|
if isinstance(response, AsyncGenerator):
|
|
with open(response_log_path, 'w') as f:
|
|
json.dump({"status": "streaming_response", "message": "Full response logged chunk by chunk where consumed."}, f, indent=2)
|
|
else:
|
|
# Assume it's a LiteLLM ModelResponse object, convert to dict
|
|
response_dict = response.dict()
|
|
with open(response_log_path, 'w') as f:
|
|
# Use default=str for potentially non-serializable items like datetime
|
|
json.dump(response_dict, f, indent=2, default=str)
|
|
except Exception as log_err:
|
|
logger.error(f"Failed to write LLM debug response log: {log_err}", exc_info=True)
|
|
# --- End Debug Logging Response ---
|
|
|
|
return response
|
|
|
|
except (litellm.exceptions.RateLimitError, OpenAIError, json.JSONDecodeError) as e:
|
|
last_error = e
|
|
await handle_error(e, attempt, MAX_RETRIES)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error during API call: {str(e)}", exc_info=True)
|
|
raise LLMError(f"API call failed: {str(e)}")
|
|
|
|
error_msg = f"Failed to make API call after {MAX_RETRIES} attempts"
|
|
if last_error:
|
|
error_msg += f". Last error: {str(last_error)}"
|
|
logger.error(error_msg, exc_info=True)
|
|
raise LLMRetryError(error_msg)
|
|
|
|
# Initialize API keys on module import
|
|
setup_api_keys()
|
|
|
|
# Test code for OpenRouter integration
|
|
async def test_openrouter():
|
|
"""Test the OpenRouter integration with a simple query."""
|
|
test_messages = [
|
|
{"role": "user", "content": "Hello, can you give me a quick test response?"}
|
|
]
|
|
|
|
try:
|
|
# Test with standard OpenRouter model
|
|
print("\n--- Testing standard OpenRouter model ---")
|
|
response = await make_llm_api_call(
|
|
model_name="openrouter/openai/gpt-3.5-turbo",
|
|
messages=test_messages,
|
|
temperature=0.7,
|
|
max_tokens=100
|
|
)
|
|
print(f"Response: {response.choices[0].message.content}")
|
|
|
|
# Test with deepseek model
|
|
print("\n--- Testing deepseek model ---")
|
|
response = await make_llm_api_call(
|
|
model_name="openrouter/deepseek/deepseek-r1-distill-llama-70b",
|
|
messages=test_messages,
|
|
temperature=0.7,
|
|
max_tokens=100
|
|
)
|
|
print(f"Response: {response.choices[0].message.content}")
|
|
print(f"Model used: {response.model}")
|
|
|
|
# Test with Mistral model
|
|
print("\n--- Testing Mistral model ---")
|
|
response = await make_llm_api_call(
|
|
model_name="openrouter/mistralai/mixtral-8x7b-instruct",
|
|
messages=test_messages,
|
|
temperature=0.7,
|
|
max_tokens=100
|
|
)
|
|
print(f"Response: {response.choices[0].message.content}")
|
|
print(f"Model used: {response.model}")
|
|
|
|
return True
|
|
except Exception as e:
|
|
print(f"Error testing OpenRouter: {str(e)}")
|
|
return False
|
|
|
|
async def test_bedrock():
|
|
"""Test the AWS Bedrock integration with a simple query."""
|
|
test_messages = [
|
|
{"role": "user", "content": "Hello, can you give me a quick test response?"}
|
|
]
|
|
|
|
try:
|
|
response = await make_llm_api_call(
|
|
model_name="bedrock/anthropic.claude-3-7-sonnet-20250219-v1:0",
|
|
model_id="arn:aws:bedrock:us-west-2:935064898258:inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
|
messages=test_messages,
|
|
temperature=0.7,
|
|
# Claude 3.7 has issues with max_tokens, so omit it
|
|
# max_tokens=100
|
|
)
|
|
print(f"Response: {response.choices[0].message.content}")
|
|
print(f"Model used: {response.model}")
|
|
|
|
return True
|
|
except Exception as e:
|
|
print(f"Error testing Bedrock: {str(e)}")
|
|
return False
|
|
|
|
if __name__ == "__main__":
|
|
import asyncio
|
|
|
|
test_success = asyncio.run(test_bedrock())
|
|
|
|
if test_success:
|
|
print("\n✅ integration test completed successfully!")
|
|
else:
|
|
print("\n❌ Bedrock integration test failed!")
|