fix billing calculation

This commit is contained in:
Saumya 2025-10-03 23:18:07 +05:30
parent 58313084c7
commit 381081964f
1 changed files with 101 additions and 0 deletions

View File

@ -26,6 +26,7 @@ from core.utils.json_helpers import (
ensure_dict, ensure_list, safe_json_parse,
to_json_string, format_for_yield
)
from litellm import token_counter
# Type alias for XML result adding strategy
XmlAddingStrategy = Literal["user_message", "assistant_message", "inline_edit"]
@ -117,6 +118,38 @@ class ResponseProcessor:
return format_for_yield(message_obj)
return None
def _estimate_token_usage(self, prompt_messages: List[Dict[str, Any]], accumulated_content: str, llm_model: str) -> Dict[str, Any]:
"""
Estimate token usage when exact usage data is unavailable.
This is critical for billing on timeouts, crashes, disconnects, etc.
"""
try:
prompt_tokens = token_counter(model=llm_model, messages=prompt_messages)
completion_tokens = token_counter(model=llm_model, text=accumulated_content) if accumulated_content else 0
logger.warning(f"⚠️ ESTIMATED TOKEN USAGE (no exact data): prompt={prompt_tokens}, completion={completion_tokens}")
return {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
"estimated": True
}
except Exception as e:
logger.error(f"Failed to estimate token usage: {e}")
fallback_prompt = len(' '.join(str(m.get('content', '')) for m in prompt_messages).split()) * 1.3
fallback_completion = len(accumulated_content.split()) * 1.3 if accumulated_content else 0
logger.warning(f"⚠️ FALLBACK TOKEN ESTIMATION: prompt≈{int(fallback_prompt)}, completion≈{int(fallback_completion)}")
return {
"prompt_tokens": int(fallback_prompt),
"completion_tokens": int(fallback_completion),
"total_tokens": int(fallback_prompt + fallback_completion),
"estimated": True,
"fallback": True
}
def _serialize_model_response(self, model_response) -> Dict[str, Any]:
"""Convert a LiteLLM ModelResponse object to a JSON-serializable dictionary.
@ -241,6 +274,7 @@ class ResponseProcessor:
final_llm_response = None
first_chunk_time = None
last_chunk_time = None
assistant_response_end_saved = False
logger.debug(f"Streaming Config: XML={config.xml_tool_calling}, Native={config.native_tool_calling}, "
f"Execute on stream={config.execute_on_stream}, Strategy={config.tool_execution_strategy}")
@ -788,6 +822,7 @@ class ResponseProcessor:
is_llm_message=False,
metadata={"thread_run_id": thread_run_id}
)
assistant_response_end_saved = True
logger.debug("Assistant response end saved for stream (before termination)")
except Exception as e:
logger.error(f"Error saving assistant response end for stream (before termination): {str(e)}")
@ -844,6 +879,7 @@ class ResponseProcessor:
is_llm_message=False,
metadata={"thread_run_id": thread_run_id}
)
assistant_response_end_saved = True
else:
logger.warning("⚠️ No complete LiteLLM response available, skipping assistant_response_end")
logger.debug("Assistant response end saved for stream")
@ -867,6 +903,71 @@ class ResponseProcessor:
raise
finally:
# CRITICAL BULLETPROOF BILLING: Always save assistant_response_end if not already saved
# This handles: timeouts, crashes, disconnects, rate limits, incomplete generations, manual stops
if not assistant_response_end_saved and last_assistant_message_object and auto_continue_count == 0:
try:
logger.info("💰 BULLETPROOF BILLING: Saving assistant_response_end in finally block")
# Try to get exact usage from final_llm_response, fall back to estimation
if final_llm_response:
logger.info("💰 Using exact usage from LLM response")
assistant_end_content = self._serialize_model_response(final_llm_response)
else:
logger.warning("💰 No LLM response with usage - ESTIMATING token usage for billing")
# CRITICAL: Estimate tokens to ensure billing even without exact data
estimated_usage = self._estimate_token_usage(prompt_messages, accumulated_content, llm_model)
assistant_end_content = {
"model": llm_model,
"usage": estimated_usage
}
assistant_end_content["streaming"] = True
# Add response timing if available
response_ms = None
if first_chunk_time and last_chunk_time:
response_ms = int((last_chunk_time - first_chunk_time) * 1000)
assistant_end_content["response_ms"] = response_ms
# Add choices structure
assistant_end_content["choices"] = [
{
"finish_reason": finish_reason or "interrupted",
"index": 0,
"message": {
"role": "assistant",
"content": accumulated_content,
"tool_calls": complete_native_tool_calls or None
}
}
]
usage_info = assistant_end_content.get('usage', {})
is_estimated = usage_info.get('estimated', False)
logger.info(f"💰 BILLING RECOVERY - Usage ({'ESTIMATED' if is_estimated else 'EXACT'}): {usage_info}")
# SAVE THE BILLING DATA - this triggers _handle_billing in thread_manager
await self.add_message(
thread_id=thread_id,
type="assistant_response_end",
content=assistant_end_content,
is_llm_message=False,
metadata={"thread_run_id": thread_run_id}
)
assistant_response_end_saved = True
logger.info(f"✅ BILLING SUCCESS: Saved assistant_response_end ({'estimated' if is_estimated else 'exact'} usage)")
except Exception as billing_e:
logger.error(f"❌ CRITICAL BILLING FAILURE: Could not save assistant_response_end: {str(billing_e)}", exc_info=True)
self.trace.event(
name="critical_billing_failure_in_finally",
level="ERROR",
status_message=(f"Failed to save assistant_response_end for billing: {str(billing_e)}")
)
elif assistant_response_end_saved:
logger.debug("✅ Billing already handled (assistant_response_end was saved earlier)")
# Update continuous state for potential auto-continue
if should_auto_continue:
continuous_state['accumulated_content'] = accumulated_content