mirror of https://github.com/kortix-ai/suna.git
Merge pull request #604 from kubet/feat/port-fix-and-panel-ajustments
Feat: remove cost save usage instead
This commit is contained in:
commit
45abd5a687
|
@ -6,7 +6,6 @@ This module handles the processing of LLM responses, including:
|
|||
- XML and native tool call detection and parsing
|
||||
- Tool execution orchestration
|
||||
- Message formatting and persistence
|
||||
- Cost calculation and tracking
|
||||
"""
|
||||
|
||||
import json
|
||||
|
@ -20,13 +19,13 @@ from utils.logger import logger
|
|||
from agentpress.tool import ToolResult
|
||||
from agentpress.tool_registry import ToolRegistry
|
||||
from agentpress.xml_tool_parser import XMLToolParser
|
||||
from litellm import completion_cost
|
||||
from langfuse.client import StatefulTraceClient
|
||||
from services.langfuse import langfuse
|
||||
from agentpress.utils.json_helpers import (
|
||||
ensure_dict, ensure_list, safe_json_parse,
|
||||
to_json_string, format_for_yield
|
||||
)
|
||||
from litellm import token_counter
|
||||
|
||||
# Type alias for XML result adding strategy
|
||||
XmlAddingStrategy = Literal["user_message", "assistant_message", "inline_edit"]
|
||||
|
@ -146,6 +145,21 @@ class ResponseProcessor:
|
|||
tool_result_message_objects = {} # tool_index -> full saved message object
|
||||
has_printed_thinking_prefix = False # Flag for printing thinking prefix only once
|
||||
agent_should_terminate = False # Flag to track if a terminating tool has been executed
|
||||
complete_native_tool_calls = [] # Initialize early for use in assistant_response_end
|
||||
|
||||
# Collect metadata for reconstructing LiteLLM response object
|
||||
streaming_metadata = {
|
||||
"model": llm_model,
|
||||
"created": None,
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
},
|
||||
"response_ms": None,
|
||||
"first_chunk_time": None,
|
||||
"last_chunk_time": None
|
||||
}
|
||||
|
||||
logger.info(f"Streaming Config: XML={config.xml_tool_calling}, Native={config.native_tool_calling}, "
|
||||
f"Execute on stream={config.execute_on_stream}, Strategy={config.tool_execution_strategy}")
|
||||
|
@ -172,6 +186,26 @@ class ResponseProcessor:
|
|||
__sequence = 0
|
||||
|
||||
async for chunk in llm_response:
|
||||
# Extract streaming metadata from chunks
|
||||
current_time = datetime.now(timezone.utc).timestamp()
|
||||
if streaming_metadata["first_chunk_time"] is None:
|
||||
streaming_metadata["first_chunk_time"] = current_time
|
||||
streaming_metadata["last_chunk_time"] = current_time
|
||||
|
||||
# Extract metadata from chunk attributes
|
||||
if hasattr(chunk, 'created') and chunk.created:
|
||||
streaming_metadata["created"] = chunk.created
|
||||
if hasattr(chunk, 'model') and chunk.model:
|
||||
streaming_metadata["model"] = chunk.model
|
||||
if hasattr(chunk, 'usage') and chunk.usage:
|
||||
# Update usage information if available (including zero values)
|
||||
if hasattr(chunk.usage, 'prompt_tokens') and chunk.usage.prompt_tokens is not None:
|
||||
streaming_metadata["usage"]["prompt_tokens"] = chunk.usage.prompt_tokens
|
||||
if hasattr(chunk.usage, 'completion_tokens') and chunk.usage.completion_tokens is not None:
|
||||
streaming_metadata["usage"]["completion_tokens"] = chunk.usage.completion_tokens
|
||||
if hasattr(chunk.usage, 'total_tokens') and chunk.usage.total_tokens is not None:
|
||||
streaming_metadata["usage"]["total_tokens"] = chunk.usage.total_tokens
|
||||
|
||||
if hasattr(chunk, 'choices') and chunk.choices and hasattr(chunk.choices[0], 'finish_reason') and chunk.choices[0].finish_reason:
|
||||
finish_reason = chunk.choices[0].finish_reason
|
||||
logger.debug(f"Detected finish_reason: {finish_reason}")
|
||||
|
@ -317,6 +351,33 @@ class ResponseProcessor:
|
|||
# print() # Add a final newline after the streaming loop finishes
|
||||
|
||||
# --- After Streaming Loop ---
|
||||
|
||||
if (
|
||||
streaming_metadata["usage"]["total_tokens"] == 0
|
||||
):
|
||||
logger.info("🔥 No usage data from provider, counting with litellm.token_counter")
|
||||
|
||||
# prompt side
|
||||
prompt_tokens = token_counter(
|
||||
model=llm_model,
|
||||
messages=prompt_messages # chat or plain; token_counter handles both
|
||||
)
|
||||
|
||||
# completion side
|
||||
completion_tokens = token_counter(
|
||||
model=llm_model,
|
||||
text=accumulated_content or "" # empty string safe
|
||||
)
|
||||
|
||||
streaming_metadata["usage"]["prompt_tokens"] = prompt_tokens
|
||||
streaming_metadata["usage"]["completion_tokens"] = completion_tokens
|
||||
streaming_metadata["usage"]["total_tokens"] = prompt_tokens + completion_tokens
|
||||
|
||||
logger.info(
|
||||
f"🔥 Estimated tokens – prompt: {prompt_tokens}, "
|
||||
f"completion: {completion_tokens}, total: {prompt_tokens + completion_tokens}"
|
||||
)
|
||||
|
||||
|
||||
# Wait for pending tool executions from streaming phase
|
||||
tool_results_buffer = [] # Stores (tool_call, result, tool_index, context)
|
||||
|
@ -409,7 +470,7 @@ class ResponseProcessor:
|
|||
accumulated_content = accumulated_content[:last_chunk_end_pos]
|
||||
|
||||
# ... (Extract complete_native_tool_calls logic) ...
|
||||
complete_native_tool_calls = []
|
||||
# Update complete_native_tool_calls from buffer (initialized earlier)
|
||||
if config.native_tool_calling:
|
||||
for idx, tc_buf in tool_calls_buffer.items():
|
||||
if tc_buf['id'] and tc_buf['function']['name'] and tc_buf['function']['arguments']:
|
||||
|
@ -575,34 +636,6 @@ class ResponseProcessor:
|
|||
self.trace.event(name="failed_to_save_tool_result_for_index", level="ERROR", status_message=(f"Failed to save tool result for index {tool_idx}, not yielding result message."))
|
||||
# Optionally yield error status for saving failure?
|
||||
|
||||
# --- Calculate and Store Cost ---
|
||||
if last_assistant_message_object: # Only calculate if assistant message was saved
|
||||
try:
|
||||
# Use accumulated_content for streaming cost calculation
|
||||
final_cost = completion_cost(
|
||||
model=llm_model,
|
||||
messages=prompt_messages, # Use the prompt messages provided
|
||||
completion=accumulated_content
|
||||
)
|
||||
if final_cost is not None and final_cost > 0:
|
||||
logger.info(f"Calculated final cost for stream: {final_cost}")
|
||||
await self.add_message(
|
||||
thread_id=thread_id,
|
||||
type="cost",
|
||||
content={"cost": final_cost},
|
||||
is_llm_message=False, # Cost is metadata
|
||||
metadata={"thread_run_id": thread_run_id} # Keep track of the run
|
||||
)
|
||||
logger.info(f"Cost message saved for stream: {final_cost}")
|
||||
self.trace.update(metadata={"cost": final_cost})
|
||||
else:
|
||||
logger.info("Stream cost calculation resulted in zero or None, not storing cost message.")
|
||||
self.trace.update(metadata={"cost": 0})
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating final cost for stream: {str(e)}")
|
||||
self.trace.event(name="error_calculating_final_cost_for_stream", level="ERROR", status_message=(f"Error calculating final cost for stream: {str(e)}"))
|
||||
|
||||
|
||||
# --- Final Finish Status ---
|
||||
if finish_reason and finish_reason != "xml_tool_limit_reached":
|
||||
finish_content = {"status_type": "finish", "finish_reason": finish_reason}
|
||||
|
@ -628,9 +661,107 @@ class ResponseProcessor:
|
|||
)
|
||||
if finish_msg_obj: yield format_for_yield(finish_msg_obj)
|
||||
|
||||
# Save assistant_response_end BEFORE terminating
|
||||
if last_assistant_message_object:
|
||||
try:
|
||||
# Calculate response time if we have timing data
|
||||
if streaming_metadata["first_chunk_time"] and streaming_metadata["last_chunk_time"]:
|
||||
streaming_metadata["response_ms"] = (streaming_metadata["last_chunk_time"] - streaming_metadata["first_chunk_time"]) * 1000
|
||||
|
||||
# Create a LiteLLM-like response object for streaming (before termination)
|
||||
# Check if we have any actual usage data
|
||||
has_usage_data = (
|
||||
streaming_metadata["usage"]["prompt_tokens"] > 0 or
|
||||
streaming_metadata["usage"]["completion_tokens"] > 0 or
|
||||
streaming_metadata["usage"]["total_tokens"] > 0
|
||||
)
|
||||
|
||||
assistant_end_content = {
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": finish_reason or "stop",
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": accumulated_content,
|
||||
"tool_calls": complete_native_tool_calls or None
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": streaming_metadata.get("created"),
|
||||
"model": streaming_metadata.get("model", llm_model),
|
||||
"usage": streaming_metadata["usage"], # Always include usage like LiteLLM does
|
||||
"streaming": True, # Add flag to indicate this was reconstructed from streaming
|
||||
}
|
||||
|
||||
# Only include response_ms if we have timing data
|
||||
if streaming_metadata.get("response_ms"):
|
||||
assistant_end_content["response_ms"] = streaming_metadata["response_ms"]
|
||||
|
||||
await self.add_message(
|
||||
thread_id=thread_id,
|
||||
type="assistant_response_end",
|
||||
content=assistant_end_content,
|
||||
is_llm_message=False,
|
||||
metadata={"thread_run_id": thread_run_id}
|
||||
)
|
||||
logger.info("Assistant response end saved for stream (before termination)")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving assistant response end for stream (before termination): {str(e)}")
|
||||
self.trace.event(name="error_saving_assistant_response_end_for_stream_before_termination", level="ERROR", status_message=(f"Error saving assistant response end for stream (before termination): {str(e)}"))
|
||||
|
||||
# Skip all remaining processing and go to finally block
|
||||
return
|
||||
|
||||
# --- Save and Yield assistant_response_end ---
|
||||
if last_assistant_message_object: # Only save if assistant message was saved
|
||||
try:
|
||||
# Calculate response time if we have timing data
|
||||
if streaming_metadata["first_chunk_time"] and streaming_metadata["last_chunk_time"]:
|
||||
streaming_metadata["response_ms"] = (streaming_metadata["last_chunk_time"] - streaming_metadata["first_chunk_time"]) * 1000
|
||||
|
||||
# Create a LiteLLM-like response object for streaming
|
||||
# Check if we have any actual usage data
|
||||
has_usage_data = (
|
||||
streaming_metadata["usage"]["prompt_tokens"] > 0 or
|
||||
streaming_metadata["usage"]["completion_tokens"] > 0 or
|
||||
streaming_metadata["usage"]["total_tokens"] > 0
|
||||
)
|
||||
|
||||
assistant_end_content = {
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": finish_reason or "stop",
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": accumulated_content,
|
||||
"tool_calls": complete_native_tool_calls or None
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": streaming_metadata.get("created"),
|
||||
"model": streaming_metadata.get("model", llm_model),
|
||||
"usage": streaming_metadata["usage"], # Always include usage like LiteLLM does
|
||||
"streaming": True, # Add flag to indicate this was reconstructed from streaming
|
||||
}
|
||||
|
||||
# Only include response_ms if we have timing data
|
||||
if streaming_metadata.get("response_ms"):
|
||||
assistant_end_content["response_ms"] = streaming_metadata["response_ms"]
|
||||
|
||||
await self.add_message(
|
||||
thread_id=thread_id,
|
||||
type="assistant_response_end",
|
||||
content=assistant_end_content,
|
||||
is_llm_message=False,
|
||||
metadata={"thread_run_id": thread_run_id}
|
||||
)
|
||||
logger.info("Assistant response end saved for stream")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving assistant response end for stream: {str(e)}")
|
||||
self.trace.event(name="error_saving_assistant_response_end_for_stream", level="ERROR", status_message=(f"Error saving assistant response end for stream: {str(e)}"))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing stream: {str(e)}", exc_info=True)
|
||||
self.trace.event(name="error_processing_stream", level="ERROR", status_message=(f"Error processing stream: {str(e)}"))
|
||||
|
@ -759,43 +890,7 @@ class ResponseProcessor:
|
|||
)
|
||||
if err_msg_obj: yield format_for_yield(err_msg_obj)
|
||||
|
||||
# --- Calculate and Store Cost ---
|
||||
if assistant_message_object: # Only calculate if assistant message was saved
|
||||
try:
|
||||
# Use the full llm_response object for potentially more accurate cost calculation
|
||||
final_cost = None
|
||||
if hasattr(llm_response, '_hidden_params') and 'response_cost' in llm_response._hidden_params and llm_response._hidden_params['response_cost'] is not None and llm_response._hidden_params['response_cost'] != 0.0:
|
||||
final_cost = llm_response._hidden_params['response_cost']
|
||||
logger.info(f"Using response_cost from _hidden_params: {final_cost}")
|
||||
|
||||
if final_cost is None: # Fall back to calculating cost if direct cost not available or zero
|
||||
logger.info("Calculating cost using completion_cost function.")
|
||||
# Note: litellm might need 'messages' kwarg depending on model/provider
|
||||
final_cost = completion_cost(
|
||||
completion_response=llm_response,
|
||||
model=llm_model, # Explicitly pass the model name
|
||||
# messages=prompt_messages # Pass prompt messages if needed by litellm for this model
|
||||
)
|
||||
|
||||
if final_cost is not None and final_cost > 0:
|
||||
logger.info(f"Calculated final cost for non-stream: {final_cost}")
|
||||
await self.add_message(
|
||||
thread_id=thread_id,
|
||||
type="cost",
|
||||
content={"cost": final_cost},
|
||||
is_llm_message=False, # Cost is metadata
|
||||
metadata={"thread_run_id": thread_run_id} # Keep track of the run
|
||||
)
|
||||
logger.info(f"Cost message saved for non-stream: {final_cost}")
|
||||
self.trace.update(metadata={"cost": final_cost})
|
||||
else:
|
||||
logger.info("Non-stream cost calculation resulted in zero or None, not storing cost message.")
|
||||
self.trace.update(metadata={"cost": 0})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating final cost for non-stream: {str(e)}")
|
||||
self.trace.event(name="error_calculating_final_cost_for_non_stream", level="ERROR", status_message=(f"Error calculating final cost for non-stream: {str(e)}"))
|
||||
# --- Execute Tools and Yield Results ---
|
||||
# --- Execute Tools and Yield Results ---
|
||||
tool_calls_to_execute = [item['tool_call'] for item in all_tool_data]
|
||||
if config.execute_tools and tool_calls_to_execute:
|
||||
logger.info(f"Executing {len(tool_calls_to_execute)} tools with strategy: {config.tool_execution_strategy}")
|
||||
|
@ -850,6 +945,22 @@ class ResponseProcessor:
|
|||
)
|
||||
if finish_msg_obj: yield format_for_yield(finish_msg_obj)
|
||||
|
||||
# --- Save and Yield assistant_response_end ---
|
||||
if assistant_message_object: # Only save if assistant message was saved
|
||||
try:
|
||||
# Save the full LiteLLM response object directly in content
|
||||
await self.add_message(
|
||||
thread_id=thread_id,
|
||||
type="assistant_response_end",
|
||||
content=llm_response,
|
||||
is_llm_message=False,
|
||||
metadata={"thread_run_id": thread_run_id}
|
||||
)
|
||||
logger.info("Assistant response end saved for non-stream")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving assistant response end for non-stream: {str(e)}")
|
||||
self.trace.event(name="error_saving_assistant_response_end_for_non_stream", level="ERROR", status_message=(f"Error saving assistant response end for non-stream: {str(e)}"))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing non-streaming response: {str(e)}", exc_info=True)
|
||||
self.trace.event(name="error_processing_non_streaming_response", level="ERROR", status_message=(f"Error processing non-streaming response: {str(e)}"))
|
||||
|
@ -1619,7 +1730,7 @@ class ResponseProcessor:
|
|||
# return summary
|
||||
|
||||
summary_output = result.output if hasattr(result, 'output') else str(result)
|
||||
success_status = structured_result["tool_execution"]["result"]["success"]
|
||||
success_status = structured_result_v1["tool_execution"]["result"]["success"]
|
||||
|
||||
# Create a more comprehensive summary for the LLM
|
||||
if xml_tag_name:
|
||||
|
|
Loading…
Reference in New Issue