thinking & reasoning

This commit is contained in:
LE Quoc Dat 2025-04-18 05:49:41 +01:00
parent adc8036615
commit c84ee59dc6
5 changed files with 230 additions and 44 deletions

View File

@ -1,4 +1,4 @@
from fastapi import APIRouter, HTTPException, Depends, Request from fastapi import APIRouter, HTTPException, Depends, Request, Body
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
import asyncio import asyncio
import json import json
@ -7,6 +7,7 @@ from datetime import datetime, timezone
import uuid import uuid
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
import jwt import jwt
from pydantic import BaseModel
from agentpress.thread_manager import ThreadManager from agentpress.thread_manager import ThreadManager
from services.supabase import DBConnection from services.supabase import DBConnection
@ -26,6 +27,12 @@ db = None
# In-memory storage for active agent runs and their responses # In-memory storage for active agent runs and their responses
active_agent_runs: Dict[str, List[Any]] = {} active_agent_runs: Dict[str, List[Any]] = {}
class AgentStartRequest(BaseModel):
model_name: Optional[str] = "anthropic/claude-3-7-sonnet-latest"
enable_thinking: Optional[bool] = False
reasoning_effort: Optional[str] = 'low'
stream: Optional[bool] = False # Default stream to False for API
def initialize( def initialize(
_thread_manager: ThreadManager, _thread_manager: ThreadManager,
_db: DBConnection, _db: DBConnection,
@ -237,9 +244,13 @@ async def _cleanup_agent_run(agent_run_id: str):
# Non-fatal error, can continue # Non-fatal error, can continue
@router.post("/thread/{thread_id}/agent/start") @router.post("/thread/{thread_id}/agent/start")
async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id)): async def start_agent(
thread_id: str,
body: AgentStartRequest = Body(...), # Accept request body
user_id: str = Depends(get_current_user_id)
):
"""Start an agent for a specific thread in the background.""" """Start an agent for a specific thread in the background."""
logger.info(f"Starting new agent for thread: {thread_id}") logger.info(f"Starting new agent for thread: {thread_id} with config: model={body.model_name}, thinking={body.enable_thinking}, effort={body.reasoning_effort}, stream={body.stream}")
client = await db.client client = await db.client
# Verify user has access to this thread # Verify user has access to this thread
@ -314,7 +325,17 @@ async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id
# Run the agent in the background # Run the agent in the background
task = asyncio.create_task( task = asyncio.create_task(
run_agent_background(agent_run_id, thread_id, instance_id, project_id, sandbox) run_agent_background(
agent_run_id=agent_run_id,
thread_id=thread_id,
instance_id=instance_id,
project_id=project_id,
sandbox=sandbox,
model_name=body.model_name,
enable_thinking=body.enable_thinking,
reasoning_effort=body.reasoning_effort,
stream=body.stream # Pass stream parameter
)
) )
# Set a callback to clean up when task is done # Set a callback to clean up when task is done
@ -441,9 +462,19 @@ async def stream_agent_run(
} }
) )
async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: str, project_id: str, sandbox): async def run_agent_background(
agent_run_id: str,
thread_id: str,
instance_id: str,
project_id: str,
sandbox,
model_name: str,
enable_thinking: Optional[bool],
reasoning_effort: Optional[str],
stream: bool # Add stream parameter
):
"""Run the agent in the background and handle status updates.""" """Run the agent in the background and handle status updates."""
logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id})") logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id}) with model={model_name}, thinking={enable_thinking}, effort={reasoning_effort}, stream={stream}")
client = await db.client client = await db.client
# Tracking variables # Tracking variables
@ -561,9 +592,16 @@ async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: s
try: try:
# Run the agent # Run the agent
logger.debug(f"Initializing agent generator for thread: {thread_id} (instance: {instance_id})") logger.debug(f"Initializing agent generator for thread: {thread_id} (instance: {instance_id})")
agent_gen = run_agent(thread_id, stream=True, agent_gen = run_agent(
thread_manager=thread_manager, project_id=project_id, thread_id=thread_id,
sandbox=sandbox) project_id=project_id,
stream=stream, # Pass stream parameter from API request
thread_manager=thread_manager,
sandbox=sandbox,
model_name=model_name, # Pass model_name
enable_thinking=enable_thinking, # Pass enable_thinking
reasoning_effort=reasoning_effort # Pass reasoning_effort
)
# Collect all responses to save to database # Collect all responses to save to database
all_responses = [] all_responses = []

View File

@ -21,7 +21,18 @@ from utils.billing import check_billing_status, get_account_id_from_thread
load_dotenv() load_dotenv()
async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = True, thread_manager: Optional[ThreadManager] = None, native_max_auto_continues: int = 25, max_iterations: int = 150): async def run_agent(
thread_id: str,
project_id: str,
sandbox,
stream: bool,
thread_manager: Optional[ThreadManager] = None,
native_max_auto_continues: int = 25,
max_iterations: int = 150,
model_name: str = "anthropic/claude-3-7-sonnet-latest",
enable_thinking: Optional[bool] = False,
reasoning_effort: Optional[str] = 'low'
):
"""Run the development agent with specified configuration.""" """Run the development agent with specified configuration."""
if not thread_manager: if not thread_manager:
@ -112,7 +123,7 @@ async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = Tru
thread_id=thread_id, thread_id=thread_id,
system_prompt=system_message, system_prompt=system_message,
stream=stream, stream=stream,
llm_model=os.getenv("MODEL_TO_USE", "anthropic/claude-3-7-sonnet-latest"), llm_model=model_name,
llm_temperature=0, llm_temperature=0,
llm_max_tokens=64000, llm_max_tokens=64000,
tool_choice="auto", tool_choice="auto",
@ -128,6 +139,8 @@ async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = Tru
), ),
native_max_auto_continues=native_max_auto_continues, native_max_auto_continues=native_max_auto_continues,
include_xml_examples=True, include_xml_examples=True,
enable_thinking=enable_thinking,
reasoning_effort=reasoning_effort
) )
if isinstance(response, dict) and "status" in response and response["status"] == "error": if isinstance(response, dict) and "status" in response and response["status"] == "error":
@ -250,7 +263,15 @@ async def test_agent():
print("\n👋 Test completed. Goodbye!") print("\n👋 Test completed. Goodbye!")
async def process_agent_response(thread_id: str, project_id: str, thread_manager: ThreadManager): async def process_agent_response(
thread_id: str,
project_id: str,
thread_manager: ThreadManager,
stream: bool = True,
model_name: str = "anthropic/claude-3-7-sonnet-latest",
enable_thinking: Optional[bool] = False,
reasoning_effort: Optional[str] = 'low'
):
"""Process the streaming response from the agent.""" """Process the streaming response from the agent."""
chunk_counter = 0 chunk_counter = 0
current_response = "" current_response = ""
@ -259,9 +280,19 @@ async def process_agent_response(thread_id: str, project_id: str, thread_manager
# Create a test sandbox for processing # Create a test sandbox for processing
sandbox_pass = str(uuid4()) sandbox_pass = str(uuid4())
sandbox = create_sandbox(sandbox_pass) sandbox = create_sandbox(sandbox_pass)
print(f"\033[91mTest sandbox created: {sandbox.get_preview_link(6080)}/vnc_lite.html?password={sandbox_pass}\033[0m") print(f"\033[91mTest sandbox created: {str(sandbox.get_preview_link(6080))}/vnc_lite.html?password={sandbox_pass}\033[0m")
async for chunk in run_agent(thread_id=thread_id, project_id=project_id, sandbox=sandbox, stream=True, thread_manager=thread_manager, native_max_auto_continues=25): async for chunk in run_agent(
thread_id=thread_id,
project_id=project_id,
sandbox=sandbox,
stream=stream,
thread_manager=thread_manager,
native_max_auto_continues=25,
model_name=model_name,
enable_thinking=enable_thinking,
reasoning_effort=reasoning_effort
):
chunk_counter += 1 chunk_counter += 1
if chunk.get('type') == 'content' and 'content' in chunk: if chunk.get('type') == 'content' and 'content' in chunk:

View File

@ -98,6 +98,8 @@ class ResponseProcessor:
llm_response: AsyncGenerator, llm_response: AsyncGenerator,
thread_id: str, thread_id: str,
config: ProcessorConfig = ProcessorConfig(), config: ProcessorConfig = ProcessorConfig(),
prompt_messages: Optional[List[Dict[str, Any]]] = None,
llm_model: Optional[str] = None
) -> AsyncGenerator: ) -> AsyncGenerator:
"""Process a streaming LLM response, handling tool calls and execution. """Process a streaming LLM response, handling tool calls and execution.
@ -105,6 +107,8 @@ class ResponseProcessor:
llm_response: Streaming response from the LLM llm_response: Streaming response from the LLM
thread_id: ID of the conversation thread thread_id: ID of the conversation thread
config: Configuration for parsing and execution config: Configuration for parsing and execution
prompt_messages: List of messages used for cost calculation
llm_model: Name of the LLM model used for cost calculation
Yields: Yields:
Formatted chunks of the response including content and tool results Formatted chunks of the response including content and tool results
@ -175,15 +179,20 @@ class ResponseProcessor:
accumulated_content += chunk_content accumulated_content += chunk_content
current_xml_content += chunk_content current_xml_content += chunk_content
# Calculate cost using prompt and completion # Process reasoning content if present (Anthropic)
try: if hasattr(delta, 'reasoning_content') and delta.reasoning_content:
cost = completion_cost(model=chunk.model, prompt=accumulated_content, completion=chunk_content) logger.info(f"[THINKING]: {delta.reasoning_content}")
tcount = token_counter(model=chunk.model, messages=[{"role": "user", "content": accumulated_content}]) accumulated_content += delta.reasoning_content # Append reasoning to main content
accumulated_cost += cost
accumulated_token_count += tcount # Calculate cost using prompt and completion - MOVED AFTER MESSAGE SAVE
logger.debug(f"Cost: {cost:.6f}, Token count: {tcount}") # try:
except Exception as e: # cost = completion_cost(model=chunk.model, prompt=accumulated_content, completion=chunk_content)
logger.error(f"Error calculating cost: {str(e)}") # tcount = token_counter(model=chunk.model, messages=[{"role": "user", "content": accumulated_content}])
# accumulated_cost += cost
# accumulated_token_count += tcount
# logger.debug(f"Cost: {cost:.6f}, Token count: {tcount}")
# except Exception as e:
# logger.error(f"Error calculating cost: {str(e)}")
# Check if we've reached the XML tool call limit before yielding content # Check if we've reached the XML tool call limit before yielding content
if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls: if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls:
@ -489,6 +498,35 @@ class ResponseProcessor:
"thread_run_id": thread_run_id "thread_run_id": thread_run_id
} }
# --- Cost Calculation (moved here) ---
if prompt_messages and llm_model and accumulated_content:
try:
cost = completion_cost(
model=llm_model,
messages=prompt_messages,
completion=accumulated_content
)
token_count = token_counter(
model=llm_model,
messages=prompt_messages + [{"role": "assistant", "content": accumulated_content}]
)
await self.add_message(
thread_id=thread_id,
type="cost",
content={
"cost": cost,
"prompt_tokens": token_count - token_counter(model=llm_model, messages=[{"role": "assistant", "content": accumulated_content}]), # Approx
"completion_tokens": token_counter(model=llm_model, messages=[{"role": "assistant", "content": accumulated_content}]), # Approx
"total_tokens": token_count,
"model_name": llm_model
},
is_llm_message=False
)
logger.info(f"Calculated cost for streaming response: {cost:.6f} using model {llm_model}")
except Exception as e:
logger.error(f"Error calculating cost: {str(e)}")
# --- End Cost Calculation ---
# --- Process All Tool Calls Now --- # --- Process All Tool Calls Now ---
if config.execute_tools: if config.execute_tools:
final_tool_calls_to_process = [] final_tool_calls_to_process = []
@ -626,23 +664,24 @@ class ResponseProcessor:
yield {"type": "error", "message": str(e), "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None} yield {"type": "error", "message": str(e), "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None}
finally: finally:
# Yield a finish signal including the final assistant message ID # Yield the detected finish reason if one exists and wasn't suppressed
if last_assistant_message_id: if finish_reason and finish_reason != "xml_tool_limit_reached":
# Yield the overall run end signal
yield { yield {
"type": "thread_run_end", "type": "finish",
"thread_run_id": thread_run_id "finish_reason": finish_reason,
}
else:
# Yield the overall run end signal
yield {
"type": "thread_run_end",
"thread_run_id": thread_run_id if 'thread_run_id' in locals() else None "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None
} }
# Yield a finish signal including the final assistant message ID
# Ensure thread_run_id is defined, even if an early error occurred
run_id = thread_run_id if 'thread_run_id' in locals() else str(uuid.uuid4()) # Fallback ID if needed
yield {
"type": "thread_run_end",
"thread_run_id": run_id
}
# Remove old cost calculation code
pass pass
# track the cost and token count
# todo: there is a bug as it adds every chunk to db because finally will run every time even in yield
# await self.add_message( # await self.add_message(
# thread_id=thread_id, # thread_id=thread_id,
# type="cost", # type="cost",
@ -658,7 +697,9 @@ class ResponseProcessor:
self, self,
llm_response: Any, llm_response: Any,
thread_id: str, thread_id: str,
config: ProcessorConfig = ProcessorConfig() config: ProcessorConfig = ProcessorConfig(),
prompt_messages: Optional[List[Dict[str, Any]]] = None,
llm_model: Optional[str] = None
) -> AsyncGenerator[Dict[str, Any], None]: ) -> AsyncGenerator[Dict[str, Any], None]:
"""Process a non-streaming LLM response, handling tool calls and execution. """Process a non-streaming LLM response, handling tool calls and execution.
@ -666,6 +707,8 @@ class ResponseProcessor:
llm_response: Response from the LLM llm_response: Response from the LLM
thread_id: ID of the conversation thread thread_id: ID of the conversation thread
config: Configuration for parsing and execution config: Configuration for parsing and execution
prompt_messages: List of messages used for cost calculation
llm_model: Name of the LLM model used for cost calculation
Yields: Yields:
Formatted response including content and tool results Formatted response including content and tool results
@ -861,6 +904,50 @@ class ResponseProcessor:
"thread_run_id": thread_run_id "thread_run_id": thread_run_id
} }
# --- Cost Calculation (moved here) ---
if prompt_messages and llm_model:
cost = None
# Attempt to get cost from LiteLLM response first
if hasattr(llm_response, '_hidden_params') and 'response_cost' in llm_response._hidden_params:
cost = llm_response._hidden_params['response_cost']
logger.info(f"Using pre-calculated cost from LiteLLM: {cost:.6f}")
# If no pre-calculated cost, calculate manually
if cost is None:
try:
cost = completion_cost(
model=llm_model,
messages=prompt_messages,
completion=content # Use extracted content
)
logger.info(f"Manually calculated cost for non-streaming response: {cost:.6f} using model {llm_model}")
except Exception as e:
logger.error(f"Error calculating cost: {str(e)}")
# Add cost message if cost was determined
if cost is not None:
try:
# Approximate token counts
completion_tokens = token_counter(model=llm_model, messages=[{"role": "assistant", "content": content}])
prompt_tokens = token_counter(model=llm_model, messages=prompt_messages)
total_tokens = prompt_tokens + completion_tokens
await self.add_message(
thread_id=thread_id,
type="cost",
content={
"cost": cost,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"model_name": llm_model
},
is_llm_message=False
)
except Exception as e:
logger.error(f"Error saving cost message: {str(e)}")
# --- End Cost Calculation ---
except Exception as e: except Exception as e:
logger.error(f"Error processing response: {str(e)}", exc_info=True) logger.error(f"Error processing response: {str(e)}", exc_info=True)
yield {"type": "error", "message": str(e), "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None} yield {"type": "error", "message": str(e), "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None}

View File

@ -162,6 +162,8 @@ class ThreadManager:
native_max_auto_continues: int = 25, native_max_auto_continues: int = 25,
max_xml_tool_calls: int = 0, max_xml_tool_calls: int = 0,
include_xml_examples: bool = False, include_xml_examples: bool = False,
enable_thinking: Optional[bool] = False,
reasoning_effort: Optional[str] = 'low'
) -> Union[Dict[str, Any], AsyncGenerator]: ) -> Union[Dict[str, Any], AsyncGenerator]:
"""Run a conversation thread with LLM integration and tool execution. """Run a conversation thread with LLM integration and tool execution.
@ -179,6 +181,8 @@ class ThreadManager:
finish_reason="tool_calls" (0 disables auto-continue) finish_reason="tool_calls" (0 disables auto-continue)
max_xml_tool_calls: Maximum number of XML tool calls to allow (0 = no limit) max_xml_tool_calls: Maximum number of XML tool calls to allow (0 = no limit)
include_xml_examples: Whether to include XML tool examples in the system prompt include_xml_examples: Whether to include XML tool examples in the system prompt
enable_thinking: Whether to enable thinking before making a decision
reasoning_effort: The effort level for reasoning
Returns: Returns:
An async generator yielding response chunks or error dict An async generator yielding response chunks or error dict
@ -313,7 +317,9 @@ Here are the XML tools available with examples:
max_tokens=llm_max_tokens, max_tokens=llm_max_tokens,
tools=openapi_tool_schemas, tools=openapi_tool_schemas,
tool_choice=tool_choice if processor_config.native_tool_calling else None, tool_choice=tool_choice if processor_config.native_tool_calling else None,
stream=stream stream=stream,
enable_thinking=enable_thinking,
reasoning_effort=reasoning_effort
) )
logger.debug("Successfully received raw LLM API response stream/object") logger.debug("Successfully received raw LLM API response stream/object")
@ -327,7 +333,9 @@ Here are the XML tools available with examples:
response_generator = self.response_processor.process_streaming_response( response_generator = self.response_processor.process_streaming_response(
llm_response=llm_response, llm_response=llm_response,
thread_id=thread_id, thread_id=thread_id,
config=processor_config config=processor_config,
prompt_messages=prepared_messages,
llm_model=llm_model
) )
return response_generator return response_generator
@ -338,7 +346,9 @@ Here are the XML tools available with examples:
response_generator = self.response_processor.process_non_streaming_response( response_generator = self.response_processor.process_non_streaming_response(
llm_response=llm_response, llm_response=llm_response,
thread_id=thread_id, thread_id=thread_id,
config=processor_config config=processor_config,
prompt_messages=prepared_messages,
llm_model=llm_model
) )
return response_generator # Return the generator return response_generator # Return the generator
except Exception as e: except Exception as e:

View File

@ -17,6 +17,8 @@ import asyncio
from openai import OpenAIError from openai import OpenAIError
import litellm import litellm
from utils.logger import logger from utils.logger import logger
from datetime import datetime
import traceback
# litellm.set_verbose=True # litellm.set_verbose=True
litellm.modify_params=True litellm.modify_params=True
@ -82,7 +84,9 @@ def prepare_params(
api_base: Optional[str] = None, api_base: Optional[str] = None,
stream: bool = False, stream: bool = False,
top_p: Optional[float] = None, top_p: Optional[float] = None,
model_id: Optional[str] = None model_id: Optional[str] = None,
enable_thinking: Optional[bool] = False,
reasoning_effort: Optional[str] = 'low'
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Prepare parameters for the API call.""" """Prepare parameters for the API call."""
params = { params = {
@ -211,6 +215,16 @@ def prepare_params(
else: else:
logger.warning(f"Last user message (index {last_user_idx}) content is not a string or list ({type(content)}), skipping cache_control.") logger.warning(f"Last user message (index {last_user_idx}) content is not a string or list ({type(content)}), skipping cache_control.")
# Add reasoning_effort for Anthropic models if enabled
use_thinking = enable_thinking if enable_thinking is not None else False
is_anthropic = "anthropic" in effective_model_name.lower() or "claude" in effective_model_name.lower()
if is_anthropic and use_thinking:
effort_level = reasoning_effort if reasoning_effort else 'low'
params["reasoning_effort"] = effort_level
params["temperature"] = 1.0 # Required by Anthropic when reasoning_effort is used
logger.info(f"Anthropic thinking enabled with reasoning_effort='{effort_level}'")
return params return params
async def make_llm_api_call( async def make_llm_api_call(
@ -225,7 +239,9 @@ async def make_llm_api_call(
api_base: Optional[str] = None, api_base: Optional[str] = None,
stream: bool = False, stream: bool = False,
top_p: Optional[float] = None, top_p: Optional[float] = None,
model_id: Optional[str] = None model_id: Optional[str] = None,
enable_thinking: Optional[bool] = False,
reasoning_effort: Optional[str] = 'low'
) -> Union[Dict[str, Any], AsyncGenerator]: ) -> Union[Dict[str, Any], AsyncGenerator]:
""" """
Make an API call to a language model using LiteLLM. Make an API call to a language model using LiteLLM.
@ -243,6 +259,8 @@ async def make_llm_api_call(
stream: Whether to stream the response stream: Whether to stream the response
top_p: Top-p sampling parameter top_p: Top-p sampling parameter
model_id: Optional ARN for Bedrock inference profiles model_id: Optional ARN for Bedrock inference profiles
enable_thinking: Whether to enable thinking
reasoning_effort: Level of reasoning effort
Returns: Returns:
Union[Dict[str, Any], AsyncGenerator]: API response or stream Union[Dict[str, Any], AsyncGenerator]: API response or stream
@ -251,7 +269,7 @@ async def make_llm_api_call(
LLMRetryError: If API call fails after retries LLMRetryError: If API call fails after retries
LLMError: For other API-related errors LLMError: For other API-related errors
""" """
logger.debug(f"Making LLM API call to model: {model_name}") logger.debug(f"Making LLM API call to model: {model_name} (Thinking: {enable_thinking}, Effort: {reasoning_effort})")
params = prepare_params( params = prepare_params(
messages=messages, messages=messages,
model_name=model_name, model_name=model_name,
@ -264,7 +282,9 @@ async def make_llm_api_call(
api_base=api_base, api_base=api_base,
stream=stream, stream=stream,
top_p=top_p, top_p=top_p,
model_id=model_id model_id=model_id,
enable_thinking=enable_thinking,
reasoning_effort=reasoning_effort
) )
last_error = None last_error = None