From b59d0dc90b097c031058b375c2274bdede0b7652 Mon Sep 17 00:00:00 2001 From: marko-kraemer Date: Thu, 17 Apr 2025 23:27:24 +0100 Subject: [PATCH] wip --- backend/agentpress/response_processor.py | 319 ++++++++++++++--------- 1 file changed, 199 insertions(+), 120 deletions(-) diff --git a/backend/agentpress/response_processor.py b/backend/agentpress/response_processor.py index 6f65d33c..8d8f5cd8 100644 --- a/backend/agentpress/response_processor.py +++ b/backend/agentpress/response_processor.py @@ -132,6 +132,10 @@ class ResponseProcessor: # Track finish reason finish_reason = None + # Store message IDs associated with yielded content/tools + last_assistant_message_id = None + tool_result_message_ids = {} # tool_index -> message_id + # logger.debug(f"Starting to process streaming response for thread {thread_id}") logger.info(f"Config: XML={config.xml_tool_calling}, Native={config.native_tool_calling}, " f"Execute on stream={config.execute_on_stream}, Execution strategy={config.tool_execution_strategy}") @@ -144,6 +148,15 @@ class ResponseProcessor: accumulated_token_count = 0 try: + # Generate a unique ID for this response run + thread_run_id = str(uuid.uuid4()) + + # Yield the overall run start signal + yield {"type": "thread_run_start", "thread_run_id": thread_run_id} + + # Yield the assistant response start signal + yield {"type": "assistant_response_start", "thread_run_id": thread_run_id} + async for chunk in llm_response: # Default content to yield @@ -177,7 +190,7 @@ class ResponseProcessor: logger.info("XML tool call limit reached - not yielding more content") else: # Always yield the content chunk if we haven't reached the limit - yield {"type": "content", "content": chunk_content} + yield {"type": "content", "content": chunk_content, "thread_run_id": thread_run_id} # Parse XML tool calls if enabled if config.xml_tool_calling: @@ -208,12 +221,12 @@ class ResponseProcessor: # Execute tool if needed, but in background if config.execute_tools and config.execute_on_stream: # Yield tool execution start message - yield self._yield_tool_started(context) + yield self._yield_tool_started(context, thread_run_id) # Start tool execution as a background task execution_task = asyncio.create_task(self._execute_tool(tool_call)) - # Store the task for later retrieval + # Store the task for later retrieval (to get result after stream) pending_tool_executions.append({ "task": execution_task, "tool_call": tool_call, @@ -260,7 +273,8 @@ class ResponseProcessor: # Yield the chunk data yield { "type": "content", - "tool_call": tool_call_data + "tool_call": tool_call_data, + "thread_run_id": thread_run_id } # Log the tool call chunk for debugging @@ -316,12 +330,12 @@ class ResponseProcessor: ) # Yield tool execution start message - yield self._yield_tool_started(context) + yield self._yield_tool_started(context, thread_run_id) # Start tool execution as a background task execution_task = asyncio.create_task(self._execute_tool(tool_call_data)) - # Store the task for later retrieval + # Store the task for later retrieval (to get result after stream) pending_tool_executions.append({ "task": execution_task, "tool_call": tool_call_data, @@ -353,7 +367,7 @@ class ResponseProcessor: tool_call = execution["tool_call"] tool_index = execution.get("tool_index", -1) - # Store result for later + # Store result for later processing AFTER assistant message is saved tool_results_buffer.append((tool_call, result, tool_index)) # Get or create the context @@ -369,13 +383,13 @@ class ResponseProcessor: logger.info(f"Skipping duplicate yield for tool index {tool_index}") continue - # Yield tool status message first - yield self._yield_tool_completed(context) + # Yield tool status message first (without DB message ID yet) + yield self._yield_tool_completed(context, tool_message_id=None, thread_run_id=thread_run_id) - # Yield tool execution result - yield self._yield_tool_result(context) - - # Track that we've yielded this tool result + # DO NOT yield the tool_result chunk here yet. + # It will be yielded after the assistant message is saved. + + # Track that we've yielded this tool result (status, not the result itself) yielded_tool_indices.add(tool_index) except Exception as e: logger.error(f"Error processing remaining tool execution: {str(e)}") @@ -398,7 +412,7 @@ class ResponseProcessor: context.error = e # Yield error status for the tool - yield self._yield_tool_error(context) + yield self._yield_tool_error(context, thread_run_id) # Track that we've yielded this tool error yielded_tool_indices.add(tool_index) @@ -407,7 +421,8 @@ class ResponseProcessor: if finish_reason == "xml_tool_limit_reached": yield { "type": "finish", - "finish_reason": "xml_tool_limit_reached" + "finish_reason": "xml_tool_limit_reached", + "thread_run_id": thread_run_id } logger.info(f"Stream finished with reason: xml_tool_limit_reached after {xml_tool_call_count} XML tool calls") @@ -452,106 +467,129 @@ class ResponseProcessor: "content": accumulated_content, "tool_calls": complete_native_tool_calls if config.native_tool_calling and complete_native_tool_calls else None } - await self.add_message( + last_assistant_message_id = await self.add_message( thread_id=thread_id, type="assistant", content=message_data, is_llm_message=True ) - # Now add all buffered tool results AFTER the assistant message, but don't yield if already yielded - for tool_call, result, result_tool_index in tool_results_buffer: - # Add result based on tool type to the conversation history - await self._add_tool_result( - thread_id, - tool_call, - result, - config.xml_adding_strategy - ) - - # We don't need to yield again for tools that were already yielded during streaming - if result_tool_index in yielded_tool_indices: - logger.info(f"Skipping duplicate yield for tool index {result_tool_index}") - continue - - # Create context for tool result - context = self._create_tool_context(tool_call, result_tool_index) - context.result = result - - # Yield tool execution result - yield self._yield_tool_result(context) - - # Increment tool index for next tool - tool_index += 1 + # Yield the assistant response end signal *immediately* after saving + if last_assistant_message_id: + yield { + "type": "assistant_response_end", + "assistant_message_id": last_assistant_message_id, + "thread_run_id": thread_run_id + } + else: + # Handle case where saving failed (though it should raise an exception) + yield { + "type": "assistant_response_end", + "assistant_message_id": None, + "thread_run_id": thread_run_id + } - # Execute any remaining tool calls if not done during streaming - # Only process if we haven't reached the XML limit - if config.execute_tools and not config.execute_on_stream and (config.max_xml_tool_calls == 0 or xml_tool_call_count < config.max_xml_tool_calls): - tool_calls_to_execute = [] + # --- Process All Tool Calls Now --- + if config.execute_tools: + final_tool_calls_to_process = [] - # Process native tool calls + # Gather native tool calls from buffer if config.native_tool_calling and complete_native_tool_calls: - for tool_call in complete_native_tool_calls: - tool_calls_to_execute.append({ - "function_name": tool_call["function"]["name"], - "arguments": tool_call["function"]["arguments"], - "id": tool_call["id"] + for tc in complete_native_tool_calls: + final_tool_calls_to_process.append({ + "function_name": tc["function"]["name"], + "arguments": tc["function"]["arguments"], + "id": tc["id"] }) - # Process XML tool calls - only if we haven't hit the limit - if config.xml_tool_calling and (config.max_xml_tool_calls == 0 or xml_tool_call_count < config.max_xml_tool_calls): - # Extract any remaining complete XML chunks + # Gather XML tool calls from buffer (up to limit) + if config.xml_tool_calling: xml_chunks = self._extract_xml_chunks(current_xml_content) xml_chunks_buffer.extend(xml_chunks) - - # Only process up to the limit - remaining_xml_calls = config.max_xml_tool_calls - xml_tool_call_count if config.max_xml_tool_calls > 0 else len(xml_chunks_buffer) - xml_chunks_to_process = xml_chunks_buffer[:remaining_xml_calls] if remaining_xml_calls > 0 else [] - - for xml_chunk in xml_chunks_to_process: - tool_call = self._parse_xml_tool_call(xml_chunk) - if tool_call: - tool_calls_to_execute.append(tool_call) - xml_tool_call_count += 1 + remaining_limit = config.max_xml_tool_calls - xml_tool_call_count if config.max_xml_tool_calls > 0 else len(xml_chunks_buffer) + xml_chunks_to_process = xml_chunks_buffer[:remaining_limit] + for chunk in xml_chunks_to_process: + tc = self._parse_xml_tool_call(chunk) + if tc: final_tool_calls_to_process.append(tc) - # Execute all collected tool calls - if tool_calls_to_execute: - tool_results = await self._execute_tools( - tool_calls_to_execute, - config.tool_execution_strategy - ) + # Get results (either from pending tasks or by executing now) + tool_results_map = {} # tool_index -> (tool_call, result) + if config.execute_on_stream and pending_tool_executions: + logger.info(f"Waiting for {len(pending_tool_executions)} pending streamed tool executions") + tasks = {exec["tool_index"]: exec["task"] for exec in pending_tool_executions} + tool_calls_by_index = {exec["tool_index"]: exec["tool_call"] for exec in pending_tool_executions} + done, _ = await asyncio.wait(tasks.values()) + for idx, task in tasks.items(): + try: + result = task.result() + tool_results_map[idx] = (tool_calls_by_index[idx], result) + except Exception as e: + logger.error(f"Error getting result for streamed tool index {idx}: {e}") + tool_results_map[idx] = (tool_calls_by_index[idx], ToolResult(success=False, output=f"Error: {e}")) + elif final_tool_calls_to_process: # Execute tools now if not streamed + logger.info(f"Executing {len(final_tool_calls_to_process)} tools sequentially/parallelly") + results_list = await self._execute_tools(final_tool_calls_to_process, config.tool_execution_strategy) + # Map results back to original tool index if possible (difficult without original index) + # For simplicity, we'll process them in the order returned + current_tool_idx = 0 # Reset index for non-streamed execution results + for tc, res in results_list: + tool_results_map[current_tool_idx] = (tc, res) + current_tool_idx += 1 + + # Now, process and yield each result sequentially + logger.info(f"Processing and yielding {len(tool_results_map)} tool results") + processed_tool_indices = set() + # We need a deterministic order, sort by index + for tool_idx in sorted(tool_results_map.keys()): + tool_call, result = tool_results_map[tool_idx] + context = self._create_tool_context(tool_call, tool_idx) + context.result = result - for tool_call, result in tool_results: - # Add result based on tool type - await self._add_tool_result( - thread_id, - tool_call, - result, - config.xml_adding_strategy - ) - - # Create context for tool result - context = self._create_tool_context(tool_call, tool_index) - context.result = result - - # Yield tool execution result - yield self._yield_tool_result(context) - - # Increment tool index for next tool - tool_index += 1 + # Yield start status (even if streamed, yield again here for strict order) + yield self._yield_tool_started(context, thread_run_id) + + # Save result to DB and get ID + tool_msg_id = await self._add_tool_result(thread_id, tool_call, result, config.xml_adding_strategy) + if tool_msg_id: + tool_result_message_ids[tool_idx] = tool_msg_id # Store for reference + else: + logger.error(f"Failed to get message ID for tool index {tool_idx}") + + # Yield completed status with ID + yield self._yield_tool_completed(context, tool_message_id=tool_msg_id, thread_run_id=thread_run_id) + + # Yield result with ID + yield self._yield_tool_result(context, tool_message_id=tool_msg_id, thread_run_id=thread_run_id) + + processed_tool_indices.add(tool_idx) # Finally, if we detected a finish reason, yield it if finish_reason and finish_reason != "xml_tool_limit_reached": # Already yielded if limit reached yield { "type": "finish", - "finish_reason": finish_reason + "finish_reason": finish_reason, + "thread_run_id": thread_run_id } except Exception as e: logger.error(f"Error processing stream: {str(e)}", exc_info=True) - yield {"type": "error", "message": str(e)} + yield {"type": "error", "message": str(e), "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None} finally: + # Yield a finish signal including the final assistant message ID + if last_assistant_message_id: + # Yield the overall run end signal + yield { + "type": "thread_run_end", + "thread_run_id": thread_run_id + } + else: + # Yield the overall run end signal + yield { + "type": "thread_run_end", + "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None + } + pass # track the cost and token count # todo: there is a bug as it adds every chunk to db because finally will run every time even in yield @@ -570,8 +608,8 @@ class ResponseProcessor: self, llm_response: Any, thread_id: str, - config: ProcessorConfig = ProcessorConfig(), - ) -> AsyncGenerator: + config: ProcessorConfig = ProcessorConfig() + ) -> AsyncGenerator[Dict[str, Any], None]: """Process a non-streaming LLM response, handling tool calls and execution. Args: @@ -585,13 +623,20 @@ class ResponseProcessor: try: # Extract content and tool calls from response content = "" + # Generate a unique ID for this thread run + thread_run_id = str(uuid.uuid4()) + tool_calls = [] # Tool execution counter tool_index = 0 # XML tool call counter xml_tool_call_count = 0 # Set to track yielded tool results - yielded_tool_indices = set() + # yielded_tool_indices = set() # Not needed for non-streaming as we yield all at once + + # Store message IDs + assistant_message_id = None + tool_result_message_ids = {} # tool_index -> message_id # Extract finish_reason if available finish_reason = None @@ -662,7 +707,7 @@ class ResponseProcessor: "content": content, "tool_calls": native_tool_calls if config.native_tool_calling and 'native_tool_calls' in locals() else None } - await self.add_message( + assistant_message_id = await self.add_message( thread_id=thread_id, type="assistant", content=message_data, @@ -670,7 +715,27 @@ class ResponseProcessor: ) # Yield content first - yield {"type": "content", "content": content} + yield { + "type": "content", + "content": content, + "assistant_message_id": assistant_message_id, + "thread_run_id": thread_run_id + } + + # Yield the assistant response end signal *immediately* after saving + if assistant_message_id: + yield { + "type": "assistant_response_end", + "assistant_message_id": assistant_message_id, + "thread_run_id": thread_run_id + } + else: + # Handle case where saving failed (though it should raise an exception) + yield { + "type": "assistant_response_end", + "assistant_message_id": None, + "thread_run_id": thread_run_id + } # Execute tools if needed - AFTER assistant message has been added if config.execute_tools and tool_calls: @@ -684,23 +749,22 @@ class ResponseProcessor: ) for tool_call, result in tool_results: - # Add result based on tool type - await self._add_tool_result( + # Capture the message ID for this tool result + message_id = await self._add_tool_result( thread_id, tool_call, result, config.xml_adding_strategy ) - + if message_id: + tool_result_message_ids[tool_index] = message_id + # Create context for tool result context = self._create_tool_context(tool_call, tool_index) context.result = result # Yield tool execution result - yield self._yield_tool_result(context) - - # Track that we've yielded this tool result - yielded_tool_indices.add(tool_index) + yield self._yield_tool_result(context, tool_message_id=message_id, thread_run_id=thread_run_id) # Increment tool index for next tool tool_index += 1 @@ -709,19 +773,21 @@ class ResponseProcessor: if finish_reason == "xml_tool_limit_reached": yield { "type": "finish", - "finish_reason": "xml_tool_limit_reached" + "finish_reason": "xml_tool_limit_reached", + "thread_run_id": thread_run_id } logger.info(f"Non-streaming response finished with reason: xml_tool_limit_reached after {xml_tool_call_count} XML tool calls") # Otherwise yield the regular finish reason if available elif finish_reason: yield { "type": "finish", - "finish_reason": finish_reason + "finish_reason": finish_reason, + "thread_run_id": thread_run_id } except Exception as e: logger.error(f"Error processing response: {str(e)}", exc_info=True) - yield {"type": "error", "message": str(e)} + yield {"type": "error", "message": str(e), "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None} # XML parsing methods def _extract_tag_content(self, xml_chunk: str, tag_name: str) -> Tuple[Optional[str], Optional[str]]: @@ -1125,7 +1191,7 @@ class ResponseProcessor: tool_call: Dict[str, Any], result: ToolResult, strategy: Union[XmlAddingStrategy, str] = "assistant_message" - ): + ) -> Optional[str]: # Return the message ID """Add a tool result to the conversation thread based on the specified format. This method formats tool results and adds them to the conversation history, @@ -1141,6 +1207,7 @@ class ResponseProcessor: ("user_message", "assistant_message", or "inline_edit") """ try: + message_id = None # Initialize message_id # Check if this is a native function call (has id field) if "id" in tool_call: # Format as a proper tool message according to OpenAI spec @@ -1175,13 +1242,13 @@ class ResponseProcessor: # Add as a tool message to the conversation history # This makes the result visible to the LLM in the next turn - await self.add_message( + message_id = await self.add_message( thread_id=thread_id, type="tool", # Special type for tool responses content=tool_message, is_llm_message=True ) - return + return message_id # Return the message ID # For XML and other non-native tools, continue with the original logic # Determine message role based on strategy @@ -1200,12 +1267,13 @@ class ResponseProcessor: "role": result_role, "content": content } - await self.add_message( + message_id = await self.add_message( thread_id=thread_id, type="tool", content=result_message, is_llm_message=True ) + return message_id # Return the message ID except Exception as e: logger.error(f"Error adding tool result: {str(e)}", exc_info=True) # Fallback to a simple message @@ -1214,14 +1282,16 @@ class ResponseProcessor: "role": "user", "content": str(result) } - await self.add_message( + message_id = await self.add_message( thread_id=thread_id, type="tool", content=fallback_message, is_llm_message=True ) + return message_id # Return the message ID except Exception as e2: logger.error(f"Failed even with fallback message: {str(e2)}", exc_info=True) + return None # Return None on error def _format_xml_tool_result(self, tool_call: Dict[str, Any], result: ToolResult) -> str: """Format a tool result wrapped in a tag. @@ -1243,15 +1313,17 @@ class ResponseProcessor: return f"Result for {function_name}: {str(result)}" # At class level, define a method for yielding tool results - def _yield_tool_result(self, context: ToolExecutionContext) -> Dict[str, Any]: + def _yield_tool_result(self, context: ToolExecutionContext, tool_message_id: Optional[str], thread_run_id: str) -> Dict[str, Any]: """Format and return a tool result message.""" if not context.result: return { "type": "tool_result", "function_name": context.function_name, "xml_tag_name": context.xml_tag_name, - "result": "No result available", - "tool_index": context.tool_index + "result": "Error: No result available in context", + "tool_index": context.tool_index, + "tool_message_id": tool_message_id, + "thread_run_id": thread_run_id } formatted_result = self._format_xml_tool_result(context.tool_call, context.result) @@ -1260,7 +1332,9 @@ class ResponseProcessor: "function_name": context.function_name, "xml_tag_name": context.xml_tag_name, "result": formatted_result, - "tool_index": context.tool_index + "tool_index": context.tool_index, + "tool_message_id": tool_message_id, + "thread_run_id": thread_run_id } def _create_tool_context(self, tool_call: Dict[str, Any], tool_index: int) -> ToolExecutionContext: @@ -1281,7 +1355,7 @@ class ResponseProcessor: return context - def _yield_tool_started(self, context: ToolExecutionContext) -> Dict[str, Any]: + def _yield_tool_started(self, context: ToolExecutionContext, thread_run_id: str) -> Dict[str, Any]: """Format and return a tool started status message.""" tool_name = context.xml_tag_name or context.function_name return { @@ -1290,13 +1364,14 @@ class ResponseProcessor: "function_name": context.function_name, "xml_tag_name": context.xml_tag_name, "message": f"Starting execution of {tool_name}", - "tool_index": context.tool_index + "tool_index": context.tool_index, + "thread_run_id": thread_run_id } - def _yield_tool_completed(self, context: ToolExecutionContext) -> Dict[str, Any]: + def _yield_tool_completed(self, context: ToolExecutionContext, tool_message_id: Optional[str], thread_run_id: str) -> Dict[str, Any]: """Format and return a tool completed/failed status message.""" if not context.result: - return self._yield_tool_error(context) + return self._yield_tool_error(context, thread_run_id) tool_name = context.xml_tag_name or context.function_name return { @@ -1305,10 +1380,12 @@ class ResponseProcessor: "function_name": context.function_name, "xml_tag_name": context.xml_tag_name, "message": f"Tool {tool_name} {'completed successfully' if context.result.success else 'failed'}", - "tool_index": context.tool_index + "tool_index": context.tool_index, + "tool_message_id": tool_message_id, + "thread_run_id": thread_run_id } - def _yield_tool_error(self, context: ToolExecutionContext) -> Dict[str, Any]: + def _yield_tool_error(self, context: ToolExecutionContext, thread_run_id: str) -> Dict[str, Any]: """Format and return a tool error status message.""" error_msg = str(context.error) if context.error else "Unknown error" tool_name = context.xml_tag_name or context.function_name @@ -1318,5 +1395,7 @@ class ResponseProcessor: "function_name": context.function_name, "xml_tag_name": context.xml_tag_name, "message": f"Error executing tool: {error_msg}", - "tool_index": context.tool_index + "tool_index": context.tool_index, + "tool_message_id": None, + "thread_run_id": thread_run_id } \ No newline at end of file