mirror of https://github.com/kortix-ai/suna.git
Merge pull request #1242 from theshyPika/feat/openai-compatible-provider
feat: add OpenAI-compatible providers and SiliconFlow support
This commit is contained in:
commit
acd3dc9bab
|
@ -28,6 +28,9 @@ OPENROUTER_API_KEY=
|
||||||
GEMINI_API_KEY=
|
GEMINI_API_KEY=
|
||||||
MORPH_API_KEY=
|
MORPH_API_KEY=
|
||||||
|
|
||||||
|
OPENAI_COMPATIBLE_API_KEY=
|
||||||
|
OPENAI_COMPATIBLE_API_BASE=
|
||||||
|
|
||||||
# DATA APIS
|
# DATA APIS
|
||||||
RAPID_API_KEY=
|
RAPID_API_KEY=
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ This module provides a unified interface for making API calls to different LLM p
|
||||||
from typing import Union, Dict, Any, Optional, AsyncGenerator, List
|
from typing import Union, Dict, Any, Optional, AsyncGenerator, List
|
||||||
import os
|
import os
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.router import Router
|
||||||
from litellm.files.main import ModelResponse
|
from litellm.files.main import ModelResponse
|
||||||
from utils.logger import logger
|
from utils.logger import logger
|
||||||
from utils.config import config
|
from utils.config import config
|
||||||
|
@ -24,15 +25,27 @@ litellm.drop_params = True
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
MAX_RETRIES = 3
|
MAX_RETRIES = 3
|
||||||
|
provider_router = None
|
||||||
|
|
||||||
|
|
||||||
class LLMError(Exception):
|
class LLMError(Exception):
|
||||||
"""Base exception for LLM-related errors."""
|
"""Base exception for LLM-related errors."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def setup_api_keys() -> None:
|
def setup_api_keys() -> None:
|
||||||
"""Set up API keys from environment variables."""
|
"""Set up API keys from environment variables."""
|
||||||
providers = ['OPENAI', 'ANTHROPIC', 'GROQ', 'OPENROUTER', 'XAI', 'MORPH', 'GEMINI']
|
providers = [
|
||||||
|
"OPENAI",
|
||||||
|
"ANTHROPIC",
|
||||||
|
"GROQ",
|
||||||
|
"OPENROUTER",
|
||||||
|
"XAI",
|
||||||
|
"MORPH",
|
||||||
|
"GEMINI",
|
||||||
|
"OPENAI_COMPATIBLE",
|
||||||
|
]
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
key = getattr(config, f'{provider}_API_KEY')
|
key = getattr(config, f"{provider}_API_KEY")
|
||||||
if key:
|
if key:
|
||||||
logger.debug(f"API key set for provider: {provider}")
|
logger.debug(f"API key set for provider: {provider}")
|
||||||
else:
|
else:
|
||||||
|
@ -40,9 +53,10 @@ def setup_api_keys() -> None:
|
||||||
|
|
||||||
# Set up OpenRouter API base if not already set
|
# Set up OpenRouter API base if not already set
|
||||||
if config.OPENROUTER_API_KEY and config.OPENROUTER_API_BASE:
|
if config.OPENROUTER_API_KEY and config.OPENROUTER_API_BASE:
|
||||||
os.environ['OPENROUTER_API_BASE'] = config.OPENROUTER_API_BASE
|
os.environ["OPENROUTER_API_BASE"] = config.OPENROUTER_API_BASE
|
||||||
logger.debug(f"Set OPENROUTER_API_BASE to {config.OPENROUTER_API_BASE}")
|
logger.debug(f"Set OPENROUTER_API_BASE to {config.OPENROUTER_API_BASE}")
|
||||||
|
|
||||||
|
|
||||||
# Set up AWS Bedrock credentials
|
# Set up AWS Bedrock credentials
|
||||||
aws_access_key = config.AWS_ACCESS_KEY_ID
|
aws_access_key = config.AWS_ACCESS_KEY_ID
|
||||||
aws_secret_key = config.AWS_SECRET_ACCESS_KEY
|
aws_secret_key = config.AWS_SECRET_ACCESS_KEY
|
||||||
|
@ -51,12 +65,34 @@ def setup_api_keys() -> None:
|
||||||
if aws_access_key and aws_secret_key and aws_region:
|
if aws_access_key and aws_secret_key and aws_region:
|
||||||
logger.debug(f"AWS credentials set for Bedrock in region: {aws_region}")
|
logger.debug(f"AWS credentials set for Bedrock in region: {aws_region}")
|
||||||
# Configure LiteLLM to use AWS credentials
|
# Configure LiteLLM to use AWS credentials
|
||||||
os.environ['AWS_ACCESS_KEY_ID'] = aws_access_key
|
os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key
|
||||||
os.environ['AWS_SECRET_ACCESS_KEY'] = aws_secret_key
|
os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_key
|
||||||
os.environ['AWS_REGION_NAME'] = aws_region
|
os.environ["AWS_REGION_NAME"] = aws_region
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Missing AWS credentials for Bedrock integration - access_key: {bool(aws_access_key)}, secret_key: {bool(aws_secret_key)}, region: {aws_region}")
|
logger.warning(f"Missing AWS credentials for Bedrock integration - access_key: {bool(aws_access_key)}, secret_key: {bool(aws_secret_key)}, region: {aws_region}")
|
||||||
|
|
||||||
|
|
||||||
|
def setup_provider_router(openai_compatible_api_key: str = None, openai_compatible_api_base: str = None):
|
||||||
|
global provider_router
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "openai-compatible/*", # support OpenAI-Compatible LLM provider
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "openai/*",
|
||||||
|
"api_key": openai_compatible_api_key or config.OPENAI_COMPATIBLE_API_KEY,
|
||||||
|
"api_base": openai_compatible_api_base or config.OPENAI_COMPATIBLE_API_BASE,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "*", # supported LLM provider by LiteLLM
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "*",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
provider_router = Router(model_list=model_list)
|
||||||
|
|
||||||
|
|
||||||
def get_openrouter_fallback(model_name: str) -> Optional[str]:
|
def get_openrouter_fallback(model_name: str) -> Optional[str]:
|
||||||
"""Get OpenRouter fallback model for a given model name."""
|
"""Get OpenRouter fallback model for a given model name."""
|
||||||
# Skip if already using OpenRouter
|
# Skip if already using OpenRouter
|
||||||
|
@ -255,7 +291,7 @@ def prepare_params(
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
model_id: Optional[str] = None,
|
model_id: Optional[str] = None,
|
||||||
enable_thinking: Optional[bool] = False,
|
enable_thinking: Optional[bool] = False,
|
||||||
reasoning_effort: Optional[str] = 'low'
|
reasoning_effort: Optional[str] = "low",
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
from models import model_manager
|
from models import model_manager
|
||||||
resolved_model_name = model_manager.resolve_model_id(model_name)
|
resolved_model_name = model_manager.resolve_model_id(model_name)
|
||||||
|
@ -278,6 +314,17 @@ def prepare_params(
|
||||||
if model_id:
|
if model_id:
|
||||||
params["model_id"] = model_id
|
params["model_id"] = model_id
|
||||||
|
|
||||||
|
if model_name.startswith("openai-compatible/"):
|
||||||
|
# Check if have required config either from parameters or environment
|
||||||
|
if (not api_key and not config.OPENAI_COMPATIBLE_API_KEY) or (
|
||||||
|
not api_base and not config.OPENAI_COMPATIBLE_API_BASE
|
||||||
|
):
|
||||||
|
raise LLMError(
|
||||||
|
"OPENAI_COMPATIBLE_API_KEY and OPENAI_COMPATIBLE_API_BASE is required for openai-compatible models. If just updated the environment variables, wait a few minutes or restart the service to ensure they are loaded."
|
||||||
|
)
|
||||||
|
|
||||||
|
setup_provider_router(api_key, api_base)
|
||||||
|
|
||||||
# Handle token limits
|
# Handle token limits
|
||||||
_configure_token_limits(params, resolved_model_name, max_tokens)
|
_configure_token_limits(params, resolved_model_name, max_tokens)
|
||||||
# Add tools if provided
|
# Add tools if provided
|
||||||
|
@ -312,7 +359,7 @@ async def make_llm_api_call(
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
model_id: Optional[str] = None,
|
model_id: Optional[str] = None,
|
||||||
enable_thinking: Optional[bool] = False,
|
enable_thinking: Optional[bool] = False,
|
||||||
reasoning_effort: Optional[str] = 'low'
|
reasoning_effort: Optional[str] = "low",
|
||||||
) -> Union[Dict[str, Any], AsyncGenerator, ModelResponse]:
|
) -> Union[Dict[str, Any], AsyncGenerator, ModelResponse]:
|
||||||
"""
|
"""
|
||||||
Make an API call to a language model using LiteLLM.
|
Make an API call to a language model using LiteLLM.
|
||||||
|
@ -357,10 +404,10 @@ async def make_llm_api_call(
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
enable_thinking=enable_thinking,
|
enable_thinking=enable_thinking,
|
||||||
reasoning_effort=reasoning_effort
|
reasoning_effort=reasoning_effort,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
response = await litellm.acompletion(**params)
|
response = await provider_router.acompletion(**params)
|
||||||
logger.debug(f"Successfully received API response from {model_name}")
|
logger.debug(f"Successfully received API response from {model_name}")
|
||||||
# logger.debug(f"Response: {response}")
|
# logger.debug(f"Response: {response}")
|
||||||
return response
|
return response
|
||||||
|
@ -371,3 +418,4 @@ async def make_llm_api_call(
|
||||||
|
|
||||||
# Initialize API keys on module import
|
# Initialize API keys on module import
|
||||||
setup_api_keys()
|
setup_api_keys()
|
||||||
|
setup_provider_router()
|
||||||
|
|
|
@ -236,6 +236,8 @@ class Configuration:
|
||||||
MORPH_API_KEY: Optional[str] = None
|
MORPH_API_KEY: Optional[str] = None
|
||||||
GEMINI_API_KEY: Optional[str] = None
|
GEMINI_API_KEY: Optional[str] = None
|
||||||
OPENROUTER_API_BASE: Optional[str] = "https://openrouter.ai/api/v1"
|
OPENROUTER_API_BASE: Optional[str] = "https://openrouter.ai/api/v1"
|
||||||
|
OPENAI_COMPATIBLE_API_KEY: Optional[str] = None
|
||||||
|
OPENAI_COMPATIBLE_API_BASE: Optional[str] = None
|
||||||
OR_SITE_URL: Optional[str] = "https://kortix.ai"
|
OR_SITE_URL: Optional[str] = "https://kortix.ai"
|
||||||
OR_APP_NAME: Optional[str] = "Kortix AI"
|
OR_APP_NAME: Optional[str] = "Kortix AI"
|
||||||
|
|
||||||
|
|
|
@ -162,6 +162,8 @@ OPENAI_API_KEY=your-openai-key
|
||||||
OPENROUTER_API_KEY=your-openrouter-key
|
OPENROUTER_API_KEY=your-openrouter-key
|
||||||
GEMINI_API_KEY=your-gemini-api-key
|
GEMINI_API_KEY=your-gemini-api-key
|
||||||
MORPH_API_KEY=
|
MORPH_API_KEY=
|
||||||
|
OPENAI_COMPATIBLE_API_KEY=your-openai-compatible-api-key
|
||||||
|
OPENAI_COMPATIBLE_API_BASE=your-openai-compatible-api-base
|
||||||
|
|
||||||
|
|
||||||
# WEB SEARCH
|
# WEB SEARCH
|
||||||
|
|
Loading…
Reference in New Issue