thinking & reasoning

This commit is contained in:
LE Quoc Dat 2025-04-18 05:49:41 +01:00
parent adc8036615
commit c84ee59dc6
5 changed files with 230 additions and 44 deletions

View File

@ -1,4 +1,4 @@
from fastapi import APIRouter, HTTPException, Depends, Request
from fastapi import APIRouter, HTTPException, Depends, Request, Body
from fastapi.responses import StreamingResponse
import asyncio
import json
@ -7,6 +7,7 @@ from datetime import datetime, timezone
import uuid
from typing import Optional, List, Dict, Any
import jwt
from pydantic import BaseModel
from agentpress.thread_manager import ThreadManager
from services.supabase import DBConnection
@ -26,6 +27,12 @@ db = None
# In-memory storage for active agent runs and their responses
active_agent_runs: Dict[str, List[Any]] = {}
class AgentStartRequest(BaseModel):
model_name: Optional[str] = "anthropic/claude-3-7-sonnet-latest"
enable_thinking: Optional[bool] = False
reasoning_effort: Optional[str] = 'low'
stream: Optional[bool] = False # Default stream to False for API
def initialize(
_thread_manager: ThreadManager,
_db: DBConnection,
@ -237,9 +244,13 @@ async def _cleanup_agent_run(agent_run_id: str):
# Non-fatal error, can continue
@router.post("/thread/{thread_id}/agent/start")
async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id)):
async def start_agent(
thread_id: str,
body: AgentStartRequest = Body(...), # Accept request body
user_id: str = Depends(get_current_user_id)
):
"""Start an agent for a specific thread in the background."""
logger.info(f"Starting new agent for thread: {thread_id}")
logger.info(f"Starting new agent for thread: {thread_id} with config: model={body.model_name}, thinking={body.enable_thinking}, effort={body.reasoning_effort}, stream={body.stream}")
client = await db.client
# Verify user has access to this thread
@ -314,7 +325,17 @@ async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id
# Run the agent in the background
task = asyncio.create_task(
run_agent_background(agent_run_id, thread_id, instance_id, project_id, sandbox)
run_agent_background(
agent_run_id=agent_run_id,
thread_id=thread_id,
instance_id=instance_id,
project_id=project_id,
sandbox=sandbox,
model_name=body.model_name,
enable_thinking=body.enable_thinking,
reasoning_effort=body.reasoning_effort,
stream=body.stream # Pass stream parameter
)
)
# Set a callback to clean up when task is done
@ -441,9 +462,19 @@ async def stream_agent_run(
}
)
async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: str, project_id: str, sandbox):
async def run_agent_background(
agent_run_id: str,
thread_id: str,
instance_id: str,
project_id: str,
sandbox,
model_name: str,
enable_thinking: Optional[bool],
reasoning_effort: Optional[str],
stream: bool # Add stream parameter
):
"""Run the agent in the background and handle status updates."""
logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id})")
logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id}) with model={model_name}, thinking={enable_thinking}, effort={reasoning_effort}, stream={stream}")
client = await db.client
# Tracking variables
@ -561,9 +592,16 @@ async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: s
try:
# Run the agent
logger.debug(f"Initializing agent generator for thread: {thread_id} (instance: {instance_id})")
agent_gen = run_agent(thread_id, stream=True,
thread_manager=thread_manager, project_id=project_id,
sandbox=sandbox)
agent_gen = run_agent(
thread_id=thread_id,
project_id=project_id,
stream=stream, # Pass stream parameter from API request
thread_manager=thread_manager,
sandbox=sandbox,
model_name=model_name, # Pass model_name
enable_thinking=enable_thinking, # Pass enable_thinking
reasoning_effort=reasoning_effort # Pass reasoning_effort
)
# Collect all responses to save to database
all_responses = []

View File

@ -21,7 +21,18 @@ from utils.billing import check_billing_status, get_account_id_from_thread
load_dotenv()
async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = True, thread_manager: Optional[ThreadManager] = None, native_max_auto_continues: int = 25, max_iterations: int = 150):
async def run_agent(
thread_id: str,
project_id: str,
sandbox,
stream: bool,
thread_manager: Optional[ThreadManager] = None,
native_max_auto_continues: int = 25,
max_iterations: int = 150,
model_name: str = "anthropic/claude-3-7-sonnet-latest",
enable_thinking: Optional[bool] = False,
reasoning_effort: Optional[str] = 'low'
):
"""Run the development agent with specified configuration."""
if not thread_manager:
@ -112,7 +123,7 @@ async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = Tru
thread_id=thread_id,
system_prompt=system_message,
stream=stream,
llm_model=os.getenv("MODEL_TO_USE", "anthropic/claude-3-7-sonnet-latest"),
llm_model=model_name,
llm_temperature=0,
llm_max_tokens=64000,
tool_choice="auto",
@ -128,6 +139,8 @@ async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = Tru
),
native_max_auto_continues=native_max_auto_continues,
include_xml_examples=True,
enable_thinking=enable_thinking,
reasoning_effort=reasoning_effort
)
if isinstance(response, dict) and "status" in response and response["status"] == "error":
@ -250,7 +263,15 @@ async def test_agent():
print("\n👋 Test completed. Goodbye!")
async def process_agent_response(thread_id: str, project_id: str, thread_manager: ThreadManager):
async def process_agent_response(
thread_id: str,
project_id: str,
thread_manager: ThreadManager,
stream: bool = True,
model_name: str = "anthropic/claude-3-7-sonnet-latest",
enable_thinking: Optional[bool] = False,
reasoning_effort: Optional[str] = 'low'
):
"""Process the streaming response from the agent."""
chunk_counter = 0
current_response = ""
@ -259,9 +280,19 @@ async def process_agent_response(thread_id: str, project_id: str, thread_manager
# Create a test sandbox for processing
sandbox_pass = str(uuid4())
sandbox = create_sandbox(sandbox_pass)
print(f"\033[91mTest sandbox created: {sandbox.get_preview_link(6080)}/vnc_lite.html?password={sandbox_pass}\033[0m")
print(f"\033[91mTest sandbox created: {str(sandbox.get_preview_link(6080))}/vnc_lite.html?password={sandbox_pass}\033[0m")
async for chunk in run_agent(thread_id=thread_id, project_id=project_id, sandbox=sandbox, stream=True, thread_manager=thread_manager, native_max_auto_continues=25):
async for chunk in run_agent(
thread_id=thread_id,
project_id=project_id,
sandbox=sandbox,
stream=stream,
thread_manager=thread_manager,
native_max_auto_continues=25,
model_name=model_name,
enable_thinking=enable_thinking,
reasoning_effort=reasoning_effort
):
chunk_counter += 1
if chunk.get('type') == 'content' and 'content' in chunk:

View File

@ -98,6 +98,8 @@ class ResponseProcessor:
llm_response: AsyncGenerator,
thread_id: str,
config: ProcessorConfig = ProcessorConfig(),
prompt_messages: Optional[List[Dict[str, Any]]] = None,
llm_model: Optional[str] = None
) -> AsyncGenerator:
"""Process a streaming LLM response, handling tool calls and execution.
@ -105,6 +107,8 @@ class ResponseProcessor:
llm_response: Streaming response from the LLM
thread_id: ID of the conversation thread
config: Configuration for parsing and execution
prompt_messages: List of messages used for cost calculation
llm_model: Name of the LLM model used for cost calculation
Yields:
Formatted chunks of the response including content and tool results
@ -175,15 +179,20 @@ class ResponseProcessor:
accumulated_content += chunk_content
current_xml_content += chunk_content
# Calculate cost using prompt and completion
try:
cost = completion_cost(model=chunk.model, prompt=accumulated_content, completion=chunk_content)
tcount = token_counter(model=chunk.model, messages=[{"role": "user", "content": accumulated_content}])
accumulated_cost += cost
accumulated_token_count += tcount
logger.debug(f"Cost: {cost:.6f}, Token count: {tcount}")
except Exception as e:
logger.error(f"Error calculating cost: {str(e)}")
# Process reasoning content if present (Anthropic)
if hasattr(delta, 'reasoning_content') and delta.reasoning_content:
logger.info(f"[THINKING]: {delta.reasoning_content}")
accumulated_content += delta.reasoning_content # Append reasoning to main content
# Calculate cost using prompt and completion - MOVED AFTER MESSAGE SAVE
# try:
# cost = completion_cost(model=chunk.model, prompt=accumulated_content, completion=chunk_content)
# tcount = token_counter(model=chunk.model, messages=[{"role": "user", "content": accumulated_content}])
# accumulated_cost += cost
# accumulated_token_count += tcount
# logger.debug(f"Cost: {cost:.6f}, Token count: {tcount}")
# except Exception as e:
# logger.error(f"Error calculating cost: {str(e)}")
# Check if we've reached the XML tool call limit before yielding content
if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls:
@ -489,6 +498,35 @@ class ResponseProcessor:
"thread_run_id": thread_run_id
}
# --- Cost Calculation (moved here) ---
if prompt_messages and llm_model and accumulated_content:
try:
cost = completion_cost(
model=llm_model,
messages=prompt_messages,
completion=accumulated_content
)
token_count = token_counter(
model=llm_model,
messages=prompt_messages + [{"role": "assistant", "content": accumulated_content}]
)
await self.add_message(
thread_id=thread_id,
type="cost",
content={
"cost": cost,
"prompt_tokens": token_count - token_counter(model=llm_model, messages=[{"role": "assistant", "content": accumulated_content}]), # Approx
"completion_tokens": token_counter(model=llm_model, messages=[{"role": "assistant", "content": accumulated_content}]), # Approx
"total_tokens": token_count,
"model_name": llm_model
},
is_llm_message=False
)
logger.info(f"Calculated cost for streaming response: {cost:.6f} using model {llm_model}")
except Exception as e:
logger.error(f"Error calculating cost: {str(e)}")
# --- End Cost Calculation ---
# --- Process All Tool Calls Now ---
if config.execute_tools:
final_tool_calls_to_process = []
@ -626,23 +664,24 @@ class ResponseProcessor:
yield {"type": "error", "message": str(e), "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None}
finally:
# Yield a finish signal including the final assistant message ID
if last_assistant_message_id:
# Yield the overall run end signal
# Yield the detected finish reason if one exists and wasn't suppressed
if finish_reason and finish_reason != "xml_tool_limit_reached":
yield {
"type": "thread_run_end",
"thread_run_id": thread_run_id
}
else:
# Yield the overall run end signal
yield {
"type": "thread_run_end",
"type": "finish",
"finish_reason": finish_reason,
"thread_run_id": thread_run_id if 'thread_run_id' in locals() else None
}
# Yield a finish signal including the final assistant message ID
# Ensure thread_run_id is defined, even if an early error occurred
run_id = thread_run_id if 'thread_run_id' in locals() else str(uuid.uuid4()) # Fallback ID if needed
yield {
"type": "thread_run_end",
"thread_run_id": run_id
}
# Remove old cost calculation code
pass
# track the cost and token count
# todo: there is a bug as it adds every chunk to db because finally will run every time even in yield
# await self.add_message(
# thread_id=thread_id,
# type="cost",
@ -658,7 +697,9 @@ class ResponseProcessor:
self,
llm_response: Any,
thread_id: str,
config: ProcessorConfig = ProcessorConfig()
config: ProcessorConfig = ProcessorConfig(),
prompt_messages: Optional[List[Dict[str, Any]]] = None,
llm_model: Optional[str] = None
) -> AsyncGenerator[Dict[str, Any], None]:
"""Process a non-streaming LLM response, handling tool calls and execution.
@ -666,6 +707,8 @@ class ResponseProcessor:
llm_response: Response from the LLM
thread_id: ID of the conversation thread
config: Configuration for parsing and execution
prompt_messages: List of messages used for cost calculation
llm_model: Name of the LLM model used for cost calculation
Yields:
Formatted response including content and tool results
@ -861,6 +904,50 @@ class ResponseProcessor:
"thread_run_id": thread_run_id
}
# --- Cost Calculation (moved here) ---
if prompt_messages and llm_model:
cost = None
# Attempt to get cost from LiteLLM response first
if hasattr(llm_response, '_hidden_params') and 'response_cost' in llm_response._hidden_params:
cost = llm_response._hidden_params['response_cost']
logger.info(f"Using pre-calculated cost from LiteLLM: {cost:.6f}")
# If no pre-calculated cost, calculate manually
if cost is None:
try:
cost = completion_cost(
model=llm_model,
messages=prompt_messages,
completion=content # Use extracted content
)
logger.info(f"Manually calculated cost for non-streaming response: {cost:.6f} using model {llm_model}")
except Exception as e:
logger.error(f"Error calculating cost: {str(e)}")
# Add cost message if cost was determined
if cost is not None:
try:
# Approximate token counts
completion_tokens = token_counter(model=llm_model, messages=[{"role": "assistant", "content": content}])
prompt_tokens = token_counter(model=llm_model, messages=prompt_messages)
total_tokens = prompt_tokens + completion_tokens
await self.add_message(
thread_id=thread_id,
type="cost",
content={
"cost": cost,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"model_name": llm_model
},
is_llm_message=False
)
except Exception as e:
logger.error(f"Error saving cost message: {str(e)}")
# --- End Cost Calculation ---
except Exception as e:
logger.error(f"Error processing response: {str(e)}", exc_info=True)
yield {"type": "error", "message": str(e), "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None}

View File

@ -162,6 +162,8 @@ class ThreadManager:
native_max_auto_continues: int = 25,
max_xml_tool_calls: int = 0,
include_xml_examples: bool = False,
enable_thinking: Optional[bool] = False,
reasoning_effort: Optional[str] = 'low'
) -> Union[Dict[str, Any], AsyncGenerator]:
"""Run a conversation thread with LLM integration and tool execution.
@ -179,6 +181,8 @@ class ThreadManager:
finish_reason="tool_calls" (0 disables auto-continue)
max_xml_tool_calls: Maximum number of XML tool calls to allow (0 = no limit)
include_xml_examples: Whether to include XML tool examples in the system prompt
enable_thinking: Whether to enable thinking before making a decision
reasoning_effort: The effort level for reasoning
Returns:
An async generator yielding response chunks or error dict
@ -313,7 +317,9 @@ Here are the XML tools available with examples:
max_tokens=llm_max_tokens,
tools=openapi_tool_schemas,
tool_choice=tool_choice if processor_config.native_tool_calling else None,
stream=stream
stream=stream,
enable_thinking=enable_thinking,
reasoning_effort=reasoning_effort
)
logger.debug("Successfully received raw LLM API response stream/object")
@ -327,7 +333,9 @@ Here are the XML tools available with examples:
response_generator = self.response_processor.process_streaming_response(
llm_response=llm_response,
thread_id=thread_id,
config=processor_config
config=processor_config,
prompt_messages=prepared_messages,
llm_model=llm_model
)
return response_generator
@ -338,7 +346,9 @@ Here are the XML tools available with examples:
response_generator = self.response_processor.process_non_streaming_response(
llm_response=llm_response,
thread_id=thread_id,
config=processor_config
config=processor_config,
prompt_messages=prepared_messages,
llm_model=llm_model
)
return response_generator # Return the generator
except Exception as e:

View File

@ -17,6 +17,8 @@ import asyncio
from openai import OpenAIError
import litellm
from utils.logger import logger
from datetime import datetime
import traceback
# litellm.set_verbose=True
litellm.modify_params=True
@ -82,7 +84,9 @@ def prepare_params(
api_base: Optional[str] = None,
stream: bool = False,
top_p: Optional[float] = None,
model_id: Optional[str] = None
model_id: Optional[str] = None,
enable_thinking: Optional[bool] = False,
reasoning_effort: Optional[str] = 'low'
) -> Dict[str, Any]:
"""Prepare parameters for the API call."""
params = {
@ -211,6 +215,16 @@ def prepare_params(
else:
logger.warning(f"Last user message (index {last_user_idx}) content is not a string or list ({type(content)}), skipping cache_control.")
# Add reasoning_effort for Anthropic models if enabled
use_thinking = enable_thinking if enable_thinking is not None else False
is_anthropic = "anthropic" in effective_model_name.lower() or "claude" in effective_model_name.lower()
if is_anthropic and use_thinking:
effort_level = reasoning_effort if reasoning_effort else 'low'
params["reasoning_effort"] = effort_level
params["temperature"] = 1.0 # Required by Anthropic when reasoning_effort is used
logger.info(f"Anthropic thinking enabled with reasoning_effort='{effort_level}'")
return params
async def make_llm_api_call(
@ -225,7 +239,9 @@ async def make_llm_api_call(
api_base: Optional[str] = None,
stream: bool = False,
top_p: Optional[float] = None,
model_id: Optional[str] = None
model_id: Optional[str] = None,
enable_thinking: Optional[bool] = False,
reasoning_effort: Optional[str] = 'low'
) -> Union[Dict[str, Any], AsyncGenerator]:
"""
Make an API call to a language model using LiteLLM.
@ -243,6 +259,8 @@ async def make_llm_api_call(
stream: Whether to stream the response
top_p: Top-p sampling parameter
model_id: Optional ARN for Bedrock inference profiles
enable_thinking: Whether to enable thinking
reasoning_effort: Level of reasoning effort
Returns:
Union[Dict[str, Any], AsyncGenerator]: API response or stream
@ -251,7 +269,7 @@ async def make_llm_api_call(
LLMRetryError: If API call fails after retries
LLMError: For other API-related errors
"""
logger.debug(f"Making LLM API call to model: {model_name}")
logger.debug(f"Making LLM API call to model: {model_name} (Thinking: {enable_thinking}, Effort: {reasoning_effort})")
params = prepare_params(
messages=messages,
model_name=model_name,
@ -264,7 +282,9 @@ async def make_llm_api_call(
api_base=api_base,
stream=stream,
top_p=top_p,
model_id=model_id
model_id=model_id,
enable_thinking=enable_thinking,
reasoning_effort=reasoning_effort
)
last_error = None