fix usage deduction

This commit is contained in:
Saumya 2025-10-04 03:05:54 +05:30
parent 381081964f
commit 056150059f
2 changed files with 94 additions and 87 deletions

View File

@ -274,7 +274,7 @@ class ResponseProcessor:
final_llm_response = None
first_chunk_time = None
last_chunk_time = None
assistant_response_end_saved = False
llm_response_end_saved = False
logger.debug(f"Streaming Config: XML={config.xml_tool_calling}, Native={config.native_tool_calling}, "
f"Execute on stream={config.execute_on_stream}, Strategy={config.tool_execution_strategy}")
@ -282,9 +282,13 @@ class ResponseProcessor:
# Reuse thread_run_id for auto-continue or create new one
thread_run_id = continuous_state.get('thread_run_id') or str(uuid.uuid4())
continuous_state['thread_run_id'] = thread_run_id
# CRITICAL: Generate unique ID for THIS specific LLM call (not per thread run)
llm_response_id = str(uuid.uuid4())
logger.info(f"🔵 LLM CALL #{auto_continue_count + 1} starting - llm_response_id: {llm_response_id}")
try:
# --- Save and Yield Start Events (only if not auto-continuing) ---
# --- Save and Yield Start Events ---
if auto_continue_count == 0:
start_content = {"status_type": "thread_run_start", "thread_run_id": thread_run_id}
start_msg_obj = await self.add_message(
@ -292,19 +296,24 @@ class ResponseProcessor:
is_llm_message=False, metadata={"thread_run_id": thread_run_id}
)
if start_msg_obj:
# logger.debug(f"📤 About to yield start_msg_obj")
yield format_for_yield(start_msg_obj)
# logger.debug(f"✅ Successfully yielded start_msg_obj")
assist_start_content = {"status_type": "assistant_response_start"}
assist_start_msg_obj = await self.add_message(
thread_id=thread_id, type="status", content=assist_start_content,
is_llm_message=False, metadata={"thread_run_id": thread_run_id}
)
if assist_start_msg_obj:
# logger.debug(f"📤 About to yield assist_start_msg_obj")
yield format_for_yield(assist_start_msg_obj)
# logger.debug(f"✅ Successfully yielded assist_start_msg_obj")
llm_start_content = {
"llm_response_id": llm_response_id,
"auto_continue_count": auto_continue_count,
"model": llm_model,
"timestamp": datetime.now(timezone.utc).isoformat()
}
llm_start_msg_obj = await self.add_message(
thread_id=thread_id, type="llm_response_start", content=llm_start_content,
is_llm_message=False, metadata={
"thread_run_id": thread_run_id,
"llm_response_id": llm_response_id
}
)
if llm_start_msg_obj:
yield format_for_yield(llm_start_msg_obj)
logger.info(f"✅ Saved llm_response_start for call #{auto_continue_count + 1}")
# --- End Start Events ---
__sequence = continuous_state.get('sequence', 0) # get the sequence from the previous auto-continue cycle
@ -782,23 +791,23 @@ class ResponseProcessor:
)
if finish_msg_obj: yield format_for_yield(finish_msg_obj)
# Save assistant_response_end BEFORE terminating
# Save llm_response_end BEFORE terminating
if last_assistant_message_object:
try:
# Use the complete LiteLLM response object as received
if final_llm_response:
logger.info("✅ Using complete LiteLLM response for assistant_response_end (before termination)")
logger.info("✅ Using complete LiteLLM response for llm_response_end (before termination)")
# Serialize the complete response object as-is
assistant_end_content = self._serialize_model_response(final_llm_response)
llm_end_content = self._serialize_model_response(final_llm_response)
# Add streaming flag and response timing if available
assistant_end_content["streaming"] = True
llm_end_content["streaming"] = True
if response_ms:
assistant_end_content["response_ms"] = response_ms
llm_end_content["response_ms"] = response_ms
# For streaming responses, we need to construct the choices manually
# since the streaming chunk doesn't have the complete message structure
assistant_end_content["choices"] = [
llm_end_content["choices"] = [
{
"finish_reason": finish_reason or "stop",
"index": 0,
@ -809,36 +818,40 @@ class ResponseProcessor:
}
}
]
llm_end_content["llm_response_id"] = llm_response_id
else:
logger.warning("⚠️ No complete LiteLLM response available, skipping assistant_response_end")
assistant_end_content = None
logger.warning("⚠️ No complete LiteLLM response available, skipping llm_response_end")
llm_end_content = None
# Only save if we have content
if assistant_end_content:
if llm_end_content:
await self.add_message(
thread_id=thread_id,
type="assistant_response_end",
content=assistant_end_content,
type="llm_response_end",
content=llm_end_content,
is_llm_message=False,
metadata={"thread_run_id": thread_run_id}
metadata={
"thread_run_id": thread_run_id,
"llm_response_id": llm_response_id
}
)
assistant_response_end_saved = True
logger.debug("Assistant response end saved for stream (before termination)")
llm_response_end_saved = True
logger.info(f"✅ llm_response_end saved for call #{auto_continue_count + 1} (before termination)")
except Exception as e:
logger.error(f"Error saving assistant response end for stream (before termination): {str(e)}")
self.trace.event(name="error_saving_assistant_response_end_for_stream_before_termination", level="ERROR", status_message=(f"Error saving assistant response end for stream (before termination): {str(e)}"))
logger.error(f"Error saving llm_response_end (before termination): {str(e)}")
self.trace.event(name="error_saving_llm_response_end_before_termination", level="ERROR", status_message=(f"Error saving llm_response_end (before termination): {str(e)}"))
# Skip all remaining processing and go to finally block
return
# --- Save and Yield assistant_response_end ---
# Only save assistant_response_end if not auto-continuing (response is actually complete)
# --- Save and Yield llm_response_end ---
# Only save llm_response_end if not auto-continuing (response is actually complete)
if not should_auto_continue:
if last_assistant_message_object: # Only save if assistant message was saved
if last_assistant_message_object:
try:
# Use the complete LiteLLM response object as received
if final_llm_response:
logger.info("✅ Using complete LiteLLM response for assistant_response_end (normal)")
logger.info("✅ Using complete LiteLLM response for llm_response_end (normal completion)")
# Log the complete response object for debugging
logger.info(f"🔍 COMPLETE RESPONSE OBJECT: {final_llm_response}")
@ -846,17 +859,17 @@ class ResponseProcessor:
logger.info(f"🔍 RESPONSE OBJECT DICT: {final_llm_response.__dict__ if hasattr(final_llm_response, '__dict__') else 'NO_DICT'}")
# Serialize the complete response object as-is
assistant_end_content = self._serialize_model_response(final_llm_response)
logger.info(f"🔍 SERIALIZED CONTENT: {assistant_end_content}")
llm_end_content = self._serialize_model_response(final_llm_response)
logger.info(f"🔍 SERIALIZED CONTENT: {llm_end_content}")
# Add streaming flag and response timing if available
assistant_end_content["streaming"] = True
llm_end_content["streaming"] = True
if response_ms:
assistant_end_content["response_ms"] = response_ms
llm_end_content["response_ms"] = response_ms
# For streaming responses, we need to construct the choices manually
# since the streaming chunk doesn't have the complete message structure
assistant_end_content["choices"] = [
llm_end_content["choices"] = [
{
"finish_reason": finish_reason or "stop",
"index": 0,
@ -867,25 +880,29 @@ class ResponseProcessor:
}
}
]
llm_end_content["llm_response_id"] = llm_response_id
# DEBUG: Log the actual response usage
logger.info(f"🔍 RESPONSE PROCESSOR COMPLETE USAGE (normal): {assistant_end_content.get('usage', 'NO_USAGE')}")
logger.info(f"🔍 FINAL ASSISTANT END CONTENT: {assistant_end_content}")
logger.info(f"🔍 RESPONSE PROCESSOR COMPLETE USAGE (normal): {llm_end_content.get('usage', 'NO_USAGE')}")
logger.info(f"🔍 FINAL LLM END CONTENT: {llm_end_content}")
await self.add_message(
thread_id=thread_id,
type="assistant_response_end",
content=assistant_end_content,
type="llm_response_end",
content=llm_end_content,
is_llm_message=False,
metadata={"thread_run_id": thread_run_id}
metadata={
"thread_run_id": thread_run_id,
"llm_response_id": llm_response_id
}
)
assistant_response_end_saved = True
llm_response_end_saved = True
else:
logger.warning("⚠️ No complete LiteLLM response available, skipping assistant_response_end")
logger.debug("Assistant response end saved for stream")
logger.warning("⚠️ No complete LiteLLM response available, skipping llm_response_end")
logger.info(f"✅ llm_response_end saved for call #{auto_continue_count + 1} (normal completion)")
except Exception as e:
logger.error(f"Error saving assistant response end for stream: {str(e)}")
self.trace.event(name="error_saving_assistant_response_end_for_stream", level="ERROR", status_message=(f"Error saving assistant response end for stream: {str(e)}"))
logger.error(f"Error saving llm_response_end: {str(e)}")
self.trace.event(name="error_saving_llm_response_end", level="ERROR", status_message=(f"Error saving llm_response_end: {str(e)}"))
except Exception as e:
# Use ErrorProcessor for consistent error handling
@ -903,35 +920,29 @@ class ResponseProcessor:
raise
finally:
# CRITICAL BULLETPROOF BILLING: Always save assistant_response_end if not already saved
# This handles: timeouts, crashes, disconnects, rate limits, incomplete generations, manual stops
if not assistant_response_end_saved and last_assistant_message_object and auto_continue_count == 0:
if not llm_response_end_saved and last_assistant_message_object:
try:
logger.info("💰 BULLETPROOF BILLING: Saving assistant_response_end in finally block")
# Try to get exact usage from final_llm_response, fall back to estimation
logger.info(f"💰 BULLETPROOF BILLING: Saving llm_response_end in finally block for call #{auto_continue_count + 1}")
if final_llm_response:
logger.info("💰 Using exact usage from LLM response")
assistant_end_content = self._serialize_model_response(final_llm_response)
llm_end_content = self._serialize_model_response(final_llm_response)
else:
logger.warning("💰 No LLM response with usage - ESTIMATING token usage for billing")
# CRITICAL: Estimate tokens to ensure billing even without exact data
estimated_usage = self._estimate_token_usage(prompt_messages, accumulated_content, llm_model)
assistant_end_content = {
llm_end_content = {
"model": llm_model,
"usage": estimated_usage
}
assistant_end_content["streaming"] = True
llm_end_content["streaming"] = True
llm_end_content["llm_response_id"] = llm_response_id
# Add response timing if available
response_ms = None
if first_chunk_time and last_chunk_time:
response_ms = int((last_chunk_time - first_chunk_time) * 1000)
assistant_end_content["response_ms"] = response_ms
llm_end_content["response_ms"] = response_ms
# Add choices structure
assistant_end_content["choices"] = [
llm_end_content["choices"] = [
{
"finish_reason": finish_reason or "interrupted",
"index": 0,
@ -943,42 +954,41 @@ class ResponseProcessor:
}
]
usage_info = assistant_end_content.get('usage', {})
usage_info = llm_end_content.get('usage', {})
is_estimated = usage_info.get('estimated', False)
logger.info(f"💰 BILLING RECOVERY - Usage ({'ESTIMATED' if is_estimated else 'EXACT'}): {usage_info}")
# SAVE THE BILLING DATA - this triggers _handle_billing in thread_manager
await self.add_message(
thread_id=thread_id,
type="assistant_response_end",
content=assistant_end_content,
type="llm_response_end",
content=llm_end_content,
is_llm_message=False,
metadata={"thread_run_id": thread_run_id}
metadata={
"thread_run_id": thread_run_id,
"llm_response_id": llm_response_id
}
)
assistant_response_end_saved = True
logger.info(f"✅ BILLING SUCCESS: Saved assistant_response_end ({'estimated' if is_estimated else 'exact'} usage)")
llm_response_end_saved = True
logger.info(f"✅ BILLING SUCCESS: Saved llm_response_end for call #{auto_continue_count + 1} ({'estimated' if is_estimated else 'exact'} usage)")
except Exception as billing_e:
logger.error(f"❌ CRITICAL BILLING FAILURE: Could not save assistant_response_end: {str(billing_e)}", exc_info=True)
logger.error(f"❌ CRITICAL BILLING FAILURE: Could not save llm_response_end: {str(billing_e)}", exc_info=True)
self.trace.event(
name="critical_billing_failure_in_finally",
level="ERROR",
status_message=(f"Failed to save assistant_response_end for billing: {str(billing_e)}")
status_message=(f"Failed to save llm_response_end for billing: {str(billing_e)}")
)
elif assistant_response_end_saved:
logger.debug("✅ Billing already handled (assistant_response_end was saved earlier)")
elif llm_response_end_saved:
logger.debug(f"✅ Billing already handled for call #{auto_continue_count + 1} (llm_response_end was saved earlier)")
# Update continuous state for potential auto-continue
if should_auto_continue:
continuous_state['accumulated_content'] = accumulated_content
continuous_state['sequence'] = __sequence
logger.debug(f"Updated continuous state for auto-continue with {len(accumulated_content)} chars")
else:
# Set the final output in the generation object if provided
if generation and 'accumulated_content' in locals():
try:
# Update generation with usage metrics from the complete LiteLLM response
if final_llm_response and hasattr(final_llm_response, 'usage'):
generation.update(
usage=final_llm_response.usage.model_dump() if hasattr(final_llm_response.usage, 'model_dump') else dict(final_llm_response.usage),

View File

@ -102,13 +102,11 @@ class ThreadManager:
try:
result = await client.table('messages').insert(data_to_insert).execute()
# logger.debug(f"Successfully added message to thread {thread_id}")
if result.data and len(result.data) > 0 and 'message_id' in result.data[0]:
saved_message = result.data[0]
# Handle billing for assistant response end messages
if type == "assistant_response_end" and isinstance(content, dict):
if type == "llm_response_end" and isinstance(content, dict):
await self._handle_billing(thread_id, content, saved_message)
return saved_message
@ -120,18 +118,17 @@ class ThreadManager:
raise
async def _handle_billing(self, thread_id: str, content: dict, saved_message: dict):
"""Handle billing for LLM usage."""
try:
usage = content.get("usage", {})
llm_response_id = content.get("llm_response_id", "unknown")
logger.info(f"💰 Processing billing for LLM response: {llm_response_id}")
# DEBUG: Log the complete usage object to see what data we have
# logger.info(f"🔍 THREAD MANAGER USAGE: {usage}")
# logger.info(f"🔍 THREAD MANAGER CONTENT: {content}")
usage = content.get("usage", {})
prompt_tokens = int(usage.get("prompt_tokens", 0) or 0)
completion_tokens = int(usage.get("completion_tokens", 0) or 0)
is_estimated = usage.get("estimated", False)
is_fallback = usage.get("fallback", False)
# Try cache_read_input_tokens first (Anthropic standard), then fallback to prompt_tokens_details.cached_tokens
cache_read_tokens = int(usage.get("cache_read_input_tokens", 0) or 0)
if cache_read_tokens == 0:
cache_read_tokens = int(usage.get("prompt_tokens_details", {}).get("cached_tokens", 0) or 0)
@ -139,8 +136,8 @@ class ThreadManager:
cache_creation_tokens = int(usage.get("cache_creation_input_tokens", 0) or 0)
model = content.get("model")
# DEBUG: Log what we detected
logger.info(f"🔍 CACHE DETECTION: cache_read={cache_read_tokens}, cache_creation={cache_creation_tokens}, prompt={prompt_tokens}")
usage_type = "FALLBACK ESTIMATE" if is_fallback else ("ESTIMATED" if is_estimated else "EXACT")
logger.info(f"💰 Usage type: {usage_type} - prompt={prompt_tokens}, completion={completion_tokens}, cache_read={cache_read_tokens}, cache_creation={cache_creation_tokens}")
client = await self.db.client
thread_row = await client.table('threads').select('account_id').eq('thread_id', thread_id).limit(1).execute()