mirror of https://github.com/kortix-ai/suna.git
fix calculation cost
This commit is contained in:
parent
50b71c064a
commit
36c8fce0f6
|
@ -95,6 +95,8 @@ class ResponseProcessor:
|
|||
self,
|
||||
llm_response: AsyncGenerator,
|
||||
thread_id: str,
|
||||
prompt_messages: List[Dict[str, Any]],
|
||||
llm_model: str,
|
||||
config: ProcessorConfig = ProcessorConfig(),
|
||||
) -> AsyncGenerator:
|
||||
"""Process a streaming LLM response, handling tool calls and execution.
|
||||
|
@ -102,6 +104,8 @@ class ResponseProcessor:
|
|||
Args:
|
||||
llm_response: Streaming response from the LLM
|
||||
thread_id: ID of the conversation thread
|
||||
prompt_messages: List of messages sent to the LLM (the prompt)
|
||||
llm_model: The name of the LLM model used
|
||||
config: Configuration for parsing and execution
|
||||
|
||||
Yields:
|
||||
|
@ -140,9 +144,6 @@ class ResponseProcessor:
|
|||
# if config.max_xml_tool_calls > 0:
|
||||
# logger.info(f"XML tool call limit enabled: {config.max_xml_tool_calls}")
|
||||
|
||||
accumulated_cost = 0
|
||||
accumulated_token_count = 0
|
||||
|
||||
try:
|
||||
async for chunk in llm_response:
|
||||
# Default content to yield
|
||||
|
@ -158,6 +159,7 @@ class ResponseProcessor:
|
|||
# Check for and log Anthropic thinking content
|
||||
if delta and hasattr(delta, 'reasoning_content') and delta.reasoning_content:
|
||||
logger.info(f"[THINKING]: {delta.reasoning_content}")
|
||||
accumulated_content += delta.reasoning_content # Append reasoning to main content
|
||||
|
||||
# Process content chunk
|
||||
if delta and hasattr(delta, 'content') and delta.content:
|
||||
|
@ -165,16 +167,6 @@ class ResponseProcessor:
|
|||
accumulated_content += chunk_content
|
||||
current_xml_content += chunk_content
|
||||
|
||||
# Calculate cost using prompt and completion
|
||||
try:
|
||||
cost = completion_cost(model=chunk.model, prompt=accumulated_content, completion=chunk_content)
|
||||
tcount = token_counter(model=chunk.model, messages=[{"role": "user", "content": accumulated_content}])
|
||||
accumulated_cost += cost
|
||||
accumulated_token_count += tcount
|
||||
logger.debug(f"Cost: {cost:.6f}, Token count: {tcount}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating cost: {str(e)}")
|
||||
|
||||
# Check if we've reached the XML tool call limit before yielding content
|
||||
if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls:
|
||||
# We've reached the limit, don't yield any more content
|
||||
|
@ -338,7 +330,7 @@ class ResponseProcessor:
|
|||
|
||||
# If we've reached the XML tool call limit, stop streaming
|
||||
if finish_reason == "xml_tool_limit_reached":
|
||||
logger.info("Stopping stream due to XML tool call limit")
|
||||
logger.info("Stopping stream processing after loop due to XML tool call limit")
|
||||
break
|
||||
|
||||
# After streaming completes or is stopped due to limit, wait for any remaining tool executions
|
||||
|
@ -470,6 +462,27 @@ class ResponseProcessor:
|
|||
is_llm_message=True
|
||||
)
|
||||
|
||||
# Calculate and store cost AFTER adding the main assistant message
|
||||
if accumulated_content: # Calculate cost if there was content (now includes reasoning)
|
||||
try:
|
||||
final_cost = completion_cost(
|
||||
model=llm_model, # Use the passed model name
|
||||
messages=prompt_messages,
|
||||
completion=accumulated_content
|
||||
)
|
||||
if final_cost is not None and final_cost > 0:
|
||||
logger.info(f"Calculated final cost for stream: {final_cost}")
|
||||
await self.add_message(
|
||||
thread_id=thread_id,
|
||||
type="cost",
|
||||
content={"cost": final_cost},
|
||||
is_llm_message=False # Cost is metadata, not LLM content
|
||||
)
|
||||
else:
|
||||
logger.info("Cost calculation resulted in zero or None, not storing cost message.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating final cost for stream: {str(e)}")
|
||||
|
||||
# Now add all buffered tool results AFTER the assistant message, but don't yield if already yielded
|
||||
for tool_call, result, result_tool_index in tool_results_buffer:
|
||||
# Add result based on tool type to the conversation history
|
||||
|
@ -563,24 +576,19 @@ class ResponseProcessor:
|
|||
yield {"type": "error", "message": str(e)}
|
||||
|
||||
finally:
|
||||
pass
|
||||
# track the cost and token count
|
||||
# todo: there is a bug as it adds every chunk to db because finally will run every time even in yield
|
||||
# await self.add_message(
|
||||
# thread_id=thread_id,
|
||||
# type="cost",
|
||||
# content={
|
||||
# "cost": accumulated_cost,
|
||||
# "token_count": accumulated_token_count
|
||||
# },
|
||||
# is_llm_message=False
|
||||
# )
|
||||
|
||||
# Finally, yield the finish reason if it was detected
|
||||
if finish_reason:
|
||||
yield {
|
||||
"type": "finish",
|
||||
"finish_reason": finish_reason
|
||||
}
|
||||
|
||||
async def process_non_streaming_response(
|
||||
self,
|
||||
llm_response: Any,
|
||||
thread_id: str,
|
||||
prompt_messages: List[Dict[str, Any]],
|
||||
llm_model: str,
|
||||
config: ProcessorConfig = ProcessorConfig(),
|
||||
) -> AsyncGenerator:
|
||||
"""Process a non-streaming LLM response, handling tool calls and execution.
|
||||
|
@ -588,6 +596,8 @@ class ResponseProcessor:
|
|||
Args:
|
||||
llm_response: Response from the LLM
|
||||
thread_id: ID of the conversation thread
|
||||
prompt_messages: List of messages sent to the LLM (the prompt)
|
||||
llm_model: The name of the LLM model used
|
||||
config: Configuration for parsing and execution
|
||||
|
||||
Yields:
|
||||
|
@ -688,6 +698,33 @@ class ResponseProcessor:
|
|||
is_llm_message=True
|
||||
)
|
||||
|
||||
# Calculate and store cost AFTER adding the main assistant message
|
||||
if content or tool_calls: # Calculate cost if there's content or tool calls
|
||||
try:
|
||||
# Use the full response object for potentially more accurate cost calculation
|
||||
# Pass model explicitly as it might not be reliably in response_object for all providers
|
||||
# First check if response_cost is directly available in _hidden_params
|
||||
if hasattr(llm_response, '_hidden_params') and 'response_cost' in llm_response._hidden_params and llm_response._hidden_params['response_cost'] != 0.0:
|
||||
final_cost = llm_response._hidden_params['response_cost']
|
||||
else:
|
||||
# Fall back to calculating cost if direct cost not available
|
||||
final_cost = completion_cost(
|
||||
completion_response=llm_response,
|
||||
model=llm_model,
|
||||
call_type="completion" # Assuming 'completion' type for this context
|
||||
)
|
||||
|
||||
if final_cost is not None and final_cost > 0:
|
||||
logger.info(f"Calculated final cost for non-stream: {final_cost}")
|
||||
await self.add_message(
|
||||
thread_id=thread_id,
|
||||
type="cost",
|
||||
content={"cost": final_cost},
|
||||
is_llm_message=False # Cost is metadata
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating final cost for non-stream: {str(e)}")
|
||||
|
||||
# Yield content first
|
||||
yield {"type": "content", "content": content}
|
||||
|
||||
|
|
|
@ -296,6 +296,8 @@ class ThreadManager:
|
|||
response_generator = self.response_processor.process_streaming_response(
|
||||
llm_response=llm_response,
|
||||
thread_id=thread_id,
|
||||
prompt_messages=prepared_messages,
|
||||
llm_model=llm_model,
|
||||
config=processor_config
|
||||
)
|
||||
|
||||
|
@ -306,55 +308,91 @@ class ThreadManager:
|
|||
response_generator = self.response_processor.process_non_streaming_response(
|
||||
llm_response=llm_response,
|
||||
thread_id=thread_id,
|
||||
prompt_messages=prepared_messages,
|
||||
llm_model=llm_model,
|
||||
config=processor_config
|
||||
)
|
||||
return response_generator # Return the generator
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in run_thread: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e)
|
||||
}
|
||||
logger.error(f"Error in _run_once: {str(e)}", exc_info=True)
|
||||
# For generators, we need to yield an error structure if returning a generator is expected
|
||||
async def error_generator():
|
||||
yield {
|
||||
"type": "error",
|
||||
"message": f"Error during LLM call or setup: {str(e)}"
|
||||
}
|
||||
return error_generator()
|
||||
|
||||
# Define a wrapper generator that handles auto-continue logic
|
||||
async def auto_continue_wrapper():
|
||||
nonlocal auto_continue, auto_continue_count
|
||||
nonlocal auto_continue, auto_continue_count, temporary_message
|
||||
|
||||
current_temp_message = temporary_message # Use a local copy for the first run
|
||||
|
||||
while auto_continue and (native_max_auto_continues == 0 or auto_continue_count < native_max_auto_continues):
|
||||
# Reset auto_continue for this iteration
|
||||
auto_continue = False
|
||||
|
||||
# Run the thread once
|
||||
response_gen = await _run_once(temporary_message if auto_continue_count == 0 else None)
|
||||
# Pass current_temp_message, which is only set for the first iteration
|
||||
response_gen = await _run_once(temp_msg=current_temp_message)
|
||||
|
||||
# Handle error responses
|
||||
if isinstance(response_gen, dict) and "status" in response_gen and response_gen["status"] == "error":
|
||||
# Clear the temporary message after the first run
|
||||
current_temp_message = None
|
||||
|
||||
# Handle error responses (checking if it's an error dict, which _run_once might return directly)
|
||||
if isinstance(response_gen, dict) and response_gen.get("status") == "error":
|
||||
yield response_gen
|
||||
return
|
||||
|
||||
# Check if it's the error generator from _run_once exception handling
|
||||
# Need a way to check if it's the specific error generator or just inspect the first item
|
||||
first_chunk = None
|
||||
try:
|
||||
first_chunk = await anext(response_gen)
|
||||
except StopAsyncIteration:
|
||||
# Empty generator, possibly due to an issue before yielding.
|
||||
logger.warning("Response generator was empty.")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting first chunk from generator: {e}")
|
||||
yield {"type": "error", "message": f"Error processing response: {e}"}
|
||||
break
|
||||
|
||||
if first_chunk and first_chunk.get('type') == 'error' and "Error during LLM call" in first_chunk.get('message', ''):
|
||||
yield first_chunk
|
||||
return # Stop processing if setup failed
|
||||
|
||||
# Yield the first chunk if it wasn't an error
|
||||
if first_chunk:
|
||||
yield first_chunk
|
||||
|
||||
# Process each chunk
|
||||
# Process remaining chunks
|
||||
async for chunk in response_gen:
|
||||
# Check if this is a finish reason chunk with tool_calls or xml_tool_limit_reached
|
||||
if chunk.get('type') == 'finish':
|
||||
if chunk.get('finish_reason') == 'tool_calls':
|
||||
finish_reason = chunk.get('finish_reason')
|
||||
if finish_reason == 'tool_calls':
|
||||
# Only auto-continue if enabled (max > 0)
|
||||
if native_max_auto_continues > 0:
|
||||
logger.info(f"Detected finish_reason='tool_calls', auto-continuing ({auto_continue_count + 1}/{native_max_auto_continues})")
|
||||
auto_continue = True
|
||||
auto_continue_count += 1
|
||||
# Don't yield the finish chunk to avoid confusing the client
|
||||
continue
|
||||
elif chunk.get('finish_reason') == 'xml_tool_limit_reached':
|
||||
# Don't yield the finish chunk to avoid confusing the client during auto-continue
|
||||
continue
|
||||
elif finish_reason == 'xml_tool_limit_reached':
|
||||
# Don't auto-continue if XML tool limit was reached
|
||||
logger.info(f"Detected finish_reason='xml_tool_limit_reached', stopping auto-continue")
|
||||
auto_continue = False
|
||||
# Still yield the chunk to inform the client
|
||||
|
||||
# Yield other finish reasons normally
|
||||
|
||||
# Otherwise just yield the chunk normally
|
||||
# Yield the chunk normally
|
||||
yield chunk
|
||||
|
||||
# If not auto-continuing, we're done
|
||||
# If not auto-continuing, we're done with the loop
|
||||
if not auto_continue:
|
||||
break
|
||||
|
||||
|
|
Loading…
Reference in New Issue