revert accidental changes

This commit is contained in:
marko-kraemer 2025-08-19 16:36:28 -07:00
parent 2d88091258
commit 60059eaa9b
2 changed files with 28 additions and 67 deletions

View File

@ -22,7 +22,7 @@ from agent.prompt import get_system_prompt
from utils.logger import logger
from utils.auth_utils import get_account_id_from_thread
from services.billing import check_billing_status, calculate_token_cost, handle_usage_with_credits
from services.billing import check_billing_status
from agent.tools.sb_vision_tool import SandboxVisionTool
from agent.tools.sb_image_edit_tool import SandboxImageEditTool
from agent.tools.sb_presentation_outline_tool import SandboxPresentationOutlineTool
@ -438,53 +438,6 @@ class AgentRunner:
def __init__(self, config: AgentConfig):
self.config = config
async def handle_assistant_response_end(self, thread_id: str, message_id: str, content: dict, client):
"""Handle billing/credit deduction for assistant response end messages.
This ensures 100% traceability by linking deductions to specific messages.
"""
try:
if not isinstance(content, dict):
logger.debug(f"Skipping billing for message {message_id}: content is not a dict")
return
usage = content.get("usage", {})
if not usage:
logger.debug(f"Skipping billing for message {message_id}: no usage data")
return
prompt_tokens = int(usage.get("prompt_tokens", 0) or 0)
completion_tokens = int(usage.get("completion_tokens", 0) or 0)
model = content.get("model")
if prompt_tokens == 0 and completion_tokens == 0:
logger.debug(f"Skipping billing for message {message_id}: zero tokens")
return
# Calculate token cost
token_cost = calculate_token_cost(prompt_tokens, completion_tokens, model or "unknown")
# Get account_id for this thread (equals user_id for personal accounts)
thread_row = await client.table('threads').select('account_id').eq('thread_id', thread_id).limit(1).execute()
user_id = thread_row.data[0]['account_id'] if thread_row.data and len(thread_row.data) > 0 else None
if user_id and token_cost > 0:
# Deduct credits if applicable and record usage against this specific message
await handle_usage_with_credits(
client,
user_id,
token_cost,
thread_id=thread_id,
message_id=message_id,
model=model or "unknown"
)
logger.debug(f"💰 Processed billing for message {message_id}: {token_cost} credits (tokens: {prompt_tokens}+{completion_tokens}={prompt_tokens+completion_tokens})")
else:
logger.debug(f"Skipping billing for message {message_id}: user_id={user_id}, token_cost={token_cost}")
except Exception as e:
logger.error(f"Error handling billing for message {message_id}: {str(e)}", exc_info=True)
async def setup(self):
if not self.config.trace:
self.config.trace = langfuse.trace(name="run_agent", session_id=self.config.thread_id, metadata={"project_id": self.config.project_id})
@ -493,8 +446,7 @@ class AgentRunner:
trace=self.config.trace,
is_agent_builder=self.config.is_agent_builder or False,
target_agent_id=self.config.target_agent_id,
agent_config=self.config.agent_config,
on_assistant_response_end=self.handle_assistant_response_end
agent_config=self.config.agent_config
)
self.client = await self.thread_manager.db.client

View File

@ -25,7 +25,7 @@ from utils.logger import logger
from langfuse.client import StatefulGenerationClient, StatefulTraceClient
from services.langfuse import langfuse
from litellm.utils import token_counter
from services.billing import calculate_token_cost, handle_usage_with_credits
import re
from datetime import datetime, timezone, timedelta
import aiofiles
@ -42,7 +42,7 @@ class ThreadManager:
XML-based tool execution patterns.
"""
def __init__(self, trace: Optional[StatefulTraceClient] = None, is_agent_builder: bool = False, target_agent_id: Optional[str] = None, agent_config: Optional[dict] = None, on_assistant_response_end: Optional[callable] = None):
def __init__(self, trace: Optional[StatefulTraceClient] = None, is_agent_builder: bool = False, target_agent_id: Optional[str] = None, agent_config: Optional[dict] = None):
"""Initialize ThreadManager.
Args:
@ -50,7 +50,6 @@ class ThreadManager:
is_agent_builder: Whether this is an agent builder session
target_agent_id: ID of the agent being built (if in agent builder mode)
agent_config: Optional agent configuration with version information
on_assistant_response_end: Optional callback for handling assistant response end events (e.g., billing)
"""
self.db = DBConnection()
self.tool_registry = ToolRegistry()
@ -58,7 +57,6 @@ class ThreadManager:
self.is_agent_builder = is_agent_builder
self.target_agent_id = target_agent_id
self.agent_config = agent_config
self.on_assistant_response_end = on_assistant_response_end
if not self.trace:
self.trace = langfuse.trace(name="anonymous:thread_manager")
self.response_processor = ResponseProcessor(
@ -176,19 +174,30 @@ class ThreadManager:
if result.data and len(result.data) > 0 and isinstance(result.data[0], dict) and 'message_id' in result.data[0]:
saved_message = result.data[0]
# Trigger callback for assistant_response_end if callback is provided
if type == "assistant_response_end" and self.on_assistant_response_end:
# If this is an assistant_response_end, attempt to deduct credits if over limit
if type == "assistant_response_end" and isinstance(content, dict):
try:
await self.on_assistant_response_end(
thread_id=thread_id,
message_id=saved_message['message_id'],
content=content,
client=client
)
except Exception as callback_e:
logger.error(f"Error in assistant_response_end callback: {str(callback_e)}", exc_info=True)
usage = content.get("usage", {}) if isinstance(content, dict) else {}
prompt_tokens = int(usage.get("prompt_tokens", 0) or 0)
completion_tokens = int(usage.get("completion_tokens", 0) or 0)
model = content.get("model") if isinstance(content, dict) else None
# Compute token cost
token_cost = calculate_token_cost(prompt_tokens, completion_tokens, model or "unknown")
# Fetch account_id for this thread, which equals user_id for personal accounts
thread_row = await client.table('threads').select('account_id').eq('thread_id', thread_id).limit(1).execute()
user_id = thread_row.data[0]['account_id'] if thread_row.data and len(thread_row.data) > 0 else None
if user_id and token_cost > 0:
# Deduct credits if applicable and record usage against this message
await handle_usage_with_credits(
client,
user_id,
token_cost,
thread_id=thread_id,
message_id=saved_message['message_id'],
model=model or "unknown"
)
except Exception as billing_e:
logger.error(f"Error handling credit usage for message {saved_message.get('message_id')}: {str(billing_e)}", exc_info=True)
return saved_message
else:
logger.error(f"Insert operation failed or did not return expected data structure for thread {thread_id}. Result data: {result.data}")
@ -650,4 +659,4 @@ When using the tools:
return await _run_once(temporary_message)
# Otherwise return the auto-continue wrapper generator
return auto_continue_wrapper()
return auto_continue_wrapper()