mirror of https://github.com/kortix-ai/suna.git
resolved: prompt caching
This commit is contained in:
parent
7de2756b44
commit
88c0d7c934
|
@ -177,4 +177,5 @@ state.json
|
||||||
|
|
||||||
# .DS_Store files
|
# .DS_Store files
|
||||||
.DS_Store
|
.DS_Store
|
||||||
**/.DS_Store
|
**/.DS_Store
|
||||||
|
.aider*
|
||||||
|
|
|
@ -80,8 +80,6 @@ async def run_agent(thread_id: str, project_id: str, stream: bool = True, thread
|
||||||
for tag_name, example in thread_manager.tool_registry.get_xml_examples().items():
|
for tag_name, example in thread_manager.tool_registry.get_xml_examples().items():
|
||||||
xml_examples += f"{example}\n"
|
xml_examples += f"{example}\n"
|
||||||
|
|
||||||
system_message = { "role": "system", "content": get_system_prompt() + "\n\n" + f"<tool_examples>\n{xml_examples}\n</tool_examples>" }
|
|
||||||
|
|
||||||
iteration_count = 0
|
iteration_count = 0
|
||||||
continue_execution = True
|
continue_execution = True
|
||||||
|
|
||||||
|
@ -109,24 +107,46 @@ async def run_agent(thread_id: str, project_id: str, stream: bool = True, thread
|
||||||
print(f"Last message was from assistant, stopping execution")
|
print(f"Last message was from assistant, stopping execution")
|
||||||
continue_execution = False
|
continue_execution = False
|
||||||
break
|
break
|
||||||
# Get the latest message from messages table that its tpye is browser_state
|
|
||||||
|
# Define Processor Config FIRST
|
||||||
|
processor_config = ProcessorConfig(
|
||||||
|
xml_tool_calling=True,
|
||||||
|
native_tool_calling=False,
|
||||||
|
execute_tools=True,
|
||||||
|
execute_on_stream=True,
|
||||||
|
tool_execution_strategy="parallel",
|
||||||
|
xml_adding_strategy="user_message"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Construct System Message Conditionally
|
||||||
|
base_system_prompt_content = get_system_prompt()
|
||||||
|
system_message_content = base_system_prompt_content
|
||||||
|
|
||||||
|
# Conditionally add XML examples based on the config
|
||||||
|
if processor_config.xml_tool_calling:
|
||||||
|
# Use the already loaded xml_examples from outside the loop
|
||||||
|
if xml_examples:
|
||||||
|
system_message_content += "\n\n" + f"<tool_examples>\n{xml_examples}\n</tool_examples>"
|
||||||
|
|
||||||
|
system_message = { "role": "system", "content": system_message_content }
|
||||||
|
|
||||||
|
# Handle Temporary Message (Browser State)
|
||||||
latest_browser_state = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'browser_state').order('created_at', desc=True).limit(1).execute()
|
latest_browser_state = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'browser_state').order('created_at', desc=True).limit(1).execute()
|
||||||
temporary_message = None
|
temporary_message = None
|
||||||
if latest_browser_state.data and len(latest_browser_state.data) > 0:
|
if latest_browser_state.data and len(latest_browser_state.data) > 0:
|
||||||
try:
|
try:
|
||||||
content = json.loads(latest_browser_state.data[0]["content"])
|
content = json.loads(latest_browser_state.data[0]["content"])
|
||||||
screenshot_base64 = content["screenshot_base64"]
|
screenshot_base64 = content.get("screenshot_base64") # Use .get() for safety
|
||||||
# Create a copy of the browser state without screenshot
|
# Create a copy of the browser state without screenshot
|
||||||
browser_state = content.copy()
|
browser_state = content.copy()
|
||||||
browser_state.pop('screenshot_base64', None)
|
browser_state.pop('screenshot_base64', None)
|
||||||
browser_state.pop('screenshot_url', None)
|
browser_state.pop('screenshot_url', None)
|
||||||
browser_state.pop('screenshot_url_base64', None)
|
browser_state.pop('screenshot_url_base64', None)
|
||||||
temporary_message = { "role": "user", "content": [] }
|
temporary_message = { "role": "user", "content": [] }
|
||||||
if browser_state:
|
if browser_state:
|
||||||
temporary_message["content"].append({
|
temporary_message["content"].append({
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": f"The following is the current state of the browser:\n{browser_state}"
|
"text": f"The following is the current state of the browser:\n{json.dumps(browser_state, indent=2)}" # Pretty print browser state
|
||||||
})
|
})
|
||||||
if screenshot_base64:
|
if screenshot_base64:
|
||||||
temporary_message["content"].append({
|
temporary_message["content"].append({
|
||||||
|
@ -136,14 +156,15 @@ async def run_agent(thread_id: str, project_id: str, stream: bool = True, thread
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
print("@@@@@ THIS TIME NO SCREENSHOT!!")
|
print("No screenshot found in the latest browser state message.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error parsing browser state: {e}")
|
print(f"Error parsing browser state: {e}")
|
||||||
# print(latest_browser_state.data[0])
|
# print(latest_browser_state.data[0])
|
||||||
|
|
||||||
|
# Run Thread
|
||||||
response = await thread_manager.run_thread(
|
response = await thread_manager.run_thread(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
system_prompt=system_message,
|
system_prompt=system_message, # Pass the constructed message
|
||||||
stream=stream,
|
stream=stream,
|
||||||
llm_model=os.getenv("MODEL_TO_USE", "anthropic/claude-3-7-sonnet-latest"),
|
llm_model=os.getenv("MODEL_TO_USE", "anthropic/claude-3-7-sonnet-latest"),
|
||||||
llm_temperature=0,
|
llm_temperature=0,
|
||||||
|
@ -151,16 +172,10 @@ async def run_agent(thread_id: str, project_id: str, stream: bool = True, thread
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
max_xml_tool_calls=1,
|
max_xml_tool_calls=1,
|
||||||
temporary_message=temporary_message,
|
temporary_message=temporary_message,
|
||||||
processor_config=ProcessorConfig(
|
processor_config=processor_config, # Pass the config object
|
||||||
xml_tool_calling=True,
|
|
||||||
native_tool_calling=False,
|
|
||||||
execute_tools=True,
|
|
||||||
execute_on_stream=True,
|
|
||||||
tool_execution_strategy="parallel",
|
|
||||||
xml_adding_strategy="user_message"
|
|
||||||
),
|
|
||||||
native_max_auto_continues=native_max_auto_continues,
|
native_max_auto_continues=native_max_auto_continues,
|
||||||
include_xml_examples=True,
|
# Explicitly set include_xml_examples to False here
|
||||||
|
include_xml_examples=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(response, dict) and "status" in response and response["status"] == "error":
|
if isinstance(response, dict) and "status" in response and response["status"] == "error":
|
||||||
|
|
|
@ -447,14 +447,21 @@ class ResponseProcessor:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Add assistant message with accumulated content
|
# Add assistant message with accumulated content
|
||||||
|
# Start with base message data
|
||||||
message_data = {
|
message_data = {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": accumulated_content,
|
"content": accumulated_content
|
||||||
"tool_calls": complete_native_tool_calls if config.native_tool_calling and complete_native_tool_calls else None
|
# tool_calls key is initially omitted
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Conditionally add tool_calls if they exist and native calling is enabled
|
||||||
|
if config.native_tool_calling and complete_native_tool_calls:
|
||||||
|
message_data["tool_calls"] = complete_native_tool_calls
|
||||||
|
|
||||||
|
# Add the message (tool_calls will only be present if added above)
|
||||||
await self.add_message(
|
await self.add_message(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
type="assistant",
|
type="assistant",
|
||||||
content=message_data,
|
content=message_data,
|
||||||
is_llm_message=True
|
is_llm_message=True
|
||||||
)
|
)
|
||||||
|
@ -657,14 +664,22 @@ class ResponseProcessor:
|
||||||
})
|
})
|
||||||
|
|
||||||
# Add assistant message FIRST - always do this regardless of finish_reason
|
# Add assistant message FIRST - always do this regardless of finish_reason
|
||||||
|
# Start with base message data
|
||||||
message_data = {
|
message_data = {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": content,
|
"content": content
|
||||||
"tool_calls": native_tool_calls if config.native_tool_calling and 'native_tool_calls' in locals() else None
|
# tool_calls key is initially omitted
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Conditionally add tool_calls if they exist and native calling is enabled
|
||||||
|
# Use 'native_tool_calls' in locals() check for safety as before
|
||||||
|
if config.native_tool_calling and 'native_tool_calls' in locals() and native_tool_calls:
|
||||||
|
message_data["tool_calls"] = native_tool_calls
|
||||||
|
|
||||||
|
# Add the message
|
||||||
await self.add_message(
|
await self.add_message(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
type="assistant",
|
type="assistant",
|
||||||
content=message_data,
|
content=message_data,
|
||||||
is_llm_message=True
|
is_llm_message=True
|
||||||
)
|
)
|
||||||
|
@ -1319,4 +1334,4 @@ class ResponseProcessor:
|
||||||
"xml_tag_name": context.xml_tag_name,
|
"xml_tag_name": context.xml_tag_name,
|
||||||
"message": f"Error executing tool: {error_msg}",
|
"message": f"Error executing tool: {error_msg}",
|
||||||
"tool_index": context.tool_index
|
"tool_index": context.tool_index
|
||||||
}
|
}
|
||||||
|
|
|
@ -198,32 +198,6 @@ class ThreadManager:
|
||||||
if max_xml_tool_calls > 0:
|
if max_xml_tool_calls > 0:
|
||||||
processor_config.max_xml_tool_calls = max_xml_tool_calls
|
processor_config.max_xml_tool_calls = max_xml_tool_calls
|
||||||
|
|
||||||
# Add XML examples to system prompt if requested
|
|
||||||
if include_xml_examples and processor_config.xml_tool_calling:
|
|
||||||
xml_examples = self.tool_registry.get_xml_examples()
|
|
||||||
if xml_examples:
|
|
||||||
# logger.debug(f"Adding {len(xml_examples)} XML examples to system prompt")
|
|
||||||
|
|
||||||
# Create or append to content
|
|
||||||
if isinstance(system_prompt['content'], str):
|
|
||||||
examples_content = """
|
|
||||||
--- XML TOOL CALLING ---
|
|
||||||
|
|
||||||
In this environment you have access to a set of tools you can use to answer the user's question. The tools are specified in XML format.
|
|
||||||
{{ FORMATTING INSTRUCTIONS }}
|
|
||||||
String and scalar parameters should be specified as attributes, while content goes between tags.
|
|
||||||
Note that spaces for string values are not stripped. The output is parsed with regular expressions.
|
|
||||||
|
|
||||||
Here are the XML tools available with examples:
|
|
||||||
"""
|
|
||||||
for tag_name, example in xml_examples.items():
|
|
||||||
examples_content += f"<{tag_name}> Example: {example}\n"
|
|
||||||
|
|
||||||
system_prompt['content'] += examples_content
|
|
||||||
else:
|
|
||||||
# If content is not a string (might be a list or dict), log a warning
|
|
||||||
logger.warning("System prompt content is not a string, cannot add XML examples")
|
|
||||||
|
|
||||||
# 1. Get messages from thread for LLM call
|
# 1. Get messages from thread for LLM call
|
||||||
messages = await self.get_llm_messages(thread_id)
|
messages = await self.get_llm_messages(thread_id)
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@ from typing import Union, Dict, Any, Optional, AsyncGenerator, List
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import time # Added for timestamp
|
||||||
from openai import OpenAIError
|
from openai import OpenAIError
|
||||||
import litellm
|
import litellm
|
||||||
from utils.logger import logger
|
from utils.logger import logger
|
||||||
|
@ -26,6 +27,9 @@ MAX_RETRIES = 3
|
||||||
RATE_LIMIT_DELAY = 30
|
RATE_LIMIT_DELAY = 30
|
||||||
RETRY_DELAY = 5
|
RETRY_DELAY = 5
|
||||||
|
|
||||||
|
# Define debug log directory relative to this file's location
|
||||||
|
DEBUG_LOG_DIR = os.path.join(os.path.dirname(__file__), '..', 'debug_logs') # Assumes backend/debug_logs
|
||||||
|
|
||||||
class LLMError(Exception):
|
class LLMError(Exception):
|
||||||
"""Base exception for LLM-related errors."""
|
"""Base exception for LLM-related errors."""
|
||||||
pass
|
pass
|
||||||
|
@ -208,15 +212,116 @@ async def make_llm_api_call(
|
||||||
model_id=model_id
|
model_id=model_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Apply Anthropic prompt caching (minimal implementation)
|
||||||
|
if params["model"].startswith("anthropic/"):
|
||||||
|
logger.debug("Applying minimal Anthropic prompt caching.")
|
||||||
|
messages = params["messages"] # Direct reference
|
||||||
|
|
||||||
|
# 1. Process the first message if it's a system prompt with string content
|
||||||
|
if messages and messages[0].get("role") == "system":
|
||||||
|
content = messages[0].get("content")
|
||||||
|
if isinstance(content, str):
|
||||||
|
messages[0]["content"] = [
|
||||||
|
{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
|
||||||
|
]
|
||||||
|
logger.debug("Applied cache_control to system message.")
|
||||||
|
modified = True
|
||||||
|
elif not isinstance(content, list):
|
||||||
|
logger.warning("System message content is not a string or list, skipping cache_control.")
|
||||||
|
# else: content is already a list, do nothing
|
||||||
|
|
||||||
|
# 2. Find and process the last user message
|
||||||
|
last_user_idx = -1
|
||||||
|
for i in range(len(messages) - 1, -1, -1):
|
||||||
|
if messages[i].get("role") == "user":
|
||||||
|
last_user_idx = i
|
||||||
|
break
|
||||||
|
|
||||||
|
if last_user_idx != -1:
|
||||||
|
last_user_message = messages[last_user_idx]
|
||||||
|
content = last_user_message.get("content")
|
||||||
|
applied_to_user = False
|
||||||
|
|
||||||
|
if isinstance(content, str):
|
||||||
|
last_user_message["content"] = [
|
||||||
|
{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
|
||||||
|
]
|
||||||
|
logger.debug(f"Applied cache_control to last user message (string content, index {last_user_idx}).")
|
||||||
|
applied_to_user = True
|
||||||
|
elif isinstance(content, list):
|
||||||
|
# Modify text blocks within the list directly
|
||||||
|
found_text_block = False
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, dict) and item.get("type") == "text":
|
||||||
|
# Add cache_control if not already present (avoids adding it multiple times)
|
||||||
|
if "cache_control" not in item:
|
||||||
|
item["cache_control"] = {"type": "ephemeral"}
|
||||||
|
found_text_block = True # Mark modification only if added
|
||||||
|
|
||||||
|
if found_text_block:
|
||||||
|
logger.debug(f"Applied cache_control to text part(s) of last user message (list content, index {last_user_idx}).")
|
||||||
|
applied_to_user = True
|
||||||
|
# else: No text block found or cache_control already present, do nothing
|
||||||
|
else:
|
||||||
|
logger.warning(f"Last user message (index {last_user_idx}) content is not a string or list ({type(content)}), skipping cache_control.")
|
||||||
|
|
||||||
|
if applied_to_user:
|
||||||
|
modified = True
|
||||||
|
|
||||||
|
# --- Debug Logging Setup ---
|
||||||
|
# Initialize log path to None, it will be set only if logging is enabled
|
||||||
|
response_log_path = None
|
||||||
|
enable_debug_logging = os.environ.get('ENABLE_LLM_DEBUG_LOGGING', 'false').lower() == 'true'
|
||||||
|
|
||||||
|
if enable_debug_logging:
|
||||||
|
try:
|
||||||
|
os.makedirs(DEBUG_LOG_DIR, exist_ok=True)
|
||||||
|
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||||
|
# Use a unique ID or counter if calls can happen in the same second
|
||||||
|
# For simplicity, using timestamp only for now
|
||||||
|
request_log_path = os.path.join(DEBUG_LOG_DIR, f"llm_request_{timestamp}.json")
|
||||||
|
response_log_path = os.path.join(DEBUG_LOG_DIR, f"llm_response_{timestamp}.json") # Set here if enabled
|
||||||
|
|
||||||
|
# Log the request parameters just before the attempt loop
|
||||||
|
logger.debug(f"Logging LLM request parameters to {request_log_path}")
|
||||||
|
with open(request_log_path, 'w') as f:
|
||||||
|
# Use default=str for potentially non-serializable items in params if needed
|
||||||
|
json.dump(params, f, indent=2, default=str)
|
||||||
|
|
||||||
|
except Exception as log_err:
|
||||||
|
logger.error(f"Failed to set up or write LLM debug request log: {log_err}", exc_info=True)
|
||||||
|
# Reset response path to None if setup failed, even if logging was enabled
|
||||||
|
response_log_path = None
|
||||||
|
else:
|
||||||
|
logger.debug("LLM debug logging is disabled via environment variable.")
|
||||||
|
# --- End Debug Logging Setup ---
|
||||||
|
|
||||||
last_error = None
|
last_error = None
|
||||||
for attempt in range(MAX_RETRIES):
|
for attempt in range(MAX_RETRIES):
|
||||||
try:
|
try:
|
||||||
logger.debug(f"Attempt {attempt + 1}/{MAX_RETRIES}")
|
logger.debug(f"Attempt {attempt + 1}/{MAX_RETRIES}")
|
||||||
# logger.debug(f"API request parameters: {json.dumps(params, indent=2)}")
|
|
||||||
|
|
||||||
response = await litellm.acompletion(**params)
|
response = await litellm.acompletion(**params)
|
||||||
logger.debug(f"Successfully received API response from {model_name}")
|
logger.debug(f"Successfully received API response from {model_name}")
|
||||||
logger.debug(f"Response: {response}")
|
|
||||||
|
# --- Debug Logging Response ---
|
||||||
|
if response_log_path: # Only log if request logging setup succeeded
|
||||||
|
try:
|
||||||
|
logger.debug(f"Logging LLM response object to {response_log_path}")
|
||||||
|
# Check if it's a streaming response (AsyncGenerator)
|
||||||
|
if isinstance(response, AsyncGenerator):
|
||||||
|
with open(response_log_path, 'w') as f:
|
||||||
|
json.dump({"status": "streaming_response", "message": "Full response logged chunk by chunk where consumed."}, f, indent=2)
|
||||||
|
else:
|
||||||
|
# Assume it's a LiteLLM ModelResponse object, convert to dict
|
||||||
|
response_dict = response.dict()
|
||||||
|
with open(response_log_path, 'w') as f:
|
||||||
|
# Use default=str for potentially non-serializable items like datetime
|
||||||
|
json.dump(response_dict, f, indent=2, default=str)
|
||||||
|
except Exception as log_err:
|
||||||
|
logger.error(f"Failed to write LLM debug response log: {log_err}", exc_info=True)
|
||||||
|
# --- End Debug Logging Response ---
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except (litellm.exceptions.RateLimitError, OpenAIError, json.JSONDecodeError) as e:
|
except (litellm.exceptions.RateLimitError, OpenAIError, json.JSONDecodeError) as e:
|
||||||
|
|
|
@ -0,0 +1,82 @@
|
||||||
|
import asyncio
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
initial_messages=[
|
||||||
|
# System Message
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Here is the full text of a complex legal agreement"
|
||||||
|
* 400,
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What are the key terms and conditions in this agreement?",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/month",
|
||||||
|
},
|
||||||
|
# The final turn is marked with cache-control, for continuing in followups.
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What are the key terms and conditions in this agreement?",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
print("--- First call ---")
|
||||||
|
first_response = await litellm.acompletion(
|
||||||
|
model="anthropic/claude-3-7-sonnet-latest",
|
||||||
|
messages=initial_messages
|
||||||
|
)
|
||||||
|
print(first_response)
|
||||||
|
|
||||||
|
# Prepare messages for the second call
|
||||||
|
second_call_messages = initial_messages + [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
# Extract the assistant's response content from the first call
|
||||||
|
"content": first_response.choices[0].message.content
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Can you elaborate on the termination clause based on the provided text? Remember the context.",
|
||||||
|
"cache_control": {"type": "ephemeral"}, # Mark for caching
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
print("\n--- Second call (testing cache) ---")
|
||||||
|
second_response = await litellm.acompletion(
|
||||||
|
model="anthropic/claude-3-7-sonnet-latest",
|
||||||
|
messages=second_call_messages
|
||||||
|
)
|
||||||
|
print(second_response)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,159 @@
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Ensure the backend directory is in the Python path
|
||||||
|
backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||||
|
if backend_dir not in sys.path:
|
||||||
|
sys.path.insert(0, backend_dir)
|
||||||
|
|
||||||
|
import logging # Import logging module
|
||||||
|
from agentpress.thread_manager import ThreadManager
|
||||||
|
from services.supabase import DBConnection
|
||||||
|
from agent.run import run_agent
|
||||||
|
from utils.logger import logger
|
||||||
|
|
||||||
|
# Set logging level to DEBUG specifically for this test script
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
# Optionally, adjust handler levels if needed (e.g., for console output)
|
||||||
|
for handler in logger.handlers:
|
||||||
|
if isinstance(handler, logging.StreamHandler): # Target console handler
|
||||||
|
handler.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
async def test_agent_limited_iterations():
|
||||||
|
"""
|
||||||
|
Test running the agent for a maximum of 3 iterations in non-streaming mode
|
||||||
|
and print the collected response chunks.
|
||||||
|
"""
|
||||||
|
print("\n" + "="*80)
|
||||||
|
print("🧪 TESTING AGENT RUN WITH MAX ITERATIONS (max_iterations=3, stream=False)")
|
||||||
|
print("="*80 + "\n")
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Initialize ThreadManager and DBConnection
|
||||||
|
thread_manager = ThreadManager()
|
||||||
|
db_connection = DBConnection()
|
||||||
|
client = await db_connection.client
|
||||||
|
|
||||||
|
thread_id = None
|
||||||
|
project_id = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# --- Test Setup ---
|
||||||
|
print("🔧 Setting up test environment (Project & Thread)...")
|
||||||
|
|
||||||
|
# Get user's personal account (replace with a specific test account if needed)
|
||||||
|
# Using a hardcoded account ID for consistency in tests
|
||||||
|
account_id = "a5fe9cb6-4812-407e-a61c-fe95b7320c59" # Replace if necessary
|
||||||
|
logger.info(f"Using Account ID: {account_id}")
|
||||||
|
|
||||||
|
if not account_id:
|
||||||
|
print("❌ Error: Could not determine Account ID.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Find or create a test project
|
||||||
|
project_name = "test_simple_dat"
|
||||||
|
project_result = await client.table('projects').select('*').eq('name', project_name).eq('account_id', account_id).execute()
|
||||||
|
|
||||||
|
if project_result.data and len(project_result.data) > 0:
|
||||||
|
project_id = project_result.data[0]['project_id']
|
||||||
|
print(f"🔄 Using existing test project: {project_id}")
|
||||||
|
else:
|
||||||
|
project_result = await client.table('projects').insert({
|
||||||
|
"name": project_name,
|
||||||
|
"account_id": account_id
|
||||||
|
}).execute()
|
||||||
|
project_id = project_result.data[0]['project_id']
|
||||||
|
print(f"✨ Created new test project: {project_id}")
|
||||||
|
|
||||||
|
# Create a new thread for this test
|
||||||
|
thread_result = await client.table('threads').insert({
|
||||||
|
'project_id': project_id,
|
||||||
|
'account_id': account_id
|
||||||
|
}).execute()
|
||||||
|
thread_id = thread_result.data[0]['thread_id']
|
||||||
|
print(f"🧵 Created new test thread: {thread_id}")
|
||||||
|
|
||||||
|
# Add an initial user message to kick off the agent
|
||||||
|
initial_message = ("Hello " * 123) + "\\n\\nHow many times did the word 'Hello' appear in the previous text?"
|
||||||
|
print(f"\\n💬 Adding initial user message: Preview='{initial_message[:50]}...'") # Print only a preview
|
||||||
|
await thread_manager.add_message(
|
||||||
|
thread_id=thread_id,
|
||||||
|
type="user",
|
||||||
|
content={
|
||||||
|
"role": "user",
|
||||||
|
"content": initial_message
|
||||||
|
},
|
||||||
|
is_llm_message=True
|
||||||
|
)
|
||||||
|
print("✅ Initial message added.")
|
||||||
|
|
||||||
|
# --- Run Agent ---
|
||||||
|
print("\n🔄 Running agent (max_iterations=3, stream=False)...")
|
||||||
|
all_chunks = []
|
||||||
|
agent_run_generator = run_agent(
|
||||||
|
thread_id=thread_id,
|
||||||
|
project_id=project_id,
|
||||||
|
stream=False, # Non-streaming
|
||||||
|
thread_manager=thread_manager,
|
||||||
|
max_iterations=5 # Limit iterations
|
||||||
|
)
|
||||||
|
|
||||||
|
async for chunk in agent_run_generator:
|
||||||
|
chunk_type = chunk.get('type', 'unknown')
|
||||||
|
print(f" 📦 Received chunk: type='{chunk_type}'")
|
||||||
|
all_chunks.append(chunk)
|
||||||
|
|
||||||
|
print("\n✅ Agent run finished.")
|
||||||
|
|
||||||
|
# --- Print Results ---
|
||||||
|
print("\n📄 Full collected response chunks:")
|
||||||
|
# Use json.dumps for pretty printing the list of dictionaries
|
||||||
|
print(json.dumps(all_chunks, indent=2, default=str)) # Use default=str for non-serializable types like datetime
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ An error occurred during the test: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
finally:
|
||||||
|
# Optional: Clean up the created thread and project
|
||||||
|
print("\n🧹 Cleaning up test resources...")
|
||||||
|
if thread_id:
|
||||||
|
await client.table('threads').delete().eq('thread_id', thread_id).execute()
|
||||||
|
print(f"🗑️ Deleted test thread: {thread_id}")
|
||||||
|
if project_id and not project_result.data: # Only delete if we created it
|
||||||
|
await client.table('projects').delete().eq('project_id', project_id).execute()
|
||||||
|
print(f"🗑️ Deleted test project: {project_id}")
|
||||||
|
|
||||||
|
print("\n" + "="*80)
|
||||||
|
print("🏁 TEST COMPLETE")
|
||||||
|
print("="*80 + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Ensure the logger is configured
|
||||||
|
logger.info("Starting test_agent_max_iterations script...")
|
||||||
|
try:
|
||||||
|
asyncio.run(test_agent_limited_iterations())
|
||||||
|
print("\n✅ Test script completed successfully.")
|
||||||
|
sys.exit(0)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\n❌ Test interrupted by user.")
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n\n❌ Error running test script: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# before result
|
||||||
|
# 2025-04-16 19:20:20,494 - DEBUG - Response: ModelResponse(id='chatcmpl-2c5c1418-4570-435c-8d31-5c7ef63a1a68', created=1744827620, model='claude-3-7-sonnet-20250219', object='chat.completion', system_fingerprint=None, choices=[Choices(finish_reason='stop', index=0, message=Message(content='I\'ll update the existing todo.md file and then proceed with counting the "Hello" occurrences.\n\n<full-file-rewrite file_path="todo.md">\n# Hello Count Task\n\n## Setup\n- [ ] Create a file to store the input text\n- [ ] Create a script to count occurrences of "Hello"\n\n## Analysis\n- [ ] Run the script to count occurrences\n- [ ] Verify the results\n\n## Delivery\n- [ ] Provide the final count to the user\n</full-file-rewrite>', role='assistant', tool_calls=None, function_call=None, provider_specific_fields={'citations': None, 'thinking_blocks': None}))], usage=Usage(completion_tokens=125, prompt_tokens=14892, total_tokens=15017, completion_tokens_details=None, prompt_tokens_details=PromptTokensDetailsWrapper(audio_tokens=None, cached_tokens=0, text_tokens=None, image_tokens=None), cache_creation_input_tokens=0, cache_read_input_tokens=0))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# after result
|
||||||
|
# read cache should > 0 (and it does)
|
Loading…
Reference in New Issue