track cost and tokens in message

This commit is contained in:
Adam Cohen Hillel 2025-04-14 01:31:23 +01:00
parent b72614cd67
commit 5a7a6ef7d3
2 changed files with 33 additions and 3 deletions

View File

@ -13,7 +13,9 @@ import asyncio
import re
import uuid
from typing import List, Dict, Any, Optional, Tuple, AsyncGenerator, Callable, Union, Literal
from dataclasses import dataclass, field
from dataclasses import dataclass
from litellm import completion_cost, token_counter
from agentpress.tool import Tool, ToolResult
from agentpress.tool_registry import ToolRegistry
@ -137,11 +139,14 @@ class ResponseProcessor:
# if config.max_xml_tool_calls > 0:
# logger.info(f"XML tool call limit enabled: {config.max_xml_tool_calls}")
accumulated_cost = 0
accumulated_token_count = 0
try:
async for chunk in llm_response:
# Default content to yield
# Check for finish_reason
if hasattr(chunk, 'choices') and chunk.choices and hasattr(chunk.choices[0], 'finish_reason') and chunk.choices[0].finish_reason:
finish_reason = chunk.choices[0].finish_reason
@ -155,6 +160,16 @@ class ResponseProcessor:
chunk_content = delta.content
accumulated_content += chunk_content
current_xml_content += chunk_content
# Calculate cost using prompt and completion
try:
cost = completion_cost(model=chunk.model, prompt=accumulated_content, completion=chunk_content)
tcount = token_counter(model=chunk.model, messages=[{"role": "user", "content": accumulated_content}])
accumulated_cost += cost
accumulated_token_count += tcount
logger.debug(f"Cost: {cost:.6f}, Token count: {tcount}")
except Exception as e:
logger.error(f"Error calculating cost: {str(e)}")
# Check if we've reached the XML tool call limit before yielding content
if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls:
@ -321,7 +336,7 @@ class ResponseProcessor:
if finish_reason == "xml_tool_limit_reached":
logger.info("Stopping stream due to XML tool call limit")
break
# After streaming completes or is stopped due to limit, wait for any remaining tool executions
if pending_tool_executions:
logger.info(f"Waiting for {len(pending_tool_executions)} pending tool executions to complete")
@ -535,6 +550,19 @@ class ResponseProcessor:
except Exception as e:
logger.error(f"Error processing stream: {str(e)}", exc_info=True)
yield {"type": "error", "message": str(e)}
finally:
# track the cost and token count
await self.add_message(
thread_id=thread_id,
type="cost",
content={
"cost": accumulated_cost,
"token_count": accumulated_token_count
},
is_llm_message=False
)
async def process_non_streaming_response(
self,

View File

@ -95,6 +95,7 @@ export type Thread = {
export type Message = {
role: string;
content: string;
type: string;
}
export type AgentRun = {
@ -369,6 +370,7 @@ export const getMessages = async (threadId: string, hideToolMsgs: boolean = fals
.from('messages')
.select('*')
.eq('thread_id', threadId)
.neq('type', 'cost')
.order('created_at', { ascending: true });
if (error) {