mirror of https://github.com/kortix-ai/suna.git
cost calculation
This commit is contained in:
parent
c84ee59dc6
commit
6e4e2673d5
|
@ -97,18 +97,18 @@ class ResponseProcessor:
|
|||
self,
|
||||
llm_response: AsyncGenerator,
|
||||
thread_id: str,
|
||||
prompt_messages: List[Dict[str, Any]],
|
||||
llm_model: str,
|
||||
config: ProcessorConfig = ProcessorConfig(),
|
||||
prompt_messages: Optional[List[Dict[str, Any]]] = None,
|
||||
llm_model: Optional[str] = None
|
||||
) -> AsyncGenerator:
|
||||
"""Process a streaming LLM response, handling tool calls and execution.
|
||||
|
||||
Args:
|
||||
llm_response: Streaming response from the LLM
|
||||
thread_id: ID of the conversation thread
|
||||
prompt_messages: List of messages sent to the LLM (the prompt)
|
||||
llm_model: The name of the LLM model used
|
||||
config: Configuration for parsing and execution
|
||||
prompt_messages: List of messages used for cost calculation
|
||||
llm_model: Name of the LLM model used for cost calculation
|
||||
|
||||
Yields:
|
||||
Formatted chunks of the response including content and tool results
|
||||
|
@ -173,27 +173,18 @@ class ResponseProcessor:
|
|||
if hasattr(chunk, 'choices') and chunk.choices:
|
||||
delta = chunk.choices[0].delta if hasattr(chunk.choices[0], 'delta') else None
|
||||
|
||||
# Check for and log Anthropic thinking content
|
||||
if delta and hasattr(delta, 'reasoning_content') and delta.reasoning_content:
|
||||
logger.info(f"[THINKING]: {delta.reasoning_content}")
|
||||
# Append reasoning to main content to be saved in the final message
|
||||
accumulated_content += delta.reasoning_content
|
||||
|
||||
# Process content chunk
|
||||
if delta and hasattr(delta, 'content') and delta.content:
|
||||
chunk_content = delta.content
|
||||
accumulated_content += chunk_content
|
||||
current_xml_content += chunk_content
|
||||
|
||||
# Process reasoning content if present (Anthropic)
|
||||
if hasattr(delta, 'reasoning_content') and delta.reasoning_content:
|
||||
logger.info(f"[THINKING]: {delta.reasoning_content}")
|
||||
accumulated_content += delta.reasoning_content # Append reasoning to main content
|
||||
|
||||
# Calculate cost using prompt and completion - MOVED AFTER MESSAGE SAVE
|
||||
# try:
|
||||
# cost = completion_cost(model=chunk.model, prompt=accumulated_content, completion=chunk_content)
|
||||
# tcount = token_counter(model=chunk.model, messages=[{"role": "user", "content": accumulated_content}])
|
||||
# accumulated_cost += cost
|
||||
# accumulated_token_count += tcount
|
||||
# logger.debug(f"Cost: {cost:.6f}, Token count: {tcount}")
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error calculating cost: {str(e)}")
|
||||
|
||||
# Check if we've reached the XML tool call limit before yielding content
|
||||
if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls:
|
||||
# We've reached the limit, don't yield any more content
|
||||
|
@ -362,7 +353,7 @@ class ResponseProcessor:
|
|||
|
||||
# If we've reached the XML tool call limit, stop streaming
|
||||
if finish_reason == "xml_tool_limit_reached":
|
||||
logger.info("Stopping stream due to XML tool call limit")
|
||||
logger.info("Stopping stream processing after loop due to XML tool call limit")
|
||||
break
|
||||
|
||||
# After streaming completes or is stopped due to limit, wait for any remaining tool executions
|
||||
|
@ -483,6 +474,27 @@ class ResponseProcessor:
|
|||
is_llm_message=True
|
||||
)
|
||||
|
||||
# Calculate and store cost AFTER adding the main assistant message
|
||||
if accumulated_content: # Calculate cost if there was content (now includes reasoning)
|
||||
try:
|
||||
final_cost = completion_cost(
|
||||
model=llm_model, # Use the passed model name
|
||||
messages=prompt_messages, # Use the provided prompt messages
|
||||
completion=accumulated_content
|
||||
)
|
||||
if final_cost is not None and final_cost > 0:
|
||||
logger.info(f"Calculated final cost for stream: {final_cost}")
|
||||
await self.add_message(
|
||||
thread_id=thread_id,
|
||||
type="cost",
|
||||
content={"cost": final_cost},
|
||||
is_llm_message=False # Cost is metadata, not LLM content
|
||||
)
|
||||
else:
|
||||
logger.info("Cost calculation resulted in zero or None, not storing cost message.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating final cost for stream: {str(e)}")
|
||||
|
||||
# Yield the assistant response end signal *immediately* after saving
|
||||
if last_assistant_message_id:
|
||||
yield {
|
||||
|
@ -498,35 +510,6 @@ class ResponseProcessor:
|
|||
"thread_run_id": thread_run_id
|
||||
}
|
||||
|
||||
# --- Cost Calculation (moved here) ---
|
||||
if prompt_messages and llm_model and accumulated_content:
|
||||
try:
|
||||
cost = completion_cost(
|
||||
model=llm_model,
|
||||
messages=prompt_messages,
|
||||
completion=accumulated_content
|
||||
)
|
||||
token_count = token_counter(
|
||||
model=llm_model,
|
||||
messages=prompt_messages + [{"role": "assistant", "content": accumulated_content}]
|
||||
)
|
||||
await self.add_message(
|
||||
thread_id=thread_id,
|
||||
type="cost",
|
||||
content={
|
||||
"cost": cost,
|
||||
"prompt_tokens": token_count - token_counter(model=llm_model, messages=[{"role": "assistant", "content": accumulated_content}]), # Approx
|
||||
"completion_tokens": token_counter(model=llm_model, messages=[{"role": "assistant", "content": accumulated_content}]), # Approx
|
||||
"total_tokens": token_count,
|
||||
"model_name": llm_model
|
||||
},
|
||||
is_llm_message=False
|
||||
)
|
||||
logger.info(f"Calculated cost for streaming response: {cost:.6f} using model {llm_model}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating cost: {str(e)}")
|
||||
# --- End Cost Calculation ---
|
||||
|
||||
# --- Process All Tool Calls Now ---
|
||||
if config.execute_tools:
|
||||
final_tool_calls_to_process = []
|
||||
|
@ -664,24 +647,23 @@ class ResponseProcessor:
|
|||
yield {"type": "error", "message": str(e), "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None}
|
||||
|
||||
finally:
|
||||
# Yield the detected finish reason if one exists and wasn't suppressed
|
||||
if finish_reason and finish_reason != "xml_tool_limit_reached":
|
||||
# Yield a finish signal including the final assistant message ID
|
||||
if last_assistant_message_id:
|
||||
# Yield the overall run end signal
|
||||
yield {
|
||||
"type": "finish",
|
||||
"finish_reason": finish_reason,
|
||||
"type": "thread_run_end",
|
||||
"thread_run_id": thread_run_id
|
||||
}
|
||||
else:
|
||||
# Yield the overall run end signal
|
||||
yield {
|
||||
"type": "thread_run_end",
|
||||
"thread_run_id": thread_run_id if 'thread_run_id' in locals() else None
|
||||
}
|
||||
|
||||
# Yield a finish signal including the final assistant message ID
|
||||
# Ensure thread_run_id is defined, even if an early error occurred
|
||||
run_id = thread_run_id if 'thread_run_id' in locals() else str(uuid.uuid4()) # Fallback ID if needed
|
||||
yield {
|
||||
"type": "thread_run_end",
|
||||
"thread_run_id": run_id
|
||||
}
|
||||
|
||||
# Remove old cost calculation code
|
||||
|
||||
pass
|
||||
# track the cost and token count
|
||||
# todo: there is a bug as it adds every chunk to db because finally will run every time even in yield
|
||||
# await self.add_message(
|
||||
# thread_id=thread_id,
|
||||
# type="cost",
|
||||
|
@ -697,18 +679,18 @@ class ResponseProcessor:
|
|||
self,
|
||||
llm_response: Any,
|
||||
thread_id: str,
|
||||
config: ProcessorConfig = ProcessorConfig(),
|
||||
prompt_messages: Optional[List[Dict[str, Any]]] = None,
|
||||
llm_model: Optional[str] = None
|
||||
prompt_messages: List[Dict[str, Any]],
|
||||
llm_model: str,
|
||||
config: ProcessorConfig = ProcessorConfig()
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Process a non-streaming LLM response, handling tool calls and execution.
|
||||
|
||||
Args:
|
||||
llm_response: Response from the LLM
|
||||
thread_id: ID of the conversation thread
|
||||
prompt_messages: List of messages sent to the LLM (the prompt)
|
||||
llm_model: The name of the LLM model used
|
||||
config: Configuration for parsing and execution
|
||||
prompt_messages: List of messages used for cost calculation
|
||||
llm_model: Name of the LLM model used for cost calculation
|
||||
|
||||
Yields:
|
||||
Formatted response including content and tool results
|
||||
|
@ -814,6 +796,41 @@ class ResponseProcessor:
|
|||
is_llm_message=True
|
||||
)
|
||||
|
||||
# Calculate and store cost AFTER adding the main assistant message
|
||||
if content or (config.native_tool_calling and 'native_tool_calls_for_message' in locals() and native_tool_calls_for_message): # Calculate cost if there's content or tool calls
|
||||
try:
|
||||
# Use the full response object for potentially more accurate cost calculation
|
||||
# Pass model explicitly as it might not be reliably in response_object for all providers
|
||||
# First check if response_cost is directly available in _hidden_params
|
||||
final_cost = None
|
||||
if hasattr(llm_response, '_hidden_params') and 'response_cost' in llm_response._hidden_params and llm_response._hidden_params['response_cost'] != 0.0:
|
||||
final_cost = llm_response._hidden_params['response_cost']
|
||||
logger.info(f"Using response_cost from _hidden_params: {final_cost}")
|
||||
|
||||
if final_cost is None: # Fall back to calculating cost if direct cost not available or zero
|
||||
logger.info("Calculating cost using completion_cost function.")
|
||||
final_cost = completion_cost(
|
||||
completion_response=llm_response,
|
||||
model=llm_model, # Use the passed model name
|
||||
# prompt_messages might be needed for some models/providers
|
||||
# messages=prompt_messages, # Uncomment if needed
|
||||
call_type="completion" # Assuming 'completion' type for this context
|
||||
)
|
||||
|
||||
if final_cost is not None and final_cost > 0:
|
||||
logger.info(f"Calculated final cost for non-stream: {final_cost}")
|
||||
await self.add_message(
|
||||
thread_id=thread_id,
|
||||
type="cost",
|
||||
content={"cost": final_cost},
|
||||
is_llm_message=False # Cost is metadata
|
||||
)
|
||||
else:
|
||||
logger.info("Final cost is zero or None, not storing cost message.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating final cost for non-stream: {str(e)}")
|
||||
|
||||
# Yield content first
|
||||
yield {
|
||||
"type": "content",
|
||||
|
@ -904,50 +921,6 @@ class ResponseProcessor:
|
|||
"thread_run_id": thread_run_id
|
||||
}
|
||||
|
||||
# --- Cost Calculation (moved here) ---
|
||||
if prompt_messages and llm_model:
|
||||
cost = None
|
||||
# Attempt to get cost from LiteLLM response first
|
||||
if hasattr(llm_response, '_hidden_params') and 'response_cost' in llm_response._hidden_params:
|
||||
cost = llm_response._hidden_params['response_cost']
|
||||
logger.info(f"Using pre-calculated cost from LiteLLM: {cost:.6f}")
|
||||
|
||||
# If no pre-calculated cost, calculate manually
|
||||
if cost is None:
|
||||
try:
|
||||
cost = completion_cost(
|
||||
model=llm_model,
|
||||
messages=prompt_messages,
|
||||
completion=content # Use extracted content
|
||||
)
|
||||
logger.info(f"Manually calculated cost for non-streaming response: {cost:.6f} using model {llm_model}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating cost: {str(e)}")
|
||||
|
||||
# Add cost message if cost was determined
|
||||
if cost is not None:
|
||||
try:
|
||||
# Approximate token counts
|
||||
completion_tokens = token_counter(model=llm_model, messages=[{"role": "assistant", "content": content}])
|
||||
prompt_tokens = token_counter(model=llm_model, messages=prompt_messages)
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
await self.add_message(
|
||||
thread_id=thread_id,
|
||||
type="cost",
|
||||
content={
|
||||
"cost": cost,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"model_name": llm_model
|
||||
},
|
||||
is_llm_message=False
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving cost message: {str(e)}")
|
||||
# --- End Cost Calculation ---
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing response: {str(e)}", exc_info=True)
|
||||
yield {"type": "error", "message": str(e), "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None}
|
||||
|
|
Loading…
Reference in New Issue