diff --git a/backend/agentpress/response_processor.py b/backend/agentpress/response_processor.py index 02a28ed9..d3fc30b0 100644 --- a/backend/agentpress/response_processor.py +++ b/backend/agentpress/response_processor.py @@ -510,6 +510,31 @@ class ResponseProcessor: logger.error(f"Failed to save tool result for index {tool_idx}, not yielding result message.") # Optionally yield error status for saving failure? + # --- Calculate and Store Cost --- + if last_assistant_message_object: # Only calculate if assistant message was saved + try: + # Use accumulated_content for streaming cost calculation + final_cost = completion_cost( + model=llm_model, + messages=prompt_messages, # Use the prompt messages provided + completion=accumulated_content + ) + if final_cost is not None and final_cost > 0: + logger.info(f"Calculated final cost for stream: {final_cost}") + await self.add_message( + thread_id=thread_id, + type="cost", + content={"cost": final_cost}, + is_llm_message=False, # Cost is metadata + metadata={"thread_run_id": thread_run_id} # Keep track of the run + ) + logger.info(f"Cost message saved for stream: {final_cost}") + else: + logger.info("Stream cost calculation resulted in zero or None, not storing cost message.") + except Exception as e: + logger.error(f"Error calculating final cost for stream: {str(e)}") + + # --- Final Finish Status --- if finish_reason and finish_reason != "xml_tool_limit_reached": finish_content = {"status_type": "finish", "finish_reason": finish_reason} @@ -538,8 +563,6 @@ class ResponseProcessor: ) if end_msg_obj: yield end_msg_obj - # ... (Cost tracking can remain if fixed) ... - async def process_non_streaming_response( self, llm_response: Any, @@ -637,6 +660,40 @@ class ResponseProcessor: ) if err_msg_obj: yield err_msg_obj + # --- Calculate and Store Cost --- + if assistant_message_object: # Only calculate if assistant message was saved + try: + # Use the full llm_response object for potentially more accurate cost calculation + final_cost = None + if hasattr(llm_response, '_hidden_params') and 'response_cost' in llm_response._hidden_params and llm_response._hidden_params['response_cost'] is not None and llm_response._hidden_params['response_cost'] != 0.0: + final_cost = llm_response._hidden_params['response_cost'] + logger.info(f"Using response_cost from _hidden_params: {final_cost}") + + if final_cost is None: # Fall back to calculating cost if direct cost not available or zero + logger.info("Calculating cost using completion_cost function.") + # Note: litellm might need 'messages' kwarg depending on model/provider + final_cost = completion_cost( + completion_response=llm_response, + model=llm_model, # Explicitly pass the model name + # messages=prompt_messages # Pass prompt messages if needed by litellm for this model + ) + + if final_cost is not None and final_cost > 0: + logger.info(f"Calculated final cost for non-stream: {final_cost}") + await self.add_message( + thread_id=thread_id, + type="cost", + content={"cost": final_cost}, + is_llm_message=False, # Cost is metadata + metadata={"thread_run_id": thread_run_id} # Keep track of the run + ) + logger.info(f"Cost message saved for non-stream: {final_cost}") + else: + logger.info("Non-stream cost calculation resulted in zero or None, not storing cost message.") + + except Exception as e: + logger.error(f"Error calculating final cost for non-stream: {str(e)}") + # --- Execute Tools and Yield Results --- tool_calls_to_execute = [item['tool_call'] for item in all_tool_data] if config.execute_tools and tool_calls_to_execute: