mirror of https://github.com/kortix-ai/suna.git
fix: remove cost add litellm response usage
This commit is contained in:
parent
d2564b49ed
commit
4cd6753e42
|
@ -6,7 +6,6 @@ This module handles the processing of LLM responses, including:
|
||||||
- XML and native tool call detection and parsing
|
- XML and native tool call detection and parsing
|
||||||
- Tool execution orchestration
|
- Tool execution orchestration
|
||||||
- Message formatting and persistence
|
- Message formatting and persistence
|
||||||
- Cost calculation and tracking
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
@ -20,13 +19,13 @@ from utils.logger import logger
|
||||||
from agentpress.tool import ToolResult
|
from agentpress.tool import ToolResult
|
||||||
from agentpress.tool_registry import ToolRegistry
|
from agentpress.tool_registry import ToolRegistry
|
||||||
from agentpress.xml_tool_parser import XMLToolParser
|
from agentpress.xml_tool_parser import XMLToolParser
|
||||||
from litellm import completion_cost
|
|
||||||
from langfuse.client import StatefulTraceClient
|
from langfuse.client import StatefulTraceClient
|
||||||
from services.langfuse import langfuse
|
from services.langfuse import langfuse
|
||||||
from agentpress.utils.json_helpers import (
|
from agentpress.utils.json_helpers import (
|
||||||
ensure_dict, ensure_list, safe_json_parse,
|
ensure_dict, ensure_list, safe_json_parse,
|
||||||
to_json_string, format_for_yield
|
to_json_string, format_for_yield
|
||||||
)
|
)
|
||||||
|
from litellm import token_counter
|
||||||
|
|
||||||
# Type alias for XML result adding strategy
|
# Type alias for XML result adding strategy
|
||||||
XmlAddingStrategy = Literal["user_message", "assistant_message", "inline_edit"]
|
XmlAddingStrategy = Literal["user_message", "assistant_message", "inline_edit"]
|
||||||
|
@ -146,6 +145,21 @@ class ResponseProcessor:
|
||||||
tool_result_message_objects = {} # tool_index -> full saved message object
|
tool_result_message_objects = {} # tool_index -> full saved message object
|
||||||
has_printed_thinking_prefix = False # Flag for printing thinking prefix only once
|
has_printed_thinking_prefix = False # Flag for printing thinking prefix only once
|
||||||
agent_should_terminate = False # Flag to track if a terminating tool has been executed
|
agent_should_terminate = False # Flag to track if a terminating tool has been executed
|
||||||
|
complete_native_tool_calls = [] # Initialize early for use in assistant_response_end
|
||||||
|
|
||||||
|
# Collect metadata for reconstructing LiteLLM response object
|
||||||
|
streaming_metadata = {
|
||||||
|
"model": llm_model,
|
||||||
|
"created": None,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": 0
|
||||||
|
},
|
||||||
|
"response_ms": None,
|
||||||
|
"first_chunk_time": None,
|
||||||
|
"last_chunk_time": None
|
||||||
|
}
|
||||||
|
|
||||||
logger.info(f"Streaming Config: XML={config.xml_tool_calling}, Native={config.native_tool_calling}, "
|
logger.info(f"Streaming Config: XML={config.xml_tool_calling}, Native={config.native_tool_calling}, "
|
||||||
f"Execute on stream={config.execute_on_stream}, Strategy={config.tool_execution_strategy}")
|
f"Execute on stream={config.execute_on_stream}, Strategy={config.tool_execution_strategy}")
|
||||||
|
@ -172,6 +186,29 @@ class ResponseProcessor:
|
||||||
__sequence = 0
|
__sequence = 0
|
||||||
|
|
||||||
async for chunk in llm_response:
|
async for chunk in llm_response:
|
||||||
|
# 🔥🔥🔥 Debug: Show entire raw chunk structure
|
||||||
|
# Extract streaming metadata from chunks
|
||||||
|
current_time = datetime.now(timezone.utc).timestamp()
|
||||||
|
if streaming_metadata["first_chunk_time"] is None:
|
||||||
|
streaming_metadata["first_chunk_time"] = current_time
|
||||||
|
streaming_metadata["last_chunk_time"] = current_time
|
||||||
|
|
||||||
|
# Extract metadata from chunk attributes
|
||||||
|
if hasattr(chunk, 'created') and chunk.created:
|
||||||
|
streaming_metadata["created"] = chunk.created
|
||||||
|
if hasattr(chunk, 'model') and chunk.model:
|
||||||
|
streaming_metadata["model"] = chunk.model
|
||||||
|
if hasattr(chunk, 'usage') and chunk.usage:
|
||||||
|
# 🔥🔥🔥 Debug: Show raw usage data from chunk
|
||||||
|
logger.info(f"🔥🔥🔥 Raw chunk usage data: {chunk.usage}")
|
||||||
|
# Update usage information if available (including zero values)
|
||||||
|
if hasattr(chunk.usage, 'prompt_tokens') and chunk.usage.prompt_tokens is not None:
|
||||||
|
streaming_metadata["usage"]["prompt_tokens"] = chunk.usage.prompt_tokens
|
||||||
|
if hasattr(chunk.usage, 'completion_tokens') and chunk.usage.completion_tokens is not None:
|
||||||
|
streaming_metadata["usage"]["completion_tokens"] = chunk.usage.completion_tokens
|
||||||
|
if hasattr(chunk.usage, 'total_tokens') and chunk.usage.total_tokens is not None:
|
||||||
|
streaming_metadata["usage"]["total_tokens"] = chunk.usage.total_tokens
|
||||||
|
|
||||||
if hasattr(chunk, 'choices') and chunk.choices and hasattr(chunk.choices[0], 'finish_reason') and chunk.choices[0].finish_reason:
|
if hasattr(chunk, 'choices') and chunk.choices and hasattr(chunk.choices[0], 'finish_reason') and chunk.choices[0].finish_reason:
|
||||||
finish_reason = chunk.choices[0].finish_reason
|
finish_reason = chunk.choices[0].finish_reason
|
||||||
logger.debug(f"Detected finish_reason: {finish_reason}")
|
logger.debug(f"Detected finish_reason: {finish_reason}")
|
||||||
|
@ -317,6 +354,34 @@ class ResponseProcessor:
|
||||||
# print() # Add a final newline after the streaming loop finishes
|
# print() # Add a final newline after the streaming loop finishes
|
||||||
|
|
||||||
# --- After Streaming Loop ---
|
# --- After Streaming Loop ---
|
||||||
|
|
||||||
|
# 🔥🔥🔥 Fallback: Estimate tokens if no usage data was captured from streaming
|
||||||
|
if (
|
||||||
|
streaming_metadata["usage"]["total_tokens"] == 0
|
||||||
|
):
|
||||||
|
logger.info("🔥 No usage data from provider, counting with litellm.token_counter")
|
||||||
|
|
||||||
|
# prompt side
|
||||||
|
prompt_tokens = token_counter(
|
||||||
|
model=llm_model,
|
||||||
|
messages=prompt_messages # chat or plain; token_counter handles both
|
||||||
|
)
|
||||||
|
|
||||||
|
# completion side
|
||||||
|
completion_tokens = token_counter(
|
||||||
|
model=llm_model,
|
||||||
|
text=accumulated_content or "" # empty string safe
|
||||||
|
)
|
||||||
|
|
||||||
|
streaming_metadata["usage"]["prompt_tokens"] = prompt_tokens
|
||||||
|
streaming_metadata["usage"]["completion_tokens"] = completion_tokens
|
||||||
|
streaming_metadata["usage"]["total_tokens"] = prompt_tokens + completion_tokens
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"🔥 Estimated tokens – prompt: {prompt_tokens}, "
|
||||||
|
f"completion: {completion_tokens}, total: {prompt_tokens + completion_tokens}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Wait for pending tool executions from streaming phase
|
# Wait for pending tool executions from streaming phase
|
||||||
tool_results_buffer = [] # Stores (tool_call, result, tool_index, context)
|
tool_results_buffer = [] # Stores (tool_call, result, tool_index, context)
|
||||||
|
@ -409,7 +474,7 @@ class ResponseProcessor:
|
||||||
accumulated_content = accumulated_content[:last_chunk_end_pos]
|
accumulated_content = accumulated_content[:last_chunk_end_pos]
|
||||||
|
|
||||||
# ... (Extract complete_native_tool_calls logic) ...
|
# ... (Extract complete_native_tool_calls logic) ...
|
||||||
complete_native_tool_calls = []
|
# Update complete_native_tool_calls from buffer (initialized earlier)
|
||||||
if config.native_tool_calling:
|
if config.native_tool_calling:
|
||||||
for idx, tc_buf in tool_calls_buffer.items():
|
for idx, tc_buf in tool_calls_buffer.items():
|
||||||
if tc_buf['id'] and tc_buf['function']['name'] and tc_buf['function']['arguments']:
|
if tc_buf['id'] and tc_buf['function']['name'] and tc_buf['function']['arguments']:
|
||||||
|
@ -575,34 +640,6 @@ class ResponseProcessor:
|
||||||
self.trace.event(name="failed_to_save_tool_result_for_index", level="ERROR", status_message=(f"Failed to save tool result for index {tool_idx}, not yielding result message."))
|
self.trace.event(name="failed_to_save_tool_result_for_index", level="ERROR", status_message=(f"Failed to save tool result for index {tool_idx}, not yielding result message."))
|
||||||
# Optionally yield error status for saving failure?
|
# Optionally yield error status for saving failure?
|
||||||
|
|
||||||
# --- Calculate and Store Cost ---
|
|
||||||
if last_assistant_message_object: # Only calculate if assistant message was saved
|
|
||||||
try:
|
|
||||||
# Use accumulated_content for streaming cost calculation
|
|
||||||
final_cost = completion_cost(
|
|
||||||
model=llm_model,
|
|
||||||
messages=prompt_messages, # Use the prompt messages provided
|
|
||||||
completion=accumulated_content
|
|
||||||
)
|
|
||||||
if final_cost is not None and final_cost > 0:
|
|
||||||
logger.info(f"Calculated final cost for stream: {final_cost}")
|
|
||||||
await self.add_message(
|
|
||||||
thread_id=thread_id,
|
|
||||||
type="cost",
|
|
||||||
content={"cost": final_cost},
|
|
||||||
is_llm_message=False, # Cost is metadata
|
|
||||||
metadata={"thread_run_id": thread_run_id} # Keep track of the run
|
|
||||||
)
|
|
||||||
logger.info(f"Cost message saved for stream: {final_cost}")
|
|
||||||
self.trace.update(metadata={"cost": final_cost})
|
|
||||||
else:
|
|
||||||
logger.info("Stream cost calculation resulted in zero or None, not storing cost message.")
|
|
||||||
self.trace.update(metadata={"cost": 0})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error calculating final cost for stream: {str(e)}")
|
|
||||||
self.trace.event(name="error_calculating_final_cost_for_stream", level="ERROR", status_message=(f"Error calculating final cost for stream: {str(e)}"))
|
|
||||||
|
|
||||||
|
|
||||||
# --- Final Finish Status ---
|
# --- Final Finish Status ---
|
||||||
if finish_reason and finish_reason != "xml_tool_limit_reached":
|
if finish_reason and finish_reason != "xml_tool_limit_reached":
|
||||||
finish_content = {"status_type": "finish", "finish_reason": finish_reason}
|
finish_content = {"status_type": "finish", "finish_reason": finish_reason}
|
||||||
|
@ -628,9 +665,107 @@ class ResponseProcessor:
|
||||||
)
|
)
|
||||||
if finish_msg_obj: yield format_for_yield(finish_msg_obj)
|
if finish_msg_obj: yield format_for_yield(finish_msg_obj)
|
||||||
|
|
||||||
|
# Save assistant_response_end BEFORE terminating
|
||||||
|
if last_assistant_message_object:
|
||||||
|
try:
|
||||||
|
# Calculate response time if we have timing data
|
||||||
|
if streaming_metadata["first_chunk_time"] and streaming_metadata["last_chunk_time"]:
|
||||||
|
streaming_metadata["response_ms"] = (streaming_metadata["last_chunk_time"] - streaming_metadata["first_chunk_time"]) * 1000
|
||||||
|
|
||||||
|
# Create a LiteLLM-like response object for streaming (before termination)
|
||||||
|
# Check if we have any actual usage data
|
||||||
|
has_usage_data = (
|
||||||
|
streaming_metadata["usage"]["prompt_tokens"] > 0 or
|
||||||
|
streaming_metadata["usage"]["completion_tokens"] > 0 or
|
||||||
|
streaming_metadata["usage"]["total_tokens"] > 0
|
||||||
|
)
|
||||||
|
|
||||||
|
assistant_end_content = {
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": finish_reason or "stop",
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": accumulated_content,
|
||||||
|
"tool_calls": complete_native_tool_calls or None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": streaming_metadata.get("created"),
|
||||||
|
"model": streaming_metadata.get("model", llm_model),
|
||||||
|
"usage": streaming_metadata["usage"], # Always include usage like LiteLLM does
|
||||||
|
"streaming": True, # Add flag to indicate this was reconstructed from streaming
|
||||||
|
}
|
||||||
|
|
||||||
|
# Only include response_ms if we have timing data
|
||||||
|
if streaming_metadata.get("response_ms"):
|
||||||
|
assistant_end_content["response_ms"] = streaming_metadata["response_ms"]
|
||||||
|
|
||||||
|
await self.add_message(
|
||||||
|
thread_id=thread_id,
|
||||||
|
type="assistant_response_end",
|
||||||
|
content=assistant_end_content,
|
||||||
|
is_llm_message=False,
|
||||||
|
metadata={"thread_run_id": thread_run_id}
|
||||||
|
)
|
||||||
|
logger.info("Assistant response end saved for stream (before termination)")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving assistant response end for stream (before termination): {str(e)}")
|
||||||
|
self.trace.event(name="error_saving_assistant_response_end_for_stream_before_termination", level="ERROR", status_message=(f"Error saving assistant response end for stream (before termination): {str(e)}"))
|
||||||
|
|
||||||
# Skip all remaining processing and go to finally block
|
# Skip all remaining processing and go to finally block
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# --- Save and Yield assistant_response_end ---
|
||||||
|
if last_assistant_message_object: # Only save if assistant message was saved
|
||||||
|
try:
|
||||||
|
# Calculate response time if we have timing data
|
||||||
|
if streaming_metadata["first_chunk_time"] and streaming_metadata["last_chunk_time"]:
|
||||||
|
streaming_metadata["response_ms"] = (streaming_metadata["last_chunk_time"] - streaming_metadata["first_chunk_time"]) * 1000
|
||||||
|
|
||||||
|
# Create a LiteLLM-like response object for streaming
|
||||||
|
# Check if we have any actual usage data
|
||||||
|
has_usage_data = (
|
||||||
|
streaming_metadata["usage"]["prompt_tokens"] > 0 or
|
||||||
|
streaming_metadata["usage"]["completion_tokens"] > 0 or
|
||||||
|
streaming_metadata["usage"]["total_tokens"] > 0
|
||||||
|
)
|
||||||
|
|
||||||
|
assistant_end_content = {
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": finish_reason or "stop",
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": accumulated_content,
|
||||||
|
"tool_calls": complete_native_tool_calls or None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": streaming_metadata.get("created"),
|
||||||
|
"model": streaming_metadata.get("model", llm_model),
|
||||||
|
"usage": streaming_metadata["usage"], # Always include usage like LiteLLM does
|
||||||
|
"streaming": True, # Add flag to indicate this was reconstructed from streaming
|
||||||
|
}
|
||||||
|
|
||||||
|
# Only include response_ms if we have timing data
|
||||||
|
if streaming_metadata.get("response_ms"):
|
||||||
|
assistant_end_content["response_ms"] = streaming_metadata["response_ms"]
|
||||||
|
|
||||||
|
await self.add_message(
|
||||||
|
thread_id=thread_id,
|
||||||
|
type="assistant_response_end",
|
||||||
|
content=assistant_end_content,
|
||||||
|
is_llm_message=False,
|
||||||
|
metadata={"thread_run_id": thread_run_id}
|
||||||
|
)
|
||||||
|
logger.info("Assistant response end saved for stream")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving assistant response end for stream: {str(e)}")
|
||||||
|
self.trace.event(name="error_saving_assistant_response_end_for_stream", level="ERROR", status_message=(f"Error saving assistant response end for stream: {str(e)}"))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing stream: {str(e)}", exc_info=True)
|
logger.error(f"Error processing stream: {str(e)}", exc_info=True)
|
||||||
self.trace.event(name="error_processing_stream", level="ERROR", status_message=(f"Error processing stream: {str(e)}"))
|
self.trace.event(name="error_processing_stream", level="ERROR", status_message=(f"Error processing stream: {str(e)}"))
|
||||||
|
@ -759,43 +894,7 @@ class ResponseProcessor:
|
||||||
)
|
)
|
||||||
if err_msg_obj: yield format_for_yield(err_msg_obj)
|
if err_msg_obj: yield format_for_yield(err_msg_obj)
|
||||||
|
|
||||||
# --- Calculate and Store Cost ---
|
# --- Execute Tools and Yield Results ---
|
||||||
if assistant_message_object: # Only calculate if assistant message was saved
|
|
||||||
try:
|
|
||||||
# Use the full llm_response object for potentially more accurate cost calculation
|
|
||||||
final_cost = None
|
|
||||||
if hasattr(llm_response, '_hidden_params') and 'response_cost' in llm_response._hidden_params and llm_response._hidden_params['response_cost'] is not None and llm_response._hidden_params['response_cost'] != 0.0:
|
|
||||||
final_cost = llm_response._hidden_params['response_cost']
|
|
||||||
logger.info(f"Using response_cost from _hidden_params: {final_cost}")
|
|
||||||
|
|
||||||
if final_cost is None: # Fall back to calculating cost if direct cost not available or zero
|
|
||||||
logger.info("Calculating cost using completion_cost function.")
|
|
||||||
# Note: litellm might need 'messages' kwarg depending on model/provider
|
|
||||||
final_cost = completion_cost(
|
|
||||||
completion_response=llm_response,
|
|
||||||
model=llm_model, # Explicitly pass the model name
|
|
||||||
# messages=prompt_messages # Pass prompt messages if needed by litellm for this model
|
|
||||||
)
|
|
||||||
|
|
||||||
if final_cost is not None and final_cost > 0:
|
|
||||||
logger.info(f"Calculated final cost for non-stream: {final_cost}")
|
|
||||||
await self.add_message(
|
|
||||||
thread_id=thread_id,
|
|
||||||
type="cost",
|
|
||||||
content={"cost": final_cost},
|
|
||||||
is_llm_message=False, # Cost is metadata
|
|
||||||
metadata={"thread_run_id": thread_run_id} # Keep track of the run
|
|
||||||
)
|
|
||||||
logger.info(f"Cost message saved for non-stream: {final_cost}")
|
|
||||||
self.trace.update(metadata={"cost": final_cost})
|
|
||||||
else:
|
|
||||||
logger.info("Non-stream cost calculation resulted in zero or None, not storing cost message.")
|
|
||||||
self.trace.update(metadata={"cost": 0})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error calculating final cost for non-stream: {str(e)}")
|
|
||||||
self.trace.event(name="error_calculating_final_cost_for_non_stream", level="ERROR", status_message=(f"Error calculating final cost for non-stream: {str(e)}"))
|
|
||||||
# --- Execute Tools and Yield Results ---
|
|
||||||
tool_calls_to_execute = [item['tool_call'] for item in all_tool_data]
|
tool_calls_to_execute = [item['tool_call'] for item in all_tool_data]
|
||||||
if config.execute_tools and tool_calls_to_execute:
|
if config.execute_tools and tool_calls_to_execute:
|
||||||
logger.info(f"Executing {len(tool_calls_to_execute)} tools with strategy: {config.tool_execution_strategy}")
|
logger.info(f"Executing {len(tool_calls_to_execute)} tools with strategy: {config.tool_execution_strategy}")
|
||||||
|
@ -850,6 +949,22 @@ class ResponseProcessor:
|
||||||
)
|
)
|
||||||
if finish_msg_obj: yield format_for_yield(finish_msg_obj)
|
if finish_msg_obj: yield format_for_yield(finish_msg_obj)
|
||||||
|
|
||||||
|
# --- Save and Yield assistant_response_end ---
|
||||||
|
if assistant_message_object: # Only save if assistant message was saved
|
||||||
|
try:
|
||||||
|
# Save the full LiteLLM response object directly in content
|
||||||
|
await self.add_message(
|
||||||
|
thread_id=thread_id,
|
||||||
|
type="assistant_response_end",
|
||||||
|
content=llm_response,
|
||||||
|
is_llm_message=False,
|
||||||
|
metadata={"thread_run_id": thread_run_id}
|
||||||
|
)
|
||||||
|
logger.info("Assistant response end saved for non-stream")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving assistant response end for non-stream: {str(e)}")
|
||||||
|
self.trace.event(name="error_saving_assistant_response_end_for_non_stream", level="ERROR", status_message=(f"Error saving assistant response end for non-stream: {str(e)}"))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing non-streaming response: {str(e)}", exc_info=True)
|
logger.error(f"Error processing non-streaming response: {str(e)}", exc_info=True)
|
||||||
self.trace.event(name="error_processing_non_streaming_response", level="ERROR", status_message=(f"Error processing non-streaming response: {str(e)}"))
|
self.trace.event(name="error_processing_non_streaming_response", level="ERROR", status_message=(f"Error processing non-streaming response: {str(e)}"))
|
||||||
|
@ -1618,7 +1733,7 @@ class ResponseProcessor:
|
||||||
# For backwards compatibility with LLM, also include a human-readable summary
|
# For backwards compatibility with LLM, also include a human-readable summary
|
||||||
# Use the original string output for the summary to avoid complex object representation
|
# Use the original string output for the summary to avoid complex object representation
|
||||||
summary_output = result.output if hasattr(result, 'output') else str(result)
|
summary_output = result.output if hasattr(result, 'output') else str(result)
|
||||||
success_status = structured_result["tool_execution"]["result"]["success"]
|
success_status = structured_result_v1["tool_execution"]["result"]["success"]
|
||||||
|
|
||||||
# Create a more comprehensive summary for the LLM
|
# Create a more comprehensive summary for the LLM
|
||||||
if xml_tag_name:
|
if xml_tag_name:
|
||||||
|
|
Loading…
Reference in New Issue