mirror of https://github.com/kortix-ai/suna.git
multi models backend support
This commit is contained in:
parent
ec4c85a433
commit
50b71c064a
|
@ -1,4 +1,4 @@
|
|||
from fastapi import APIRouter, HTTPException, Depends, Request
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request, Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
import asyncio
|
||||
import json
|
||||
|
@ -7,6 +7,7 @@ from datetime import datetime, timezone
|
|||
import uuid
|
||||
from typing import Optional, List, Dict, Any
|
||||
import jwt
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agentpress.thread_manager import ThreadManager
|
||||
from services.supabase import DBConnection
|
||||
|
@ -15,7 +16,13 @@ from agent.run import run_agent
|
|||
from utils.auth_utils import get_current_user_id, get_user_id_from_stream_auth, verify_thread_access
|
||||
from utils.logger import logger
|
||||
from utils.billing import check_billing_status, get_account_id_from_thread
|
||||
from utils.db import update_agent_run_status
|
||||
# Removed duplicate import of update_agent_run_status from utils.db as it's defined locally
|
||||
|
||||
# Define request model for starting agent
|
||||
class AgentStartRequest(BaseModel):
|
||||
model_name: Optional[str] = "anthropic/claude-3-7-sonnet-latest"
|
||||
enable_thinking: Optional[bool] = False
|
||||
reasoning_effort: Optional[str] = 'low'
|
||||
|
||||
# Initialize shared resources
|
||||
router = APIRouter()
|
||||
|
@ -236,9 +243,13 @@ async def _cleanup_agent_run(agent_run_id: str):
|
|||
# Non-fatal error, can continue
|
||||
|
||||
@router.post("/thread/{thread_id}/agent/start")
|
||||
async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id)):
|
||||
"""Start an agent for a specific thread in the background."""
|
||||
logger.info(f"Starting new agent for thread: {thread_id}")
|
||||
async def start_agent(
|
||||
thread_id: str,
|
||||
body: AgentStartRequest = Body(...), # Accept request body
|
||||
user_id: str = Depends(get_current_user_id)
|
||||
):
|
||||
"""Start an agent for a specific thread in the background with dynamic settings."""
|
||||
logger.info(f"Starting new agent for thread: {thread_id} with model: {body.model_name}, thinking: {body.enable_thinking}, effort: {body.reasoning_effort}")
|
||||
client = await db.client
|
||||
|
||||
# Verify user has access to this thread
|
||||
|
@ -291,11 +302,19 @@ async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id
|
|||
except Exception as e:
|
||||
logger.warning(f"Failed to register agent run in Redis, continuing without Redis tracking: {str(e)}")
|
||||
|
||||
# Run the agent in the background
|
||||
# Run the agent in the background, passing the dynamic settings
|
||||
task = asyncio.create_task(
|
||||
run_agent_background(agent_run_id, thread_id, instance_id, project_id)
|
||||
run_agent_background(
|
||||
agent_run_id=agent_run_id,
|
||||
thread_id=thread_id,
|
||||
instance_id=instance_id,
|
||||
project_id=project_id,
|
||||
model_name=body.model_name,
|
||||
enable_thinking=body.enable_thinking,
|
||||
reasoning_effort=body.reasoning_effort
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Set a callback to clean up when task is done
|
||||
task.add_done_callback(
|
||||
lambda _: asyncio.create_task(
|
||||
|
@ -420,9 +439,17 @@ async def stream_agent_run(
|
|||
}
|
||||
)
|
||||
|
||||
async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: str, project_id: str):
|
||||
async def run_agent_background(
|
||||
agent_run_id: str,
|
||||
thread_id: str,
|
||||
instance_id: str,
|
||||
project_id: str,
|
||||
model_name: str, # Add model_name parameter
|
||||
enable_thinking: Optional[bool], # Add enable_thinking parameter
|
||||
reasoning_effort: Optional[str] # Add reasoning_effort parameter
|
||||
):
|
||||
"""Run the agent in the background and handle status updates."""
|
||||
logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id})")
|
||||
logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id}) with model: {model_name}")
|
||||
client = await db.client
|
||||
|
||||
# Tracking variables
|
||||
|
@ -538,11 +565,18 @@ async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: s
|
|||
logger.warning(f"No stop signal checker for agent run: {agent_run_id} - pubsub unavailable")
|
||||
|
||||
try:
|
||||
# Run the agent
|
||||
# Run the agent, passing the dynamic settings
|
||||
logger.debug(f"Initializing agent generator for thread: {thread_id} (instance: {instance_id})")
|
||||
agent_gen = run_agent(thread_id, stream=True,
|
||||
thread_manager=thread_manager, project_id=project_id)
|
||||
|
||||
agent_gen = run_agent(
|
||||
thread_id=thread_id,
|
||||
project_id=project_id,
|
||||
stream=True,
|
||||
thread_manager=thread_manager,
|
||||
model_name=model_name, # Pass model_name
|
||||
enable_thinking=enable_thinking, # Pass enable_thinking
|
||||
reasoning_effort=reasoning_effort # Pass reasoning_effort
|
||||
)
|
||||
|
||||
# Collect all responses to save to database
|
||||
all_responses = []
|
||||
|
||||
|
@ -654,4 +688,4 @@ async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: s
|
|||
except Exception as e:
|
||||
logger.warning(f"Error deleting active run key: {str(e)}")
|
||||
|
||||
logger.info(f"Agent run background task fully completed for: {agent_run_id} (instance: {instance_id})")
|
||||
logger.info(f"Agent run background task fully completed for: {agent_run_id} (instance: {instance_id})")
|
||||
|
|
|
@ -20,9 +20,19 @@ from utils.billing import check_billing_status, get_account_id_from_thread
|
|||
|
||||
load_dotenv()
|
||||
|
||||
async def run_agent(thread_id: str, project_id: str, stream: bool = True, thread_manager: Optional[ThreadManager] = None, native_max_auto_continues: int = 25, max_iterations: int = 150):
|
||||
async def run_agent(
|
||||
thread_id: str,
|
||||
project_id: str,
|
||||
stream: bool = True,
|
||||
thread_manager: Optional[ThreadManager] = None,
|
||||
native_max_auto_continues: int = 25,
|
||||
max_iterations: int = 150,
|
||||
model_name: str = "anthropic/claude-3-7-sonnet-latest", # Add model_name parameter with default
|
||||
enable_thinking: Optional[bool] = False, # Add enable_thinking parameter
|
||||
reasoning_effort: Optional[str] = 'low' # Add reasoning_effort parameter
|
||||
):
|
||||
"""Run the development agent with specified configuration."""
|
||||
|
||||
|
||||
if not thread_manager:
|
||||
thread_manager = ThreadManager()
|
||||
client = await thread_manager.db.client
|
||||
|
@ -161,31 +171,29 @@ async def run_agent(thread_id: str, project_id: str, stream: bool = True, thread
|
|||
print(f"Error parsing browser state: {e}")
|
||||
# print(latest_browser_state.data[0])
|
||||
|
||||
# Determine model and max tokens
|
||||
model_to_use = os.getenv("MODEL_TO_USE", "anthropic/claude-3-7-sonnet-latest")
|
||||
# Determine max tokens based on the passed model_name
|
||||
max_tokens = None
|
||||
if model_to_use == "anthropic/claude-3-7-sonnet-latest":
|
||||
max_tokens = 64000
|
||||
if model_name == "anthropic/claude-3-7-sonnet-latest":
|
||||
max_tokens = 64000 # Example: Set max tokens for a specific model
|
||||
|
||||
# Run Thread
|
||||
# Run Thread, passing the dynamic settings
|
||||
response = await thread_manager.run_thread(
|
||||
thread_id=thread_id,
|
||||
system_prompt=system_message, # Pass the constructed message
|
||||
stream=stream,
|
||||
# stream=False,
|
||||
llm_model=model_to_use,
|
||||
# llm_temperature=0.1,
|
||||
llm_temperature=1,
|
||||
llm_model=model_name, # Use the passed model_name
|
||||
llm_temperature=1, # Example temperature
|
||||
llm_max_tokens=max_tokens, # Use the determined value
|
||||
tool_choice="auto",
|
||||
max_xml_tool_calls=1,
|
||||
temporary_message=temporary_message,
|
||||
processor_config=processor_config, # Pass the config object
|
||||
native_max_auto_continues=native_max_auto_continues,
|
||||
# Explicitly set include_xml_examples to False here
|
||||
include_xml_examples=False,
|
||||
include_xml_examples=False, # Explicitly set include_xml_examples to False here
|
||||
enable_thinking=enable_thinking, # Pass enable_thinking
|
||||
reasoning_effort=reasoning_effort # Pass reasoning_effort
|
||||
)
|
||||
|
||||
|
||||
if isinstance(response, dict) and "status" in response and response["status"] == "error":
|
||||
yield response
|
||||
break
|
||||
|
|
|
@ -154,9 +154,10 @@ class ThreadManager:
|
|||
native_max_auto_continues: int = 25,
|
||||
max_xml_tool_calls: int = 0,
|
||||
include_xml_examples: bool = False,
|
||||
enable_thinking: Optional[bool] = False, # Add enable_thinking parameter
|
||||
reasoning_effort: Optional[str] = 'low' # Add reasoning_effort parameter
|
||||
) -> Union[Dict[str, Any], AsyncGenerator]:
|
||||
"""Run a conversation thread with LLM integration and tool execution.
|
||||
|
||||
Args:
|
||||
thread_id: The ID of the thread to run
|
||||
system_prompt: System message to set the assistant's behavior
|
||||
|
@ -279,7 +280,9 @@ class ThreadManager:
|
|||
max_tokens=llm_max_tokens,
|
||||
tools=openapi_tool_schemas,
|
||||
tool_choice=tool_choice if processor_config.native_tool_calling else None,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
enable_thinking=enable_thinking, # Pass enable_thinking
|
||||
reasoning_effort=reasoning_effort # Pass reasoning_effort
|
||||
)
|
||||
logger.debug("Successfully received raw LLM API response stream/object")
|
||||
|
||||
|
|
|
@ -86,7 +86,10 @@ def prepare_params(
|
|||
api_base: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
top_p: Optional[float] = None,
|
||||
model_id: Optional[str] = None
|
||||
model_id: Optional[str] = None,
|
||||
# Add parameters for thinking/reasoning
|
||||
enable_thinking: Optional[bool] = None,
|
||||
reasoning_effort: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Prepare parameters for the API call."""
|
||||
params = {
|
||||
|
@ -157,17 +160,19 @@ def prepare_params(
|
|||
logger.debug(f"Auto-set model_id for Claude 3.7 Sonnet: {params['model_id']}")
|
||||
|
||||
# --- Add Anthropic Thinking/Reasoning Effort ---
|
||||
# Read environment variables for thinking/reasoning
|
||||
enable_thinking_env = os.environ.get('ENABLE_THINKING', 'false').lower() == 'true'
|
||||
# Determine if thinking should be enabled based on the passed parameter
|
||||
use_thinking = enable_thinking if enable_thinking is not None else False
|
||||
|
||||
# Check if the model is Anthropic
|
||||
is_anthropic = "sonnet-3-7" in model_name.lower() or "anthropic" in model_name.lower()
|
||||
|
||||
# Add reasoning_effort parameter if enabled and applicable
|
||||
if is_anthropic and enable_thinking_env:
|
||||
reasoning_effort_env = os.environ.get('REASONING_EFFORT', 'low') # Default to 'low'
|
||||
params["reasoning_effort"] = reasoning_effort_env
|
||||
logger.info(f"Anthropic thinking enabled with reasoning_effort='{reasoning_effort_env}'")
|
||||
if is_anthropic and use_thinking:
|
||||
# Determine reasoning effort based on the passed parameter, defaulting to 'low'
|
||||
effort_level = reasoning_effort if reasoning_effort else 'low'
|
||||
|
||||
params["reasoning_effort"] = effort_level
|
||||
logger.info(f"Anthropic thinking enabled with reasoning_effort='{effort_level}'")
|
||||
|
||||
# Anthropic requires temperature=1 when thinking/reasoning_effort is enabled
|
||||
params["temperature"] = 1.0
|
||||
|
@ -186,7 +191,10 @@ async def make_llm_api_call(
|
|||
api_base: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
top_p: Optional[float] = None,
|
||||
model_id: Optional[str] = None
|
||||
model_id: Optional[str] = None,
|
||||
# Add parameters for thinking/reasoning
|
||||
enable_thinking: Optional[bool] = None,
|
||||
reasoning_effort: Optional[str] = None
|
||||
) -> Union[Dict[str, Any], AsyncGenerator]:
|
||||
"""
|
||||
Make an API call to a language model using LiteLLM.
|
||||
|
@ -225,7 +233,10 @@ async def make_llm_api_call(
|
|||
api_base=api_base,
|
||||
stream=stream,
|
||||
top_p=top_p,
|
||||
model_id=model_id
|
||||
model_id=model_id,
|
||||
# Add parameters for thinking/reasoning
|
||||
enable_thinking=enable_thinking,
|
||||
reasoning_effort=reasoning_effort
|
||||
)
|
||||
|
||||
# Apply Anthropic prompt caching (minimal implementation)
|
||||
|
|
Loading…
Reference in New Issue