Merge pull request #47 from kortix-ai/reasoning

Reasoning + Cost Calculation
This commit is contained in:
Marko Kraemer 2025-04-17 15:15:06 -07:00 committed by GitHub
commit be9e6d46cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 944 additions and 173 deletions

View File

@ -1,4 +1,4 @@
from fastapi import APIRouter, HTTPException, Depends, Request from fastapi import APIRouter, HTTPException, Depends, Request, Body
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
import asyncio import asyncio
import json import json
@ -7,6 +7,7 @@ from datetime import datetime, timezone
import uuid import uuid
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
import jwt import jwt
from pydantic import BaseModel
from agentpress.thread_manager import ThreadManager from agentpress.thread_manager import ThreadManager
from services.supabase import DBConnection from services.supabase import DBConnection
@ -15,7 +16,14 @@ from agent.run import run_agent
from utils.auth_utils import get_current_user_id, get_user_id_from_stream_auth, verify_thread_access from utils.auth_utils import get_current_user_id, get_user_id_from_stream_auth, verify_thread_access
from utils.logger import logger from utils.logger import logger
from utils.billing import check_billing_status, get_account_id_from_thread from utils.billing import check_billing_status, get_account_id_from_thread
from utils.db import update_agent_run_status # Removed duplicate import of update_agent_run_status from utils.db as it's defined locally
# Define request model for starting agent
class AgentStartRequest(BaseModel):
model_name: Optional[str] = "anthropic/claude-3-7-sonnet-latest"
enable_thinking: Optional[bool] = False
reasoning_effort: Optional[str] = 'low'
stream: Optional[bool] = False # Add stream option, default to False
# Initialize shared resources # Initialize shared resources
router = APIRouter() router = APIRouter()
@ -236,11 +244,15 @@ async def _cleanup_agent_run(agent_run_id: str):
# Non-fatal error, can continue # Non-fatal error, can continue
@router.post("/thread/{thread_id}/agent/start") @router.post("/thread/{thread_id}/agent/start")
async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id)): async def start_agent(
"""Start an agent for a specific thread in the background.""" thread_id: str,
logger.info(f"Starting new agent for thread: {thread_id}") body: AgentStartRequest = Body(...), # Accept request body
user_id: str = Depends(get_current_user_id)
):
"""Start an agent for a specific thread in the background with dynamic settings."""
logger.info(f"Starting new agent for thread: {thread_id} with model: {body.model_name}, thinking: {body.enable_thinking}, effort: {body.reasoning_effort}, stream: {body.stream}")
client = await db.client client = await db.client
# Verify user has access to this thread # Verify user has access to this thread
await verify_thread_access(client, thread_id, user_id) await verify_thread_access(client, thread_id, user_id)
@ -291,11 +303,20 @@ async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id
except Exception as e: except Exception as e:
logger.warning(f"Failed to register agent run in Redis, continuing without Redis tracking: {str(e)}") logger.warning(f"Failed to register agent run in Redis, continuing without Redis tracking: {str(e)}")
# Run the agent in the background # Run the agent in the background, passing the dynamic settings
task = asyncio.create_task( task = asyncio.create_task(
run_agent_background(agent_run_id, thread_id, instance_id, project_id) run_agent_background(
agent_run_id=agent_run_id,
thread_id=thread_id,
instance_id=instance_id,
project_id=project_id,
model_name=body.model_name,
enable_thinking=body.enable_thinking,
reasoning_effort=body.reasoning_effort,
stream=body.stream # Pass stream parameter
)
) )
# Set a callback to clean up when task is done # Set a callback to clean up when task is done
task.add_done_callback( task.add_done_callback(
lambda _: asyncio.create_task( lambda _: asyncio.create_task(
@ -420,11 +441,20 @@ async def stream_agent_run(
} }
) )
async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: str, project_id: str): async def run_agent_background(
agent_run_id: str,
thread_id: str,
instance_id: str,
project_id: str,
model_name: str, # Add model_name parameter
enable_thinking: Optional[bool], # Add enable_thinking parameter
reasoning_effort: Optional[str], # Add reasoning_effort parameter
stream: bool # Add stream parameter
):
"""Run the agent in the background and handle status updates.""" """Run the agent in the background and handle status updates."""
logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id})") logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id}) with model: {model_name}, stream: {stream}")
client = await db.client client = await db.client
# Tracking variables # Tracking variables
total_responses = 0 total_responses = 0
start_time = datetime.now(timezone.utc) start_time = datetime.now(timezone.utc)
@ -538,11 +568,18 @@ async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: s
logger.warning(f"No stop signal checker for agent run: {agent_run_id} - pubsub unavailable") logger.warning(f"No stop signal checker for agent run: {agent_run_id} - pubsub unavailable")
try: try:
# Run the agent # Run the agent, passing the dynamic settings
logger.debug(f"Initializing agent generator for thread: {thread_id} (instance: {instance_id})") logger.debug(f"Initializing agent generator for thread: {thread_id} (instance: {instance_id})")
agent_gen = run_agent(thread_id, stream=True, agent_gen = run_agent(
thread_manager=thread_manager, project_id=project_id) thread_id=thread_id,
project_id=project_id,
stream=stream, # Pass stream parameter from API request
thread_manager=thread_manager,
model_name=model_name, # Pass model_name
enable_thinking=enable_thinking, # Pass enable_thinking
reasoning_effort=reasoning_effort # Pass reasoning_effort
)
# Collect all responses to save to database # Collect all responses to save to database
all_responses = [] all_responses = []
@ -654,4 +691,4 @@ async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: s
except Exception as e: except Exception as e:
logger.warning(f"Error deleting active run key: {str(e)}") logger.warning(f"Error deleting active run key: {str(e)}")
logger.info(f"Agent run background task fully completed for: {agent_run_id} (instance: {instance_id})") logger.info(f"Agent run background task fully completed for: {agent_run_id} (instance: {instance_id})")

View File

@ -20,9 +20,19 @@ from utils.billing import check_billing_status, get_account_id_from_thread
load_dotenv() load_dotenv()
async def run_agent(thread_id: str, project_id: str, stream: bool = True, thread_manager: Optional[ThreadManager] = None, native_max_auto_continues: int = 25, max_iterations: int = 150): async def run_agent(
thread_id: str,
project_id: str,
stream: bool, # Accept stream parameter from caller (api.py)
thread_manager: Optional[ThreadManager] = None,
native_max_auto_continues: int = 25,
max_iterations: int = 150,
model_name: str = "anthropic/claude-3-7-sonnet-latest", # Add model_name parameter with default
enable_thinking: Optional[bool] = False, # Add enable_thinking parameter
reasoning_effort: Optional[str] = 'low' # Add reasoning_effort parameter
):
"""Run the development agent with specified configuration.""" """Run the development agent with specified configuration."""
if not thread_manager: if not thread_manager:
thread_manager = ThreadManager() thread_manager = ThreadManager()
client = await thread_manager.db.client client = await thread_manager.db.client
@ -59,8 +69,8 @@ async def run_agent(thread_id: str, project_id: str, stream: bool = True, thread
'sandbox': { 'sandbox': {
'id': sandbox_id, 'id': sandbox_id,
'pass': sandbox_pass, 'pass': sandbox_pass,
'vnc_preview': sandbox.get_preview_link(6080), 'vnc_preview': str(sandbox.get_preview_link(6080)), # Convert to string
'sandbox_url': sandbox.get_preview_link(8080) 'sandbox_url': str(sandbox.get_preview_link(8080)) # Convert to string
} }
}).eq('project_id', project_id).execute() }).eq('project_id', project_id).execute()
@ -161,23 +171,29 @@ async def run_agent(thread_id: str, project_id: str, stream: bool = True, thread
print(f"Error parsing browser state: {e}") print(f"Error parsing browser state: {e}")
# print(latest_browser_state.data[0]) # print(latest_browser_state.data[0])
# Run Thread # Determine max tokens based on the passed model_name
max_tokens = None
if model_name == "anthropic/claude-3-7-sonnet-latest":
max_tokens = 64000 # Example: Set max tokens for a specific model
# Run Thread, passing the dynamic settings
response = await thread_manager.run_thread( response = await thread_manager.run_thread(
thread_id=thread_id, thread_id=thread_id,
system_prompt=system_message, # Pass the constructed message system_prompt=system_message, # Pass the constructed message
stream=stream, stream=stream, # Pass the stream parameter received by run_agent
llm_model=os.getenv("MODEL_TO_USE", "anthropic/claude-3-7-sonnet-latest"), llm_model=model_name, # Use the passed model_name
llm_temperature=0, llm_temperature=1, # Example temperature
llm_max_tokens=64000, llm_max_tokens=max_tokens, # Use the determined value
tool_choice="auto", tool_choice="auto",
max_xml_tool_calls=1, max_xml_tool_calls=1,
temporary_message=temporary_message, temporary_message=temporary_message,
processor_config=processor_config, # Pass the config object processor_config=processor_config, # Pass the config object
native_max_auto_continues=native_max_auto_continues, native_max_auto_continues=native_max_auto_continues,
# Explicitly set include_xml_examples to False here include_xml_examples=False, # Explicitly set include_xml_examples to False here
include_xml_examples=False, enable_thinking=enable_thinking, # Pass enable_thinking
reasoning_effort=reasoning_effort # Pass reasoning_effort
) )
if isinstance(response, dict) and "status" in response and response["status"] == "error": if isinstance(response, dict) and "status" in response and response["status"] == "error":
yield response yield response
break break
@ -279,7 +295,8 @@ async def test_agent():
if not user_message.strip(): if not user_message.strip():
print("\n🔄 Running agent...\n") print("\n🔄 Running agent...\n")
await process_agent_response(thread_id, project_id, thread_manager) # Pass stream=True explicitly when calling from test_agent
await process_agent_response(thread_id, project_id, thread_manager, stream=True)
continue continue
# Add the user message to the thread # Add the user message to the thread
@ -294,19 +311,40 @@ async def test_agent():
) )
print("\n🔄 Running agent...\n") print("\n🔄 Running agent...\n")
await process_agent_response(thread_id, project_id, thread_manager) # Pass stream=True explicitly when calling from test_agent
await process_agent_response(thread_id, project_id, thread_manager, stream=True)
print("\n👋 Test completed. Goodbye!") print("\n👋 Test completed. Goodbye!")
async def process_agent_response(thread_id: str, project_id: str, thread_manager: ThreadManager): async def process_agent_response(
"""Process the streaming response from the agent.""" thread_id: str,
project_id: str,
thread_manager: ThreadManager,
stream: bool = True, # Add stream parameter, default to True for testing
model_name: str = "anthropic/claude-3-7-sonnet-latest", # Add model_name with default
enable_thinking: Optional[bool] = False, # Add enable_thinking
reasoning_effort: Optional[str] = 'low' # Add reasoning_effort
):
"""Process the streaming response from the agent, passing model/thinking parameters."""
chunk_counter = 0 chunk_counter = 0
current_response = "" current_response = ""
tool_call_counter = 0 # Track number of tool calls tool_call_counter = 0 # Track number of tool calls
async for chunk in run_agent(thread_id=thread_id, project_id=project_id, stream=True, thread_manager=thread_manager, native_max_auto_continues=25): # Pass the received parameters to the run_agent call
agent_generator = run_agent(
thread_id=thread_id,
project_id=project_id,
stream=stream, # Pass the stream parameter here
thread_manager=thread_manager,
native_max_auto_continues=25,
model_name=model_name,
enable_thinking=enable_thinking,
reasoning_effort=reasoning_effort
)
async for chunk in agent_generator:
chunk_counter += 1 chunk_counter += 1
if chunk.get('type') == 'content' and 'content' in chunk: if chunk.get('type') == 'content' and 'content' in chunk:
current_response += chunk.get('content', '') current_response += chunk.get('content', '')
# Print the response as it comes in # Print the response as it comes in
@ -382,4 +420,4 @@ if __name__ == "__main__":
load_dotenv() # Ensure environment variables are loaded load_dotenv() # Ensure environment variables are loaded
# Run the test function # Run the test function
asyncio.run(test_agent()) asyncio.run(test_agent())

View File

@ -95,6 +95,8 @@ class ResponseProcessor:
self, self,
llm_response: AsyncGenerator, llm_response: AsyncGenerator,
thread_id: str, thread_id: str,
prompt_messages: List[Dict[str, Any]],
llm_model: str,
config: ProcessorConfig = ProcessorConfig(), config: ProcessorConfig = ProcessorConfig(),
) -> AsyncGenerator: ) -> AsyncGenerator:
"""Process a streaming LLM response, handling tool calls and execution. """Process a streaming LLM response, handling tool calls and execution.
@ -102,6 +104,8 @@ class ResponseProcessor:
Args: Args:
llm_response: Streaming response from the LLM llm_response: Streaming response from the LLM
thread_id: ID of the conversation thread thread_id: ID of the conversation thread
prompt_messages: List of messages sent to the LLM (the prompt)
llm_model: The name of the LLM model used
config: Configuration for parsing and execution config: Configuration for parsing and execution
Yields: Yields:
@ -140,9 +144,6 @@ class ResponseProcessor:
# if config.max_xml_tool_calls > 0: # if config.max_xml_tool_calls > 0:
# logger.info(f"XML tool call limit enabled: {config.max_xml_tool_calls}") # logger.info(f"XML tool call limit enabled: {config.max_xml_tool_calls}")
accumulated_cost = 0
accumulated_token_count = 0
try: try:
async for chunk in llm_response: async for chunk in llm_response:
# Default content to yield # Default content to yield
@ -155,22 +156,17 @@ class ResponseProcessor:
if hasattr(chunk, 'choices') and chunk.choices: if hasattr(chunk, 'choices') and chunk.choices:
delta = chunk.choices[0].delta if hasattr(chunk.choices[0], 'delta') else None delta = chunk.choices[0].delta if hasattr(chunk.choices[0], 'delta') else None
# Check for and log Anthropic thinking content
if delta and hasattr(delta, 'reasoning_content') and delta.reasoning_content:
logger.info(f"[THINKING]: {delta.reasoning_content}")
accumulated_content += delta.reasoning_content # Append reasoning to main content
# Process content chunk # Process content chunk
if delta and hasattr(delta, 'content') and delta.content: if delta and hasattr(delta, 'content') and delta.content:
chunk_content = delta.content chunk_content = delta.content
accumulated_content += chunk_content accumulated_content += chunk_content
current_xml_content += chunk_content current_xml_content += chunk_content
# Calculate cost using prompt and completion
try:
cost = completion_cost(model=chunk.model, prompt=accumulated_content, completion=chunk_content)
tcount = token_counter(model=chunk.model, messages=[{"role": "user", "content": accumulated_content}])
accumulated_cost += cost
accumulated_token_count += tcount
logger.debug(f"Cost: {cost:.6f}, Token count: {tcount}")
except Exception as e:
logger.error(f"Error calculating cost: {str(e)}")
# Check if we've reached the XML tool call limit before yielding content # Check if we've reached the XML tool call limit before yielding content
if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls: if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls:
# We've reached the limit, don't yield any more content # We've reached the limit, don't yield any more content
@ -334,7 +330,7 @@ class ResponseProcessor:
# If we've reached the XML tool call limit, stop streaming # If we've reached the XML tool call limit, stop streaming
if finish_reason == "xml_tool_limit_reached": if finish_reason == "xml_tool_limit_reached":
logger.info("Stopping stream due to XML tool call limit") logger.info("Stopping stream processing after loop due to XML tool call limit")
break break
# After streaming completes or is stopped due to limit, wait for any remaining tool executions # After streaming completes or is stopped due to limit, wait for any remaining tool executions
@ -466,6 +462,27 @@ class ResponseProcessor:
is_llm_message=True is_llm_message=True
) )
# Calculate and store cost AFTER adding the main assistant message
if accumulated_content: # Calculate cost if there was content (now includes reasoning)
try:
final_cost = completion_cost(
model=llm_model, # Use the passed model name
messages=prompt_messages,
completion=accumulated_content
)
if final_cost is not None and final_cost > 0:
logger.info(f"Calculated final cost for stream: {final_cost}")
await self.add_message(
thread_id=thread_id,
type="cost",
content={"cost": final_cost},
is_llm_message=False # Cost is metadata, not LLM content
)
else:
logger.info("Cost calculation resulted in zero or None, not storing cost message.")
except Exception as e:
logger.error(f"Error calculating final cost for stream: {str(e)}")
# Now add all buffered tool results AFTER the assistant message, but don't yield if already yielded # Now add all buffered tool results AFTER the assistant message, but don't yield if already yielded
for tool_call, result, result_tool_index in tool_results_buffer: for tool_call, result, result_tool_index in tool_results_buffer:
# Add result based on tool type to the conversation history # Add result based on tool type to the conversation history
@ -559,24 +576,19 @@ class ResponseProcessor:
yield {"type": "error", "message": str(e)} yield {"type": "error", "message": str(e)}
finally: finally:
pass # Finally, yield the finish reason if it was detected
# track the cost and token count if finish_reason:
# todo: there is a bug as it adds every chunk to db because finally will run every time even in yield yield {
# await self.add_message( "type": "finish",
# thread_id=thread_id, "finish_reason": finish_reason
# type="cost", }
# content={
# "cost": accumulated_cost,
# "token_count": accumulated_token_count
# },
# is_llm_message=False
# )
async def process_non_streaming_response( async def process_non_streaming_response(
self, self,
llm_response: Any, llm_response: Any,
thread_id: str, thread_id: str,
prompt_messages: List[Dict[str, Any]],
llm_model: str,
config: ProcessorConfig = ProcessorConfig(), config: ProcessorConfig = ProcessorConfig(),
) -> AsyncGenerator: ) -> AsyncGenerator:
"""Process a non-streaming LLM response, handling tool calls and execution. """Process a non-streaming LLM response, handling tool calls and execution.
@ -584,6 +596,8 @@ class ResponseProcessor:
Args: Args:
llm_response: Response from the LLM llm_response: Response from the LLM
thread_id: ID of the conversation thread thread_id: ID of the conversation thread
prompt_messages: List of messages sent to the LLM (the prompt)
llm_model: The name of the LLM model used
config: Configuration for parsing and execution config: Configuration for parsing and execution
Yields: Yields:
@ -684,6 +698,33 @@ class ResponseProcessor:
is_llm_message=True is_llm_message=True
) )
# Calculate and store cost AFTER adding the main assistant message
if content or tool_calls: # Calculate cost if there's content or tool calls
try:
# Use the full response object for potentially more accurate cost calculation
# Pass model explicitly as it might not be reliably in response_object for all providers
# First check if response_cost is directly available in _hidden_params
if hasattr(llm_response, '_hidden_params') and 'response_cost' in llm_response._hidden_params and llm_response._hidden_params['response_cost'] != 0.0:
final_cost = llm_response._hidden_params['response_cost']
else:
# Fall back to calculating cost if direct cost not available
final_cost = completion_cost(
completion_response=llm_response,
model=llm_model,
call_type="completion" # Assuming 'completion' type for this context
)
if final_cost is not None and final_cost > 0:
logger.info(f"Calculated final cost for non-stream: {final_cost}")
await self.add_message(
thread_id=thread_id,
type="cost",
content={"cost": final_cost},
is_llm_message=False # Cost is metadata
)
except Exception as e:
logger.error(f"Error calculating final cost for non-stream: {str(e)}")
# Yield content first # Yield content first
yield {"type": "content", "content": content} yield {"type": "content", "content": content}

View File

@ -154,9 +154,10 @@ class ThreadManager:
native_max_auto_continues: int = 25, native_max_auto_continues: int = 25,
max_xml_tool_calls: int = 0, max_xml_tool_calls: int = 0,
include_xml_examples: bool = False, include_xml_examples: bool = False,
enable_thinking: Optional[bool] = False, # Add enable_thinking parameter
reasoning_effort: Optional[str] = 'low' # Add reasoning_effort parameter
) -> Union[Dict[str, Any], AsyncGenerator]: ) -> Union[Dict[str, Any], AsyncGenerator]:
"""Run a conversation thread with LLM integration and tool execution. """Run a conversation thread with LLM integration and tool execution.
Args: Args:
thread_id: The ID of the thread to run thread_id: The ID of the thread to run
system_prompt: System message to set the assistant's behavior system_prompt: System message to set the assistant's behavior
@ -279,7 +280,9 @@ class ThreadManager:
max_tokens=llm_max_tokens, max_tokens=llm_max_tokens,
tools=openapi_tool_schemas, tools=openapi_tool_schemas,
tool_choice=tool_choice if processor_config.native_tool_calling else None, tool_choice=tool_choice if processor_config.native_tool_calling else None,
stream=stream stream=stream,
enable_thinking=enable_thinking, # Pass enable_thinking
reasoning_effort=reasoning_effort # Pass reasoning_effort
) )
logger.debug("Successfully received raw LLM API response stream/object") logger.debug("Successfully received raw LLM API response stream/object")
@ -293,6 +296,8 @@ class ThreadManager:
response_generator = self.response_processor.process_streaming_response( response_generator = self.response_processor.process_streaming_response(
llm_response=llm_response, llm_response=llm_response,
thread_id=thread_id, thread_id=thread_id,
prompt_messages=prepared_messages,
llm_model=llm_model,
config=processor_config config=processor_config
) )
@ -303,55 +308,91 @@ class ThreadManager:
response_generator = self.response_processor.process_non_streaming_response( response_generator = self.response_processor.process_non_streaming_response(
llm_response=llm_response, llm_response=llm_response,
thread_id=thread_id, thread_id=thread_id,
prompt_messages=prepared_messages,
llm_model=llm_model,
config=processor_config config=processor_config
) )
return response_generator # Return the generator return response_generator # Return the generator
except Exception as e: except Exception as e:
logger.error(f"Error in run_thread: {str(e)}", exc_info=True) logger.error(f"Error in _run_once: {str(e)}", exc_info=True)
return { # For generators, we need to yield an error structure if returning a generator is expected
"status": "error", async def error_generator():
"message": str(e) yield {
} "type": "error",
"message": f"Error during LLM call or setup: {str(e)}"
}
return error_generator()
# Define a wrapper generator that handles auto-continue logic # Define a wrapper generator that handles auto-continue logic
async def auto_continue_wrapper(): async def auto_continue_wrapper():
nonlocal auto_continue, auto_continue_count nonlocal auto_continue, auto_continue_count, temporary_message
current_temp_message = temporary_message # Use a local copy for the first run
while auto_continue and (native_max_auto_continues == 0 or auto_continue_count < native_max_auto_continues): while auto_continue and (native_max_auto_continues == 0 or auto_continue_count < native_max_auto_continues):
# Reset auto_continue for this iteration # Reset auto_continue for this iteration
auto_continue = False auto_continue = False
# Run the thread once # Run the thread once
response_gen = await _run_once(temporary_message if auto_continue_count == 0 else None) # Pass current_temp_message, which is only set for the first iteration
response_gen = await _run_once(temp_msg=current_temp_message)
# Handle error responses # Clear the temporary message after the first run
if isinstance(response_gen, dict) and "status" in response_gen and response_gen["status"] == "error": current_temp_message = None
# Handle error responses (checking if it's an error dict, which _run_once might return directly)
if isinstance(response_gen, dict) and response_gen.get("status") == "error":
yield response_gen yield response_gen
return return
# Check if it's the error generator from _run_once exception handling
# Need a way to check if it's the specific error generator or just inspect the first item
first_chunk = None
try:
first_chunk = await anext(response_gen)
except StopAsyncIteration:
# Empty generator, possibly due to an issue before yielding.
logger.warning("Response generator was empty.")
break
except Exception as e:
logger.error(f"Error getting first chunk from generator: {e}")
yield {"type": "error", "message": f"Error processing response: {e}"}
break
if first_chunk and first_chunk.get('type') == 'error' and "Error during LLM call" in first_chunk.get('message', ''):
yield first_chunk
return # Stop processing if setup failed
# Yield the first chunk if it wasn't an error
if first_chunk:
yield first_chunk
# Process each chunk # Process remaining chunks
async for chunk in response_gen: async for chunk in response_gen:
# Check if this is a finish reason chunk with tool_calls or xml_tool_limit_reached # Check if this is a finish reason chunk with tool_calls or xml_tool_limit_reached
if chunk.get('type') == 'finish': if chunk.get('type') == 'finish':
if chunk.get('finish_reason') == 'tool_calls': finish_reason = chunk.get('finish_reason')
if finish_reason == 'tool_calls':
# Only auto-continue if enabled (max > 0) # Only auto-continue if enabled (max > 0)
if native_max_auto_continues > 0: if native_max_auto_continues > 0:
logger.info(f"Detected finish_reason='tool_calls', auto-continuing ({auto_continue_count + 1}/{native_max_auto_continues})") logger.info(f"Detected finish_reason='tool_calls', auto-continuing ({auto_continue_count + 1}/{native_max_auto_continues})")
auto_continue = True auto_continue = True
auto_continue_count += 1 auto_continue_count += 1
# Don't yield the finish chunk to avoid confusing the client # Don't yield the finish chunk to avoid confusing the client during auto-continue
continue continue
elif chunk.get('finish_reason') == 'xml_tool_limit_reached': elif finish_reason == 'xml_tool_limit_reached':
# Don't auto-continue if XML tool limit was reached # Don't auto-continue if XML tool limit was reached
logger.info(f"Detected finish_reason='xml_tool_limit_reached', stopping auto-continue") logger.info(f"Detected finish_reason='xml_tool_limit_reached', stopping auto-continue")
auto_continue = False auto_continue = False
# Still yield the chunk to inform the client # Still yield the chunk to inform the client
# Yield other finish reasons normally
# Otherwise just yield the chunk normally # Yield the chunk normally
yield chunk yield chunk
# If not auto-continuing, we're done # If not auto-continuing, we're done with the loop
if not auto_continue: if not auto_continue:
break break

View File

@ -1,6 +1,6 @@
streamlit-quill==0.0.3 streamlit-quill==0.0.3
python-dotenv==1.0.1 python-dotenv==1.0.1
litellm>=1.44.0 litellm>=1.66.2
click==8.1.7 click==8.1.7
questionary==2.0.1 questionary==2.0.1
requests>=2.31.0 requests>=2.31.0

View File

@ -28,7 +28,7 @@ RATE_LIMIT_DELAY = 30
RETRY_DELAY = 5 RETRY_DELAY = 5
# Define debug log directory relative to this file's location # Define debug log directory relative to this file's location
DEBUG_LOG_DIR = os.path.join(os.path.dirname(__file__), '..', 'debug_logs') # Assumes backend/debug_logs DEBUG_LOG_DIR = os.path.join(os.path.dirname(__file__), 'debug_logs')
class LLMError(Exception): class LLMError(Exception):
"""Base exception for LLM-related errors.""" """Base exception for LLM-related errors."""
@ -86,7 +86,10 @@ def prepare_params(
api_base: Optional[str] = None, api_base: Optional[str] = None,
stream: bool = False, stream: bool = False,
top_p: Optional[float] = None, top_p: Optional[float] = None,
model_id: Optional[str] = None model_id: Optional[str] = None,
# Add parameters for thinking/reasoning
enable_thinking: Optional[bool] = None,
reasoning_effort: Optional[str] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Prepare parameters for the API call.""" """Prepare parameters for the API call."""
params = { params = {
@ -156,6 +159,24 @@ def prepare_params(
params["model_id"] = "arn:aws:bedrock:us-west-2:935064898258:inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0" params["model_id"] = "arn:aws:bedrock:us-west-2:935064898258:inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0"
logger.debug(f"Auto-set model_id for Claude 3.7 Sonnet: {params['model_id']}") logger.debug(f"Auto-set model_id for Claude 3.7 Sonnet: {params['model_id']}")
# --- Add Anthropic Thinking/Reasoning Effort ---
# Determine if thinking should be enabled based on the passed parameter
use_thinking = enable_thinking if enable_thinking is not None else False
# Check if the model is Anthropic
is_anthropic = "sonnet-3-7" in model_name.lower() or "anthropic" in model_name.lower()
# Add reasoning_effort parameter if enabled and applicable
if is_anthropic and use_thinking:
# Determine reasoning effort based on the passed parameter, defaulting to 'low'
effort_level = reasoning_effort if reasoning_effort else 'low'
params["reasoning_effort"] = effort_level
logger.info(f"Anthropic thinking enabled with reasoning_effort='{effort_level}'")
# Anthropic requires temperature=1 when thinking/reasoning_effort is enabled
params["temperature"] = 1.0
return params return params
async def make_llm_api_call( async def make_llm_api_call(
@ -170,7 +191,10 @@ async def make_llm_api_call(
api_base: Optional[str] = None, api_base: Optional[str] = None,
stream: bool = False, stream: bool = False,
top_p: Optional[float] = None, top_p: Optional[float] = None,
model_id: Optional[str] = None model_id: Optional[str] = None,
# Add parameters for thinking/reasoning
enable_thinking: Optional[bool] = None,
reasoning_effort: Optional[str] = None
) -> Union[Dict[str, Any], AsyncGenerator]: ) -> Union[Dict[str, Any], AsyncGenerator]:
""" """
Make an API call to a language model using LiteLLM. Make an API call to a language model using LiteLLM.
@ -209,7 +233,10 @@ async def make_llm_api_call(
api_base=api_base, api_base=api_base,
stream=stream, stream=stream,
top_p=top_p, top_p=top_p,
model_id=model_id model_id=model_id,
# Add parameters for thinking/reasoning
enable_thinking=enable_thinking,
reasoning_effort=reasoning_effort
) )
# Apply Anthropic prompt caching (minimal implementation) # Apply Anthropic prompt caching (minimal implementation)
@ -272,15 +299,19 @@ async def make_llm_api_call(
# Initialize log path to None, it will be set only if logging is enabled # Initialize log path to None, it will be set only if logging is enabled
response_log_path = None response_log_path = None
enable_debug_logging = os.environ.get('ENABLE_LLM_DEBUG_LOGGING', 'false').lower() == 'true' enable_debug_logging = os.environ.get('ENABLE_LLM_DEBUG_LOGGING', 'false').lower() == 'true'
# enable_debug_logging = True
if enable_debug_logging: if enable_debug_logging:
try: try:
os.makedirs(DEBUG_LOG_DIR, exist_ok=True) os.makedirs(DEBUG_LOG_DIR, exist_ok=True)
# save the model name too
model_name = params["model"]
timestamp = time.strftime("%Y%m%d_%H%M%S") timestamp = time.strftime("%Y%m%d_%H%M%S")
# Use a unique ID or counter if calls can happen in the same second # Use a unique ID or counter if calls can happen in the same second
# For simplicity, using timestamp only for now # For simplicity, using timestamp only for now
request_log_path = os.path.join(DEBUG_LOG_DIR, f"llm_request_{timestamp}.json") request_log_path = os.path.join(DEBUG_LOG_DIR, f"llm_request_{timestamp}_{model_name}.json")
response_log_path = os.path.join(DEBUG_LOG_DIR, f"llm_response_{timestamp}.json") # Set here if enabled response_log_path = os.path.join(DEBUG_LOG_DIR, f"llm_response_{timestamp}_{model_name}.json") # Set here if enabled
# Log the request parameters just before the attempt loop # Log the request parameters just before the attempt loop
logger.debug(f"Logging LLM request parameters to {request_log_path}") logger.debug(f"Logging LLM request parameters to {request_log_path}")
@ -300,7 +331,7 @@ async def make_llm_api_call(
for attempt in range(MAX_RETRIES): for attempt in range(MAX_RETRIES):
try: try:
logger.debug(f"Attempt {attempt + 1}/{MAX_RETRIES}") logger.debug(f"Attempt {attempt + 1}/{MAX_RETRIES}")
# print(params)
response = await litellm.acompletion(**params) response = await litellm.acompletion(**params)
logger.debug(f"Successfully received API response from {model_name}") logger.debug(f"Successfully received API response from {model_name}")

View File

@ -0,0 +1,154 @@
import asyncio
import litellm
import copy
import json
from dotenv import load_dotenv
load_dotenv()
async def run_conversation_turn(model: str, messages: list, user_prompt: str | list, reasoning_effort: str | None = None):
"""
Handles a single turn of the conversation, prepares arguments, and calls litellm.acompletion.
Args:
model: The model name string.
messages: The list of message dictionaries (will be modified in place).
user_prompt: The user's prompt for this turn (string or list).
reasoning_effort: Optional reasoning effort string for Anthropic models.
Returns:
The response object from litellm.acompletion.
"""
# Append user prompt
if isinstance(user_prompt, str):
messages.append({"role": "user", "content": user_prompt})
elif isinstance(user_prompt, list): # Handle list/dict content structure
messages.append({"role": "user", "content": user_prompt})
# --- Start of merged logic from call_litellm_with_cache ---
processed_messages = copy.deepcopy(messages) # Work on a copy for modification
is_anthropic = model.startswith("anthropic")
kwargs = {
"model": model,
"messages": processed_messages,
}
call_description = [f"Calling {model}", f"{len(processed_messages)} messages"]
if is_anthropic:
# Add cache_control for Anthropic models
if processed_messages and processed_messages[0]["role"] == "system":
content = processed_messages[0].get("content")
if isinstance(content, list):
for part in content:
if isinstance(part, dict):
part["cache_control"] = {"type": "ephemeral"}
elif isinstance(content, str):
processed_messages[0]["content"] = [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
if processed_messages and processed_messages[-1]["role"] == "user":
content = processed_messages[-1].get("content")
if isinstance(content, list):
for part in content:
if isinstance(part, dict):
part["cache_control"] = {"type": "ephemeral"}
elif isinstance(content, str):
processed_messages[-1]["content"] = [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
call_description.append("cache enabled")
# Add reasoning_effort only for Anthropic models if provided
if reasoning_effort:
kwargs["reasoning_effort"] = reasoning_effort
call_description.append(f"reasoning: {reasoning_effort}")
print(f"\n--- {' | '.join(call_description)} ---")
response = await litellm.acompletion(**kwargs)
print("--- Full Response Object ---")
# Convert response object to dict and print as indented JSON
try:
print(json.dumps(response.dict(), indent=2))
print(response._hidden_params)
except Exception as e:
print(f"Could not format response as JSON: {e}")
print(response) # Fallback to printing the raw object if conversion fails
print("--- End Response ---")
# --- End of merged logic ---
# Append assistant response to the original messages list
if response.choices and response.choices[0].message.content:
messages.append({
"role": "assistant",
"content": response.choices[0].message.content
})
else:
# Handle cases where response might be empty or malformed
print("Warning: Assistant response content is missing.")
messages.append({"role": "assistant", "content": ""}) # Append empty content
return response
async def main(model_name: str, reasoning_effort: str = "medium"):
hello_string = "Hello " * 1234
# Initial messages
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": f"Here is some text: {hello_string}"},
{"role": "assistant", "content": "Okay, I have received the text."},
]
# Turn 1: Ask to count "Hello"
print("\n=== Turn 1: Counting 'Hello' ===")
await run_conversation_turn(
model=model_name,
messages=messages,
user_prompt="How many times does the word 'Hello' appear in the text I provided?",
reasoning_effort=reasoning_effort,
)
# Turn 2: Ask for a short story
print("\n=== Turn 2: Short Story Request ===")
await run_conversation_turn(
model=model_name,
messages=messages,
user_prompt=[ # Using list/dict format for user content
{
"type": "text",
"text": "Great, thanks for counting. Now, can you write a short story (less than 50 words) where the word 'Hello' appears exactly 5 times?",
}
],
reasoning_effort=reasoning_effort,
)
# Turn 3: Ask about the main character
print("\n=== Turn 3: Main Character Question ===")
await run_conversation_turn(
model=model_name,
messages=messages,
user_prompt=[ # Using list/dict format for user content
{
"type": "text",
"text": "Based on the short story you just wrote, who is the main character?",
}
],
reasoning_effort=reasoning_effort,
)
if __name__ == "__main__":
# Select the model to test
model = "anthropic/claude-3-7-sonnet-latest"
# model = "groq/llama-3.3-70b-versatile"
# model = "openai/gpt-4o-mini"
# model = "openai/gpt-4.1-2025-04-14" # Placeholder if needed
print(f"Running test with model: {model}")
asyncio.run(main(
model_name=model,
# reasoning_effort="medium"
reasoning_effort="low"
))

View File

@ -0,0 +1,181 @@
import asyncio
import litellm
import copy
import json
from dotenv import load_dotenv
load_dotenv()
async def run_streaming_conversation_turn(model: str, messages: list, user_prompt: str | list, enable_thinking: bool = False, reasoning_effort: str | None = None):
"""
Handles a single turn of the conversation using streaming, prepares arguments,
and calls litellm.acompletion with stream=True.
Args:
model: The model name string.
messages: The list of message dictionaries (will be modified in place).
user_prompt: The user's prompt for this turn (string or list).
enable_thinking: Boolean to enable thinking for Anthropic models.
reasoning_effort: Optional reasoning effort string for Anthropic models.
Returns:
The final accumulated assistant response dictionary.
"""
# Append user prompt
if isinstance(user_prompt, str):
messages.append({"role": "user", "content": user_prompt})
elif isinstance(user_prompt, list): # Handle list/dict content structure
messages.append({"role": "user", "content": user_prompt})
processed_messages = copy.deepcopy(messages) # Work on a copy for modification
is_anthropic = model.startswith("anthropic")
kwargs = {
"model": model,
"messages": processed_messages,
"stream": True, # Enable streaming
}
if is_anthropic:
# Add cache_control for Anthropic models
if processed_messages and processed_messages[0]["role"] == "system":
content = processed_messages[0].get("content")
if isinstance(content, list):
for part in content:
if isinstance(part, dict):
part["cache_control"] = {"type": "ephemeral"}
elif isinstance(content, str):
processed_messages[0]["content"] = [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
if processed_messages and processed_messages[-1]["role"] == "user":
content = processed_messages[-1].get("content")
if isinstance(content, list):
for part in content:
if isinstance(part, dict):
part["cache_control"] = {"type": "ephemeral"}
elif isinstance(content, str):
processed_messages[-1]["content"] = [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
# Add reasoning_effort only for Anthropic models if provided and thinking is enabled
if enable_thinking and reasoning_effort:
kwargs["reasoning_effort"] = reasoning_effort
stream_response = await litellm.acompletion(**kwargs)
# Collect the full response from streaming chunks
full_response_content = ""
thinking_printed = False
response_printed = False
async for chunk in stream_response:
if chunk.choices and chunk.choices[0].delta:
delta = chunk.choices[0].delta
# Print thinking/reasoning content if present
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
if not thinking_printed:
print("[Thinking]: ", end="", flush=True)
thinking_printed = True
print(f"{delta.reasoning_content}", end="", flush=True) # Print thinking step
# Print and accumulate regular content
if delta.content:
if not response_printed:
# Add newline if thinking was printed before response starts
if thinking_printed:
print() # Newline to separate thinking and response
print("[Response]: ", end="", flush=True)
response_printed = True
chunk_content = delta.content
full_response_content += chunk_content
# Stream to stdout in real-time
print(chunk_content, end="", flush=True)
print() # Newline after streaming finishes
# Print hidden params if available
try:
print("--- Hidden Params ---")
print(stream_response._hidden_params)
print("--- End Hidden Params ---")
except AttributeError:
print("(_hidden_params attribute not found on stream response object)")
except Exception as e:
print(f"Could not print _hidden_params: {e}")
print("--------------------------------")
print() # Add another newline for separation
# Create a complete response object with the full content
final_response = {
"model": model,
"choices": [{
"message": {"role": "assistant", "content": full_response_content}
}]
}
# Add the assistant's response to the messages
messages.append({"role": "assistant", "content": full_response_content})
return final_response
async def main(model_name: str, enable_thinking: bool = False, reasoning_effort: str = "medium"):
"""Runs a multi-turn conversation test with streaming enabled."""
hello_string = "Hello " * 1234
# Initial messages
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": f"Here is some text: {hello_string}"},
{"role": "assistant", "content": "Okay, I have received the text."},
]
# Turn 1: Ask to count "Hello"
await run_streaming_conversation_turn(
model=model_name,
messages=messages,
user_prompt="How many times does the word 'Hello' appear in the text I provided?",
enable_thinking=enable_thinking,
reasoning_effort=reasoning_effort,
)
# Turn 2: Ask for a short story
await run_streaming_conversation_turn(
model=model_name,
messages=messages,
user_prompt=[ # Using list/dict format for user content
{
"type": "text",
"text": "Great, thanks for counting. Now, can you write a short story (less than 50 words) where the word 'Hello' appears exactly 5 times?",
}
],
enable_thinking=enable_thinking,
reasoning_effort=reasoning_effort,
)
# Turn 3: Ask about the main character
await run_streaming_conversation_turn(
model=model_name,
messages=messages,
user_prompt=[ # Using list/dict format for user content
{
"type": "text",
"text": "Based on the short story you just wrote, who is the main character?",
}
],
enable_thinking=enable_thinking,
reasoning_effort=reasoning_effort,
)
if __name__ == "__main__":
# Select the model to test
model = "anthropic/claude-3-7-sonnet-latest"
# model = "openai/gpt-4o-mini"
# model = "openai/gpt-4.1-2025-04-14" # Placeholder if needed
asyncio.run(main(
model_name=model,
enable_thinking=True, # Enable thinking for the test run
# reasoning_effort="medium"
reasoning_effort="low" # Start with low for faster responses
))

View File

@ -1,82 +0,0 @@
import asyncio
import litellm
async def main():
initial_messages=[
# System Message
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement"
* 400,
"cache_control": {"type": "ephemeral"},
}
],
},
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
{
"role": "assistant",
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/month",
},
# The final turn is marked with cache-control, for continuing in followups.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
]
print("--- First call ---")
first_response = await litellm.acompletion(
model="anthropic/claude-3-7-sonnet-latest",
messages=initial_messages
)
print(first_response)
# Prepare messages for the second call
second_call_messages = initial_messages + [
{
"role": "assistant",
# Extract the assistant's response content from the first call
"content": first_response.choices[0].message.content
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "Can you elaborate on the termination clause based on the provided text? Remember the context.",
"cache_control": {"type": "ephemeral"}, # Mark for caching
}
],
},
]
print("\n--- Second call (testing cache) ---")
second_response = await litellm.acompletion(
model="anthropic/claude-3-7-sonnet-latest",
messages=second_call_messages
)
print(second_response)
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,115 @@
"""
Automated test script for the AgentPress agent.
This script sends a specific query ("short report for today's news") to the agent
and prints the streaming output to the console.
"""
import asyncio
from dotenv import load_dotenv
from agentpress.thread_manager import ThreadManager
from services.supabase import DBConnection
from agent.run import process_agent_response # Reusing the processing logic
from utils.logger import logger
async def main():
"""Main function to run the automated agent test."""
load_dotenv() # Ensure environment variables are loaded
logger.info("--- Starting Automated Agent Test: News Report ---")
# Initialize ThreadManager and DBConnection
thread_manager = ThreadManager()
db_connection = DBConnection()
await db_connection.initialize() # Ensure connection is ready
client = await db_connection.client
project_id = None
thread_id = None
try:
# 1. Set up Test Project and Thread
logger.info("Setting up test project and thread...")
# Get user's personal account (replace with actual logic if needed)
# Using a fixed account ID for simplicity in this example
# In a real scenario, you might fetch this dynamically
account_id = "a5fe9cb6-4812-407e-a61c-fe95b7320c59" # Example account ID
test_project_name = "automated_test_project_news"
if not account_id:
logger.error("Error: Could not determine account ID.")
return
# Find or create a test project
project_result = await client.table('projects').select('*').eq('name', test_project_name).eq('account_id',
account_id).limit(1).execute()
if project_result.data:
project_id = project_result.data[0]['project_id']
logger.info(f"Using existing test project: {project_id}")
else:
project_insert_result = await client.table('projects').insert({
"name": test_project_name,
"account_id": account_id
}).execute()
if not project_insert_result.data:
logger.error("Failed to create test project.")
return
project_id = project_insert_result.data[0]['project_id']
logger.info(f"Created new test project: {project_id}")
# Create a new thread for this test run
thread_result = await client.table('threads').insert({
'project_id': project_id,
'account_id': account_id
}).execute()
if not thread_result.data:
logger.error("Error: Failed to create test thread.")
return
thread_id = thread_result.data[0]['thread_id']
logger.info(f"Test Thread Created: {thread_id}")
# 2. Define and Add User Message
user_message = "short report for today's news"
logger.info(f"Adding user message to thread: '{user_message}'")
await thread_manager.add_message(
thread_id=thread_id,
type="user",
content={
"role": "user",
"content": user_message
},
is_llm_message=True # Treat it as a message the LLM should see
)
# 3. Run the Agent and Process Response
logger.info("Running agent and processing response...")
# We reuse the process_agent_response function from run.py which handles the streaming output
await process_agent_response(thread_id, project_id, thread_manager)
logger.info("--- Agent Test Completed ---")
except Exception as e:
logger.error(f"An error occurred during the test: {str(e)}", exc_info=True)
finally:
# Optional: Clean up the created thread?
# if thread_id:
# logger.info(f"Cleaning up test thread: {thread_id}")
# await client.table('messages').delete().eq('thread_id', thread_id).execute()
# await client.table('threads').delete().eq('thread_id', thread_id).execute()
# Disconnect DB
await db_connection.disconnect()
logger.info("Database connection closed.")
if __name__ == "__main__":
# Configure logging if needed (e.g., set level)
# logging.getLogger('agentpress').setLevel(logging.DEBUG)
# Run the main async function
asyncio.run(main())

View File

@ -0,0 +1,215 @@
"""
Test script for running the AgentPress agent with thinking enabled.
This test specifically targets Anthropic models that support the 'reasoning_effort'
parameter to observe the agent's behavior when thinking is explicitly enabled.
"""
import asyncio
import json
import os
import sys
import traceback
from dotenv import load_dotenv
# Ensure the backend directory is in the Python path
backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if backend_dir not in sys.path:
sys.path.insert(0, backend_dir)
import logging
from agentpress.thread_manager import ThreadManager
from services.supabase import DBConnection
from agent.run import run_agent, process_agent_response # Reuse processing logic
from utils.logger import logger
logger.setLevel(logging.DEBUG)
async def test_agent_with_thinking():
"""
Test running the agent with thinking enabled for an Anthropic model.
"""
print("\n" + "="*80)
print("🧪 TESTING AGENT RUN WITH THINKING ENABLED (Anthropic)")
print("="*80 + "\n")
# Load environment variables
load_dotenv()
# Initialize ThreadManager and DBConnection
thread_manager = ThreadManager()
db_connection = DBConnection()
await db_connection.initialize() # Ensure connection is ready
client = await db_connection.client
thread_id = None
project_id = None
project_created = False # Flag to track if we created the project
try:
# --- Test Setup ---
print("🔧 Setting up test environment (Project & Thread)...")
logger.info("Setting up test project and thread...")
# Using a hardcoded account ID for consistency in tests
account_id = "a5fe9cb6-4812-407e-a61c-fe95b7320c59" # Replace if necessary
test_project_name = "test_agent_thinking_project"
logger.info(f"Using Account ID: {account_id}")
if not account_id:
print("❌ Error: Could not determine Account ID.")
logger.error("Could not determine Account ID.")
return
# Find or create a test project
project_result = await client.table('projects').select('*').eq('name', test_project_name).eq('account_id', account_id).limit(1).execute()
if project_result.data:
project_id = project_result.data[0]['project_id']
print(f"🔄 Using existing test project: {project_id}")
logger.info(f"Using existing test project: {project_id}")
else:
project_insert_result = await client.table('projects').insert({
"name": test_project_name,
"account_id": account_id
}).execute()
if not project_insert_result.data:
print("❌ Error: Failed to create test project.")
logger.error("Failed to create test project.")
return
project_id = project_insert_result.data[0]['project_id']
project_created = True
print(f"✨ Created new test project: {project_id}")
logger.info(f"Created new test project: {project_id}")
# Create a new thread for this test run
thread_result = await client.table('threads').insert({
'project_id': project_id,
'account_id': account_id
}).execute()
if not thread_result.data:
print("❌ Error: Failed to create test thread.")
logger.error("Failed to create test thread.")
return
thread_id = thread_result.data[0]['thread_id']
print(f"🧵 Created new test thread: {thread_id}")
logger.info(f"Test Thread Created: {thread_id}")
# Add an initial user message that requires planning
initial_message = "Create a plan to build a simple 'Hello World' HTML page in the workspace, then execute the first step of the plan."
print(f"\n💬 Adding initial user message: '{initial_message}'")
logger.info(f"Adding initial user message: '{initial_message}'")
await thread_manager.add_message(
thread_id=thread_id,
type="user",
content={
"role": "user",
"content": initial_message
},
is_llm_message=True
)
print("✅ Initial message added.")
# --- Run Agent with Thinking Enabled ---
logger.info("Running agent ...")
# Use the process_agent_response helper to handle streaming output.
# Pass the desired model, thinking, and stream parameters directly to it.
await process_agent_response(
thread_id=thread_id,
project_id=project_id,
thread_manager=thread_manager,
stream=False, # Explicitly set stream to True for testing
model_name="anthropic/claude-3-7-sonnet-latest", # Specify the model here
enable_thinking=True, # Enable thinking here
reasoning_effort='low' # Specify effort here
)
# await process_agent_response(
# thread_id=thread_id,
# project_id=project_id,
# thread_manager=thread_manager,
# model_name="openai/gpt-4.1-2025-04-14", # Specify the model here
# model_name="groq/llama-3.3-70b-versatile",
# enable_thinking=False, # Enable thinking here
# reasoning_effort='low' # Specify effort here
# )
# --- Direct Stream Processing (Alternative to process_agent_response) ---
# The direct run_agent call above was removed as process_agent_response handles it.
# print("\n--- Agent Response Stream ---")
# async for chunk in agent_run_generator:
# chunk_type = chunk.get('type', 'unknown')
# if chunk_type == 'content' and 'content' in chunk:
# print(chunk['content'], end='', flush=True)
# elif chunk_type == 'tool_result':
# tool_name = chunk.get('function_name', 'Tool')
# result = chunk.get('result', '')
# print(f"\n\n🛠 TOOL RESULT [{tool_name}] → {result}", flush=True)
# elif chunk_type == 'tool_status':
# status = chunk.get('status', '')
# func_name = chunk.get('function_name', '')
# if status and func_name:
# emoji = "✅" if status == "completed" else "⏳" if status == "started" else "❌"
# print(f"\n{emoji} TOOL {status.upper()}: {func_name}", flush=True)
# elif chunk_type == 'finish':
# reason = chunk.get('finish_reason', '')
# if reason:
# print(f"\n📌 Finished: {reason}", flush=True)
# elif chunk_type == 'error':
# print(f"\n❌ ERROR: {chunk.get('message', 'Unknown error')}", flush=True)
# break # Stop processing on error
print("\n\n✅ Agent run finished.")
logger.info("Agent run finished.")
except Exception as e:
print(f"\n❌ An error occurred during the test: {e}")
logger.error(f"An error occurred during the test: {str(e)}", exc_info=True)
traceback.print_exc()
finally:
# --- Cleanup ---
print("\n🧹 Cleaning up test resources...")
logger.info("Cleaning up test resources...")
if thread_id:
try:
await client.table('messages').delete().eq('thread_id', thread_id).execute()
await client.table('threads').delete().eq('thread_id', thread_id).execute()
print(f"🗑️ Deleted test thread: {thread_id}")
logger.info(f"Deleted test thread: {thread_id}")
except Exception as e:
print(f"⚠️ Error cleaning up thread {thread_id}: {e}")
logger.warning(f"Error cleaning up thread {thread_id}: {e}")
if project_id and project_created: # Only delete if we created it in this run
try:
await client.table('projects').delete().eq('project_id', project_id).execute()
print(f"🗑️ Deleted test project: {project_id}")
logger.info(f"Deleted test project: {project_id}")
except Exception as e:
print(f"⚠️ Error cleaning up project {project_id}: {e}")
logger.warning(f"Error cleaning up project {project_id}: {e}")
# Disconnect DB
await db_connection.disconnect()
logger.info("Database connection closed.")
print("\n" + "="*80)
print("🏁 THINKING TEST COMPLETE")
print("="*80 + "\n")
if __name__ == "__main__":
# Ensure the logger is configured
logger.info("Starting test_agent_thinking script...")
try:
asyncio.run(test_agent_with_thinking())
print("\n✅ Test script completed successfully.")
sys.exit(0)
except KeyboardInterrupt:
print("\n\n❌ Test interrupted by user.")
sys.exit(1)
except Exception as e:
print(f"\n\n❌ Error running test script: {e}")
traceback.print_exc()
sys.exit(1)