From c84ee59dc613436634d67e6706c0f2603cbbb228 Mon Sep 17 00:00:00 2001 From: LE Quoc Dat Date: Fri, 18 Apr 2025 05:49:41 +0100 Subject: [PATCH] thinking & reasoning --- backend/agent/api.py | 56 ++++++++-- backend/agent/run.py | 41 ++++++- backend/agentpress/response_processor.py | 133 +++++++++++++++++++---- backend/agentpress/thread_manager.py | 16 ++- backend/services/llm.py | 28 ++++- 5 files changed, 230 insertions(+), 44 deletions(-) diff --git a/backend/agent/api.py b/backend/agent/api.py index edb54c35..d300ca9d 100644 --- a/backend/agent/api.py +++ b/backend/agent/api.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, HTTPException, Depends, Request +from fastapi import APIRouter, HTTPException, Depends, Request, Body from fastapi.responses import StreamingResponse import asyncio import json @@ -7,6 +7,7 @@ from datetime import datetime, timezone import uuid from typing import Optional, List, Dict, Any import jwt +from pydantic import BaseModel from agentpress.thread_manager import ThreadManager from services.supabase import DBConnection @@ -26,6 +27,12 @@ db = None # In-memory storage for active agent runs and their responses active_agent_runs: Dict[str, List[Any]] = {} +class AgentStartRequest(BaseModel): + model_name: Optional[str] = "anthropic/claude-3-7-sonnet-latest" + enable_thinking: Optional[bool] = False + reasoning_effort: Optional[str] = 'low' + stream: Optional[bool] = False # Default stream to False for API + def initialize( _thread_manager: ThreadManager, _db: DBConnection, @@ -237,9 +244,13 @@ async def _cleanup_agent_run(agent_run_id: str): # Non-fatal error, can continue @router.post("/thread/{thread_id}/agent/start") -async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id)): +async def start_agent( + thread_id: str, + body: AgentStartRequest = Body(...), # Accept request body + user_id: str = Depends(get_current_user_id) +): """Start an agent for a specific thread in the background.""" - logger.info(f"Starting new agent for thread: {thread_id}") + logger.info(f"Starting new agent for thread: {thread_id} with config: model={body.model_name}, thinking={body.enable_thinking}, effort={body.reasoning_effort}, stream={body.stream}") client = await db.client # Verify user has access to this thread @@ -314,7 +325,17 @@ async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id # Run the agent in the background task = asyncio.create_task( - run_agent_background(agent_run_id, thread_id, instance_id, project_id, sandbox) + run_agent_background( + agent_run_id=agent_run_id, + thread_id=thread_id, + instance_id=instance_id, + project_id=project_id, + sandbox=sandbox, + model_name=body.model_name, + enable_thinking=body.enable_thinking, + reasoning_effort=body.reasoning_effort, + stream=body.stream # Pass stream parameter + ) ) # Set a callback to clean up when task is done @@ -441,9 +462,19 @@ async def stream_agent_run( } ) -async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: str, project_id: str, sandbox): +async def run_agent_background( + agent_run_id: str, + thread_id: str, + instance_id: str, + project_id: str, + sandbox, + model_name: str, + enable_thinking: Optional[bool], + reasoning_effort: Optional[str], + stream: bool # Add stream parameter +): """Run the agent in the background and handle status updates.""" - logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id})") + logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id}) with model={model_name}, thinking={enable_thinking}, effort={reasoning_effort}, stream={stream}") client = await db.client # Tracking variables @@ -561,9 +592,16 @@ async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: s try: # Run the agent logger.debug(f"Initializing agent generator for thread: {thread_id} (instance: {instance_id})") - agent_gen = run_agent(thread_id, stream=True, - thread_manager=thread_manager, project_id=project_id, - sandbox=sandbox) + agent_gen = run_agent( + thread_id=thread_id, + project_id=project_id, + stream=stream, # Pass stream parameter from API request + thread_manager=thread_manager, + sandbox=sandbox, + model_name=model_name, # Pass model_name + enable_thinking=enable_thinking, # Pass enable_thinking + reasoning_effort=reasoning_effort # Pass reasoning_effort + ) # Collect all responses to save to database all_responses = [] diff --git a/backend/agent/run.py b/backend/agent/run.py index 87929225..d4fddc61 100644 --- a/backend/agent/run.py +++ b/backend/agent/run.py @@ -21,7 +21,18 @@ from utils.billing import check_billing_status, get_account_id_from_thread load_dotenv() -async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = True, thread_manager: Optional[ThreadManager] = None, native_max_auto_continues: int = 25, max_iterations: int = 150): +async def run_agent( + thread_id: str, + project_id: str, + sandbox, + stream: bool, + thread_manager: Optional[ThreadManager] = None, + native_max_auto_continues: int = 25, + max_iterations: int = 150, + model_name: str = "anthropic/claude-3-7-sonnet-latest", + enable_thinking: Optional[bool] = False, + reasoning_effort: Optional[str] = 'low' +): """Run the development agent with specified configuration.""" if not thread_manager: @@ -112,7 +123,7 @@ async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = Tru thread_id=thread_id, system_prompt=system_message, stream=stream, - llm_model=os.getenv("MODEL_TO_USE", "anthropic/claude-3-7-sonnet-latest"), + llm_model=model_name, llm_temperature=0, llm_max_tokens=64000, tool_choice="auto", @@ -128,6 +139,8 @@ async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = Tru ), native_max_auto_continues=native_max_auto_continues, include_xml_examples=True, + enable_thinking=enable_thinking, + reasoning_effort=reasoning_effort ) if isinstance(response, dict) and "status" in response and response["status"] == "error": @@ -250,7 +263,15 @@ async def test_agent(): print("\nšŸ‘‹ Test completed. Goodbye!") -async def process_agent_response(thread_id: str, project_id: str, thread_manager: ThreadManager): +async def process_agent_response( + thread_id: str, + project_id: str, + thread_manager: ThreadManager, + stream: bool = True, + model_name: str = "anthropic/claude-3-7-sonnet-latest", + enable_thinking: Optional[bool] = False, + reasoning_effort: Optional[str] = 'low' +): """Process the streaming response from the agent.""" chunk_counter = 0 current_response = "" @@ -259,9 +280,19 @@ async def process_agent_response(thread_id: str, project_id: str, thread_manager # Create a test sandbox for processing sandbox_pass = str(uuid4()) sandbox = create_sandbox(sandbox_pass) - print(f"\033[91mTest sandbox created: {sandbox.get_preview_link(6080)}/vnc_lite.html?password={sandbox_pass}\033[0m") + print(f"\033[91mTest sandbox created: {str(sandbox.get_preview_link(6080))}/vnc_lite.html?password={sandbox_pass}\033[0m") - async for chunk in run_agent(thread_id=thread_id, project_id=project_id, sandbox=sandbox, stream=True, thread_manager=thread_manager, native_max_auto_continues=25): + async for chunk in run_agent( + thread_id=thread_id, + project_id=project_id, + sandbox=sandbox, + stream=stream, + thread_manager=thread_manager, + native_max_auto_continues=25, + model_name=model_name, + enable_thinking=enable_thinking, + reasoning_effort=reasoning_effort + ): chunk_counter += 1 if chunk.get('type') == 'content' and 'content' in chunk: diff --git a/backend/agentpress/response_processor.py b/backend/agentpress/response_processor.py index 0e250cae..468a6b09 100644 --- a/backend/agentpress/response_processor.py +++ b/backend/agentpress/response_processor.py @@ -98,6 +98,8 @@ class ResponseProcessor: llm_response: AsyncGenerator, thread_id: str, config: ProcessorConfig = ProcessorConfig(), + prompt_messages: Optional[List[Dict[str, Any]]] = None, + llm_model: Optional[str] = None ) -> AsyncGenerator: """Process a streaming LLM response, handling tool calls and execution. @@ -105,6 +107,8 @@ class ResponseProcessor: llm_response: Streaming response from the LLM thread_id: ID of the conversation thread config: Configuration for parsing and execution + prompt_messages: List of messages used for cost calculation + llm_model: Name of the LLM model used for cost calculation Yields: Formatted chunks of the response including content and tool results @@ -175,15 +179,20 @@ class ResponseProcessor: accumulated_content += chunk_content current_xml_content += chunk_content - # Calculate cost using prompt and completion - try: - cost = completion_cost(model=chunk.model, prompt=accumulated_content, completion=chunk_content) - tcount = token_counter(model=chunk.model, messages=[{"role": "user", "content": accumulated_content}]) - accumulated_cost += cost - accumulated_token_count += tcount - logger.debug(f"Cost: {cost:.6f}, Token count: {tcount}") - except Exception as e: - logger.error(f"Error calculating cost: {str(e)}") + # Process reasoning content if present (Anthropic) + if hasattr(delta, 'reasoning_content') and delta.reasoning_content: + logger.info(f"[THINKING]: {delta.reasoning_content}") + accumulated_content += delta.reasoning_content # Append reasoning to main content + + # Calculate cost using prompt and completion - MOVED AFTER MESSAGE SAVE + # try: + # cost = completion_cost(model=chunk.model, prompt=accumulated_content, completion=chunk_content) + # tcount = token_counter(model=chunk.model, messages=[{"role": "user", "content": accumulated_content}]) + # accumulated_cost += cost + # accumulated_token_count += tcount + # logger.debug(f"Cost: {cost:.6f}, Token count: {tcount}") + # except Exception as e: + # logger.error(f"Error calculating cost: {str(e)}") # Check if we've reached the XML tool call limit before yielding content if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls: @@ -489,6 +498,35 @@ class ResponseProcessor: "thread_run_id": thread_run_id } + # --- Cost Calculation (moved here) --- + if prompt_messages and llm_model and accumulated_content: + try: + cost = completion_cost( + model=llm_model, + messages=prompt_messages, + completion=accumulated_content + ) + token_count = token_counter( + model=llm_model, + messages=prompt_messages + [{"role": "assistant", "content": accumulated_content}] + ) + await self.add_message( + thread_id=thread_id, + type="cost", + content={ + "cost": cost, + "prompt_tokens": token_count - token_counter(model=llm_model, messages=[{"role": "assistant", "content": accumulated_content}]), # Approx + "completion_tokens": token_counter(model=llm_model, messages=[{"role": "assistant", "content": accumulated_content}]), # Approx + "total_tokens": token_count, + "model_name": llm_model + }, + is_llm_message=False + ) + logger.info(f"Calculated cost for streaming response: {cost:.6f} using model {llm_model}") + except Exception as e: + logger.error(f"Error calculating cost: {str(e)}") + # --- End Cost Calculation --- + # --- Process All Tool Calls Now --- if config.execute_tools: final_tool_calls_to_process = [] @@ -626,23 +664,24 @@ class ResponseProcessor: yield {"type": "error", "message": str(e), "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None} finally: - # Yield a finish signal including the final assistant message ID - if last_assistant_message_id: - # Yield the overall run end signal + # Yield the detected finish reason if one exists and wasn't suppressed + if finish_reason and finish_reason != "xml_tool_limit_reached": yield { - "type": "thread_run_end", - "thread_run_id": thread_run_id - } - else: - # Yield the overall run end signal - yield { - "type": "thread_run_end", + "type": "finish", + "finish_reason": finish_reason, "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None } - + + # Yield a finish signal including the final assistant message ID + # Ensure thread_run_id is defined, even if an early error occurred + run_id = thread_run_id if 'thread_run_id' in locals() else str(uuid.uuid4()) # Fallback ID if needed + yield { + "type": "thread_run_end", + "thread_run_id": run_id + } + + # Remove old cost calculation code pass - # track the cost and token count - # todo: there is a bug as it adds every chunk to db because finally will run every time even in yield # await self.add_message( # thread_id=thread_id, # type="cost", @@ -658,7 +697,9 @@ class ResponseProcessor: self, llm_response: Any, thread_id: str, - config: ProcessorConfig = ProcessorConfig() + config: ProcessorConfig = ProcessorConfig(), + prompt_messages: Optional[List[Dict[str, Any]]] = None, + llm_model: Optional[str] = None ) -> AsyncGenerator[Dict[str, Any], None]: """Process a non-streaming LLM response, handling tool calls and execution. @@ -666,6 +707,8 @@ class ResponseProcessor: llm_response: Response from the LLM thread_id: ID of the conversation thread config: Configuration for parsing and execution + prompt_messages: List of messages used for cost calculation + llm_model: Name of the LLM model used for cost calculation Yields: Formatted response including content and tool results @@ -861,6 +904,50 @@ class ResponseProcessor: "thread_run_id": thread_run_id } + # --- Cost Calculation (moved here) --- + if prompt_messages and llm_model: + cost = None + # Attempt to get cost from LiteLLM response first + if hasattr(llm_response, '_hidden_params') and 'response_cost' in llm_response._hidden_params: + cost = llm_response._hidden_params['response_cost'] + logger.info(f"Using pre-calculated cost from LiteLLM: {cost:.6f}") + + # If no pre-calculated cost, calculate manually + if cost is None: + try: + cost = completion_cost( + model=llm_model, + messages=prompt_messages, + completion=content # Use extracted content + ) + logger.info(f"Manually calculated cost for non-streaming response: {cost:.6f} using model {llm_model}") + except Exception as e: + logger.error(f"Error calculating cost: {str(e)}") + + # Add cost message if cost was determined + if cost is not None: + try: + # Approximate token counts + completion_tokens = token_counter(model=llm_model, messages=[{"role": "assistant", "content": content}]) + prompt_tokens = token_counter(model=llm_model, messages=prompt_messages) + total_tokens = prompt_tokens + completion_tokens + + await self.add_message( + thread_id=thread_id, + type="cost", + content={ + "cost": cost, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + "model_name": llm_model + }, + is_llm_message=False + ) + except Exception as e: + logger.error(f"Error saving cost message: {str(e)}") + # --- End Cost Calculation --- + except Exception as e: logger.error(f"Error processing response: {str(e)}", exc_info=True) yield {"type": "error", "message": str(e), "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None} diff --git a/backend/agentpress/thread_manager.py b/backend/agentpress/thread_manager.py index f03459b4..8dbe7524 100644 --- a/backend/agentpress/thread_manager.py +++ b/backend/agentpress/thread_manager.py @@ -162,6 +162,8 @@ class ThreadManager: native_max_auto_continues: int = 25, max_xml_tool_calls: int = 0, include_xml_examples: bool = False, + enable_thinking: Optional[bool] = False, + reasoning_effort: Optional[str] = 'low' ) -> Union[Dict[str, Any], AsyncGenerator]: """Run a conversation thread with LLM integration and tool execution. @@ -179,6 +181,8 @@ class ThreadManager: finish_reason="tool_calls" (0 disables auto-continue) max_xml_tool_calls: Maximum number of XML tool calls to allow (0 = no limit) include_xml_examples: Whether to include XML tool examples in the system prompt + enable_thinking: Whether to enable thinking before making a decision + reasoning_effort: The effort level for reasoning Returns: An async generator yielding response chunks or error dict @@ -313,7 +317,9 @@ Here are the XML tools available with examples: max_tokens=llm_max_tokens, tools=openapi_tool_schemas, tool_choice=tool_choice if processor_config.native_tool_calling else None, - stream=stream + stream=stream, + enable_thinking=enable_thinking, + reasoning_effort=reasoning_effort ) logger.debug("Successfully received raw LLM API response stream/object") @@ -327,7 +333,9 @@ Here are the XML tools available with examples: response_generator = self.response_processor.process_streaming_response( llm_response=llm_response, thread_id=thread_id, - config=processor_config + config=processor_config, + prompt_messages=prepared_messages, + llm_model=llm_model ) return response_generator @@ -338,7 +346,9 @@ Here are the XML tools available with examples: response_generator = self.response_processor.process_non_streaming_response( llm_response=llm_response, thread_id=thread_id, - config=processor_config + config=processor_config, + prompt_messages=prepared_messages, + llm_model=llm_model ) return response_generator # Return the generator except Exception as e: diff --git a/backend/services/llm.py b/backend/services/llm.py index f3abe4df..2bcfb7ae 100644 --- a/backend/services/llm.py +++ b/backend/services/llm.py @@ -17,6 +17,8 @@ import asyncio from openai import OpenAIError import litellm from utils.logger import logger +from datetime import datetime +import traceback # litellm.set_verbose=True litellm.modify_params=True @@ -82,7 +84,9 @@ def prepare_params( api_base: Optional[str] = None, stream: bool = False, top_p: Optional[float] = None, - model_id: Optional[str] = None + model_id: Optional[str] = None, + enable_thinking: Optional[bool] = False, + reasoning_effort: Optional[str] = 'low' ) -> Dict[str, Any]: """Prepare parameters for the API call.""" params = { @@ -211,6 +215,16 @@ def prepare_params( else: logger.warning(f"Last user message (index {last_user_idx}) content is not a string or list ({type(content)}), skipping cache_control.") + # Add reasoning_effort for Anthropic models if enabled + use_thinking = enable_thinking if enable_thinking is not None else False + is_anthropic = "anthropic" in effective_model_name.lower() or "claude" in effective_model_name.lower() + + if is_anthropic and use_thinking: + effort_level = reasoning_effort if reasoning_effort else 'low' + params["reasoning_effort"] = effort_level + params["temperature"] = 1.0 # Required by Anthropic when reasoning_effort is used + logger.info(f"Anthropic thinking enabled with reasoning_effort='{effort_level}'") + return params async def make_llm_api_call( @@ -225,7 +239,9 @@ async def make_llm_api_call( api_base: Optional[str] = None, stream: bool = False, top_p: Optional[float] = None, - model_id: Optional[str] = None + model_id: Optional[str] = None, + enable_thinking: Optional[bool] = False, + reasoning_effort: Optional[str] = 'low' ) -> Union[Dict[str, Any], AsyncGenerator]: """ Make an API call to a language model using LiteLLM. @@ -243,6 +259,8 @@ async def make_llm_api_call( stream: Whether to stream the response top_p: Top-p sampling parameter model_id: Optional ARN for Bedrock inference profiles + enable_thinking: Whether to enable thinking + reasoning_effort: Level of reasoning effort Returns: Union[Dict[str, Any], AsyncGenerator]: API response or stream @@ -251,7 +269,7 @@ async def make_llm_api_call( LLMRetryError: If API call fails after retries LLMError: For other API-related errors """ - logger.debug(f"Making LLM API call to model: {model_name}") + logger.debug(f"Making LLM API call to model: {model_name} (Thinking: {enable_thinking}, Effort: {reasoning_effort})") params = prepare_params( messages=messages, model_name=model_name, @@ -264,7 +282,9 @@ async def make_llm_api_call( api_base=api_base, stream=stream, top_p=top_p, - model_id=model_id + model_id=model_id, + enable_thinking=enable_thinking, + reasoning_effort=reasoning_effort ) last_error = None