Merge pull request #604 from kubet/feat/port-fix-and-panel-ajustments

Feat: remove cost save usage instead
This commit is contained in:
kubet 2025-06-02 22:03:43 +02:00 committed by GitHub
commit 45abd5a687
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 180 additions and 69 deletions

View File

@ -6,7 +6,6 @@ This module handles the processing of LLM responses, including:
- XML and native tool call detection and parsing
- Tool execution orchestration
- Message formatting and persistence
- Cost calculation and tracking
"""
import json
@ -20,13 +19,13 @@ from utils.logger import logger
from agentpress.tool import ToolResult
from agentpress.tool_registry import ToolRegistry
from agentpress.xml_tool_parser import XMLToolParser
from litellm import completion_cost
from langfuse.client import StatefulTraceClient
from services.langfuse import langfuse
from agentpress.utils.json_helpers import (
ensure_dict, ensure_list, safe_json_parse,
to_json_string, format_for_yield
)
from litellm import token_counter
# Type alias for XML result adding strategy
XmlAddingStrategy = Literal["user_message", "assistant_message", "inline_edit"]
@ -146,6 +145,21 @@ class ResponseProcessor:
tool_result_message_objects = {} # tool_index -> full saved message object
has_printed_thinking_prefix = False # Flag for printing thinking prefix only once
agent_should_terminate = False # Flag to track if a terminating tool has been executed
complete_native_tool_calls = [] # Initialize early for use in assistant_response_end
# Collect metadata for reconstructing LiteLLM response object
streaming_metadata = {
"model": llm_model,
"created": None,
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
},
"response_ms": None,
"first_chunk_time": None,
"last_chunk_time": None
}
logger.info(f"Streaming Config: XML={config.xml_tool_calling}, Native={config.native_tool_calling}, "
f"Execute on stream={config.execute_on_stream}, Strategy={config.tool_execution_strategy}")
@ -172,6 +186,26 @@ class ResponseProcessor:
__sequence = 0
async for chunk in llm_response:
# Extract streaming metadata from chunks
current_time = datetime.now(timezone.utc).timestamp()
if streaming_metadata["first_chunk_time"] is None:
streaming_metadata["first_chunk_time"] = current_time
streaming_metadata["last_chunk_time"] = current_time
# Extract metadata from chunk attributes
if hasattr(chunk, 'created') and chunk.created:
streaming_metadata["created"] = chunk.created
if hasattr(chunk, 'model') and chunk.model:
streaming_metadata["model"] = chunk.model
if hasattr(chunk, 'usage') and chunk.usage:
# Update usage information if available (including zero values)
if hasattr(chunk.usage, 'prompt_tokens') and chunk.usage.prompt_tokens is not None:
streaming_metadata["usage"]["prompt_tokens"] = chunk.usage.prompt_tokens
if hasattr(chunk.usage, 'completion_tokens') and chunk.usage.completion_tokens is not None:
streaming_metadata["usage"]["completion_tokens"] = chunk.usage.completion_tokens
if hasattr(chunk.usage, 'total_tokens') and chunk.usage.total_tokens is not None:
streaming_metadata["usage"]["total_tokens"] = chunk.usage.total_tokens
if hasattr(chunk, 'choices') and chunk.choices and hasattr(chunk.choices[0], 'finish_reason') and chunk.choices[0].finish_reason:
finish_reason = chunk.choices[0].finish_reason
logger.debug(f"Detected finish_reason: {finish_reason}")
@ -317,6 +351,33 @@ class ResponseProcessor:
# print() # Add a final newline after the streaming loop finishes
# --- After Streaming Loop ---
if (
streaming_metadata["usage"]["total_tokens"] == 0
):
logger.info("🔥 No usage data from provider, counting with litellm.token_counter")
# prompt side
prompt_tokens = token_counter(
model=llm_model,
messages=prompt_messages # chat or plain; token_counter handles both
)
# completion side
completion_tokens = token_counter(
model=llm_model,
text=accumulated_content or "" # empty string safe
)
streaming_metadata["usage"]["prompt_tokens"] = prompt_tokens
streaming_metadata["usage"]["completion_tokens"] = completion_tokens
streaming_metadata["usage"]["total_tokens"] = prompt_tokens + completion_tokens
logger.info(
f"🔥 Estimated tokens prompt: {prompt_tokens}, "
f"completion: {completion_tokens}, total: {prompt_tokens + completion_tokens}"
)
# Wait for pending tool executions from streaming phase
tool_results_buffer = [] # Stores (tool_call, result, tool_index, context)
@ -409,7 +470,7 @@ class ResponseProcessor:
accumulated_content = accumulated_content[:last_chunk_end_pos]
# ... (Extract complete_native_tool_calls logic) ...
complete_native_tool_calls = []
# Update complete_native_tool_calls from buffer (initialized earlier)
if config.native_tool_calling:
for idx, tc_buf in tool_calls_buffer.items():
if tc_buf['id'] and tc_buf['function']['name'] and tc_buf['function']['arguments']:
@ -575,34 +636,6 @@ class ResponseProcessor:
self.trace.event(name="failed_to_save_tool_result_for_index", level="ERROR", status_message=(f"Failed to save tool result for index {tool_idx}, not yielding result message."))
# Optionally yield error status for saving failure?
# --- Calculate and Store Cost ---
if last_assistant_message_object: # Only calculate if assistant message was saved
try:
# Use accumulated_content for streaming cost calculation
final_cost = completion_cost(
model=llm_model,
messages=prompt_messages, # Use the prompt messages provided
completion=accumulated_content
)
if final_cost is not None and final_cost > 0:
logger.info(f"Calculated final cost for stream: {final_cost}")
await self.add_message(
thread_id=thread_id,
type="cost",
content={"cost": final_cost},
is_llm_message=False, # Cost is metadata
metadata={"thread_run_id": thread_run_id} # Keep track of the run
)
logger.info(f"Cost message saved for stream: {final_cost}")
self.trace.update(metadata={"cost": final_cost})
else:
logger.info("Stream cost calculation resulted in zero or None, not storing cost message.")
self.trace.update(metadata={"cost": 0})
except Exception as e:
logger.error(f"Error calculating final cost for stream: {str(e)}")
self.trace.event(name="error_calculating_final_cost_for_stream", level="ERROR", status_message=(f"Error calculating final cost for stream: {str(e)}"))
# --- Final Finish Status ---
if finish_reason and finish_reason != "xml_tool_limit_reached":
finish_content = {"status_type": "finish", "finish_reason": finish_reason}
@ -628,9 +661,107 @@ class ResponseProcessor:
)
if finish_msg_obj: yield format_for_yield(finish_msg_obj)
# Save assistant_response_end BEFORE terminating
if last_assistant_message_object:
try:
# Calculate response time if we have timing data
if streaming_metadata["first_chunk_time"] and streaming_metadata["last_chunk_time"]:
streaming_metadata["response_ms"] = (streaming_metadata["last_chunk_time"] - streaming_metadata["first_chunk_time"]) * 1000
# Create a LiteLLM-like response object for streaming (before termination)
# Check if we have any actual usage data
has_usage_data = (
streaming_metadata["usage"]["prompt_tokens"] > 0 or
streaming_metadata["usage"]["completion_tokens"] > 0 or
streaming_metadata["usage"]["total_tokens"] > 0
)
assistant_end_content = {
"choices": [
{
"finish_reason": finish_reason or "stop",
"index": 0,
"message": {
"role": "assistant",
"content": accumulated_content,
"tool_calls": complete_native_tool_calls or None
}
}
],
"created": streaming_metadata.get("created"),
"model": streaming_metadata.get("model", llm_model),
"usage": streaming_metadata["usage"], # Always include usage like LiteLLM does
"streaming": True, # Add flag to indicate this was reconstructed from streaming
}
# Only include response_ms if we have timing data
if streaming_metadata.get("response_ms"):
assistant_end_content["response_ms"] = streaming_metadata["response_ms"]
await self.add_message(
thread_id=thread_id,
type="assistant_response_end",
content=assistant_end_content,
is_llm_message=False,
metadata={"thread_run_id": thread_run_id}
)
logger.info("Assistant response end saved for stream (before termination)")
except Exception as e:
logger.error(f"Error saving assistant response end for stream (before termination): {str(e)}")
self.trace.event(name="error_saving_assistant_response_end_for_stream_before_termination", level="ERROR", status_message=(f"Error saving assistant response end for stream (before termination): {str(e)}"))
# Skip all remaining processing and go to finally block
return
# --- Save and Yield assistant_response_end ---
if last_assistant_message_object: # Only save if assistant message was saved
try:
# Calculate response time if we have timing data
if streaming_metadata["first_chunk_time"] and streaming_metadata["last_chunk_time"]:
streaming_metadata["response_ms"] = (streaming_metadata["last_chunk_time"] - streaming_metadata["first_chunk_time"]) * 1000
# Create a LiteLLM-like response object for streaming
# Check if we have any actual usage data
has_usage_data = (
streaming_metadata["usage"]["prompt_tokens"] > 0 or
streaming_metadata["usage"]["completion_tokens"] > 0 or
streaming_metadata["usage"]["total_tokens"] > 0
)
assistant_end_content = {
"choices": [
{
"finish_reason": finish_reason or "stop",
"index": 0,
"message": {
"role": "assistant",
"content": accumulated_content,
"tool_calls": complete_native_tool_calls or None
}
}
],
"created": streaming_metadata.get("created"),
"model": streaming_metadata.get("model", llm_model),
"usage": streaming_metadata["usage"], # Always include usage like LiteLLM does
"streaming": True, # Add flag to indicate this was reconstructed from streaming
}
# Only include response_ms if we have timing data
if streaming_metadata.get("response_ms"):
assistant_end_content["response_ms"] = streaming_metadata["response_ms"]
await self.add_message(
thread_id=thread_id,
type="assistant_response_end",
content=assistant_end_content,
is_llm_message=False,
metadata={"thread_run_id": thread_run_id}
)
logger.info("Assistant response end saved for stream")
except Exception as e:
logger.error(f"Error saving assistant response end for stream: {str(e)}")
self.trace.event(name="error_saving_assistant_response_end_for_stream", level="ERROR", status_message=(f"Error saving assistant response end for stream: {str(e)}"))
except Exception as e:
logger.error(f"Error processing stream: {str(e)}", exc_info=True)
self.trace.event(name="error_processing_stream", level="ERROR", status_message=(f"Error processing stream: {str(e)}"))
@ -759,43 +890,7 @@ class ResponseProcessor:
)
if err_msg_obj: yield format_for_yield(err_msg_obj)
# --- Calculate and Store Cost ---
if assistant_message_object: # Only calculate if assistant message was saved
try:
# Use the full llm_response object for potentially more accurate cost calculation
final_cost = None
if hasattr(llm_response, '_hidden_params') and 'response_cost' in llm_response._hidden_params and llm_response._hidden_params['response_cost'] is not None and llm_response._hidden_params['response_cost'] != 0.0:
final_cost = llm_response._hidden_params['response_cost']
logger.info(f"Using response_cost from _hidden_params: {final_cost}")
if final_cost is None: # Fall back to calculating cost if direct cost not available or zero
logger.info("Calculating cost using completion_cost function.")
# Note: litellm might need 'messages' kwarg depending on model/provider
final_cost = completion_cost(
completion_response=llm_response,
model=llm_model, # Explicitly pass the model name
# messages=prompt_messages # Pass prompt messages if needed by litellm for this model
)
if final_cost is not None and final_cost > 0:
logger.info(f"Calculated final cost for non-stream: {final_cost}")
await self.add_message(
thread_id=thread_id,
type="cost",
content={"cost": final_cost},
is_llm_message=False, # Cost is metadata
metadata={"thread_run_id": thread_run_id} # Keep track of the run
)
logger.info(f"Cost message saved for non-stream: {final_cost}")
self.trace.update(metadata={"cost": final_cost})
else:
logger.info("Non-stream cost calculation resulted in zero or None, not storing cost message.")
self.trace.update(metadata={"cost": 0})
except Exception as e:
logger.error(f"Error calculating final cost for non-stream: {str(e)}")
self.trace.event(name="error_calculating_final_cost_for_non_stream", level="ERROR", status_message=(f"Error calculating final cost for non-stream: {str(e)}"))
# --- Execute Tools and Yield Results ---
# --- Execute Tools and Yield Results ---
tool_calls_to_execute = [item['tool_call'] for item in all_tool_data]
if config.execute_tools and tool_calls_to_execute:
logger.info(f"Executing {len(tool_calls_to_execute)} tools with strategy: {config.tool_execution_strategy}")
@ -850,6 +945,22 @@ class ResponseProcessor:
)
if finish_msg_obj: yield format_for_yield(finish_msg_obj)
# --- Save and Yield assistant_response_end ---
if assistant_message_object: # Only save if assistant message was saved
try:
# Save the full LiteLLM response object directly in content
await self.add_message(
thread_id=thread_id,
type="assistant_response_end",
content=llm_response,
is_llm_message=False,
metadata={"thread_run_id": thread_run_id}
)
logger.info("Assistant response end saved for non-stream")
except Exception as e:
logger.error(f"Error saving assistant response end for non-stream: {str(e)}")
self.trace.event(name="error_saving_assistant_response_end_for_non_stream", level="ERROR", status_message=(f"Error saving assistant response end for non-stream: {str(e)}"))
except Exception as e:
logger.error(f"Error processing non-streaming response: {str(e)}", exc_info=True)
self.trace.event(name="error_processing_non_streaming_response", level="ERROR", status_message=(f"Error processing non-streaming response: {str(e)}"))
@ -1619,7 +1730,7 @@ class ResponseProcessor:
# return summary
summary_output = result.output if hasattr(result, 'output') else str(result)
success_status = structured_result["tool_execution"]["result"]["success"]
success_status = structured_result_v1["tool_execution"]["result"]["success"]
# Create a more comprehensive summary for the LLM
if xml_tag_name: