track cost and tokens in message

This commit is contained in:
Adam Cohen Hillel 2025-04-14 01:31:23 +01:00
parent b72614cd67
commit 5a7a6ef7d3
2 changed files with 33 additions and 3 deletions

View File

@ -13,7 +13,9 @@ import asyncio
import re import re
import uuid import uuid
from typing import List, Dict, Any, Optional, Tuple, AsyncGenerator, Callable, Union, Literal from typing import List, Dict, Any, Optional, Tuple, AsyncGenerator, Callable, Union, Literal
from dataclasses import dataclass, field from dataclasses import dataclass
from litellm import completion_cost, token_counter
from agentpress.tool import Tool, ToolResult from agentpress.tool import Tool, ToolResult
from agentpress.tool_registry import ToolRegistry from agentpress.tool_registry import ToolRegistry
@ -137,11 +139,14 @@ class ResponseProcessor:
# if config.max_xml_tool_calls > 0: # if config.max_xml_tool_calls > 0:
# logger.info(f"XML tool call limit enabled: {config.max_xml_tool_calls}") # logger.info(f"XML tool call limit enabled: {config.max_xml_tool_calls}")
accumulated_cost = 0
accumulated_token_count = 0
try: try:
async for chunk in llm_response: async for chunk in llm_response:
# Default content to yield # Default content to yield
# Check for finish_reason # Check for finish_reason
if hasattr(chunk, 'choices') and chunk.choices and hasattr(chunk.choices[0], 'finish_reason') and chunk.choices[0].finish_reason: if hasattr(chunk, 'choices') and chunk.choices and hasattr(chunk.choices[0], 'finish_reason') and chunk.choices[0].finish_reason:
finish_reason = chunk.choices[0].finish_reason finish_reason = chunk.choices[0].finish_reason
@ -155,6 +160,16 @@ class ResponseProcessor:
chunk_content = delta.content chunk_content = delta.content
accumulated_content += chunk_content accumulated_content += chunk_content
current_xml_content += chunk_content current_xml_content += chunk_content
# Calculate cost using prompt and completion
try:
cost = completion_cost(model=chunk.model, prompt=accumulated_content, completion=chunk_content)
tcount = token_counter(model=chunk.model, messages=[{"role": "user", "content": accumulated_content}])
accumulated_cost += cost
accumulated_token_count += tcount
logger.debug(f"Cost: {cost:.6f}, Token count: {tcount}")
except Exception as e:
logger.error(f"Error calculating cost: {str(e)}")
# Check if we've reached the XML tool call limit before yielding content # Check if we've reached the XML tool call limit before yielding content
if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls: if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls:
@ -321,7 +336,7 @@ class ResponseProcessor:
if finish_reason == "xml_tool_limit_reached": if finish_reason == "xml_tool_limit_reached":
logger.info("Stopping stream due to XML tool call limit") logger.info("Stopping stream due to XML tool call limit")
break break
# After streaming completes or is stopped due to limit, wait for any remaining tool executions # After streaming completes or is stopped due to limit, wait for any remaining tool executions
if pending_tool_executions: if pending_tool_executions:
logger.info(f"Waiting for {len(pending_tool_executions)} pending tool executions to complete") logger.info(f"Waiting for {len(pending_tool_executions)} pending tool executions to complete")
@ -535,6 +550,19 @@ class ResponseProcessor:
except Exception as e: except Exception as e:
logger.error(f"Error processing stream: {str(e)}", exc_info=True) logger.error(f"Error processing stream: {str(e)}", exc_info=True)
yield {"type": "error", "message": str(e)} yield {"type": "error", "message": str(e)}
finally:
# track the cost and token count
await self.add_message(
thread_id=thread_id,
type="cost",
content={
"cost": accumulated_cost,
"token_count": accumulated_token_count
},
is_llm_message=False
)
async def process_non_streaming_response( async def process_non_streaming_response(
self, self,

View File

@ -95,6 +95,7 @@ export type Thread = {
export type Message = { export type Message = {
role: string; role: string;
content: string; content: string;
type: string;
} }
export type AgentRun = { export type AgentRun = {
@ -369,6 +370,7 @@ export const getMessages = async (threadId: string, hideToolMsgs: boolean = fals
.from('messages') .from('messages')
.select('*') .select('*')
.eq('thread_id', threadId) .eq('thread_id', threadId)
.neq('type', 'cost')
.order('created_at', { ascending: true }); .order('created_at', { ascending: true });
if (error) { if (error) {