mirror of https://github.com/kortix-ai/suna.git
thinking & reasoning
This commit is contained in:
parent
adc8036615
commit
c84ee59dc6
|
@ -1,4 +1,4 @@
|
||||||
from fastapi import APIRouter, HTTPException, Depends, Request
|
from fastapi import APIRouter, HTTPException, Depends, Request, Body
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
@ -7,6 +7,7 @@ from datetime import datetime, timezone
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any
|
||||||
import jwt
|
import jwt
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from agentpress.thread_manager import ThreadManager
|
from agentpress.thread_manager import ThreadManager
|
||||||
from services.supabase import DBConnection
|
from services.supabase import DBConnection
|
||||||
|
@ -26,6 +27,12 @@ db = None
|
||||||
# In-memory storage for active agent runs and their responses
|
# In-memory storage for active agent runs and their responses
|
||||||
active_agent_runs: Dict[str, List[Any]] = {}
|
active_agent_runs: Dict[str, List[Any]] = {}
|
||||||
|
|
||||||
|
class AgentStartRequest(BaseModel):
|
||||||
|
model_name: Optional[str] = "anthropic/claude-3-7-sonnet-latest"
|
||||||
|
enable_thinking: Optional[bool] = False
|
||||||
|
reasoning_effort: Optional[str] = 'low'
|
||||||
|
stream: Optional[bool] = False # Default stream to False for API
|
||||||
|
|
||||||
def initialize(
|
def initialize(
|
||||||
_thread_manager: ThreadManager,
|
_thread_manager: ThreadManager,
|
||||||
_db: DBConnection,
|
_db: DBConnection,
|
||||||
|
@ -237,9 +244,13 @@ async def _cleanup_agent_run(agent_run_id: str):
|
||||||
# Non-fatal error, can continue
|
# Non-fatal error, can continue
|
||||||
|
|
||||||
@router.post("/thread/{thread_id}/agent/start")
|
@router.post("/thread/{thread_id}/agent/start")
|
||||||
async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id)):
|
async def start_agent(
|
||||||
|
thread_id: str,
|
||||||
|
body: AgentStartRequest = Body(...), # Accept request body
|
||||||
|
user_id: str = Depends(get_current_user_id)
|
||||||
|
):
|
||||||
"""Start an agent for a specific thread in the background."""
|
"""Start an agent for a specific thread in the background."""
|
||||||
logger.info(f"Starting new agent for thread: {thread_id}")
|
logger.info(f"Starting new agent for thread: {thread_id} with config: model={body.model_name}, thinking={body.enable_thinking}, effort={body.reasoning_effort}, stream={body.stream}")
|
||||||
client = await db.client
|
client = await db.client
|
||||||
|
|
||||||
# Verify user has access to this thread
|
# Verify user has access to this thread
|
||||||
|
@ -314,7 +325,17 @@ async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id
|
||||||
|
|
||||||
# Run the agent in the background
|
# Run the agent in the background
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
run_agent_background(agent_run_id, thread_id, instance_id, project_id, sandbox)
|
run_agent_background(
|
||||||
|
agent_run_id=agent_run_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
instance_id=instance_id,
|
||||||
|
project_id=project_id,
|
||||||
|
sandbox=sandbox,
|
||||||
|
model_name=body.model_name,
|
||||||
|
enable_thinking=body.enable_thinking,
|
||||||
|
reasoning_effort=body.reasoning_effort,
|
||||||
|
stream=body.stream # Pass stream parameter
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set a callback to clean up when task is done
|
# Set a callback to clean up when task is done
|
||||||
|
@ -441,9 +462,19 @@ async def stream_agent_run(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: str, project_id: str, sandbox):
|
async def run_agent_background(
|
||||||
|
agent_run_id: str,
|
||||||
|
thread_id: str,
|
||||||
|
instance_id: str,
|
||||||
|
project_id: str,
|
||||||
|
sandbox,
|
||||||
|
model_name: str,
|
||||||
|
enable_thinking: Optional[bool],
|
||||||
|
reasoning_effort: Optional[str],
|
||||||
|
stream: bool # Add stream parameter
|
||||||
|
):
|
||||||
"""Run the agent in the background and handle status updates."""
|
"""Run the agent in the background and handle status updates."""
|
||||||
logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id})")
|
logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id}) with model={model_name}, thinking={enable_thinking}, effort={reasoning_effort}, stream={stream}")
|
||||||
client = await db.client
|
client = await db.client
|
||||||
|
|
||||||
# Tracking variables
|
# Tracking variables
|
||||||
|
@ -561,9 +592,16 @@ async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: s
|
||||||
try:
|
try:
|
||||||
# Run the agent
|
# Run the agent
|
||||||
logger.debug(f"Initializing agent generator for thread: {thread_id} (instance: {instance_id})")
|
logger.debug(f"Initializing agent generator for thread: {thread_id} (instance: {instance_id})")
|
||||||
agent_gen = run_agent(thread_id, stream=True,
|
agent_gen = run_agent(
|
||||||
thread_manager=thread_manager, project_id=project_id,
|
thread_id=thread_id,
|
||||||
sandbox=sandbox)
|
project_id=project_id,
|
||||||
|
stream=stream, # Pass stream parameter from API request
|
||||||
|
thread_manager=thread_manager,
|
||||||
|
sandbox=sandbox,
|
||||||
|
model_name=model_name, # Pass model_name
|
||||||
|
enable_thinking=enable_thinking, # Pass enable_thinking
|
||||||
|
reasoning_effort=reasoning_effort # Pass reasoning_effort
|
||||||
|
)
|
||||||
|
|
||||||
# Collect all responses to save to database
|
# Collect all responses to save to database
|
||||||
all_responses = []
|
all_responses = []
|
||||||
|
|
|
@ -21,7 +21,18 @@ from utils.billing import check_billing_status, get_account_id_from_thread
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = True, thread_manager: Optional[ThreadManager] = None, native_max_auto_continues: int = 25, max_iterations: int = 150):
|
async def run_agent(
|
||||||
|
thread_id: str,
|
||||||
|
project_id: str,
|
||||||
|
sandbox,
|
||||||
|
stream: bool,
|
||||||
|
thread_manager: Optional[ThreadManager] = None,
|
||||||
|
native_max_auto_continues: int = 25,
|
||||||
|
max_iterations: int = 150,
|
||||||
|
model_name: str = "anthropic/claude-3-7-sonnet-latest",
|
||||||
|
enable_thinking: Optional[bool] = False,
|
||||||
|
reasoning_effort: Optional[str] = 'low'
|
||||||
|
):
|
||||||
"""Run the development agent with specified configuration."""
|
"""Run the development agent with specified configuration."""
|
||||||
|
|
||||||
if not thread_manager:
|
if not thread_manager:
|
||||||
|
@ -112,7 +123,7 @@ async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = Tru
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
system_prompt=system_message,
|
system_prompt=system_message,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
llm_model=os.getenv("MODEL_TO_USE", "anthropic/claude-3-7-sonnet-latest"),
|
llm_model=model_name,
|
||||||
llm_temperature=0,
|
llm_temperature=0,
|
||||||
llm_max_tokens=64000,
|
llm_max_tokens=64000,
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
|
@ -128,6 +139,8 @@ async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = Tru
|
||||||
),
|
),
|
||||||
native_max_auto_continues=native_max_auto_continues,
|
native_max_auto_continues=native_max_auto_continues,
|
||||||
include_xml_examples=True,
|
include_xml_examples=True,
|
||||||
|
enable_thinking=enable_thinking,
|
||||||
|
reasoning_effort=reasoning_effort
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(response, dict) and "status" in response and response["status"] == "error":
|
if isinstance(response, dict) and "status" in response and response["status"] == "error":
|
||||||
|
@ -250,7 +263,15 @@ async def test_agent():
|
||||||
|
|
||||||
print("\n👋 Test completed. Goodbye!")
|
print("\n👋 Test completed. Goodbye!")
|
||||||
|
|
||||||
async def process_agent_response(thread_id: str, project_id: str, thread_manager: ThreadManager):
|
async def process_agent_response(
|
||||||
|
thread_id: str,
|
||||||
|
project_id: str,
|
||||||
|
thread_manager: ThreadManager,
|
||||||
|
stream: bool = True,
|
||||||
|
model_name: str = "anthropic/claude-3-7-sonnet-latest",
|
||||||
|
enable_thinking: Optional[bool] = False,
|
||||||
|
reasoning_effort: Optional[str] = 'low'
|
||||||
|
):
|
||||||
"""Process the streaming response from the agent."""
|
"""Process the streaming response from the agent."""
|
||||||
chunk_counter = 0
|
chunk_counter = 0
|
||||||
current_response = ""
|
current_response = ""
|
||||||
|
@ -259,9 +280,19 @@ async def process_agent_response(thread_id: str, project_id: str, thread_manager
|
||||||
# Create a test sandbox for processing
|
# Create a test sandbox for processing
|
||||||
sandbox_pass = str(uuid4())
|
sandbox_pass = str(uuid4())
|
||||||
sandbox = create_sandbox(sandbox_pass)
|
sandbox = create_sandbox(sandbox_pass)
|
||||||
print(f"\033[91mTest sandbox created: {sandbox.get_preview_link(6080)}/vnc_lite.html?password={sandbox_pass}\033[0m")
|
print(f"\033[91mTest sandbox created: {str(sandbox.get_preview_link(6080))}/vnc_lite.html?password={sandbox_pass}\033[0m")
|
||||||
|
|
||||||
async for chunk in run_agent(thread_id=thread_id, project_id=project_id, sandbox=sandbox, stream=True, thread_manager=thread_manager, native_max_auto_continues=25):
|
async for chunk in run_agent(
|
||||||
|
thread_id=thread_id,
|
||||||
|
project_id=project_id,
|
||||||
|
sandbox=sandbox,
|
||||||
|
stream=stream,
|
||||||
|
thread_manager=thread_manager,
|
||||||
|
native_max_auto_continues=25,
|
||||||
|
model_name=model_name,
|
||||||
|
enable_thinking=enable_thinking,
|
||||||
|
reasoning_effort=reasoning_effort
|
||||||
|
):
|
||||||
chunk_counter += 1
|
chunk_counter += 1
|
||||||
|
|
||||||
if chunk.get('type') == 'content' and 'content' in chunk:
|
if chunk.get('type') == 'content' and 'content' in chunk:
|
||||||
|
|
|
@ -98,6 +98,8 @@ class ResponseProcessor:
|
||||||
llm_response: AsyncGenerator,
|
llm_response: AsyncGenerator,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
config: ProcessorConfig = ProcessorConfig(),
|
config: ProcessorConfig = ProcessorConfig(),
|
||||||
|
prompt_messages: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
llm_model: Optional[str] = None
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
"""Process a streaming LLM response, handling tool calls and execution.
|
"""Process a streaming LLM response, handling tool calls and execution.
|
||||||
|
|
||||||
|
@ -105,6 +107,8 @@ class ResponseProcessor:
|
||||||
llm_response: Streaming response from the LLM
|
llm_response: Streaming response from the LLM
|
||||||
thread_id: ID of the conversation thread
|
thread_id: ID of the conversation thread
|
||||||
config: Configuration for parsing and execution
|
config: Configuration for parsing and execution
|
||||||
|
prompt_messages: List of messages used for cost calculation
|
||||||
|
llm_model: Name of the LLM model used for cost calculation
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Formatted chunks of the response including content and tool results
|
Formatted chunks of the response including content and tool results
|
||||||
|
@ -175,15 +179,20 @@ class ResponseProcessor:
|
||||||
accumulated_content += chunk_content
|
accumulated_content += chunk_content
|
||||||
current_xml_content += chunk_content
|
current_xml_content += chunk_content
|
||||||
|
|
||||||
# Calculate cost using prompt and completion
|
# Process reasoning content if present (Anthropic)
|
||||||
try:
|
if hasattr(delta, 'reasoning_content') and delta.reasoning_content:
|
||||||
cost = completion_cost(model=chunk.model, prompt=accumulated_content, completion=chunk_content)
|
logger.info(f"[THINKING]: {delta.reasoning_content}")
|
||||||
tcount = token_counter(model=chunk.model, messages=[{"role": "user", "content": accumulated_content}])
|
accumulated_content += delta.reasoning_content # Append reasoning to main content
|
||||||
accumulated_cost += cost
|
|
||||||
accumulated_token_count += tcount
|
# Calculate cost using prompt and completion - MOVED AFTER MESSAGE SAVE
|
||||||
logger.debug(f"Cost: {cost:.6f}, Token count: {tcount}")
|
# try:
|
||||||
except Exception as e:
|
# cost = completion_cost(model=chunk.model, prompt=accumulated_content, completion=chunk_content)
|
||||||
logger.error(f"Error calculating cost: {str(e)}")
|
# tcount = token_counter(model=chunk.model, messages=[{"role": "user", "content": accumulated_content}])
|
||||||
|
# accumulated_cost += cost
|
||||||
|
# accumulated_token_count += tcount
|
||||||
|
# logger.debug(f"Cost: {cost:.6f}, Token count: {tcount}")
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(f"Error calculating cost: {str(e)}")
|
||||||
|
|
||||||
# Check if we've reached the XML tool call limit before yielding content
|
# Check if we've reached the XML tool call limit before yielding content
|
||||||
if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls:
|
if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls:
|
||||||
|
@ -489,6 +498,35 @@ class ResponseProcessor:
|
||||||
"thread_run_id": thread_run_id
|
"thread_run_id": thread_run_id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# --- Cost Calculation (moved here) ---
|
||||||
|
if prompt_messages and llm_model and accumulated_content:
|
||||||
|
try:
|
||||||
|
cost = completion_cost(
|
||||||
|
model=llm_model,
|
||||||
|
messages=prompt_messages,
|
||||||
|
completion=accumulated_content
|
||||||
|
)
|
||||||
|
token_count = token_counter(
|
||||||
|
model=llm_model,
|
||||||
|
messages=prompt_messages + [{"role": "assistant", "content": accumulated_content}]
|
||||||
|
)
|
||||||
|
await self.add_message(
|
||||||
|
thread_id=thread_id,
|
||||||
|
type="cost",
|
||||||
|
content={
|
||||||
|
"cost": cost,
|
||||||
|
"prompt_tokens": token_count - token_counter(model=llm_model, messages=[{"role": "assistant", "content": accumulated_content}]), # Approx
|
||||||
|
"completion_tokens": token_counter(model=llm_model, messages=[{"role": "assistant", "content": accumulated_content}]), # Approx
|
||||||
|
"total_tokens": token_count,
|
||||||
|
"model_name": llm_model
|
||||||
|
},
|
||||||
|
is_llm_message=False
|
||||||
|
)
|
||||||
|
logger.info(f"Calculated cost for streaming response: {cost:.6f} using model {llm_model}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error calculating cost: {str(e)}")
|
||||||
|
# --- End Cost Calculation ---
|
||||||
|
|
||||||
# --- Process All Tool Calls Now ---
|
# --- Process All Tool Calls Now ---
|
||||||
if config.execute_tools:
|
if config.execute_tools:
|
||||||
final_tool_calls_to_process = []
|
final_tool_calls_to_process = []
|
||||||
|
@ -626,23 +664,24 @@ class ResponseProcessor:
|
||||||
yield {"type": "error", "message": str(e), "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None}
|
yield {"type": "error", "message": str(e), "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None}
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Yield a finish signal including the final assistant message ID
|
# Yield the detected finish reason if one exists and wasn't suppressed
|
||||||
if last_assistant_message_id:
|
if finish_reason and finish_reason != "xml_tool_limit_reached":
|
||||||
# Yield the overall run end signal
|
|
||||||
yield {
|
yield {
|
||||||
"type": "thread_run_end",
|
"type": "finish",
|
||||||
"thread_run_id": thread_run_id
|
"finish_reason": finish_reason,
|
||||||
}
|
|
||||||
else:
|
|
||||||
# Yield the overall run end signal
|
|
||||||
yield {
|
|
||||||
"type": "thread_run_end",
|
|
||||||
"thread_run_id": thread_run_id if 'thread_run_id' in locals() else None
|
"thread_run_id": thread_run_id if 'thread_run_id' in locals() else None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Yield a finish signal including the final assistant message ID
|
||||||
|
# Ensure thread_run_id is defined, even if an early error occurred
|
||||||
|
run_id = thread_run_id if 'thread_run_id' in locals() else str(uuid.uuid4()) # Fallback ID if needed
|
||||||
|
yield {
|
||||||
|
"type": "thread_run_end",
|
||||||
|
"thread_run_id": run_id
|
||||||
|
}
|
||||||
|
|
||||||
|
# Remove old cost calculation code
|
||||||
pass
|
pass
|
||||||
# track the cost and token count
|
|
||||||
# todo: there is a bug as it adds every chunk to db because finally will run every time even in yield
|
|
||||||
# await self.add_message(
|
# await self.add_message(
|
||||||
# thread_id=thread_id,
|
# thread_id=thread_id,
|
||||||
# type="cost",
|
# type="cost",
|
||||||
|
@ -658,7 +697,9 @@ class ResponseProcessor:
|
||||||
self,
|
self,
|
||||||
llm_response: Any,
|
llm_response: Any,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
config: ProcessorConfig = ProcessorConfig()
|
config: ProcessorConfig = ProcessorConfig(),
|
||||||
|
prompt_messages: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
llm_model: Optional[str] = None
|
||||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||||
"""Process a non-streaming LLM response, handling tool calls and execution.
|
"""Process a non-streaming LLM response, handling tool calls and execution.
|
||||||
|
|
||||||
|
@ -666,6 +707,8 @@ class ResponseProcessor:
|
||||||
llm_response: Response from the LLM
|
llm_response: Response from the LLM
|
||||||
thread_id: ID of the conversation thread
|
thread_id: ID of the conversation thread
|
||||||
config: Configuration for parsing and execution
|
config: Configuration for parsing and execution
|
||||||
|
prompt_messages: List of messages used for cost calculation
|
||||||
|
llm_model: Name of the LLM model used for cost calculation
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Formatted response including content and tool results
|
Formatted response including content and tool results
|
||||||
|
@ -861,6 +904,50 @@ class ResponseProcessor:
|
||||||
"thread_run_id": thread_run_id
|
"thread_run_id": thread_run_id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# --- Cost Calculation (moved here) ---
|
||||||
|
if prompt_messages and llm_model:
|
||||||
|
cost = None
|
||||||
|
# Attempt to get cost from LiteLLM response first
|
||||||
|
if hasattr(llm_response, '_hidden_params') and 'response_cost' in llm_response._hidden_params:
|
||||||
|
cost = llm_response._hidden_params['response_cost']
|
||||||
|
logger.info(f"Using pre-calculated cost from LiteLLM: {cost:.6f}")
|
||||||
|
|
||||||
|
# If no pre-calculated cost, calculate manually
|
||||||
|
if cost is None:
|
||||||
|
try:
|
||||||
|
cost = completion_cost(
|
||||||
|
model=llm_model,
|
||||||
|
messages=prompt_messages,
|
||||||
|
completion=content # Use extracted content
|
||||||
|
)
|
||||||
|
logger.info(f"Manually calculated cost for non-streaming response: {cost:.6f} using model {llm_model}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error calculating cost: {str(e)}")
|
||||||
|
|
||||||
|
# Add cost message if cost was determined
|
||||||
|
if cost is not None:
|
||||||
|
try:
|
||||||
|
# Approximate token counts
|
||||||
|
completion_tokens = token_counter(model=llm_model, messages=[{"role": "assistant", "content": content}])
|
||||||
|
prompt_tokens = token_counter(model=llm_model, messages=prompt_messages)
|
||||||
|
total_tokens = prompt_tokens + completion_tokens
|
||||||
|
|
||||||
|
await self.add_message(
|
||||||
|
thread_id=thread_id,
|
||||||
|
type="cost",
|
||||||
|
content={
|
||||||
|
"cost": cost,
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": completion_tokens,
|
||||||
|
"total_tokens": total_tokens,
|
||||||
|
"model_name": llm_model
|
||||||
|
},
|
||||||
|
is_llm_message=False
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving cost message: {str(e)}")
|
||||||
|
# --- End Cost Calculation ---
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing response: {str(e)}", exc_info=True)
|
logger.error(f"Error processing response: {str(e)}", exc_info=True)
|
||||||
yield {"type": "error", "message": str(e), "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None}
|
yield {"type": "error", "message": str(e), "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None}
|
||||||
|
|
|
@ -162,6 +162,8 @@ class ThreadManager:
|
||||||
native_max_auto_continues: int = 25,
|
native_max_auto_continues: int = 25,
|
||||||
max_xml_tool_calls: int = 0,
|
max_xml_tool_calls: int = 0,
|
||||||
include_xml_examples: bool = False,
|
include_xml_examples: bool = False,
|
||||||
|
enable_thinking: Optional[bool] = False,
|
||||||
|
reasoning_effort: Optional[str] = 'low'
|
||||||
) -> Union[Dict[str, Any], AsyncGenerator]:
|
) -> Union[Dict[str, Any], AsyncGenerator]:
|
||||||
"""Run a conversation thread with LLM integration and tool execution.
|
"""Run a conversation thread with LLM integration and tool execution.
|
||||||
|
|
||||||
|
@ -179,6 +181,8 @@ class ThreadManager:
|
||||||
finish_reason="tool_calls" (0 disables auto-continue)
|
finish_reason="tool_calls" (0 disables auto-continue)
|
||||||
max_xml_tool_calls: Maximum number of XML tool calls to allow (0 = no limit)
|
max_xml_tool_calls: Maximum number of XML tool calls to allow (0 = no limit)
|
||||||
include_xml_examples: Whether to include XML tool examples in the system prompt
|
include_xml_examples: Whether to include XML tool examples in the system prompt
|
||||||
|
enable_thinking: Whether to enable thinking before making a decision
|
||||||
|
reasoning_effort: The effort level for reasoning
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An async generator yielding response chunks or error dict
|
An async generator yielding response chunks or error dict
|
||||||
|
@ -313,7 +317,9 @@ Here are the XML tools available with examples:
|
||||||
max_tokens=llm_max_tokens,
|
max_tokens=llm_max_tokens,
|
||||||
tools=openapi_tool_schemas,
|
tools=openapi_tool_schemas,
|
||||||
tool_choice=tool_choice if processor_config.native_tool_calling else None,
|
tool_choice=tool_choice if processor_config.native_tool_calling else None,
|
||||||
stream=stream
|
stream=stream,
|
||||||
|
enable_thinking=enable_thinking,
|
||||||
|
reasoning_effort=reasoning_effort
|
||||||
)
|
)
|
||||||
logger.debug("Successfully received raw LLM API response stream/object")
|
logger.debug("Successfully received raw LLM API response stream/object")
|
||||||
|
|
||||||
|
@ -327,7 +333,9 @@ Here are the XML tools available with examples:
|
||||||
response_generator = self.response_processor.process_streaming_response(
|
response_generator = self.response_processor.process_streaming_response(
|
||||||
llm_response=llm_response,
|
llm_response=llm_response,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
config=processor_config
|
config=processor_config,
|
||||||
|
prompt_messages=prepared_messages,
|
||||||
|
llm_model=llm_model
|
||||||
)
|
)
|
||||||
|
|
||||||
return response_generator
|
return response_generator
|
||||||
|
@ -338,7 +346,9 @@ Here are the XML tools available with examples:
|
||||||
response_generator = self.response_processor.process_non_streaming_response(
|
response_generator = self.response_processor.process_non_streaming_response(
|
||||||
llm_response=llm_response,
|
llm_response=llm_response,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
config=processor_config
|
config=processor_config,
|
||||||
|
prompt_messages=prepared_messages,
|
||||||
|
llm_model=llm_model
|
||||||
)
|
)
|
||||||
return response_generator # Return the generator
|
return response_generator # Return the generator
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -17,6 +17,8 @@ import asyncio
|
||||||
from openai import OpenAIError
|
from openai import OpenAIError
|
||||||
import litellm
|
import litellm
|
||||||
from utils.logger import logger
|
from utils.logger import logger
|
||||||
|
from datetime import datetime
|
||||||
|
import traceback
|
||||||
|
|
||||||
# litellm.set_verbose=True
|
# litellm.set_verbose=True
|
||||||
litellm.modify_params=True
|
litellm.modify_params=True
|
||||||
|
@ -82,7 +84,9 @@ def prepare_params(
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
model_id: Optional[str] = None
|
model_id: Optional[str] = None,
|
||||||
|
enable_thinking: Optional[bool] = False,
|
||||||
|
reasoning_effort: Optional[str] = 'low'
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Prepare parameters for the API call."""
|
"""Prepare parameters for the API call."""
|
||||||
params = {
|
params = {
|
||||||
|
@ -211,6 +215,16 @@ def prepare_params(
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Last user message (index {last_user_idx}) content is not a string or list ({type(content)}), skipping cache_control.")
|
logger.warning(f"Last user message (index {last_user_idx}) content is not a string or list ({type(content)}), skipping cache_control.")
|
||||||
|
|
||||||
|
# Add reasoning_effort for Anthropic models if enabled
|
||||||
|
use_thinking = enable_thinking if enable_thinking is not None else False
|
||||||
|
is_anthropic = "anthropic" in effective_model_name.lower() or "claude" in effective_model_name.lower()
|
||||||
|
|
||||||
|
if is_anthropic and use_thinking:
|
||||||
|
effort_level = reasoning_effort if reasoning_effort else 'low'
|
||||||
|
params["reasoning_effort"] = effort_level
|
||||||
|
params["temperature"] = 1.0 # Required by Anthropic when reasoning_effort is used
|
||||||
|
logger.info(f"Anthropic thinking enabled with reasoning_effort='{effort_level}'")
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
async def make_llm_api_call(
|
async def make_llm_api_call(
|
||||||
|
@ -225,7 +239,9 @@ async def make_llm_api_call(
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
model_id: Optional[str] = None
|
model_id: Optional[str] = None,
|
||||||
|
enable_thinking: Optional[bool] = False,
|
||||||
|
reasoning_effort: Optional[str] = 'low'
|
||||||
) -> Union[Dict[str, Any], AsyncGenerator]:
|
) -> Union[Dict[str, Any], AsyncGenerator]:
|
||||||
"""
|
"""
|
||||||
Make an API call to a language model using LiteLLM.
|
Make an API call to a language model using LiteLLM.
|
||||||
|
@ -243,6 +259,8 @@ async def make_llm_api_call(
|
||||||
stream: Whether to stream the response
|
stream: Whether to stream the response
|
||||||
top_p: Top-p sampling parameter
|
top_p: Top-p sampling parameter
|
||||||
model_id: Optional ARN for Bedrock inference profiles
|
model_id: Optional ARN for Bedrock inference profiles
|
||||||
|
enable_thinking: Whether to enable thinking
|
||||||
|
reasoning_effort: Level of reasoning effort
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Union[Dict[str, Any], AsyncGenerator]: API response or stream
|
Union[Dict[str, Any], AsyncGenerator]: API response or stream
|
||||||
|
@ -251,7 +269,7 @@ async def make_llm_api_call(
|
||||||
LLMRetryError: If API call fails after retries
|
LLMRetryError: If API call fails after retries
|
||||||
LLMError: For other API-related errors
|
LLMError: For other API-related errors
|
||||||
"""
|
"""
|
||||||
logger.debug(f"Making LLM API call to model: {model_name}")
|
logger.debug(f"Making LLM API call to model: {model_name} (Thinking: {enable_thinking}, Effort: {reasoning_effort})")
|
||||||
params = prepare_params(
|
params = prepare_params(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
|
@ -264,7 +282,9 @@ async def make_llm_api_call(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
model_id=model_id
|
model_id=model_id,
|
||||||
|
enable_thinking=enable_thinking,
|
||||||
|
reasoning_effort=reasoning_effort
|
||||||
)
|
)
|
||||||
|
|
||||||
last_error = None
|
last_error = None
|
||||||
|
|
Loading…
Reference in New Issue