cost calculation

This commit is contained in:
LE Quoc Dat 2025-04-18 05:59:00 +01:00
parent c84ee59dc6
commit 6e4e2673d5
1 changed files with 85 additions and 112 deletions

View File

@ -97,18 +97,18 @@ class ResponseProcessor:
self,
llm_response: AsyncGenerator,
thread_id: str,
prompt_messages: List[Dict[str, Any]],
llm_model: str,
config: ProcessorConfig = ProcessorConfig(),
prompt_messages: Optional[List[Dict[str, Any]]] = None,
llm_model: Optional[str] = None
) -> AsyncGenerator:
"""Process a streaming LLM response, handling tool calls and execution.
Args:
llm_response: Streaming response from the LLM
thread_id: ID of the conversation thread
prompt_messages: List of messages sent to the LLM (the prompt)
llm_model: The name of the LLM model used
config: Configuration for parsing and execution
prompt_messages: List of messages used for cost calculation
llm_model: Name of the LLM model used for cost calculation
Yields:
Formatted chunks of the response including content and tool results
@ -173,27 +173,18 @@ class ResponseProcessor:
if hasattr(chunk, 'choices') and chunk.choices:
delta = chunk.choices[0].delta if hasattr(chunk.choices[0], 'delta') else None
# Check for and log Anthropic thinking content
if delta and hasattr(delta, 'reasoning_content') and delta.reasoning_content:
logger.info(f"[THINKING]: {delta.reasoning_content}")
# Append reasoning to main content to be saved in the final message
accumulated_content += delta.reasoning_content
# Process content chunk
if delta and hasattr(delta, 'content') and delta.content:
chunk_content = delta.content
accumulated_content += chunk_content
current_xml_content += chunk_content
# Process reasoning content if present (Anthropic)
if hasattr(delta, 'reasoning_content') and delta.reasoning_content:
logger.info(f"[THINKING]: {delta.reasoning_content}")
accumulated_content += delta.reasoning_content # Append reasoning to main content
# Calculate cost using prompt and completion - MOVED AFTER MESSAGE SAVE
# try:
# cost = completion_cost(model=chunk.model, prompt=accumulated_content, completion=chunk_content)
# tcount = token_counter(model=chunk.model, messages=[{"role": "user", "content": accumulated_content}])
# accumulated_cost += cost
# accumulated_token_count += tcount
# logger.debug(f"Cost: {cost:.6f}, Token count: {tcount}")
# except Exception as e:
# logger.error(f"Error calculating cost: {str(e)}")
# Check if we've reached the XML tool call limit before yielding content
if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls:
# We've reached the limit, don't yield any more content
@ -362,7 +353,7 @@ class ResponseProcessor:
# If we've reached the XML tool call limit, stop streaming
if finish_reason == "xml_tool_limit_reached":
logger.info("Stopping stream due to XML tool call limit")
logger.info("Stopping stream processing after loop due to XML tool call limit")
break
# After streaming completes or is stopped due to limit, wait for any remaining tool executions
@ -483,6 +474,27 @@ class ResponseProcessor:
is_llm_message=True
)
# Calculate and store cost AFTER adding the main assistant message
if accumulated_content: # Calculate cost if there was content (now includes reasoning)
try:
final_cost = completion_cost(
model=llm_model, # Use the passed model name
messages=prompt_messages, # Use the provided prompt messages
completion=accumulated_content
)
if final_cost is not None and final_cost > 0:
logger.info(f"Calculated final cost for stream: {final_cost}")
await self.add_message(
thread_id=thread_id,
type="cost",
content={"cost": final_cost},
is_llm_message=False # Cost is metadata, not LLM content
)
else:
logger.info("Cost calculation resulted in zero or None, not storing cost message.")
except Exception as e:
logger.error(f"Error calculating final cost for stream: {str(e)}")
# Yield the assistant response end signal *immediately* after saving
if last_assistant_message_id:
yield {
@ -498,35 +510,6 @@ class ResponseProcessor:
"thread_run_id": thread_run_id
}
# --- Cost Calculation (moved here) ---
if prompt_messages and llm_model and accumulated_content:
try:
cost = completion_cost(
model=llm_model,
messages=prompt_messages,
completion=accumulated_content
)
token_count = token_counter(
model=llm_model,
messages=prompt_messages + [{"role": "assistant", "content": accumulated_content}]
)
await self.add_message(
thread_id=thread_id,
type="cost",
content={
"cost": cost,
"prompt_tokens": token_count - token_counter(model=llm_model, messages=[{"role": "assistant", "content": accumulated_content}]), # Approx
"completion_tokens": token_counter(model=llm_model, messages=[{"role": "assistant", "content": accumulated_content}]), # Approx
"total_tokens": token_count,
"model_name": llm_model
},
is_llm_message=False
)
logger.info(f"Calculated cost for streaming response: {cost:.6f} using model {llm_model}")
except Exception as e:
logger.error(f"Error calculating cost: {str(e)}")
# --- End Cost Calculation ---
# --- Process All Tool Calls Now ---
if config.execute_tools:
final_tool_calls_to_process = []
@ -664,24 +647,23 @@ class ResponseProcessor:
yield {"type": "error", "message": str(e), "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None}
finally:
# Yield the detected finish reason if one exists and wasn't suppressed
if finish_reason and finish_reason != "xml_tool_limit_reached":
# Yield a finish signal including the final assistant message ID
if last_assistant_message_id:
# Yield the overall run end signal
yield {
"type": "finish",
"finish_reason": finish_reason,
"type": "thread_run_end",
"thread_run_id": thread_run_id
}
else:
# Yield the overall run end signal
yield {
"type": "thread_run_end",
"thread_run_id": thread_run_id if 'thread_run_id' in locals() else None
}
# Yield a finish signal including the final assistant message ID
# Ensure thread_run_id is defined, even if an early error occurred
run_id = thread_run_id if 'thread_run_id' in locals() else str(uuid.uuid4()) # Fallback ID if needed
yield {
"type": "thread_run_end",
"thread_run_id": run_id
}
# Remove old cost calculation code
pass
# track the cost and token count
# todo: there is a bug as it adds every chunk to db because finally will run every time even in yield
# await self.add_message(
# thread_id=thread_id,
# type="cost",
@ -697,18 +679,18 @@ class ResponseProcessor:
self,
llm_response: Any,
thread_id: str,
config: ProcessorConfig = ProcessorConfig(),
prompt_messages: Optional[List[Dict[str, Any]]] = None,
llm_model: Optional[str] = None
prompt_messages: List[Dict[str, Any]],
llm_model: str,
config: ProcessorConfig = ProcessorConfig()
) -> AsyncGenerator[Dict[str, Any], None]:
"""Process a non-streaming LLM response, handling tool calls and execution.
Args:
llm_response: Response from the LLM
thread_id: ID of the conversation thread
prompt_messages: List of messages sent to the LLM (the prompt)
llm_model: The name of the LLM model used
config: Configuration for parsing and execution
prompt_messages: List of messages used for cost calculation
llm_model: Name of the LLM model used for cost calculation
Yields:
Formatted response including content and tool results
@ -814,6 +796,41 @@ class ResponseProcessor:
is_llm_message=True
)
# Calculate and store cost AFTER adding the main assistant message
if content or (config.native_tool_calling and 'native_tool_calls_for_message' in locals() and native_tool_calls_for_message): # Calculate cost if there's content or tool calls
try:
# Use the full response object for potentially more accurate cost calculation
# Pass model explicitly as it might not be reliably in response_object for all providers
# First check if response_cost is directly available in _hidden_params
final_cost = None
if hasattr(llm_response, '_hidden_params') and 'response_cost' in llm_response._hidden_params and llm_response._hidden_params['response_cost'] != 0.0:
final_cost = llm_response._hidden_params['response_cost']
logger.info(f"Using response_cost from _hidden_params: {final_cost}")
if final_cost is None: # Fall back to calculating cost if direct cost not available or zero
logger.info("Calculating cost using completion_cost function.")
final_cost = completion_cost(
completion_response=llm_response,
model=llm_model, # Use the passed model name
# prompt_messages might be needed for some models/providers
# messages=prompt_messages, # Uncomment if needed
call_type="completion" # Assuming 'completion' type for this context
)
if final_cost is not None and final_cost > 0:
logger.info(f"Calculated final cost for non-stream: {final_cost}")
await self.add_message(
thread_id=thread_id,
type="cost",
content={"cost": final_cost},
is_llm_message=False # Cost is metadata
)
else:
logger.info("Final cost is zero or None, not storing cost message.")
except Exception as e:
logger.error(f"Error calculating final cost for non-stream: {str(e)}")
# Yield content first
yield {
"type": "content",
@ -904,50 +921,6 @@ class ResponseProcessor:
"thread_run_id": thread_run_id
}
# --- Cost Calculation (moved here) ---
if prompt_messages and llm_model:
cost = None
# Attempt to get cost from LiteLLM response first
if hasattr(llm_response, '_hidden_params') and 'response_cost' in llm_response._hidden_params:
cost = llm_response._hidden_params['response_cost']
logger.info(f"Using pre-calculated cost from LiteLLM: {cost:.6f}")
# If no pre-calculated cost, calculate manually
if cost is None:
try:
cost = completion_cost(
model=llm_model,
messages=prompt_messages,
completion=content # Use extracted content
)
logger.info(f"Manually calculated cost for non-streaming response: {cost:.6f} using model {llm_model}")
except Exception as e:
logger.error(f"Error calculating cost: {str(e)}")
# Add cost message if cost was determined
if cost is not None:
try:
# Approximate token counts
completion_tokens = token_counter(model=llm_model, messages=[{"role": "assistant", "content": content}])
prompt_tokens = token_counter(model=llm_model, messages=prompt_messages)
total_tokens = prompt_tokens + completion_tokens
await self.add_message(
thread_id=thread_id,
type="cost",
content={
"cost": cost,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"model_name": llm_model
},
is_llm_message=False
)
except Exception as e:
logger.error(f"Error saving cost message: {str(e)}")
# --- End Cost Calculation ---
except Exception as e:
logger.error(f"Error processing response: {str(e)}", exc_info=True)
yield {"type": "error", "message": str(e), "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None}