roll-back cost calculation

This commit is contained in:
LE Quoc Dat 2025-04-18 17:57:10 +01:00
parent db46c1aee5
commit e3b08b1326
1 changed files with 59 additions and 2 deletions

View File

@ -510,6 +510,31 @@ class ResponseProcessor:
logger.error(f"Failed to save tool result for index {tool_idx}, not yielding result message.")
# Optionally yield error status for saving failure?
# --- Calculate and Store Cost ---
if last_assistant_message_object: # Only calculate if assistant message was saved
try:
# Use accumulated_content for streaming cost calculation
final_cost = completion_cost(
model=llm_model,
messages=prompt_messages, # Use the prompt messages provided
completion=accumulated_content
)
if final_cost is not None and final_cost > 0:
logger.info(f"Calculated final cost for stream: {final_cost}")
await self.add_message(
thread_id=thread_id,
type="cost",
content={"cost": final_cost},
is_llm_message=False, # Cost is metadata
metadata={"thread_run_id": thread_run_id} # Keep track of the run
)
logger.info(f"Cost message saved for stream: {final_cost}")
else:
logger.info("Stream cost calculation resulted in zero or None, not storing cost message.")
except Exception as e:
logger.error(f"Error calculating final cost for stream: {str(e)}")
# --- Final Finish Status ---
if finish_reason and finish_reason != "xml_tool_limit_reached":
finish_content = {"status_type": "finish", "finish_reason": finish_reason}
@ -538,8 +563,6 @@ class ResponseProcessor:
)
if end_msg_obj: yield end_msg_obj
# ... (Cost tracking can remain if fixed) ...
async def process_non_streaming_response(
self,
llm_response: Any,
@ -637,6 +660,40 @@ class ResponseProcessor:
)
if err_msg_obj: yield err_msg_obj
# --- Calculate and Store Cost ---
if assistant_message_object: # Only calculate if assistant message was saved
try:
# Use the full llm_response object for potentially more accurate cost calculation
final_cost = None
if hasattr(llm_response, '_hidden_params') and 'response_cost' in llm_response._hidden_params and llm_response._hidden_params['response_cost'] is not None and llm_response._hidden_params['response_cost'] != 0.0:
final_cost = llm_response._hidden_params['response_cost']
logger.info(f"Using response_cost from _hidden_params: {final_cost}")
if final_cost is None: # Fall back to calculating cost if direct cost not available or zero
logger.info("Calculating cost using completion_cost function.")
# Note: litellm might need 'messages' kwarg depending on model/provider
final_cost = completion_cost(
completion_response=llm_response,
model=llm_model, # Explicitly pass the model name
# messages=prompt_messages # Pass prompt messages if needed by litellm for this model
)
if final_cost is not None and final_cost > 0:
logger.info(f"Calculated final cost for non-stream: {final_cost}")
await self.add_message(
thread_id=thread_id,
type="cost",
content={"cost": final_cost},
is_llm_message=False, # Cost is metadata
metadata={"thread_run_id": thread_run_id} # Keep track of the run
)
logger.info(f"Cost message saved for non-stream: {final_cost}")
else:
logger.info("Non-stream cost calculation resulted in zero or None, not storing cost message.")
except Exception as e:
logger.error(f"Error calculating final cost for non-stream: {str(e)}")
# --- Execute Tools and Yield Results ---
tool_calls_to_execute = [item['tool_call'] for item in all_tool_data]
if config.execute_tools and tool_calls_to_execute: