mirror of https://github.com/kortix-ai/suna.git
multi models backend support
This commit is contained in:
parent
ec4c85a433
commit
50b71c064a
|
@ -1,4 +1,4 @@
|
||||||
from fastapi import APIRouter, HTTPException, Depends, Request
|
from fastapi import APIRouter, HTTPException, Depends, Request, Body
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
@ -7,6 +7,7 @@ from datetime import datetime, timezone
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any
|
||||||
import jwt
|
import jwt
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from agentpress.thread_manager import ThreadManager
|
from agentpress.thread_manager import ThreadManager
|
||||||
from services.supabase import DBConnection
|
from services.supabase import DBConnection
|
||||||
|
@ -15,7 +16,13 @@ from agent.run import run_agent
|
||||||
from utils.auth_utils import get_current_user_id, get_user_id_from_stream_auth, verify_thread_access
|
from utils.auth_utils import get_current_user_id, get_user_id_from_stream_auth, verify_thread_access
|
||||||
from utils.logger import logger
|
from utils.logger import logger
|
||||||
from utils.billing import check_billing_status, get_account_id_from_thread
|
from utils.billing import check_billing_status, get_account_id_from_thread
|
||||||
from utils.db import update_agent_run_status
|
# Removed duplicate import of update_agent_run_status from utils.db as it's defined locally
|
||||||
|
|
||||||
|
# Define request model for starting agent
|
||||||
|
class AgentStartRequest(BaseModel):
|
||||||
|
model_name: Optional[str] = "anthropic/claude-3-7-sonnet-latest"
|
||||||
|
enable_thinking: Optional[bool] = False
|
||||||
|
reasoning_effort: Optional[str] = 'low'
|
||||||
|
|
||||||
# Initialize shared resources
|
# Initialize shared resources
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
@ -236,9 +243,13 @@ async def _cleanup_agent_run(agent_run_id: str):
|
||||||
# Non-fatal error, can continue
|
# Non-fatal error, can continue
|
||||||
|
|
||||||
@router.post("/thread/{thread_id}/agent/start")
|
@router.post("/thread/{thread_id}/agent/start")
|
||||||
async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id)):
|
async def start_agent(
|
||||||
"""Start an agent for a specific thread in the background."""
|
thread_id: str,
|
||||||
logger.info(f"Starting new agent for thread: {thread_id}")
|
body: AgentStartRequest = Body(...), # Accept request body
|
||||||
|
user_id: str = Depends(get_current_user_id)
|
||||||
|
):
|
||||||
|
"""Start an agent for a specific thread in the background with dynamic settings."""
|
||||||
|
logger.info(f"Starting new agent for thread: {thread_id} with model: {body.model_name}, thinking: {body.enable_thinking}, effort: {body.reasoning_effort}")
|
||||||
client = await db.client
|
client = await db.client
|
||||||
|
|
||||||
# Verify user has access to this thread
|
# Verify user has access to this thread
|
||||||
|
@ -291,11 +302,19 @@ async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to register agent run in Redis, continuing without Redis tracking: {str(e)}")
|
logger.warning(f"Failed to register agent run in Redis, continuing without Redis tracking: {str(e)}")
|
||||||
|
|
||||||
# Run the agent in the background
|
# Run the agent in the background, passing the dynamic settings
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
run_agent_background(agent_run_id, thread_id, instance_id, project_id)
|
run_agent_background(
|
||||||
|
agent_run_id=agent_run_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
instance_id=instance_id,
|
||||||
|
project_id=project_id,
|
||||||
|
model_name=body.model_name,
|
||||||
|
enable_thinking=body.enable_thinking,
|
||||||
|
reasoning_effort=body.reasoning_effort
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set a callback to clean up when task is done
|
# Set a callback to clean up when task is done
|
||||||
task.add_done_callback(
|
task.add_done_callback(
|
||||||
lambda _: asyncio.create_task(
|
lambda _: asyncio.create_task(
|
||||||
|
@ -420,9 +439,17 @@ async def stream_agent_run(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: str, project_id: str):
|
async def run_agent_background(
|
||||||
|
agent_run_id: str,
|
||||||
|
thread_id: str,
|
||||||
|
instance_id: str,
|
||||||
|
project_id: str,
|
||||||
|
model_name: str, # Add model_name parameter
|
||||||
|
enable_thinking: Optional[bool], # Add enable_thinking parameter
|
||||||
|
reasoning_effort: Optional[str] # Add reasoning_effort parameter
|
||||||
|
):
|
||||||
"""Run the agent in the background and handle status updates."""
|
"""Run the agent in the background and handle status updates."""
|
||||||
logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id})")
|
logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id}) with model: {model_name}")
|
||||||
client = await db.client
|
client = await db.client
|
||||||
|
|
||||||
# Tracking variables
|
# Tracking variables
|
||||||
|
@ -538,11 +565,18 @@ async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: s
|
||||||
logger.warning(f"No stop signal checker for agent run: {agent_run_id} - pubsub unavailable")
|
logger.warning(f"No stop signal checker for agent run: {agent_run_id} - pubsub unavailable")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Run the agent
|
# Run the agent, passing the dynamic settings
|
||||||
logger.debug(f"Initializing agent generator for thread: {thread_id} (instance: {instance_id})")
|
logger.debug(f"Initializing agent generator for thread: {thread_id} (instance: {instance_id})")
|
||||||
agent_gen = run_agent(thread_id, stream=True,
|
agent_gen = run_agent(
|
||||||
thread_manager=thread_manager, project_id=project_id)
|
thread_id=thread_id,
|
||||||
|
project_id=project_id,
|
||||||
|
stream=True,
|
||||||
|
thread_manager=thread_manager,
|
||||||
|
model_name=model_name, # Pass model_name
|
||||||
|
enable_thinking=enable_thinking, # Pass enable_thinking
|
||||||
|
reasoning_effort=reasoning_effort # Pass reasoning_effort
|
||||||
|
)
|
||||||
|
|
||||||
# Collect all responses to save to database
|
# Collect all responses to save to database
|
||||||
all_responses = []
|
all_responses = []
|
||||||
|
|
||||||
|
@ -654,4 +688,4 @@ async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: s
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error deleting active run key: {str(e)}")
|
logger.warning(f"Error deleting active run key: {str(e)}")
|
||||||
|
|
||||||
logger.info(f"Agent run background task fully completed for: {agent_run_id} (instance: {instance_id})")
|
logger.info(f"Agent run background task fully completed for: {agent_run_id} (instance: {instance_id})")
|
||||||
|
|
|
@ -20,9 +20,19 @@ from utils.billing import check_billing_status, get_account_id_from_thread
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
async def run_agent(thread_id: str, project_id: str, stream: bool = True, thread_manager: Optional[ThreadManager] = None, native_max_auto_continues: int = 25, max_iterations: int = 150):
|
async def run_agent(
|
||||||
|
thread_id: str,
|
||||||
|
project_id: str,
|
||||||
|
stream: bool = True,
|
||||||
|
thread_manager: Optional[ThreadManager] = None,
|
||||||
|
native_max_auto_continues: int = 25,
|
||||||
|
max_iterations: int = 150,
|
||||||
|
model_name: str = "anthropic/claude-3-7-sonnet-latest", # Add model_name parameter with default
|
||||||
|
enable_thinking: Optional[bool] = False, # Add enable_thinking parameter
|
||||||
|
reasoning_effort: Optional[str] = 'low' # Add reasoning_effort parameter
|
||||||
|
):
|
||||||
"""Run the development agent with specified configuration."""
|
"""Run the development agent with specified configuration."""
|
||||||
|
|
||||||
if not thread_manager:
|
if not thread_manager:
|
||||||
thread_manager = ThreadManager()
|
thread_manager = ThreadManager()
|
||||||
client = await thread_manager.db.client
|
client = await thread_manager.db.client
|
||||||
|
@ -161,31 +171,29 @@ async def run_agent(thread_id: str, project_id: str, stream: bool = True, thread
|
||||||
print(f"Error parsing browser state: {e}")
|
print(f"Error parsing browser state: {e}")
|
||||||
# print(latest_browser_state.data[0])
|
# print(latest_browser_state.data[0])
|
||||||
|
|
||||||
# Determine model and max tokens
|
# Determine max tokens based on the passed model_name
|
||||||
model_to_use = os.getenv("MODEL_TO_USE", "anthropic/claude-3-7-sonnet-latest")
|
|
||||||
max_tokens = None
|
max_tokens = None
|
||||||
if model_to_use == "anthropic/claude-3-7-sonnet-latest":
|
if model_name == "anthropic/claude-3-7-sonnet-latest":
|
||||||
max_tokens = 64000
|
max_tokens = 64000 # Example: Set max tokens for a specific model
|
||||||
|
|
||||||
# Run Thread
|
# Run Thread, passing the dynamic settings
|
||||||
response = await thread_manager.run_thread(
|
response = await thread_manager.run_thread(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
system_prompt=system_message, # Pass the constructed message
|
system_prompt=system_message, # Pass the constructed message
|
||||||
stream=stream,
|
stream=stream,
|
||||||
# stream=False,
|
llm_model=model_name, # Use the passed model_name
|
||||||
llm_model=model_to_use,
|
llm_temperature=1, # Example temperature
|
||||||
# llm_temperature=0.1,
|
|
||||||
llm_temperature=1,
|
|
||||||
llm_max_tokens=max_tokens, # Use the determined value
|
llm_max_tokens=max_tokens, # Use the determined value
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
max_xml_tool_calls=1,
|
max_xml_tool_calls=1,
|
||||||
temporary_message=temporary_message,
|
temporary_message=temporary_message,
|
||||||
processor_config=processor_config, # Pass the config object
|
processor_config=processor_config, # Pass the config object
|
||||||
native_max_auto_continues=native_max_auto_continues,
|
native_max_auto_continues=native_max_auto_continues,
|
||||||
# Explicitly set include_xml_examples to False here
|
include_xml_examples=False, # Explicitly set include_xml_examples to False here
|
||||||
include_xml_examples=False,
|
enable_thinking=enable_thinking, # Pass enable_thinking
|
||||||
|
reasoning_effort=reasoning_effort # Pass reasoning_effort
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(response, dict) and "status" in response and response["status"] == "error":
|
if isinstance(response, dict) and "status" in response and response["status"] == "error":
|
||||||
yield response
|
yield response
|
||||||
break
|
break
|
||||||
|
|
|
@ -154,9 +154,10 @@ class ThreadManager:
|
||||||
native_max_auto_continues: int = 25,
|
native_max_auto_continues: int = 25,
|
||||||
max_xml_tool_calls: int = 0,
|
max_xml_tool_calls: int = 0,
|
||||||
include_xml_examples: bool = False,
|
include_xml_examples: bool = False,
|
||||||
|
enable_thinking: Optional[bool] = False, # Add enable_thinking parameter
|
||||||
|
reasoning_effort: Optional[str] = 'low' # Add reasoning_effort parameter
|
||||||
) -> Union[Dict[str, Any], AsyncGenerator]:
|
) -> Union[Dict[str, Any], AsyncGenerator]:
|
||||||
"""Run a conversation thread with LLM integration and tool execution.
|
"""Run a conversation thread with LLM integration and tool execution.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
thread_id: The ID of the thread to run
|
thread_id: The ID of the thread to run
|
||||||
system_prompt: System message to set the assistant's behavior
|
system_prompt: System message to set the assistant's behavior
|
||||||
|
@ -279,7 +280,9 @@ class ThreadManager:
|
||||||
max_tokens=llm_max_tokens,
|
max_tokens=llm_max_tokens,
|
||||||
tools=openapi_tool_schemas,
|
tools=openapi_tool_schemas,
|
||||||
tool_choice=tool_choice if processor_config.native_tool_calling else None,
|
tool_choice=tool_choice if processor_config.native_tool_calling else None,
|
||||||
stream=stream
|
stream=stream,
|
||||||
|
enable_thinking=enable_thinking, # Pass enable_thinking
|
||||||
|
reasoning_effort=reasoning_effort # Pass reasoning_effort
|
||||||
)
|
)
|
||||||
logger.debug("Successfully received raw LLM API response stream/object")
|
logger.debug("Successfully received raw LLM API response stream/object")
|
||||||
|
|
||||||
|
|
|
@ -86,7 +86,10 @@ def prepare_params(
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
model_id: Optional[str] = None
|
model_id: Optional[str] = None,
|
||||||
|
# Add parameters for thinking/reasoning
|
||||||
|
enable_thinking: Optional[bool] = None,
|
||||||
|
reasoning_effort: Optional[str] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Prepare parameters for the API call."""
|
"""Prepare parameters for the API call."""
|
||||||
params = {
|
params = {
|
||||||
|
@ -157,17 +160,19 @@ def prepare_params(
|
||||||
logger.debug(f"Auto-set model_id for Claude 3.7 Sonnet: {params['model_id']}")
|
logger.debug(f"Auto-set model_id for Claude 3.7 Sonnet: {params['model_id']}")
|
||||||
|
|
||||||
# --- Add Anthropic Thinking/Reasoning Effort ---
|
# --- Add Anthropic Thinking/Reasoning Effort ---
|
||||||
# Read environment variables for thinking/reasoning
|
# Determine if thinking should be enabled based on the passed parameter
|
||||||
enable_thinking_env = os.environ.get('ENABLE_THINKING', 'false').lower() == 'true'
|
use_thinking = enable_thinking if enable_thinking is not None else False
|
||||||
|
|
||||||
# Check if the model is Anthropic
|
# Check if the model is Anthropic
|
||||||
is_anthropic = "sonnet-3-7" in model_name.lower() or "anthropic" in model_name.lower()
|
is_anthropic = "sonnet-3-7" in model_name.lower() or "anthropic" in model_name.lower()
|
||||||
|
|
||||||
# Add reasoning_effort parameter if enabled and applicable
|
# Add reasoning_effort parameter if enabled and applicable
|
||||||
if is_anthropic and enable_thinking_env:
|
if is_anthropic and use_thinking:
|
||||||
reasoning_effort_env = os.environ.get('REASONING_EFFORT', 'low') # Default to 'low'
|
# Determine reasoning effort based on the passed parameter, defaulting to 'low'
|
||||||
params["reasoning_effort"] = reasoning_effort_env
|
effort_level = reasoning_effort if reasoning_effort else 'low'
|
||||||
logger.info(f"Anthropic thinking enabled with reasoning_effort='{reasoning_effort_env}'")
|
|
||||||
|
params["reasoning_effort"] = effort_level
|
||||||
|
logger.info(f"Anthropic thinking enabled with reasoning_effort='{effort_level}'")
|
||||||
|
|
||||||
# Anthropic requires temperature=1 when thinking/reasoning_effort is enabled
|
# Anthropic requires temperature=1 when thinking/reasoning_effort is enabled
|
||||||
params["temperature"] = 1.0
|
params["temperature"] = 1.0
|
||||||
|
@ -186,7 +191,10 @@ async def make_llm_api_call(
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
model_id: Optional[str] = None
|
model_id: Optional[str] = None,
|
||||||
|
# Add parameters for thinking/reasoning
|
||||||
|
enable_thinking: Optional[bool] = None,
|
||||||
|
reasoning_effort: Optional[str] = None
|
||||||
) -> Union[Dict[str, Any], AsyncGenerator]:
|
) -> Union[Dict[str, Any], AsyncGenerator]:
|
||||||
"""
|
"""
|
||||||
Make an API call to a language model using LiteLLM.
|
Make an API call to a language model using LiteLLM.
|
||||||
|
@ -225,7 +233,10 @@ async def make_llm_api_call(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
model_id=model_id
|
model_id=model_id,
|
||||||
|
# Add parameters for thinking/reasoning
|
||||||
|
enable_thinking=enable_thinking,
|
||||||
|
reasoning_effort=reasoning_effort
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply Anthropic prompt caching (minimal implementation)
|
# Apply Anthropic prompt caching (minimal implementation)
|
||||||
|
|
Loading…
Reference in New Issue