2025-04-04 23:06:49 +08:00
"""
Tests for tool execution strategies in AgentPress .
This module tests both sequential and parallel execution strategies using the WaitTool
in a realistic thread with XML tool calls .
"""
import os
import asyncio
import sys
from unittest . mock import AsyncMock , patch
from dotenv import load_dotenv
from agentpress . thread_manager import ThreadManager
from agentpress . response_processor import ProcessorConfig
from agent . tools . wait_tool import WaitTool
# Load environment variables
load_dotenv ( )
TOOL_XML_SEQUENTIAL = """
Here are some examples of using the wait tool :
< wait seconds = " 2 " > This is sequential wait 1 < / wait >
< wait seconds = " 2 " > This is sequential wait 2 < / wait >
< wait seconds = " 2 " > This is sequential wait 3 < / wait >
Now wait sequence :
< wait - sequence count = " 3 " seconds = " 1 " label = " Sequential Test " / >
"""
TOOL_XML_PARALLEL = """
Here are some examples of using the wait tool :
< wait seconds = " 2 " > This is parallel wait 1 < / wait >
< wait seconds = " 2 " > This is parallel wait 2 < / wait >
< wait seconds = " 2 " > This is parallel wait 3 < / wait >
Now wait sequence :
< wait - sequence count = " 3 " seconds = " 1 " label = " Parallel Test " / >
"""
# Create a simple mock function that logs instead of accessing the database
async def mock_add_message ( thread_id , message ) :
print ( f " MOCK: Adding message to thread { thread_id } " )
print ( f " MOCK: Message role: { message . get ( ' role ' ) } " )
print ( f " MOCK: Content length: { len ( message . get ( ' content ' , ' ' ) ) } " )
return { " id " : " mock-message-id " , " thread_id " : thread_id }
async def test_execution_strategies ( ) :
""" Test both sequential and parallel execution strategies in a thread. """
print ( " \n " + " = " * 80 )
print ( " 🧪 TESTING TOOL EXECUTION STRATEGIES " )
print ( " = " * 80 + " \n " )
# Initialize ThreadManager and register tools
thread_manager = ThreadManager ( )
thread_manager . add_tool ( WaitTool )
# Mock both ThreadManager's and ResponseProcessor's add_message method
thread_manager . add_message = AsyncMock ( side_effect = mock_add_message )
# This is crucial - the ResponseProcessor receives add_message as a callback
thread_manager . response_processor . add_message = AsyncMock ( side_effect = mock_add_message )
# Create a test thread - we'll use a dummy ID since we're mocking the database
thread_id = " test-thread-id "
print ( f " 🧵 Using test thread: { thread_id } \n " )
2025-04-11 07:53:01 +08:00
# Set up the get_llm_messages mock
original_get_llm_messages = thread_manager . get_llm_messages
thread_manager . get_llm_messages = AsyncMock ( )
2025-04-04 23:06:49 +08:00
# Test both strategies
test_cases = [
{ " name " : " Sequential " , " strategy " : " sequential " , " content " : TOOL_XML_SEQUENTIAL } ,
{ " name " : " Parallel " , " strategy " : " parallel " , " content " : TOOL_XML_PARALLEL }
]
# Expected values for validation - this varies based on XML parsing
# For reliable testing, we look at <wait> tags which we know are being parsed
expected_wait_count = 3 # 3 wait tags per test
test_results = { }
for test in test_cases :
print ( " \n " + " - " * 60 )
print ( f " 🔍 Testing { test [ ' name ' ] } Execution Strategy " )
print ( " - " * 60 + " \n " )
2025-04-11 07:53:01 +08:00
# Setup mock for get_llm_messages to return our test content
thread_manager . get_llm_messages . return_value = [
2025-04-04 23:06:49 +08:00
{
" role " : " system " ,
" content " : " You are a testing assistant that will execute wait commands. "
} ,
{
" role " : " assistant " ,
" content " : test [ " content " ]
}
]
# Simulate adding message (mocked)
print ( f " MOCK: Adding test message with { test [ ' name ' ] } execution strategy content " )
await thread_manager . add_message (
2025-04-06 17:10:18 +08:00
thread_id = thread_id ,
type = " assistant " ,
content = {
2025-04-04 23:06:49 +08:00
" role " : " assistant " ,
" content " : test [ " content " ]
2025-04-06 17:10:18 +08:00
} ,
is_llm_message = True
2025-04-04 23:06:49 +08:00
)
start_time = asyncio . get_event_loop ( ) . time ( )
print ( f " ⏱️ Starting execution with { test [ ' strategy ' ] } strategy at { start_time : .2f } s " )
# Process the response with appropriate strategy
config = ProcessorConfig (
xml_tool_calling = True ,
native_tool_calling = False ,
execute_tools = True ,
execute_on_stream = False ,
tool_execution_strategy = test [ " strategy " ]
)
# Get the last message to process (mocked)
2025-04-11 07:53:01 +08:00
messages = await thread_manager . get_llm_messages ( thread_id )
2025-04-04 23:06:49 +08:00
last_message = messages [ - 1 ]
# Create a simple non-streaming response object
class MockResponse :
def __init__ ( self , content ) :
self . choices = [ type ( ' obj ' , ( object , ) , {
' message ' : type ( ' obj ' , ( object , ) , {
' content ' : content
} )
} ) ]
mock_response = MockResponse ( last_message [ " content " ] )
# Process using the response processor
tool_execution_count = 0
wait_tool_count = 0
tool_results = [ ]
async for chunk in thread_manager . response_processor . process_non_streaming_response (
llm_response = mock_response ,
thread_id = thread_id ,
config = config
) :
if chunk . get ( ' type ' ) == ' tool_result ' :
tool_name = chunk . get ( ' name ' , ' ' )
tool_execution_count + = 1
if tool_name == ' wait ' :
wait_tool_count + = 1
elapsed = asyncio . get_event_loop ( ) . time ( ) - start_time
print ( f " ⏱️ [ { elapsed : .2f } s] Tool result: { chunk [ ' name ' ] } " )
print ( f " { chunk [ ' result ' ] } " )
print ( )
tool_results . append ( chunk )
end_time = asyncio . get_event_loop ( ) . time ( )
elapsed = end_time - start_time
print ( f " \n ⏱️ { test [ ' name ' ] } execution completed in { elapsed : .2f } seconds " )
print ( f " 🔢 Total tool executions: { tool_execution_count } " )
print ( f " 🔢 Wait tool executions: { wait_tool_count } " )
# Store results for validation
test_results [ test [ ' name ' ] ] = {
' execution_time ' : elapsed ,
' tool_count ' : tool_execution_count ,
' wait_count ' : wait_tool_count ,
' tool_results ' : tool_results
}
# Assert correct number of wait tools executions (this is more reliable than total count)
assert wait_tool_count == expected_wait_count , f " ❌ Expected { expected_wait_count } wait tool executions, got { wait_tool_count } in { test [ ' name ' ] } strategy "
print ( f " ✅ PASS: { test [ ' name ' ] } executed { wait_tool_count } wait tools as expected " )
2025-04-11 07:53:01 +08:00
# Restore original get_llm_messages method
thread_manager . get_llm_messages = original_get_llm_messages
2025-04-04 23:06:49 +08:00
# Additional assertions for both test cases
assert ' Sequential ' in test_results , " ❌ Sequential test not completed "
assert ' Parallel ' in test_results , " ❌ Parallel test not completed "
# Validate parallel is faster than sequential for multiple wait tools
sequential_time = test_results [ ' Sequential ' ] [ ' execution_time ' ]
parallel_time = test_results [ ' Parallel ' ] [ ' execution_time ' ]
speedup = sequential_time / parallel_time if parallel_time > 0 else 0
# Parallel should be faster than sequential (at least 1.5x speedup expected)
print ( f " \n ⏱️ Execution time comparison: " )
print ( f " Sequential: { sequential_time : .2f } s " )
print ( f " Parallel: { parallel_time : .2f } s " )
print ( f " Speedup: { speedup : .2f } x " )
min_expected_speedup = 1.5
assert speedup > = min_expected_speedup , f " ❌ Expected parallel execution to be at least { min_expected_speedup } x faster than sequential, but got { speedup : .2f } x "
print ( f " ✅ PASS: Parallel execution is { speedup : .2f } x faster than sequential " )
# Check if all results have a status field
all_have_status = all (
' status ' in result
for test_data in test_results . values ( )
for result in test_data [ ' tool_results ' ]
)
# If results have a status field, check if they're all successful
if all_have_status :
all_successful = all (
result . get ( ' status ' ) == ' success '
for test_data in test_results . values ( )
for result in test_data [ ' tool_results ' ]
)
assert all_successful , " ❌ Not all tool executions were successful "
print ( " ✅ PASS: All tool executions completed successfully " )
print ( " \n " + " = " * 80 )
print ( " ✅ ALL TESTS PASSED " )
print ( " = " * 80 + " \n " )
return test_results
if __name__ == " __main__ " :
try :
asyncio . run ( test_execution_strategies ( ) )
print ( " \n ✅ Test completed successfully " )
sys . exit ( 0 )
except AssertionError as e :
print ( f " \n \n ❌ Test failed: { str ( e ) } " )
sys . exit ( 1 )
except KeyboardInterrupt :
print ( " \n \n ❌ Test interrupted by user " )
sys . exit ( 1 )
except Exception as e :
print ( f " \n \n ❌ Error during test: { str ( e ) } " )
sys . exit ( 1 )