From 88c0d7c9348b025a402a25bddfe74c456c93c91e Mon Sep 17 00:00:00 2001 From: LE Quoc Dat Date: Thu, 17 Apr 2025 00:54:06 +0100 Subject: [PATCH] resolved: prompt caching --- .gitignore | 3 +- backend/agent/run.py | 51 ++++--- backend/agentpress/response_processor.py | 33 ++-- backend/agentpress/thread_manager.py | 26 ---- backend/services/llm.py | 109 +++++++++++++- backend/tests/raw_test.py | 82 ++++++++++ backend/tests/test_simple_prompt_caching.py | 159 ++++++++++++++++++++ 7 files changed, 407 insertions(+), 56 deletions(-) create mode 100644 backend/tests/raw_test.py create mode 100644 backend/tests/test_simple_prompt_caching.py diff --git a/.gitignore b/.gitignore index 89b44174..51e28a7c 100644 --- a/.gitignore +++ b/.gitignore @@ -177,4 +177,5 @@ state.json # .DS_Store files .DS_Store -**/.DS_Store \ No newline at end of file +**/.DS_Store +.aider* diff --git a/backend/agent/run.py b/backend/agent/run.py index 196072da..38129baa 100644 --- a/backend/agent/run.py +++ b/backend/agent/run.py @@ -80,8 +80,6 @@ async def run_agent(thread_id: str, project_id: str, stream: bool = True, thread for tag_name, example in thread_manager.tool_registry.get_xml_examples().items(): xml_examples += f"{example}\n" - system_message = { "role": "system", "content": get_system_prompt() + "\n\n" + f"\n{xml_examples}\n" } - iteration_count = 0 continue_execution = True @@ -109,24 +107,46 @@ async def run_agent(thread_id: str, project_id: str, stream: bool = True, thread print(f"Last message was from assistant, stopping execution") continue_execution = False break - # Get the latest message from messages table that its tpye is browser_state - + + # Define Processor Config FIRST + processor_config = ProcessorConfig( + xml_tool_calling=True, + native_tool_calling=False, + execute_tools=True, + execute_on_stream=True, + tool_execution_strategy="parallel", + xml_adding_strategy="user_message" + ) + + # Construct System Message Conditionally + base_system_prompt_content = get_system_prompt() + system_message_content = base_system_prompt_content + + # Conditionally add XML examples based on the config + if processor_config.xml_tool_calling: + # Use the already loaded xml_examples from outside the loop + if xml_examples: + system_message_content += "\n\n" + f"\n{xml_examples}\n" + + system_message = { "role": "system", "content": system_message_content } + + # Handle Temporary Message (Browser State) latest_browser_state = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'browser_state').order('created_at', desc=True).limit(1).execute() temporary_message = None if latest_browser_state.data and len(latest_browser_state.data) > 0: try: content = json.loads(latest_browser_state.data[0]["content"]) - screenshot_base64 = content["screenshot_base64"] + screenshot_base64 = content.get("screenshot_base64") # Use .get() for safety # Create a copy of the browser state without screenshot browser_state = content.copy() browser_state.pop('screenshot_base64', None) - browser_state.pop('screenshot_url', None) + browser_state.pop('screenshot_url', None) browser_state.pop('screenshot_url_base64', None) temporary_message = { "role": "user", "content": [] } if browser_state: temporary_message["content"].append({ "type": "text", - "text": f"The following is the current state of the browser:\n{browser_state}" + "text": f"The following is the current state of the browser:\n{json.dumps(browser_state, indent=2)}" # Pretty print browser state }) if screenshot_base64: temporary_message["content"].append({ @@ -136,14 +156,15 @@ async def run_agent(thread_id: str, project_id: str, stream: bool = True, thread } }) else: - print("@@@@@ THIS TIME NO SCREENSHOT!!") + print("No screenshot found in the latest browser state message.") except Exception as e: print(f"Error parsing browser state: {e}") # print(latest_browser_state.data[0]) + # Run Thread response = await thread_manager.run_thread( thread_id=thread_id, - system_prompt=system_message, + system_prompt=system_message, # Pass the constructed message stream=stream, llm_model=os.getenv("MODEL_TO_USE", "anthropic/claude-3-7-sonnet-latest"), llm_temperature=0, @@ -151,16 +172,10 @@ async def run_agent(thread_id: str, project_id: str, stream: bool = True, thread tool_choice="auto", max_xml_tool_calls=1, temporary_message=temporary_message, - processor_config=ProcessorConfig( - xml_tool_calling=True, - native_tool_calling=False, - execute_tools=True, - execute_on_stream=True, - tool_execution_strategy="parallel", - xml_adding_strategy="user_message" - ), + processor_config=processor_config, # Pass the config object native_max_auto_continues=native_max_auto_continues, - include_xml_examples=True, + # Explicitly set include_xml_examples to False here + include_xml_examples=False, ) if isinstance(response, dict) and "status" in response and response["status"] == "error": diff --git a/backend/agentpress/response_processor.py b/backend/agentpress/response_processor.py index 6f65d33c..3f0d48e1 100644 --- a/backend/agentpress/response_processor.py +++ b/backend/agentpress/response_processor.py @@ -447,14 +447,21 @@ class ResponseProcessor: continue # Add assistant message with accumulated content + # Start with base message data message_data = { "role": "assistant", - "content": accumulated_content, - "tool_calls": complete_native_tool_calls if config.native_tool_calling and complete_native_tool_calls else None + "content": accumulated_content + # tool_calls key is initially omitted } + + # Conditionally add tool_calls if they exist and native calling is enabled + if config.native_tool_calling and complete_native_tool_calls: + message_data["tool_calls"] = complete_native_tool_calls + + # Add the message (tool_calls will only be present if added above) await self.add_message( - thread_id=thread_id, - type="assistant", + thread_id=thread_id, + type="assistant", content=message_data, is_llm_message=True ) @@ -657,14 +664,22 @@ class ResponseProcessor: }) # Add assistant message FIRST - always do this regardless of finish_reason + # Start with base message data message_data = { "role": "assistant", - "content": content, - "tool_calls": native_tool_calls if config.native_tool_calling and 'native_tool_calls' in locals() else None + "content": content + # tool_calls key is initially omitted } + + # Conditionally add tool_calls if they exist and native calling is enabled + # Use 'native_tool_calls' in locals() check for safety as before + if config.native_tool_calling and 'native_tool_calls' in locals() and native_tool_calls: + message_data["tool_calls"] = native_tool_calls + + # Add the message await self.add_message( - thread_id=thread_id, - type="assistant", + thread_id=thread_id, + type="assistant", content=message_data, is_llm_message=True ) @@ -1319,4 +1334,4 @@ class ResponseProcessor: "xml_tag_name": context.xml_tag_name, "message": f"Error executing tool: {error_msg}", "tool_index": context.tool_index - } \ No newline at end of file + } diff --git a/backend/agentpress/thread_manager.py b/backend/agentpress/thread_manager.py index d2523f26..401bd280 100644 --- a/backend/agentpress/thread_manager.py +++ b/backend/agentpress/thread_manager.py @@ -198,32 +198,6 @@ class ThreadManager: if max_xml_tool_calls > 0: processor_config.max_xml_tool_calls = max_xml_tool_calls - # Add XML examples to system prompt if requested - if include_xml_examples and processor_config.xml_tool_calling: - xml_examples = self.tool_registry.get_xml_examples() - if xml_examples: - # logger.debug(f"Adding {len(xml_examples)} XML examples to system prompt") - - # Create or append to content - if isinstance(system_prompt['content'], str): - examples_content = """ ---- XML TOOL CALLING --- - -In this environment you have access to a set of tools you can use to answer the user's question. The tools are specified in XML format. -{{ FORMATTING INSTRUCTIONS }} -String and scalar parameters should be specified as attributes, while content goes between tags. -Note that spaces for string values are not stripped. The output is parsed with regular expressions. - -Here are the XML tools available with examples: -""" - for tag_name, example in xml_examples.items(): - examples_content += f"<{tag_name}> Example: {example}\n" - - system_prompt['content'] += examples_content - else: - # If content is not a string (might be a list or dict), log a warning - logger.warning("System prompt content is not a string, cannot add XML examples") - # 1. Get messages from thread for LLM call messages = await self.get_llm_messages(thread_id) diff --git a/backend/services/llm.py b/backend/services/llm.py index 162418b5..cd5e8d31 100644 --- a/backend/services/llm.py +++ b/backend/services/llm.py @@ -14,6 +14,7 @@ from typing import Union, Dict, Any, Optional, AsyncGenerator, List import os import json import asyncio +import time # Added for timestamp from openai import OpenAIError import litellm from utils.logger import logger @@ -26,6 +27,9 @@ MAX_RETRIES = 3 RATE_LIMIT_DELAY = 30 RETRY_DELAY = 5 +# Define debug log directory relative to this file's location +DEBUG_LOG_DIR = os.path.join(os.path.dirname(__file__), '..', 'debug_logs') # Assumes backend/debug_logs + class LLMError(Exception): """Base exception for LLM-related errors.""" pass @@ -208,15 +212,116 @@ async def make_llm_api_call( model_id=model_id ) + # Apply Anthropic prompt caching (minimal implementation) + if params["model"].startswith("anthropic/"): + logger.debug("Applying minimal Anthropic prompt caching.") + messages = params["messages"] # Direct reference + + # 1. Process the first message if it's a system prompt with string content + if messages and messages[0].get("role") == "system": + content = messages[0].get("content") + if isinstance(content, str): + messages[0]["content"] = [ + {"type": "text", "text": content, "cache_control": {"type": "ephemeral"}} + ] + logger.debug("Applied cache_control to system message.") + modified = True + elif not isinstance(content, list): + logger.warning("System message content is not a string or list, skipping cache_control.") + # else: content is already a list, do nothing + + # 2. Find and process the last user message + last_user_idx = -1 + for i in range(len(messages) - 1, -1, -1): + if messages[i].get("role") == "user": + last_user_idx = i + break + + if last_user_idx != -1: + last_user_message = messages[last_user_idx] + content = last_user_message.get("content") + applied_to_user = False + + if isinstance(content, str): + last_user_message["content"] = [ + {"type": "text", "text": content, "cache_control": {"type": "ephemeral"}} + ] + logger.debug(f"Applied cache_control to last user message (string content, index {last_user_idx}).") + applied_to_user = True + elif isinstance(content, list): + # Modify text blocks within the list directly + found_text_block = False + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + # Add cache_control if not already present (avoids adding it multiple times) + if "cache_control" not in item: + item["cache_control"] = {"type": "ephemeral"} + found_text_block = True # Mark modification only if added + + if found_text_block: + logger.debug(f"Applied cache_control to text part(s) of last user message (list content, index {last_user_idx}).") + applied_to_user = True + # else: No text block found or cache_control already present, do nothing + else: + logger.warning(f"Last user message (index {last_user_idx}) content is not a string or list ({type(content)}), skipping cache_control.") + + if applied_to_user: + modified = True + + # --- Debug Logging Setup --- + # Initialize log path to None, it will be set only if logging is enabled + response_log_path = None + enable_debug_logging = os.environ.get('ENABLE_LLM_DEBUG_LOGGING', 'false').lower() == 'true' + + if enable_debug_logging: + try: + os.makedirs(DEBUG_LOG_DIR, exist_ok=True) + timestamp = time.strftime("%Y%m%d_%H%M%S") + # Use a unique ID or counter if calls can happen in the same second + # For simplicity, using timestamp only for now + request_log_path = os.path.join(DEBUG_LOG_DIR, f"llm_request_{timestamp}.json") + response_log_path = os.path.join(DEBUG_LOG_DIR, f"llm_response_{timestamp}.json") # Set here if enabled + + # Log the request parameters just before the attempt loop + logger.debug(f"Logging LLM request parameters to {request_log_path}") + with open(request_log_path, 'w') as f: + # Use default=str for potentially non-serializable items in params if needed + json.dump(params, f, indent=2, default=str) + + except Exception as log_err: + logger.error(f"Failed to set up or write LLM debug request log: {log_err}", exc_info=True) + # Reset response path to None if setup failed, even if logging was enabled + response_log_path = None + else: + logger.debug("LLM debug logging is disabled via environment variable.") + # --- End Debug Logging Setup --- + last_error = None for attempt in range(MAX_RETRIES): try: logger.debug(f"Attempt {attempt + 1}/{MAX_RETRIES}") - # logger.debug(f"API request parameters: {json.dumps(params, indent=2)}") response = await litellm.acompletion(**params) logger.debug(f"Successfully received API response from {model_name}") - logger.debug(f"Response: {response}") + + # --- Debug Logging Response --- + if response_log_path: # Only log if request logging setup succeeded + try: + logger.debug(f"Logging LLM response object to {response_log_path}") + # Check if it's a streaming response (AsyncGenerator) + if isinstance(response, AsyncGenerator): + with open(response_log_path, 'w') as f: + json.dump({"status": "streaming_response", "message": "Full response logged chunk by chunk where consumed."}, f, indent=2) + else: + # Assume it's a LiteLLM ModelResponse object, convert to dict + response_dict = response.dict() + with open(response_log_path, 'w') as f: + # Use default=str for potentially non-serializable items like datetime + json.dump(response_dict, f, indent=2, default=str) + except Exception as log_err: + logger.error(f"Failed to write LLM debug response log: {log_err}", exc_info=True) + # --- End Debug Logging Response --- + return response except (litellm.exceptions.RateLimitError, OpenAIError, json.JSONDecodeError) as e: diff --git a/backend/tests/raw_test.py b/backend/tests/raw_test.py new file mode 100644 index 00000000..51a0d80f --- /dev/null +++ b/backend/tests/raw_test.py @@ -0,0 +1,82 @@ +import asyncio +import litellm + +async def main(): + initial_messages=[ + # System Message + { + "role": "system", + "content": [ + { + "type": "text", + "text": "Here is the full text of a complex legal agreement" + * 400, + "cache_control": {"type": "ephemeral"}, + } + ], + }, + # marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache. + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What are the key terms and conditions in this agreement?", + "cache_control": {"type": "ephemeral"}, + } + ], + }, + { + "role": "assistant", + "content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/month", + }, + # The final turn is marked with cache-control, for continuing in followups. + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What are the key terms and conditions in this agreement?", + "cache_control": {"type": "ephemeral"}, + } + ], + }, + ] + + print("--- First call ---") + first_response = await litellm.acompletion( + model="anthropic/claude-3-7-sonnet-latest", + messages=initial_messages + ) + print(first_response) + + # Prepare messages for the second call + second_call_messages = initial_messages + [ + { + "role": "assistant", + # Extract the assistant's response content from the first call + "content": first_response.choices[0].message.content + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Can you elaborate on the termination clause based on the provided text? Remember the context.", + "cache_control": {"type": "ephemeral"}, # Mark for caching + } + ], + }, + ] + + print("\n--- Second call (testing cache) ---") + second_response = await litellm.acompletion( + model="anthropic/claude-3-7-sonnet-latest", + messages=second_call_messages + ) + print(second_response) + +if __name__ == "__main__": + asyncio.run(main()) + + diff --git a/backend/tests/test_simple_prompt_caching.py b/backend/tests/test_simple_prompt_caching.py new file mode 100644 index 00000000..af724b17 --- /dev/null +++ b/backend/tests/test_simple_prompt_caching.py @@ -0,0 +1,159 @@ +import asyncio +import json +import os +import sys +import traceback +from dotenv import load_dotenv +load_dotenv() + +# Ensure the backend directory is in the Python path +backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +if backend_dir not in sys.path: + sys.path.insert(0, backend_dir) + +import logging # Import logging module +from agentpress.thread_manager import ThreadManager +from services.supabase import DBConnection +from agent.run import run_agent +from utils.logger import logger + +# Set logging level to DEBUG specifically for this test script +logger.setLevel(logging.DEBUG) +# Optionally, adjust handler levels if needed (e.g., for console output) +for handler in logger.handlers: + if isinstance(handler, logging.StreamHandler): # Target console handler + handler.setLevel(logging.DEBUG) + +async def test_agent_limited_iterations(): + """ + Test running the agent for a maximum of 3 iterations in non-streaming mode + and print the collected response chunks. + """ + print("\n" + "="*80) + print("๐Ÿงช TESTING AGENT RUN WITH MAX ITERATIONS (max_iterations=3, stream=False)") + print("="*80 + "\n") + + # Load environment variables + load_dotenv() + + # Initialize ThreadManager and DBConnection + thread_manager = ThreadManager() + db_connection = DBConnection() + client = await db_connection.client + + thread_id = None + project_id = None + + try: + # --- Test Setup --- + print("๐Ÿ”ง Setting up test environment (Project & Thread)...") + + # Get user's personal account (replace with a specific test account if needed) + # Using a hardcoded account ID for consistency in tests + account_id = "a5fe9cb6-4812-407e-a61c-fe95b7320c59" # Replace if necessary + logger.info(f"Using Account ID: {account_id}") + + if not account_id: + print("โŒ Error: Could not determine Account ID.") + return + + # Find or create a test project + project_name = "test_simple_dat" + project_result = await client.table('projects').select('*').eq('name', project_name).eq('account_id', account_id).execute() + + if project_result.data and len(project_result.data) > 0: + project_id = project_result.data[0]['project_id'] + print(f"๐Ÿ”„ Using existing test project: {project_id}") + else: + project_result = await client.table('projects').insert({ + "name": project_name, + "account_id": account_id + }).execute() + project_id = project_result.data[0]['project_id'] + print(f"โœจ Created new test project: {project_id}") + + # Create a new thread for this test + thread_result = await client.table('threads').insert({ + 'project_id': project_id, + 'account_id': account_id + }).execute() + thread_id = thread_result.data[0]['thread_id'] + print(f"๐Ÿงต Created new test thread: {thread_id}") + + # Add an initial user message to kick off the agent + initial_message = ("Hello " * 123) + "\\n\\nHow many times did the word 'Hello' appear in the previous text?" + print(f"\\n๐Ÿ’ฌ Adding initial user message: Preview='{initial_message[:50]}...'") # Print only a preview + await thread_manager.add_message( + thread_id=thread_id, + type="user", + content={ + "role": "user", + "content": initial_message + }, + is_llm_message=True + ) + print("โœ… Initial message added.") + + # --- Run Agent --- + print("\n๐Ÿ”„ Running agent (max_iterations=3, stream=False)...") + all_chunks = [] + agent_run_generator = run_agent( + thread_id=thread_id, + project_id=project_id, + stream=False, # Non-streaming + thread_manager=thread_manager, + max_iterations=5 # Limit iterations + ) + + async for chunk in agent_run_generator: + chunk_type = chunk.get('type', 'unknown') + print(f" ๐Ÿ“ฆ Received chunk: type='{chunk_type}'") + all_chunks.append(chunk) + + print("\nโœ… Agent run finished.") + + # --- Print Results --- + print("\n๐Ÿ“„ Full collected response chunks:") + # Use json.dumps for pretty printing the list of dictionaries + print(json.dumps(all_chunks, indent=2, default=str)) # Use default=str for non-serializable types like datetime + + except Exception as e: + print(f"\nโŒ An error occurred during the test: {e}") + traceback.print_exc() + finally: + # Optional: Clean up the created thread and project + print("\n๐Ÿงน Cleaning up test resources...") + if thread_id: + await client.table('threads').delete().eq('thread_id', thread_id).execute() + print(f"๐Ÿ—‘๏ธ Deleted test thread: {thread_id}") + if project_id and not project_result.data: # Only delete if we created it + await client.table('projects').delete().eq('project_id', project_id).execute() + print(f"๐Ÿ—‘๏ธ Deleted test project: {project_id}") + + print("\n" + "="*80) + print("๐Ÿ TEST COMPLETE") + print("="*80 + "\n") + + +if __name__ == "__main__": + # Ensure the logger is configured + logger.info("Starting test_agent_max_iterations script...") + try: + asyncio.run(test_agent_limited_iterations()) + print("\nโœ… Test script completed successfully.") + sys.exit(0) + except KeyboardInterrupt: + print("\n\nโŒ Test interrupted by user.") + sys.exit(1) + except Exception as e: + print(f"\n\nโŒ Error running test script: {e}") + traceback.print_exc() + sys.exit(1) + +# before result +# 2025-04-16 19:20:20,494 - DEBUG - Response: ModelResponse(id='chatcmpl-2c5c1418-4570-435c-8d31-5c7ef63a1a68', created=1744827620, model='claude-3-7-sonnet-20250219', object='chat.completion', system_fingerprint=None, choices=[Choices(finish_reason='stop', index=0, message=Message(content='I\'ll update the existing todo.md file and then proceed with counting the "Hello" occurrences.\n\n\n# Hello Count Task\n\n## Setup\n- [ ] Create a file to store the input text\n- [ ] Create a script to count occurrences of "Hello"\n\n## Analysis\n- [ ] Run the script to count occurrences\n- [ ] Verify the results\n\n## Delivery\n- [ ] Provide the final count to the user\n', role='assistant', tool_calls=None, function_call=None, provider_specific_fields={'citations': None, 'thinking_blocks': None}))], usage=Usage(completion_tokens=125, prompt_tokens=14892, total_tokens=15017, completion_tokens_details=None, prompt_tokens_details=PromptTokensDetailsWrapper(audio_tokens=None, cached_tokens=0, text_tokens=None, image_tokens=None), cache_creation_input_tokens=0, cache_read_input_tokens=0)) + + + +# after result +# read cache should > 0 (and it does) \ No newline at end of file