mirror of https://github.com/kortix-ai/suna.git
track cost and tokens in message
This commit is contained in:
parent
b72614cd67
commit
5a7a6ef7d3
|
@ -13,7 +13,9 @@ import asyncio
|
|||
import re
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional, Tuple, AsyncGenerator, Callable, Union, Literal
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
|
||||
from litellm import completion_cost, token_counter
|
||||
|
||||
from agentpress.tool import Tool, ToolResult
|
||||
from agentpress.tool_registry import ToolRegistry
|
||||
|
@ -137,11 +139,14 @@ class ResponseProcessor:
|
|||
|
||||
# if config.max_xml_tool_calls > 0:
|
||||
# logger.info(f"XML tool call limit enabled: {config.max_xml_tool_calls}")
|
||||
|
||||
accumulated_cost = 0
|
||||
accumulated_token_count = 0
|
||||
|
||||
try:
|
||||
async for chunk in llm_response:
|
||||
# Default content to yield
|
||||
|
||||
|
||||
# Check for finish_reason
|
||||
if hasattr(chunk, 'choices') and chunk.choices and hasattr(chunk.choices[0], 'finish_reason') and chunk.choices[0].finish_reason:
|
||||
finish_reason = chunk.choices[0].finish_reason
|
||||
|
@ -155,6 +160,16 @@ class ResponseProcessor:
|
|||
chunk_content = delta.content
|
||||
accumulated_content += chunk_content
|
||||
current_xml_content += chunk_content
|
||||
|
||||
# Calculate cost using prompt and completion
|
||||
try:
|
||||
cost = completion_cost(model=chunk.model, prompt=accumulated_content, completion=chunk_content)
|
||||
tcount = token_counter(model=chunk.model, messages=[{"role": "user", "content": accumulated_content}])
|
||||
accumulated_cost += cost
|
||||
accumulated_token_count += tcount
|
||||
logger.debug(f"Cost: {cost:.6f}, Token count: {tcount}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating cost: {str(e)}")
|
||||
|
||||
# Check if we've reached the XML tool call limit before yielding content
|
||||
if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls:
|
||||
|
@ -321,7 +336,7 @@ class ResponseProcessor:
|
|||
if finish_reason == "xml_tool_limit_reached":
|
||||
logger.info("Stopping stream due to XML tool call limit")
|
||||
break
|
||||
|
||||
|
||||
# After streaming completes or is stopped due to limit, wait for any remaining tool executions
|
||||
if pending_tool_executions:
|
||||
logger.info(f"Waiting for {len(pending_tool_executions)} pending tool executions to complete")
|
||||
|
@ -535,6 +550,19 @@ class ResponseProcessor:
|
|||
except Exception as e:
|
||||
logger.error(f"Error processing stream: {str(e)}", exc_info=True)
|
||||
yield {"type": "error", "message": str(e)}
|
||||
|
||||
finally:
|
||||
# track the cost and token count
|
||||
await self.add_message(
|
||||
thread_id=thread_id,
|
||||
type="cost",
|
||||
content={
|
||||
"cost": accumulated_cost,
|
||||
"token_count": accumulated_token_count
|
||||
},
|
||||
is_llm_message=False
|
||||
)
|
||||
|
||||
|
||||
async def process_non_streaming_response(
|
||||
self,
|
||||
|
|
|
@ -95,6 +95,7 @@ export type Thread = {
|
|||
export type Message = {
|
||||
role: string;
|
||||
content: string;
|
||||
type: string;
|
||||
}
|
||||
|
||||
export type AgentRun = {
|
||||
|
@ -369,6 +370,7 @@ export const getMessages = async (threadId: string, hideToolMsgs: boolean = fals
|
|||
.from('messages')
|
||||
.select('*')
|
||||
.eq('thread_id', threadId)
|
||||
.neq('type', 'cost')
|
||||
.order('created_at', { ascending: true });
|
||||
|
||||
if (error) {
|
||||
|
|
Loading…
Reference in New Issue