From 381081964f484af853e1c67e2ce4640fb6b30d02 Mon Sep 17 00:00:00 2001 From: Saumya Date: Fri, 3 Oct 2025 23:18:07 +0530 Subject: [PATCH] fix billing calculation --- backend/core/agentpress/response_processor.py | 101 ++++++++++++++++++ 1 file changed, 101 insertions(+) diff --git a/backend/core/agentpress/response_processor.py b/backend/core/agentpress/response_processor.py index 73a6e2e8..de82ea19 100644 --- a/backend/core/agentpress/response_processor.py +++ b/backend/core/agentpress/response_processor.py @@ -26,6 +26,7 @@ from core.utils.json_helpers import ( ensure_dict, ensure_list, safe_json_parse, to_json_string, format_for_yield ) +from litellm import token_counter # Type alias for XML result adding strategy XmlAddingStrategy = Literal["user_message", "assistant_message", "inline_edit"] @@ -117,6 +118,38 @@ class ResponseProcessor: return format_for_yield(message_obj) return None + def _estimate_token_usage(self, prompt_messages: List[Dict[str, Any]], accumulated_content: str, llm_model: str) -> Dict[str, Any]: + """ + Estimate token usage when exact usage data is unavailable. + This is critical for billing on timeouts, crashes, disconnects, etc. + """ + try: + prompt_tokens = token_counter(model=llm_model, messages=prompt_messages) + completion_tokens = token_counter(model=llm_model, text=accumulated_content) if accumulated_content else 0 + + logger.warning(f"⚠️ ESTIMATED TOKEN USAGE (no exact data): prompt={prompt_tokens}, completion={completion_tokens}") + + return { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + "estimated": True + } + except Exception as e: + logger.error(f"Failed to estimate token usage: {e}") + fallback_prompt = len(' '.join(str(m.get('content', '')) for m in prompt_messages).split()) * 1.3 + fallback_completion = len(accumulated_content.split()) * 1.3 if accumulated_content else 0 + + logger.warning(f"⚠️ FALLBACK TOKEN ESTIMATION: prompt≈{int(fallback_prompt)}, completion≈{int(fallback_completion)}") + + return { + "prompt_tokens": int(fallback_prompt), + "completion_tokens": int(fallback_completion), + "total_tokens": int(fallback_prompt + fallback_completion), + "estimated": True, + "fallback": True + } + def _serialize_model_response(self, model_response) -> Dict[str, Any]: """Convert a LiteLLM ModelResponse object to a JSON-serializable dictionary. @@ -241,6 +274,7 @@ class ResponseProcessor: final_llm_response = None first_chunk_time = None last_chunk_time = None + assistant_response_end_saved = False logger.debug(f"Streaming Config: XML={config.xml_tool_calling}, Native={config.native_tool_calling}, " f"Execute on stream={config.execute_on_stream}, Strategy={config.tool_execution_strategy}") @@ -788,6 +822,7 @@ class ResponseProcessor: is_llm_message=False, metadata={"thread_run_id": thread_run_id} ) + assistant_response_end_saved = True logger.debug("Assistant response end saved for stream (before termination)") except Exception as e: logger.error(f"Error saving assistant response end for stream (before termination): {str(e)}") @@ -844,6 +879,7 @@ class ResponseProcessor: is_llm_message=False, metadata={"thread_run_id": thread_run_id} ) + assistant_response_end_saved = True else: logger.warning("⚠️ No complete LiteLLM response available, skipping assistant_response_end") logger.debug("Assistant response end saved for stream") @@ -867,6 +903,71 @@ class ResponseProcessor: raise finally: + # CRITICAL BULLETPROOF BILLING: Always save assistant_response_end if not already saved + # This handles: timeouts, crashes, disconnects, rate limits, incomplete generations, manual stops + if not assistant_response_end_saved and last_assistant_message_object and auto_continue_count == 0: + try: + logger.info("💰 BULLETPROOF BILLING: Saving assistant_response_end in finally block") + + # Try to get exact usage from final_llm_response, fall back to estimation + if final_llm_response: + logger.info("💰 Using exact usage from LLM response") + assistant_end_content = self._serialize_model_response(final_llm_response) + else: + logger.warning("💰 No LLM response with usage - ESTIMATING token usage for billing") + # CRITICAL: Estimate tokens to ensure billing even without exact data + estimated_usage = self._estimate_token_usage(prompt_messages, accumulated_content, llm_model) + assistant_end_content = { + "model": llm_model, + "usage": estimated_usage + } + + assistant_end_content["streaming"] = True + + # Add response timing if available + response_ms = None + if first_chunk_time and last_chunk_time: + response_ms = int((last_chunk_time - first_chunk_time) * 1000) + assistant_end_content["response_ms"] = response_ms + + # Add choices structure + assistant_end_content["choices"] = [ + { + "finish_reason": finish_reason or "interrupted", + "index": 0, + "message": { + "role": "assistant", + "content": accumulated_content, + "tool_calls": complete_native_tool_calls or None + } + } + ] + + usage_info = assistant_end_content.get('usage', {}) + is_estimated = usage_info.get('estimated', False) + logger.info(f"💰 BILLING RECOVERY - Usage ({'ESTIMATED' if is_estimated else 'EXACT'}): {usage_info}") + + # SAVE THE BILLING DATA - this triggers _handle_billing in thread_manager + await self.add_message( + thread_id=thread_id, + type="assistant_response_end", + content=assistant_end_content, + is_llm_message=False, + metadata={"thread_run_id": thread_run_id} + ) + assistant_response_end_saved = True + logger.info(f"✅ BILLING SUCCESS: Saved assistant_response_end ({'estimated' if is_estimated else 'exact'} usage)") + + except Exception as billing_e: + logger.error(f"❌ CRITICAL BILLING FAILURE: Could not save assistant_response_end: {str(billing_e)}", exc_info=True) + self.trace.event( + name="critical_billing_failure_in_finally", + level="ERROR", + status_message=(f"Failed to save assistant_response_end for billing: {str(billing_e)}") + ) + elif assistant_response_end_saved: + logger.debug("✅ Billing already handled (assistant_response_end was saved earlier)") + # Update continuous state for potential auto-continue if should_auto_continue: continuous_state['accumulated_content'] = accumulated_content