fix calculation cost

This commit is contained in:
LE Quoc Dat 2025-04-17 22:51:12 +01:00
parent 50b71c064a
commit 36c8fce0f6
2 changed files with 118 additions and 43 deletions

View File

@ -95,6 +95,8 @@ class ResponseProcessor:
self,
llm_response: AsyncGenerator,
thread_id: str,
prompt_messages: List[Dict[str, Any]],
llm_model: str,
config: ProcessorConfig = ProcessorConfig(),
) -> AsyncGenerator:
"""Process a streaming LLM response, handling tool calls and execution.
@ -102,6 +104,8 @@ class ResponseProcessor:
Args:
llm_response: Streaming response from the LLM
thread_id: ID of the conversation thread
prompt_messages: List of messages sent to the LLM (the prompt)
llm_model: The name of the LLM model used
config: Configuration for parsing and execution
Yields:
@ -140,9 +144,6 @@ class ResponseProcessor:
# if config.max_xml_tool_calls > 0:
# logger.info(f"XML tool call limit enabled: {config.max_xml_tool_calls}")
accumulated_cost = 0
accumulated_token_count = 0
try:
async for chunk in llm_response:
# Default content to yield
@ -158,6 +159,7 @@ class ResponseProcessor:
# Check for and log Anthropic thinking content
if delta and hasattr(delta, 'reasoning_content') and delta.reasoning_content:
logger.info(f"[THINKING]: {delta.reasoning_content}")
accumulated_content += delta.reasoning_content # Append reasoning to main content
# Process content chunk
if delta and hasattr(delta, 'content') and delta.content:
@ -165,16 +167,6 @@ class ResponseProcessor:
accumulated_content += chunk_content
current_xml_content += chunk_content
# Calculate cost using prompt and completion
try:
cost = completion_cost(model=chunk.model, prompt=accumulated_content, completion=chunk_content)
tcount = token_counter(model=chunk.model, messages=[{"role": "user", "content": accumulated_content}])
accumulated_cost += cost
accumulated_token_count += tcount
logger.debug(f"Cost: {cost:.6f}, Token count: {tcount}")
except Exception as e:
logger.error(f"Error calculating cost: {str(e)}")
# Check if we've reached the XML tool call limit before yielding content
if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls:
# We've reached the limit, don't yield any more content
@ -338,7 +330,7 @@ class ResponseProcessor:
# If we've reached the XML tool call limit, stop streaming
if finish_reason == "xml_tool_limit_reached":
logger.info("Stopping stream due to XML tool call limit")
logger.info("Stopping stream processing after loop due to XML tool call limit")
break
# After streaming completes or is stopped due to limit, wait for any remaining tool executions
@ -470,6 +462,27 @@ class ResponseProcessor:
is_llm_message=True
)
# Calculate and store cost AFTER adding the main assistant message
if accumulated_content: # Calculate cost if there was content (now includes reasoning)
try:
final_cost = completion_cost(
model=llm_model, # Use the passed model name
messages=prompt_messages,
completion=accumulated_content
)
if final_cost is not None and final_cost > 0:
logger.info(f"Calculated final cost for stream: {final_cost}")
await self.add_message(
thread_id=thread_id,
type="cost",
content={"cost": final_cost},
is_llm_message=False # Cost is metadata, not LLM content
)
else:
logger.info("Cost calculation resulted in zero or None, not storing cost message.")
except Exception as e:
logger.error(f"Error calculating final cost for stream: {str(e)}")
# Now add all buffered tool results AFTER the assistant message, but don't yield if already yielded
for tool_call, result, result_tool_index in tool_results_buffer:
# Add result based on tool type to the conversation history
@ -563,24 +576,19 @@ class ResponseProcessor:
yield {"type": "error", "message": str(e)}
finally:
pass
# track the cost and token count
# todo: there is a bug as it adds every chunk to db because finally will run every time even in yield
# await self.add_message(
# thread_id=thread_id,
# type="cost",
# content={
# "cost": accumulated_cost,
# "token_count": accumulated_token_count
# },
# is_llm_message=False
# )
# Finally, yield the finish reason if it was detected
if finish_reason:
yield {
"type": "finish",
"finish_reason": finish_reason
}
async def process_non_streaming_response(
self,
llm_response: Any,
thread_id: str,
prompt_messages: List[Dict[str, Any]],
llm_model: str,
config: ProcessorConfig = ProcessorConfig(),
) -> AsyncGenerator:
"""Process a non-streaming LLM response, handling tool calls and execution.
@ -588,6 +596,8 @@ class ResponseProcessor:
Args:
llm_response: Response from the LLM
thread_id: ID of the conversation thread
prompt_messages: List of messages sent to the LLM (the prompt)
llm_model: The name of the LLM model used
config: Configuration for parsing and execution
Yields:
@ -688,6 +698,33 @@ class ResponseProcessor:
is_llm_message=True
)
# Calculate and store cost AFTER adding the main assistant message
if content or tool_calls: # Calculate cost if there's content or tool calls
try:
# Use the full response object for potentially more accurate cost calculation
# Pass model explicitly as it might not be reliably in response_object for all providers
# First check if response_cost is directly available in _hidden_params
if hasattr(llm_response, '_hidden_params') and 'response_cost' in llm_response._hidden_params and llm_response._hidden_params['response_cost'] != 0.0:
final_cost = llm_response._hidden_params['response_cost']
else:
# Fall back to calculating cost if direct cost not available
final_cost = completion_cost(
completion_response=llm_response,
model=llm_model,
call_type="completion" # Assuming 'completion' type for this context
)
if final_cost is not None and final_cost > 0:
logger.info(f"Calculated final cost for non-stream: {final_cost}")
await self.add_message(
thread_id=thread_id,
type="cost",
content={"cost": final_cost},
is_llm_message=False # Cost is metadata
)
except Exception as e:
logger.error(f"Error calculating final cost for non-stream: {str(e)}")
# Yield content first
yield {"type": "content", "content": content}

View File

@ -296,6 +296,8 @@ class ThreadManager:
response_generator = self.response_processor.process_streaming_response(
llm_response=llm_response,
thread_id=thread_id,
prompt_messages=prepared_messages,
llm_model=llm_model,
config=processor_config
)
@ -306,55 +308,91 @@ class ThreadManager:
response_generator = self.response_processor.process_non_streaming_response(
llm_response=llm_response,
thread_id=thread_id,
prompt_messages=prepared_messages,
llm_model=llm_model,
config=processor_config
)
return response_generator # Return the generator
except Exception as e:
logger.error(f"Error in run_thread: {str(e)}", exc_info=True)
return {
"status": "error",
"message": str(e)
}
logger.error(f"Error in _run_once: {str(e)}", exc_info=True)
# For generators, we need to yield an error structure if returning a generator is expected
async def error_generator():
yield {
"type": "error",
"message": f"Error during LLM call or setup: {str(e)}"
}
return error_generator()
# Define a wrapper generator that handles auto-continue logic
async def auto_continue_wrapper():
nonlocal auto_continue, auto_continue_count
nonlocal auto_continue, auto_continue_count, temporary_message
current_temp_message = temporary_message # Use a local copy for the first run
while auto_continue and (native_max_auto_continues == 0 or auto_continue_count < native_max_auto_continues):
# Reset auto_continue for this iteration
auto_continue = False
# Run the thread once
response_gen = await _run_once(temporary_message if auto_continue_count == 0 else None)
# Pass current_temp_message, which is only set for the first iteration
response_gen = await _run_once(temp_msg=current_temp_message)
# Handle error responses
if isinstance(response_gen, dict) and "status" in response_gen and response_gen["status"] == "error":
# Clear the temporary message after the first run
current_temp_message = None
# Handle error responses (checking if it's an error dict, which _run_once might return directly)
if isinstance(response_gen, dict) and response_gen.get("status") == "error":
yield response_gen
return
# Check if it's the error generator from _run_once exception handling
# Need a way to check if it's the specific error generator or just inspect the first item
first_chunk = None
try:
first_chunk = await anext(response_gen)
except StopAsyncIteration:
# Empty generator, possibly due to an issue before yielding.
logger.warning("Response generator was empty.")
break
except Exception as e:
logger.error(f"Error getting first chunk from generator: {e}")
yield {"type": "error", "message": f"Error processing response: {e}"}
break
if first_chunk and first_chunk.get('type') == 'error' and "Error during LLM call" in first_chunk.get('message', ''):
yield first_chunk
return # Stop processing if setup failed
# Yield the first chunk if it wasn't an error
if first_chunk:
yield first_chunk
# Process each chunk
# Process remaining chunks
async for chunk in response_gen:
# Check if this is a finish reason chunk with tool_calls or xml_tool_limit_reached
if chunk.get('type') == 'finish':
if chunk.get('finish_reason') == 'tool_calls':
finish_reason = chunk.get('finish_reason')
if finish_reason == 'tool_calls':
# Only auto-continue if enabled (max > 0)
if native_max_auto_continues > 0:
logger.info(f"Detected finish_reason='tool_calls', auto-continuing ({auto_continue_count + 1}/{native_max_auto_continues})")
auto_continue = True
auto_continue_count += 1
# Don't yield the finish chunk to avoid confusing the client
continue
elif chunk.get('finish_reason') == 'xml_tool_limit_reached':
# Don't yield the finish chunk to avoid confusing the client during auto-continue
continue
elif finish_reason == 'xml_tool_limit_reached':
# Don't auto-continue if XML tool limit was reached
logger.info(f"Detected finish_reason='xml_tool_limit_reached', stopping auto-continue")
auto_continue = False
# Still yield the chunk to inform the client
# Yield other finish reasons normally
# Otherwise just yield the chunk normally
# Yield the chunk normally
yield chunk
# If not auto-continuing, we're done
# If not auto-continuing, we're done with the loop
if not auto_continue:
break