resolved: prompt caching

This commit is contained in:
LE Quoc Dat 2025-04-17 00:54:06 +01:00
parent 7de2756b44
commit 88c0d7c934
7 changed files with 407 additions and 56 deletions

3
.gitignore vendored
View File

@ -177,4 +177,5 @@ state.json
# .DS_Store files
.DS_Store
**/.DS_Store
**/.DS_Store
.aider*

View File

@ -80,8 +80,6 @@ async def run_agent(thread_id: str, project_id: str, stream: bool = True, thread
for tag_name, example in thread_manager.tool_registry.get_xml_examples().items():
xml_examples += f"{example}\n"
system_message = { "role": "system", "content": get_system_prompt() + "\n\n" + f"<tool_examples>\n{xml_examples}\n</tool_examples>" }
iteration_count = 0
continue_execution = True
@ -109,24 +107,46 @@ async def run_agent(thread_id: str, project_id: str, stream: bool = True, thread
print(f"Last message was from assistant, stopping execution")
continue_execution = False
break
# Get the latest message from messages table that its tpye is browser_state
# Define Processor Config FIRST
processor_config = ProcessorConfig(
xml_tool_calling=True,
native_tool_calling=False,
execute_tools=True,
execute_on_stream=True,
tool_execution_strategy="parallel",
xml_adding_strategy="user_message"
)
# Construct System Message Conditionally
base_system_prompt_content = get_system_prompt()
system_message_content = base_system_prompt_content
# Conditionally add XML examples based on the config
if processor_config.xml_tool_calling:
# Use the already loaded xml_examples from outside the loop
if xml_examples:
system_message_content += "\n\n" + f"<tool_examples>\n{xml_examples}\n</tool_examples>"
system_message = { "role": "system", "content": system_message_content }
# Handle Temporary Message (Browser State)
latest_browser_state = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'browser_state').order('created_at', desc=True).limit(1).execute()
temporary_message = None
if latest_browser_state.data and len(latest_browser_state.data) > 0:
try:
content = json.loads(latest_browser_state.data[0]["content"])
screenshot_base64 = content["screenshot_base64"]
screenshot_base64 = content.get("screenshot_base64") # Use .get() for safety
# Create a copy of the browser state without screenshot
browser_state = content.copy()
browser_state.pop('screenshot_base64', None)
browser_state.pop('screenshot_url', None)
browser_state.pop('screenshot_url', None)
browser_state.pop('screenshot_url_base64', None)
temporary_message = { "role": "user", "content": [] }
if browser_state:
temporary_message["content"].append({
"type": "text",
"text": f"The following is the current state of the browser:\n{browser_state}"
"text": f"The following is the current state of the browser:\n{json.dumps(browser_state, indent=2)}" # Pretty print browser state
})
if screenshot_base64:
temporary_message["content"].append({
@ -136,14 +156,15 @@ async def run_agent(thread_id: str, project_id: str, stream: bool = True, thread
}
})
else:
print("@@@@@ THIS TIME NO SCREENSHOT!!")
print("No screenshot found in the latest browser state message.")
except Exception as e:
print(f"Error parsing browser state: {e}")
# print(latest_browser_state.data[0])
# Run Thread
response = await thread_manager.run_thread(
thread_id=thread_id,
system_prompt=system_message,
system_prompt=system_message, # Pass the constructed message
stream=stream,
llm_model=os.getenv("MODEL_TO_USE", "anthropic/claude-3-7-sonnet-latest"),
llm_temperature=0,
@ -151,16 +172,10 @@ async def run_agent(thread_id: str, project_id: str, stream: bool = True, thread
tool_choice="auto",
max_xml_tool_calls=1,
temporary_message=temporary_message,
processor_config=ProcessorConfig(
xml_tool_calling=True,
native_tool_calling=False,
execute_tools=True,
execute_on_stream=True,
tool_execution_strategy="parallel",
xml_adding_strategy="user_message"
),
processor_config=processor_config, # Pass the config object
native_max_auto_continues=native_max_auto_continues,
include_xml_examples=True,
# Explicitly set include_xml_examples to False here
include_xml_examples=False,
)
if isinstance(response, dict) and "status" in response and response["status"] == "error":

View File

@ -447,14 +447,21 @@ class ResponseProcessor:
continue
# Add assistant message with accumulated content
# Start with base message data
message_data = {
"role": "assistant",
"content": accumulated_content,
"tool_calls": complete_native_tool_calls if config.native_tool_calling and complete_native_tool_calls else None
"content": accumulated_content
# tool_calls key is initially omitted
}
# Conditionally add tool_calls if they exist and native calling is enabled
if config.native_tool_calling and complete_native_tool_calls:
message_data["tool_calls"] = complete_native_tool_calls
# Add the message (tool_calls will only be present if added above)
await self.add_message(
thread_id=thread_id,
type="assistant",
thread_id=thread_id,
type="assistant",
content=message_data,
is_llm_message=True
)
@ -657,14 +664,22 @@ class ResponseProcessor:
})
# Add assistant message FIRST - always do this regardless of finish_reason
# Start with base message data
message_data = {
"role": "assistant",
"content": content,
"tool_calls": native_tool_calls if config.native_tool_calling and 'native_tool_calls' in locals() else None
"content": content
# tool_calls key is initially omitted
}
# Conditionally add tool_calls if they exist and native calling is enabled
# Use 'native_tool_calls' in locals() check for safety as before
if config.native_tool_calling and 'native_tool_calls' in locals() and native_tool_calls:
message_data["tool_calls"] = native_tool_calls
# Add the message
await self.add_message(
thread_id=thread_id,
type="assistant",
thread_id=thread_id,
type="assistant",
content=message_data,
is_llm_message=True
)
@ -1319,4 +1334,4 @@ class ResponseProcessor:
"xml_tag_name": context.xml_tag_name,
"message": f"Error executing tool: {error_msg}",
"tool_index": context.tool_index
}
}

View File

@ -198,32 +198,6 @@ class ThreadManager:
if max_xml_tool_calls > 0:
processor_config.max_xml_tool_calls = max_xml_tool_calls
# Add XML examples to system prompt if requested
if include_xml_examples and processor_config.xml_tool_calling:
xml_examples = self.tool_registry.get_xml_examples()
if xml_examples:
# logger.debug(f"Adding {len(xml_examples)} XML examples to system prompt")
# Create or append to content
if isinstance(system_prompt['content'], str):
examples_content = """
--- XML TOOL CALLING ---
In this environment you have access to a set of tools you can use to answer the user's question. The tools are specified in XML format.
{{ FORMATTING INSTRUCTIONS }}
String and scalar parameters should be specified as attributes, while content goes between tags.
Note that spaces for string values are not stripped. The output is parsed with regular expressions.
Here are the XML tools available with examples:
"""
for tag_name, example in xml_examples.items():
examples_content += f"<{tag_name}> Example: {example}\n"
system_prompt['content'] += examples_content
else:
# If content is not a string (might be a list or dict), log a warning
logger.warning("System prompt content is not a string, cannot add XML examples")
# 1. Get messages from thread for LLM call
messages = await self.get_llm_messages(thread_id)

View File

@ -14,6 +14,7 @@ from typing import Union, Dict, Any, Optional, AsyncGenerator, List
import os
import json
import asyncio
import time # Added for timestamp
from openai import OpenAIError
import litellm
from utils.logger import logger
@ -26,6 +27,9 @@ MAX_RETRIES = 3
RATE_LIMIT_DELAY = 30
RETRY_DELAY = 5
# Define debug log directory relative to this file's location
DEBUG_LOG_DIR = os.path.join(os.path.dirname(__file__), '..', 'debug_logs') # Assumes backend/debug_logs
class LLMError(Exception):
"""Base exception for LLM-related errors."""
pass
@ -208,15 +212,116 @@ async def make_llm_api_call(
model_id=model_id
)
# Apply Anthropic prompt caching (minimal implementation)
if params["model"].startswith("anthropic/"):
logger.debug("Applying minimal Anthropic prompt caching.")
messages = params["messages"] # Direct reference
# 1. Process the first message if it's a system prompt with string content
if messages and messages[0].get("role") == "system":
content = messages[0].get("content")
if isinstance(content, str):
messages[0]["content"] = [
{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
]
logger.debug("Applied cache_control to system message.")
modified = True
elif not isinstance(content, list):
logger.warning("System message content is not a string or list, skipping cache_control.")
# else: content is already a list, do nothing
# 2. Find and process the last user message
last_user_idx = -1
for i in range(len(messages) - 1, -1, -1):
if messages[i].get("role") == "user":
last_user_idx = i
break
if last_user_idx != -1:
last_user_message = messages[last_user_idx]
content = last_user_message.get("content")
applied_to_user = False
if isinstance(content, str):
last_user_message["content"] = [
{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
]
logger.debug(f"Applied cache_control to last user message (string content, index {last_user_idx}).")
applied_to_user = True
elif isinstance(content, list):
# Modify text blocks within the list directly
found_text_block = False
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
# Add cache_control if not already present (avoids adding it multiple times)
if "cache_control" not in item:
item["cache_control"] = {"type": "ephemeral"}
found_text_block = True # Mark modification only if added
if found_text_block:
logger.debug(f"Applied cache_control to text part(s) of last user message (list content, index {last_user_idx}).")
applied_to_user = True
# else: No text block found or cache_control already present, do nothing
else:
logger.warning(f"Last user message (index {last_user_idx}) content is not a string or list ({type(content)}), skipping cache_control.")
if applied_to_user:
modified = True
# --- Debug Logging Setup ---
# Initialize log path to None, it will be set only if logging is enabled
response_log_path = None
enable_debug_logging = os.environ.get('ENABLE_LLM_DEBUG_LOGGING', 'false').lower() == 'true'
if enable_debug_logging:
try:
os.makedirs(DEBUG_LOG_DIR, exist_ok=True)
timestamp = time.strftime("%Y%m%d_%H%M%S")
# Use a unique ID or counter if calls can happen in the same second
# For simplicity, using timestamp only for now
request_log_path = os.path.join(DEBUG_LOG_DIR, f"llm_request_{timestamp}.json")
response_log_path = os.path.join(DEBUG_LOG_DIR, f"llm_response_{timestamp}.json") # Set here if enabled
# Log the request parameters just before the attempt loop
logger.debug(f"Logging LLM request parameters to {request_log_path}")
with open(request_log_path, 'w') as f:
# Use default=str for potentially non-serializable items in params if needed
json.dump(params, f, indent=2, default=str)
except Exception as log_err:
logger.error(f"Failed to set up or write LLM debug request log: {log_err}", exc_info=True)
# Reset response path to None if setup failed, even if logging was enabled
response_log_path = None
else:
logger.debug("LLM debug logging is disabled via environment variable.")
# --- End Debug Logging Setup ---
last_error = None
for attempt in range(MAX_RETRIES):
try:
logger.debug(f"Attempt {attempt + 1}/{MAX_RETRIES}")
# logger.debug(f"API request parameters: {json.dumps(params, indent=2)}")
response = await litellm.acompletion(**params)
logger.debug(f"Successfully received API response from {model_name}")
logger.debug(f"Response: {response}")
# --- Debug Logging Response ---
if response_log_path: # Only log if request logging setup succeeded
try:
logger.debug(f"Logging LLM response object to {response_log_path}")
# Check if it's a streaming response (AsyncGenerator)
if isinstance(response, AsyncGenerator):
with open(response_log_path, 'w') as f:
json.dump({"status": "streaming_response", "message": "Full response logged chunk by chunk where consumed."}, f, indent=2)
else:
# Assume it's a LiteLLM ModelResponse object, convert to dict
response_dict = response.dict()
with open(response_log_path, 'w') as f:
# Use default=str for potentially non-serializable items like datetime
json.dump(response_dict, f, indent=2, default=str)
except Exception as log_err:
logger.error(f"Failed to write LLM debug response log: {log_err}", exc_info=True)
# --- End Debug Logging Response ---
return response
except (litellm.exceptions.RateLimitError, OpenAIError, json.JSONDecodeError) as e:

82
backend/tests/raw_test.py Normal file
View File

@ -0,0 +1,82 @@
import asyncio
import litellm
async def main():
initial_messages=[
# System Message
{
"role": "system",
"content": [
{
"type": "text",
"text": "Here is the full text of a complex legal agreement"
* 400,
"cache_control": {"type": "ephemeral"},
}
],
},
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
{
"role": "assistant",
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/month",
},
# The final turn is marked with cache-control, for continuing in followups.
{
"role": "user",
"content": [
{
"type": "text",
"text": "What are the key terms and conditions in this agreement?",
"cache_control": {"type": "ephemeral"},
}
],
},
]
print("--- First call ---")
first_response = await litellm.acompletion(
model="anthropic/claude-3-7-sonnet-latest",
messages=initial_messages
)
print(first_response)
# Prepare messages for the second call
second_call_messages = initial_messages + [
{
"role": "assistant",
# Extract the assistant's response content from the first call
"content": first_response.choices[0].message.content
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "Can you elaborate on the termination clause based on the provided text? Remember the context.",
"cache_control": {"type": "ephemeral"}, # Mark for caching
}
],
},
]
print("\n--- Second call (testing cache) ---")
second_response = await litellm.acompletion(
model="anthropic/claude-3-7-sonnet-latest",
messages=second_call_messages
)
print(second_response)
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,159 @@
import asyncio
import json
import os
import sys
import traceback
from dotenv import load_dotenv
load_dotenv()
# Ensure the backend directory is in the Python path
backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if backend_dir not in sys.path:
sys.path.insert(0, backend_dir)
import logging # Import logging module
from agentpress.thread_manager import ThreadManager
from services.supabase import DBConnection
from agent.run import run_agent
from utils.logger import logger
# Set logging level to DEBUG specifically for this test script
logger.setLevel(logging.DEBUG)
# Optionally, adjust handler levels if needed (e.g., for console output)
for handler in logger.handlers:
if isinstance(handler, logging.StreamHandler): # Target console handler
handler.setLevel(logging.DEBUG)
async def test_agent_limited_iterations():
"""
Test running the agent for a maximum of 3 iterations in non-streaming mode
and print the collected response chunks.
"""
print("\n" + "="*80)
print("🧪 TESTING AGENT RUN WITH MAX ITERATIONS (max_iterations=3, stream=False)")
print("="*80 + "\n")
# Load environment variables
load_dotenv()
# Initialize ThreadManager and DBConnection
thread_manager = ThreadManager()
db_connection = DBConnection()
client = await db_connection.client
thread_id = None
project_id = None
try:
# --- Test Setup ---
print("🔧 Setting up test environment (Project & Thread)...")
# Get user's personal account (replace with a specific test account if needed)
# Using a hardcoded account ID for consistency in tests
account_id = "a5fe9cb6-4812-407e-a61c-fe95b7320c59" # Replace if necessary
logger.info(f"Using Account ID: {account_id}")
if not account_id:
print("❌ Error: Could not determine Account ID.")
return
# Find or create a test project
project_name = "test_simple_dat"
project_result = await client.table('projects').select('*').eq('name', project_name).eq('account_id', account_id).execute()
if project_result.data and len(project_result.data) > 0:
project_id = project_result.data[0]['project_id']
print(f"🔄 Using existing test project: {project_id}")
else:
project_result = await client.table('projects').insert({
"name": project_name,
"account_id": account_id
}).execute()
project_id = project_result.data[0]['project_id']
print(f"✨ Created new test project: {project_id}")
# Create a new thread for this test
thread_result = await client.table('threads').insert({
'project_id': project_id,
'account_id': account_id
}).execute()
thread_id = thread_result.data[0]['thread_id']
print(f"🧵 Created new test thread: {thread_id}")
# Add an initial user message to kick off the agent
initial_message = ("Hello " * 123) + "\\n\\nHow many times did the word 'Hello' appear in the previous text?"
print(f"\\n💬 Adding initial user message: Preview='{initial_message[:50]}...'") # Print only a preview
await thread_manager.add_message(
thread_id=thread_id,
type="user",
content={
"role": "user",
"content": initial_message
},
is_llm_message=True
)
print("✅ Initial message added.")
# --- Run Agent ---
print("\n🔄 Running agent (max_iterations=3, stream=False)...")
all_chunks = []
agent_run_generator = run_agent(
thread_id=thread_id,
project_id=project_id,
stream=False, # Non-streaming
thread_manager=thread_manager,
max_iterations=5 # Limit iterations
)
async for chunk in agent_run_generator:
chunk_type = chunk.get('type', 'unknown')
print(f" 📦 Received chunk: type='{chunk_type}'")
all_chunks.append(chunk)
print("\n✅ Agent run finished.")
# --- Print Results ---
print("\n📄 Full collected response chunks:")
# Use json.dumps for pretty printing the list of dictionaries
print(json.dumps(all_chunks, indent=2, default=str)) # Use default=str for non-serializable types like datetime
except Exception as e:
print(f"\n❌ An error occurred during the test: {e}")
traceback.print_exc()
finally:
# Optional: Clean up the created thread and project
print("\n🧹 Cleaning up test resources...")
if thread_id:
await client.table('threads').delete().eq('thread_id', thread_id).execute()
print(f"🗑️ Deleted test thread: {thread_id}")
if project_id and not project_result.data: # Only delete if we created it
await client.table('projects').delete().eq('project_id', project_id).execute()
print(f"🗑️ Deleted test project: {project_id}")
print("\n" + "="*80)
print("🏁 TEST COMPLETE")
print("="*80 + "\n")
if __name__ == "__main__":
# Ensure the logger is configured
logger.info("Starting test_agent_max_iterations script...")
try:
asyncio.run(test_agent_limited_iterations())
print("\n✅ Test script completed successfully.")
sys.exit(0)
except KeyboardInterrupt:
print("\n\n❌ Test interrupted by user.")
sys.exit(1)
except Exception as e:
print(f"\n\n❌ Error running test script: {e}")
traceback.print_exc()
sys.exit(1)
# before result
# 2025-04-16 19:20:20,494 - DEBUG - Response: ModelResponse(id='chatcmpl-2c5c1418-4570-435c-8d31-5c7ef63a1a68', created=1744827620, model='claude-3-7-sonnet-20250219', object='chat.completion', system_fingerprint=None, choices=[Choices(finish_reason='stop', index=0, message=Message(content='I\'ll update the existing todo.md file and then proceed with counting the "Hello" occurrences.\n\n<full-file-rewrite file_path="todo.md">\n# Hello Count Task\n\n## Setup\n- [ ] Create a file to store the input text\n- [ ] Create a script to count occurrences of "Hello"\n\n## Analysis\n- [ ] Run the script to count occurrences\n- [ ] Verify the results\n\n## Delivery\n- [ ] Provide the final count to the user\n</full-file-rewrite>', role='assistant', tool_calls=None, function_call=None, provider_specific_fields={'citations': None, 'thinking_blocks': None}))], usage=Usage(completion_tokens=125, prompt_tokens=14892, total_tokens=15017, completion_tokens_details=None, prompt_tokens_details=PromptTokensDetailsWrapper(audio_tokens=None, cached_tokens=0, text_tokens=None, image_tokens=None), cache_creation_input_tokens=0, cache_read_input_tokens=0))
# after result
# read cache should > 0 (and it does)