mirror of https://github.com/kortix-ai/suna.git
revert accidental changes
This commit is contained in:
parent
2d88091258
commit
60059eaa9b
|
@ -22,7 +22,7 @@ from agent.prompt import get_system_prompt
|
|||
|
||||
from utils.logger import logger
|
||||
from utils.auth_utils import get_account_id_from_thread
|
||||
from services.billing import check_billing_status, calculate_token_cost, handle_usage_with_credits
|
||||
from services.billing import check_billing_status
|
||||
from agent.tools.sb_vision_tool import SandboxVisionTool
|
||||
from agent.tools.sb_image_edit_tool import SandboxImageEditTool
|
||||
from agent.tools.sb_presentation_outline_tool import SandboxPresentationOutlineTool
|
||||
|
@ -438,53 +438,6 @@ class AgentRunner:
|
|||
def __init__(self, config: AgentConfig):
|
||||
self.config = config
|
||||
|
||||
async def handle_assistant_response_end(self, thread_id: str, message_id: str, content: dict, client):
|
||||
"""Handle billing/credit deduction for assistant response end messages.
|
||||
|
||||
This ensures 100% traceability by linking deductions to specific messages.
|
||||
"""
|
||||
try:
|
||||
if not isinstance(content, dict):
|
||||
logger.debug(f"Skipping billing for message {message_id}: content is not a dict")
|
||||
return
|
||||
|
||||
usage = content.get("usage", {})
|
||||
if not usage:
|
||||
logger.debug(f"Skipping billing for message {message_id}: no usage data")
|
||||
return
|
||||
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0) or 0)
|
||||
completion_tokens = int(usage.get("completion_tokens", 0) or 0)
|
||||
model = content.get("model")
|
||||
|
||||
if prompt_tokens == 0 and completion_tokens == 0:
|
||||
logger.debug(f"Skipping billing for message {message_id}: zero tokens")
|
||||
return
|
||||
|
||||
# Calculate token cost
|
||||
token_cost = calculate_token_cost(prompt_tokens, completion_tokens, model or "unknown")
|
||||
|
||||
# Get account_id for this thread (equals user_id for personal accounts)
|
||||
thread_row = await client.table('threads').select('account_id').eq('thread_id', thread_id).limit(1).execute()
|
||||
user_id = thread_row.data[0]['account_id'] if thread_row.data and len(thread_row.data) > 0 else None
|
||||
|
||||
if user_id and token_cost > 0:
|
||||
# Deduct credits if applicable and record usage against this specific message
|
||||
await handle_usage_with_credits(
|
||||
client,
|
||||
user_id,
|
||||
token_cost,
|
||||
thread_id=thread_id,
|
||||
message_id=message_id,
|
||||
model=model or "unknown"
|
||||
)
|
||||
logger.debug(f"💰 Processed billing for message {message_id}: {token_cost} credits (tokens: {prompt_tokens}+{completion_tokens}={prompt_tokens+completion_tokens})")
|
||||
else:
|
||||
logger.debug(f"Skipping billing for message {message_id}: user_id={user_id}, token_cost={token_cost}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling billing for message {message_id}: {str(e)}", exc_info=True)
|
||||
|
||||
async def setup(self):
|
||||
if not self.config.trace:
|
||||
self.config.trace = langfuse.trace(name="run_agent", session_id=self.config.thread_id, metadata={"project_id": self.config.project_id})
|
||||
|
@ -493,8 +446,7 @@ class AgentRunner:
|
|||
trace=self.config.trace,
|
||||
is_agent_builder=self.config.is_agent_builder or False,
|
||||
target_agent_id=self.config.target_agent_id,
|
||||
agent_config=self.config.agent_config,
|
||||
on_assistant_response_end=self.handle_assistant_response_end
|
||||
agent_config=self.config.agent_config
|
||||
)
|
||||
|
||||
self.client = await self.thread_manager.db.client
|
||||
|
|
|
@ -25,7 +25,7 @@ from utils.logger import logger
|
|||
from langfuse.client import StatefulGenerationClient, StatefulTraceClient
|
||||
from services.langfuse import langfuse
|
||||
from litellm.utils import token_counter
|
||||
|
||||
from services.billing import calculate_token_cost, handle_usage_with_credits
|
||||
import re
|
||||
from datetime import datetime, timezone, timedelta
|
||||
import aiofiles
|
||||
|
@ -42,7 +42,7 @@ class ThreadManager:
|
|||
XML-based tool execution patterns.
|
||||
"""
|
||||
|
||||
def __init__(self, trace: Optional[StatefulTraceClient] = None, is_agent_builder: bool = False, target_agent_id: Optional[str] = None, agent_config: Optional[dict] = None, on_assistant_response_end: Optional[callable] = None):
|
||||
def __init__(self, trace: Optional[StatefulTraceClient] = None, is_agent_builder: bool = False, target_agent_id: Optional[str] = None, agent_config: Optional[dict] = None):
|
||||
"""Initialize ThreadManager.
|
||||
|
||||
Args:
|
||||
|
@ -50,7 +50,6 @@ class ThreadManager:
|
|||
is_agent_builder: Whether this is an agent builder session
|
||||
target_agent_id: ID of the agent being built (if in agent builder mode)
|
||||
agent_config: Optional agent configuration with version information
|
||||
on_assistant_response_end: Optional callback for handling assistant response end events (e.g., billing)
|
||||
"""
|
||||
self.db = DBConnection()
|
||||
self.tool_registry = ToolRegistry()
|
||||
|
@ -58,7 +57,6 @@ class ThreadManager:
|
|||
self.is_agent_builder = is_agent_builder
|
||||
self.target_agent_id = target_agent_id
|
||||
self.agent_config = agent_config
|
||||
self.on_assistant_response_end = on_assistant_response_end
|
||||
if not self.trace:
|
||||
self.trace = langfuse.trace(name="anonymous:thread_manager")
|
||||
self.response_processor = ResponseProcessor(
|
||||
|
@ -176,19 +174,30 @@ class ThreadManager:
|
|||
|
||||
if result.data and len(result.data) > 0 and isinstance(result.data[0], dict) and 'message_id' in result.data[0]:
|
||||
saved_message = result.data[0]
|
||||
|
||||
# Trigger callback for assistant_response_end if callback is provided
|
||||
if type == "assistant_response_end" and self.on_assistant_response_end:
|
||||
# If this is an assistant_response_end, attempt to deduct credits if over limit
|
||||
if type == "assistant_response_end" and isinstance(content, dict):
|
||||
try:
|
||||
await self.on_assistant_response_end(
|
||||
thread_id=thread_id,
|
||||
message_id=saved_message['message_id'],
|
||||
content=content,
|
||||
client=client
|
||||
)
|
||||
except Exception as callback_e:
|
||||
logger.error(f"Error in assistant_response_end callback: {str(callback_e)}", exc_info=True)
|
||||
|
||||
usage = content.get("usage", {}) if isinstance(content, dict) else {}
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0) or 0)
|
||||
completion_tokens = int(usage.get("completion_tokens", 0) or 0)
|
||||
model = content.get("model") if isinstance(content, dict) else None
|
||||
# Compute token cost
|
||||
token_cost = calculate_token_cost(prompt_tokens, completion_tokens, model or "unknown")
|
||||
# Fetch account_id for this thread, which equals user_id for personal accounts
|
||||
thread_row = await client.table('threads').select('account_id').eq('thread_id', thread_id).limit(1).execute()
|
||||
user_id = thread_row.data[0]['account_id'] if thread_row.data and len(thread_row.data) > 0 else None
|
||||
if user_id and token_cost > 0:
|
||||
# Deduct credits if applicable and record usage against this message
|
||||
await handle_usage_with_credits(
|
||||
client,
|
||||
user_id,
|
||||
token_cost,
|
||||
thread_id=thread_id,
|
||||
message_id=saved_message['message_id'],
|
||||
model=model or "unknown"
|
||||
)
|
||||
except Exception as billing_e:
|
||||
logger.error(f"Error handling credit usage for message {saved_message.get('message_id')}: {str(billing_e)}", exc_info=True)
|
||||
return saved_message
|
||||
else:
|
||||
logger.error(f"Insert operation failed or did not return expected data structure for thread {thread_id}. Result data: {result.data}")
|
||||
|
@ -650,4 +659,4 @@ When using the tools:
|
|||
return await _run_once(temporary_message)
|
||||
|
||||
# Otherwise return the auto-continue wrapper generator
|
||||
return auto_continue_wrapper()
|
||||
return auto_continue_wrapper()
|
Loading…
Reference in New Issue