mirror of https://github.com/kortix-ai/suna.git
track cost and tokens in message
This commit is contained in:
parent
b72614cd67
commit
5a7a6ef7d3
|
@ -13,7 +13,9 @@ import asyncio
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Dict, Any, Optional, Tuple, AsyncGenerator, Callable, Union, Literal
|
from typing import List, Dict, Any, Optional, Tuple, AsyncGenerator, Callable, Union, Literal
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from litellm import completion_cost, token_counter
|
||||||
|
|
||||||
from agentpress.tool import Tool, ToolResult
|
from agentpress.tool import Tool, ToolResult
|
||||||
from agentpress.tool_registry import ToolRegistry
|
from agentpress.tool_registry import ToolRegistry
|
||||||
|
@ -137,11 +139,14 @@ class ResponseProcessor:
|
||||||
|
|
||||||
# if config.max_xml_tool_calls > 0:
|
# if config.max_xml_tool_calls > 0:
|
||||||
# logger.info(f"XML tool call limit enabled: {config.max_xml_tool_calls}")
|
# logger.info(f"XML tool call limit enabled: {config.max_xml_tool_calls}")
|
||||||
|
|
||||||
|
accumulated_cost = 0
|
||||||
|
accumulated_token_count = 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for chunk in llm_response:
|
async for chunk in llm_response:
|
||||||
# Default content to yield
|
# Default content to yield
|
||||||
|
|
||||||
# Check for finish_reason
|
# Check for finish_reason
|
||||||
if hasattr(chunk, 'choices') and chunk.choices and hasattr(chunk.choices[0], 'finish_reason') and chunk.choices[0].finish_reason:
|
if hasattr(chunk, 'choices') and chunk.choices and hasattr(chunk.choices[0], 'finish_reason') and chunk.choices[0].finish_reason:
|
||||||
finish_reason = chunk.choices[0].finish_reason
|
finish_reason = chunk.choices[0].finish_reason
|
||||||
|
@ -155,6 +160,16 @@ class ResponseProcessor:
|
||||||
chunk_content = delta.content
|
chunk_content = delta.content
|
||||||
accumulated_content += chunk_content
|
accumulated_content += chunk_content
|
||||||
current_xml_content += chunk_content
|
current_xml_content += chunk_content
|
||||||
|
|
||||||
|
# Calculate cost using prompt and completion
|
||||||
|
try:
|
||||||
|
cost = completion_cost(model=chunk.model, prompt=accumulated_content, completion=chunk_content)
|
||||||
|
tcount = token_counter(model=chunk.model, messages=[{"role": "user", "content": accumulated_content}])
|
||||||
|
accumulated_cost += cost
|
||||||
|
accumulated_token_count += tcount
|
||||||
|
logger.debug(f"Cost: {cost:.6f}, Token count: {tcount}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error calculating cost: {str(e)}")
|
||||||
|
|
||||||
# Check if we've reached the XML tool call limit before yielding content
|
# Check if we've reached the XML tool call limit before yielding content
|
||||||
if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls:
|
if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls:
|
||||||
|
@ -321,7 +336,7 @@ class ResponseProcessor:
|
||||||
if finish_reason == "xml_tool_limit_reached":
|
if finish_reason == "xml_tool_limit_reached":
|
||||||
logger.info("Stopping stream due to XML tool call limit")
|
logger.info("Stopping stream due to XML tool call limit")
|
||||||
break
|
break
|
||||||
|
|
||||||
# After streaming completes or is stopped due to limit, wait for any remaining tool executions
|
# After streaming completes or is stopped due to limit, wait for any remaining tool executions
|
||||||
if pending_tool_executions:
|
if pending_tool_executions:
|
||||||
logger.info(f"Waiting for {len(pending_tool_executions)} pending tool executions to complete")
|
logger.info(f"Waiting for {len(pending_tool_executions)} pending tool executions to complete")
|
||||||
|
@ -535,6 +550,19 @@ class ResponseProcessor:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing stream: {str(e)}", exc_info=True)
|
logger.error(f"Error processing stream: {str(e)}", exc_info=True)
|
||||||
yield {"type": "error", "message": str(e)}
|
yield {"type": "error", "message": str(e)}
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# track the cost and token count
|
||||||
|
await self.add_message(
|
||||||
|
thread_id=thread_id,
|
||||||
|
type="cost",
|
||||||
|
content={
|
||||||
|
"cost": accumulated_cost,
|
||||||
|
"token_count": accumulated_token_count
|
||||||
|
},
|
||||||
|
is_llm_message=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def process_non_streaming_response(
|
async def process_non_streaming_response(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -95,6 +95,7 @@ export type Thread = {
|
||||||
export type Message = {
|
export type Message = {
|
||||||
role: string;
|
role: string;
|
||||||
content: string;
|
content: string;
|
||||||
|
type: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export type AgentRun = {
|
export type AgentRun = {
|
||||||
|
@ -369,6 +370,7 @@ export const getMessages = async (threadId: string, hideToolMsgs: boolean = fals
|
||||||
.from('messages')
|
.from('messages')
|
||||||
.select('*')
|
.select('*')
|
||||||
.eq('thread_id', threadId)
|
.eq('thread_id', threadId)
|
||||||
|
.neq('type', 'cost')
|
||||||
.order('created_at', { ascending: true });
|
.order('created_at', { ascending: true });
|
||||||
|
|
||||||
if (error) {
|
if (error) {
|
||||||
|
|
Loading…
Reference in New Issue