mirror of https://github.com/kortix-ai/suna.git
roll-back cost calculation
This commit is contained in:
parent
db46c1aee5
commit
e3b08b1326
|
@ -510,6 +510,31 @@ class ResponseProcessor:
|
||||||
logger.error(f"Failed to save tool result for index {tool_idx}, not yielding result message.")
|
logger.error(f"Failed to save tool result for index {tool_idx}, not yielding result message.")
|
||||||
# Optionally yield error status for saving failure?
|
# Optionally yield error status for saving failure?
|
||||||
|
|
||||||
|
# --- Calculate and Store Cost ---
|
||||||
|
if last_assistant_message_object: # Only calculate if assistant message was saved
|
||||||
|
try:
|
||||||
|
# Use accumulated_content for streaming cost calculation
|
||||||
|
final_cost = completion_cost(
|
||||||
|
model=llm_model,
|
||||||
|
messages=prompt_messages, # Use the prompt messages provided
|
||||||
|
completion=accumulated_content
|
||||||
|
)
|
||||||
|
if final_cost is not None and final_cost > 0:
|
||||||
|
logger.info(f"Calculated final cost for stream: {final_cost}")
|
||||||
|
await self.add_message(
|
||||||
|
thread_id=thread_id,
|
||||||
|
type="cost",
|
||||||
|
content={"cost": final_cost},
|
||||||
|
is_llm_message=False, # Cost is metadata
|
||||||
|
metadata={"thread_run_id": thread_run_id} # Keep track of the run
|
||||||
|
)
|
||||||
|
logger.info(f"Cost message saved for stream: {final_cost}")
|
||||||
|
else:
|
||||||
|
logger.info("Stream cost calculation resulted in zero or None, not storing cost message.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error calculating final cost for stream: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
# --- Final Finish Status ---
|
# --- Final Finish Status ---
|
||||||
if finish_reason and finish_reason != "xml_tool_limit_reached":
|
if finish_reason and finish_reason != "xml_tool_limit_reached":
|
||||||
finish_content = {"status_type": "finish", "finish_reason": finish_reason}
|
finish_content = {"status_type": "finish", "finish_reason": finish_reason}
|
||||||
|
@ -538,8 +563,6 @@ class ResponseProcessor:
|
||||||
)
|
)
|
||||||
if end_msg_obj: yield end_msg_obj
|
if end_msg_obj: yield end_msg_obj
|
||||||
|
|
||||||
# ... (Cost tracking can remain if fixed) ...
|
|
||||||
|
|
||||||
async def process_non_streaming_response(
|
async def process_non_streaming_response(
|
||||||
self,
|
self,
|
||||||
llm_response: Any,
|
llm_response: Any,
|
||||||
|
@ -637,6 +660,40 @@ class ResponseProcessor:
|
||||||
)
|
)
|
||||||
if err_msg_obj: yield err_msg_obj
|
if err_msg_obj: yield err_msg_obj
|
||||||
|
|
||||||
|
# --- Calculate and Store Cost ---
|
||||||
|
if assistant_message_object: # Only calculate if assistant message was saved
|
||||||
|
try:
|
||||||
|
# Use the full llm_response object for potentially more accurate cost calculation
|
||||||
|
final_cost = None
|
||||||
|
if hasattr(llm_response, '_hidden_params') and 'response_cost' in llm_response._hidden_params and llm_response._hidden_params['response_cost'] is not None and llm_response._hidden_params['response_cost'] != 0.0:
|
||||||
|
final_cost = llm_response._hidden_params['response_cost']
|
||||||
|
logger.info(f"Using response_cost from _hidden_params: {final_cost}")
|
||||||
|
|
||||||
|
if final_cost is None: # Fall back to calculating cost if direct cost not available or zero
|
||||||
|
logger.info("Calculating cost using completion_cost function.")
|
||||||
|
# Note: litellm might need 'messages' kwarg depending on model/provider
|
||||||
|
final_cost = completion_cost(
|
||||||
|
completion_response=llm_response,
|
||||||
|
model=llm_model, # Explicitly pass the model name
|
||||||
|
# messages=prompt_messages # Pass prompt messages if needed by litellm for this model
|
||||||
|
)
|
||||||
|
|
||||||
|
if final_cost is not None and final_cost > 0:
|
||||||
|
logger.info(f"Calculated final cost for non-stream: {final_cost}")
|
||||||
|
await self.add_message(
|
||||||
|
thread_id=thread_id,
|
||||||
|
type="cost",
|
||||||
|
content={"cost": final_cost},
|
||||||
|
is_llm_message=False, # Cost is metadata
|
||||||
|
metadata={"thread_run_id": thread_run_id} # Keep track of the run
|
||||||
|
)
|
||||||
|
logger.info(f"Cost message saved for non-stream: {final_cost}")
|
||||||
|
else:
|
||||||
|
logger.info("Non-stream cost calculation resulted in zero or None, not storing cost message.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error calculating final cost for non-stream: {str(e)}")
|
||||||
|
|
||||||
# --- Execute Tools and Yield Results ---
|
# --- Execute Tools and Yield Results ---
|
||||||
tool_calls_to_execute = [item['tool_call'] for item in all_tool_data]
|
tool_calls_to_execute = [item['tool_call'] for item in all_tool_data]
|
||||||
if config.execute_tools and tool_calls_to_execute:
|
if config.execute_tools and tool_calls_to_execute:
|
||||||
|
|
Loading…
Reference in New Issue