diff --git a/backend/agentpress/response_processor.py b/backend/agentpress/response_processor.py index 8b7af0d7..048e581a 100644 --- a/backend/agentpress/response_processor.py +++ b/backend/agentpress/response_processor.py @@ -6,7 +6,6 @@ This module handles the processing of LLM responses, including: - XML and native tool call detection and parsing - Tool execution orchestration - Message formatting and persistence -- Cost calculation and tracking """ import json @@ -20,13 +19,13 @@ from utils.logger import logger from agentpress.tool import ToolResult from agentpress.tool_registry import ToolRegistry from agentpress.xml_tool_parser import XMLToolParser -from litellm import completion_cost from langfuse.client import StatefulTraceClient from services.langfuse import langfuse from agentpress.utils.json_helpers import ( ensure_dict, ensure_list, safe_json_parse, to_json_string, format_for_yield ) +from litellm import token_counter # Type alias for XML result adding strategy XmlAddingStrategy = Literal["user_message", "assistant_message", "inline_edit"] @@ -146,6 +145,21 @@ class ResponseProcessor: tool_result_message_objects = {} # tool_index -> full saved message object has_printed_thinking_prefix = False # Flag for printing thinking prefix only once agent_should_terminate = False # Flag to track if a terminating tool has been executed + complete_native_tool_calls = [] # Initialize early for use in assistant_response_end + + # Collect metadata for reconstructing LiteLLM response object + streaming_metadata = { + "model": llm_model, + "created": None, + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + }, + "response_ms": None, + "first_chunk_time": None, + "last_chunk_time": None + } logger.info(f"Streaming Config: XML={config.xml_tool_calling}, Native={config.native_tool_calling}, " f"Execute on stream={config.execute_on_stream}, Strategy={config.tool_execution_strategy}") @@ -172,6 +186,26 @@ class ResponseProcessor: __sequence = 0 async for chunk in llm_response: + # Extract streaming metadata from chunks + current_time = datetime.now(timezone.utc).timestamp() + if streaming_metadata["first_chunk_time"] is None: + streaming_metadata["first_chunk_time"] = current_time + streaming_metadata["last_chunk_time"] = current_time + + # Extract metadata from chunk attributes + if hasattr(chunk, 'created') and chunk.created: + streaming_metadata["created"] = chunk.created + if hasattr(chunk, 'model') and chunk.model: + streaming_metadata["model"] = chunk.model + if hasattr(chunk, 'usage') and chunk.usage: + # Update usage information if available (including zero values) + if hasattr(chunk.usage, 'prompt_tokens') and chunk.usage.prompt_tokens is not None: + streaming_metadata["usage"]["prompt_tokens"] = chunk.usage.prompt_tokens + if hasattr(chunk.usage, 'completion_tokens') and chunk.usage.completion_tokens is not None: + streaming_metadata["usage"]["completion_tokens"] = chunk.usage.completion_tokens + if hasattr(chunk.usage, 'total_tokens') and chunk.usage.total_tokens is not None: + streaming_metadata["usage"]["total_tokens"] = chunk.usage.total_tokens + if hasattr(chunk, 'choices') and chunk.choices and hasattr(chunk.choices[0], 'finish_reason') and chunk.choices[0].finish_reason: finish_reason = chunk.choices[0].finish_reason logger.debug(f"Detected finish_reason: {finish_reason}") @@ -317,6 +351,33 @@ class ResponseProcessor: # print() # Add a final newline after the streaming loop finishes # --- After Streaming Loop --- + + if ( + streaming_metadata["usage"]["total_tokens"] == 0 + ): + logger.info("🔥 No usage data from provider, counting with litellm.token_counter") + + # prompt side + prompt_tokens = token_counter( + model=llm_model, + messages=prompt_messages # chat or plain; token_counter handles both + ) + + # completion side + completion_tokens = token_counter( + model=llm_model, + text=accumulated_content or "" # empty string safe + ) + + streaming_metadata["usage"]["prompt_tokens"] = prompt_tokens + streaming_metadata["usage"]["completion_tokens"] = completion_tokens + streaming_metadata["usage"]["total_tokens"] = prompt_tokens + completion_tokens + + logger.info( + f"🔥 Estimated tokens – prompt: {prompt_tokens}, " + f"completion: {completion_tokens}, total: {prompt_tokens + completion_tokens}" + ) + # Wait for pending tool executions from streaming phase tool_results_buffer = [] # Stores (tool_call, result, tool_index, context) @@ -409,7 +470,7 @@ class ResponseProcessor: accumulated_content = accumulated_content[:last_chunk_end_pos] # ... (Extract complete_native_tool_calls logic) ... - complete_native_tool_calls = [] + # Update complete_native_tool_calls from buffer (initialized earlier) if config.native_tool_calling: for idx, tc_buf in tool_calls_buffer.items(): if tc_buf['id'] and tc_buf['function']['name'] and tc_buf['function']['arguments']: @@ -575,34 +636,6 @@ class ResponseProcessor: self.trace.event(name="failed_to_save_tool_result_for_index", level="ERROR", status_message=(f"Failed to save tool result for index {tool_idx}, not yielding result message.")) # Optionally yield error status for saving failure? - # --- Calculate and Store Cost --- - if last_assistant_message_object: # Only calculate if assistant message was saved - try: - # Use accumulated_content for streaming cost calculation - final_cost = completion_cost( - model=llm_model, - messages=prompt_messages, # Use the prompt messages provided - completion=accumulated_content - ) - if final_cost is not None and final_cost > 0: - logger.info(f"Calculated final cost for stream: {final_cost}") - await self.add_message( - thread_id=thread_id, - type="cost", - content={"cost": final_cost}, - is_llm_message=False, # Cost is metadata - metadata={"thread_run_id": thread_run_id} # Keep track of the run - ) - logger.info(f"Cost message saved for stream: {final_cost}") - self.trace.update(metadata={"cost": final_cost}) - else: - logger.info("Stream cost calculation resulted in zero or None, not storing cost message.") - self.trace.update(metadata={"cost": 0}) - except Exception as e: - logger.error(f"Error calculating final cost for stream: {str(e)}") - self.trace.event(name="error_calculating_final_cost_for_stream", level="ERROR", status_message=(f"Error calculating final cost for stream: {str(e)}")) - - # --- Final Finish Status --- if finish_reason and finish_reason != "xml_tool_limit_reached": finish_content = {"status_type": "finish", "finish_reason": finish_reason} @@ -628,9 +661,107 @@ class ResponseProcessor: ) if finish_msg_obj: yield format_for_yield(finish_msg_obj) + # Save assistant_response_end BEFORE terminating + if last_assistant_message_object: + try: + # Calculate response time if we have timing data + if streaming_metadata["first_chunk_time"] and streaming_metadata["last_chunk_time"]: + streaming_metadata["response_ms"] = (streaming_metadata["last_chunk_time"] - streaming_metadata["first_chunk_time"]) * 1000 + + # Create a LiteLLM-like response object for streaming (before termination) + # Check if we have any actual usage data + has_usage_data = ( + streaming_metadata["usage"]["prompt_tokens"] > 0 or + streaming_metadata["usage"]["completion_tokens"] > 0 or + streaming_metadata["usage"]["total_tokens"] > 0 + ) + + assistant_end_content = { + "choices": [ + { + "finish_reason": finish_reason or "stop", + "index": 0, + "message": { + "role": "assistant", + "content": accumulated_content, + "tool_calls": complete_native_tool_calls or None + } + } + ], + "created": streaming_metadata.get("created"), + "model": streaming_metadata.get("model", llm_model), + "usage": streaming_metadata["usage"], # Always include usage like LiteLLM does + "streaming": True, # Add flag to indicate this was reconstructed from streaming + } + + # Only include response_ms if we have timing data + if streaming_metadata.get("response_ms"): + assistant_end_content["response_ms"] = streaming_metadata["response_ms"] + + await self.add_message( + thread_id=thread_id, + type="assistant_response_end", + content=assistant_end_content, + is_llm_message=False, + metadata={"thread_run_id": thread_run_id} + ) + logger.info("Assistant response end saved for stream (before termination)") + except Exception as e: + logger.error(f"Error saving assistant response end for stream (before termination): {str(e)}") + self.trace.event(name="error_saving_assistant_response_end_for_stream_before_termination", level="ERROR", status_message=(f"Error saving assistant response end for stream (before termination): {str(e)}")) + # Skip all remaining processing and go to finally block return + # --- Save and Yield assistant_response_end --- + if last_assistant_message_object: # Only save if assistant message was saved + try: + # Calculate response time if we have timing data + if streaming_metadata["first_chunk_time"] and streaming_metadata["last_chunk_time"]: + streaming_metadata["response_ms"] = (streaming_metadata["last_chunk_time"] - streaming_metadata["first_chunk_time"]) * 1000 + + # Create a LiteLLM-like response object for streaming + # Check if we have any actual usage data + has_usage_data = ( + streaming_metadata["usage"]["prompt_tokens"] > 0 or + streaming_metadata["usage"]["completion_tokens"] > 0 or + streaming_metadata["usage"]["total_tokens"] > 0 + ) + + assistant_end_content = { + "choices": [ + { + "finish_reason": finish_reason or "stop", + "index": 0, + "message": { + "role": "assistant", + "content": accumulated_content, + "tool_calls": complete_native_tool_calls or None + } + } + ], + "created": streaming_metadata.get("created"), + "model": streaming_metadata.get("model", llm_model), + "usage": streaming_metadata["usage"], # Always include usage like LiteLLM does + "streaming": True, # Add flag to indicate this was reconstructed from streaming + } + + # Only include response_ms if we have timing data + if streaming_metadata.get("response_ms"): + assistant_end_content["response_ms"] = streaming_metadata["response_ms"] + + await self.add_message( + thread_id=thread_id, + type="assistant_response_end", + content=assistant_end_content, + is_llm_message=False, + metadata={"thread_run_id": thread_run_id} + ) + logger.info("Assistant response end saved for stream") + except Exception as e: + logger.error(f"Error saving assistant response end for stream: {str(e)}") + self.trace.event(name="error_saving_assistant_response_end_for_stream", level="ERROR", status_message=(f"Error saving assistant response end for stream: {str(e)}")) + except Exception as e: logger.error(f"Error processing stream: {str(e)}", exc_info=True) self.trace.event(name="error_processing_stream", level="ERROR", status_message=(f"Error processing stream: {str(e)}")) @@ -759,43 +890,7 @@ class ResponseProcessor: ) if err_msg_obj: yield format_for_yield(err_msg_obj) - # --- Calculate and Store Cost --- - if assistant_message_object: # Only calculate if assistant message was saved - try: - # Use the full llm_response object for potentially more accurate cost calculation - final_cost = None - if hasattr(llm_response, '_hidden_params') and 'response_cost' in llm_response._hidden_params and llm_response._hidden_params['response_cost'] is not None and llm_response._hidden_params['response_cost'] != 0.0: - final_cost = llm_response._hidden_params['response_cost'] - logger.info(f"Using response_cost from _hidden_params: {final_cost}") - - if final_cost is None: # Fall back to calculating cost if direct cost not available or zero - logger.info("Calculating cost using completion_cost function.") - # Note: litellm might need 'messages' kwarg depending on model/provider - final_cost = completion_cost( - completion_response=llm_response, - model=llm_model, # Explicitly pass the model name - # messages=prompt_messages # Pass prompt messages if needed by litellm for this model - ) - - if final_cost is not None and final_cost > 0: - logger.info(f"Calculated final cost for non-stream: {final_cost}") - await self.add_message( - thread_id=thread_id, - type="cost", - content={"cost": final_cost}, - is_llm_message=False, # Cost is metadata - metadata={"thread_run_id": thread_run_id} # Keep track of the run - ) - logger.info(f"Cost message saved for non-stream: {final_cost}") - self.trace.update(metadata={"cost": final_cost}) - else: - logger.info("Non-stream cost calculation resulted in zero or None, not storing cost message.") - self.trace.update(metadata={"cost": 0}) - - except Exception as e: - logger.error(f"Error calculating final cost for non-stream: {str(e)}") - self.trace.event(name="error_calculating_final_cost_for_non_stream", level="ERROR", status_message=(f"Error calculating final cost for non-stream: {str(e)}")) - # --- Execute Tools and Yield Results --- + # --- Execute Tools and Yield Results --- tool_calls_to_execute = [item['tool_call'] for item in all_tool_data] if config.execute_tools and tool_calls_to_execute: logger.info(f"Executing {len(tool_calls_to_execute)} tools with strategy: {config.tool_execution_strategy}") @@ -850,6 +945,22 @@ class ResponseProcessor: ) if finish_msg_obj: yield format_for_yield(finish_msg_obj) + # --- Save and Yield assistant_response_end --- + if assistant_message_object: # Only save if assistant message was saved + try: + # Save the full LiteLLM response object directly in content + await self.add_message( + thread_id=thread_id, + type="assistant_response_end", + content=llm_response, + is_llm_message=False, + metadata={"thread_run_id": thread_run_id} + ) + logger.info("Assistant response end saved for non-stream") + except Exception as e: + logger.error(f"Error saving assistant response end for non-stream: {str(e)}") + self.trace.event(name="error_saving_assistant_response_end_for_non_stream", level="ERROR", status_message=(f"Error saving assistant response end for non-stream: {str(e)}")) + except Exception as e: logger.error(f"Error processing non-streaming response: {str(e)}", exc_info=True) self.trace.event(name="error_processing_non_streaming_response", level="ERROR", status_message=(f"Error processing non-streaming response: {str(e)}")) @@ -1619,7 +1730,7 @@ class ResponseProcessor: # return summary summary_output = result.output if hasattr(result, 'output') else str(result) - success_status = structured_result["tool_execution"]["result"]["success"] + success_status = structured_result_v1["tool_execution"]["result"]["success"] # Create a more comprehensive summary for the LLM if xml_tag_name: