diff --git a/backend/agentpress/response_processor.py b/backend/agentpress/response_processor.py index f3b85433..889ab893 100644 --- a/backend/agentpress/response_processor.py +++ b/backend/agentpress/response_processor.py @@ -13,7 +13,9 @@ import asyncio import re import uuid from typing import List, Dict, Any, Optional, Tuple, AsyncGenerator, Callable, Union, Literal -from dataclasses import dataclass, field +from dataclasses import dataclass + +from litellm import completion_cost, token_counter from agentpress.tool import Tool, ToolResult from agentpress.tool_registry import ToolRegistry @@ -137,11 +139,14 @@ class ResponseProcessor: # if config.max_xml_tool_calls > 0: # logger.info(f"XML tool call limit enabled: {config.max_xml_tool_calls}") + + accumulated_cost = 0 + accumulated_token_count = 0 try: async for chunk in llm_response: # Default content to yield - + # Check for finish_reason if hasattr(chunk, 'choices') and chunk.choices and hasattr(chunk.choices[0], 'finish_reason') and chunk.choices[0].finish_reason: finish_reason = chunk.choices[0].finish_reason @@ -155,6 +160,16 @@ class ResponseProcessor: chunk_content = delta.content accumulated_content += chunk_content current_xml_content += chunk_content + + # Calculate cost using prompt and completion + try: + cost = completion_cost(model=chunk.model, prompt=accumulated_content, completion=chunk_content) + tcount = token_counter(model=chunk.model, messages=[{"role": "user", "content": accumulated_content}]) + accumulated_cost += cost + accumulated_token_count += tcount + logger.debug(f"Cost: {cost:.6f}, Token count: {tcount}") + except Exception as e: + logger.error(f"Error calculating cost: {str(e)}") # Check if we've reached the XML tool call limit before yielding content if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls: @@ -321,7 +336,7 @@ class ResponseProcessor: if finish_reason == "xml_tool_limit_reached": logger.info("Stopping stream due to XML tool call limit") break - + # After streaming completes or is stopped due to limit, wait for any remaining tool executions if pending_tool_executions: logger.info(f"Waiting for {len(pending_tool_executions)} pending tool executions to complete") @@ -535,6 +550,19 @@ class ResponseProcessor: except Exception as e: logger.error(f"Error processing stream: {str(e)}", exc_info=True) yield {"type": "error", "message": str(e)} + + finally: + # track the cost and token count + await self.add_message( + thread_id=thread_id, + type="cost", + content={ + "cost": accumulated_cost, + "token_count": accumulated_token_count + }, + is_llm_message=False + ) + async def process_non_streaming_response( self, diff --git a/frontend/src/lib/api.ts b/frontend/src/lib/api.ts index 7f38762f..5b297a83 100644 --- a/frontend/src/lib/api.ts +++ b/frontend/src/lib/api.ts @@ -95,6 +95,7 @@ export type Thread = { export type Message = { role: string; content: string; + type: string; } export type AgentRun = { @@ -369,6 +370,7 @@ export const getMessages = async (threadId: string, hideToolMsgs: boolean = fals .from('messages') .select('*') .eq('thread_id', threadId) + .neq('type', 'cost') .order('created_at', { ascending: true }); if (error) {