Merge branch 'kortix-ai:main' into refactor/slider

This commit is contained in:
Krishav 2025-08-19 19:06:11 +05:30 committed by GitHub
commit e7385698c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 1288 additions and 124 deletions

View File

@ -24,8 +24,12 @@ from services.supabase import DBConnection
from utils.logger import logger
from langfuse.client import StatefulGenerationClient, StatefulTraceClient
from services.langfuse import langfuse
import datetime
from litellm.utils import token_counter
from services.billing import calculate_token_cost, handle_usage_with_credits
import re
from datetime import datetime, timezone, timedelta
import aiofiles
import yaml
# Type alias for tool choice
ToolChoice = Literal["auto", "required", "none"]
@ -169,7 +173,32 @@ class ThreadManager:
logger.debug(f"Successfully added message to thread {thread_id}")
if result.data and len(result.data) > 0 and isinstance(result.data[0], dict) and 'message_id' in result.data[0]:
return result.data[0]
saved_message = result.data[0]
# If this is an assistant_response_end, attempt to deduct credits if over limit
if type == "assistant_response_end" and isinstance(content, dict):
try:
usage = content.get("usage", {}) if isinstance(content, dict) else {}
prompt_tokens = int(usage.get("prompt_tokens", 0) or 0)
completion_tokens = int(usage.get("completion_tokens", 0) or 0)
model = content.get("model") if isinstance(content, dict) else None
# Compute token cost
token_cost = calculate_token_cost(prompt_tokens, completion_tokens, model or "unknown")
# Fetch account_id for this thread, which equals user_id for personal accounts
thread_row = await client.table('threads').select('account_id').eq('thread_id', thread_id).limit(1).execute()
user_id = thread_row.data[0]['account_id'] if thread_row.data and len(thread_row.data) > 0 else None
if user_id and token_cost > 0:
# Deduct credits if applicable and record usage against this message
await handle_usage_with_credits(
client,
user_id,
token_cost,
thread_id=thread_id,
message_id=saved_message['message_id'],
model=model or "unknown"
)
except Exception as billing_e:
logger.error(f"Error handling credit usage for message {saved_message.get('message_id')}: {str(billing_e)}", exc_info=True)
return saved_message
else:
logger.error(f"Insert operation failed or did not return expected data structure for thread {thread_id}. Result data: {result.data}")
return None
@ -455,7 +484,7 @@ When using the tools:
if generation:
generation.update(
input=prepared_messages,
start_time=datetime.datetime.now(datetime.timezone.utc),
start_time=datetime.now(timezone.utc),
model=llm_model,
model_parameters={
"max_tokens": llm_max_tokens,

View File

@ -769,36 +769,6 @@ async def create_composio_trigger(req: CreateComposioTriggerRequest, current_use
except Exception:
pass
# If still missing, fetch from list_active
if not composio_trigger_id:
try:
params_lookup = {
"limit": 50,
"slug": req.slug,
"userId": composio_user_id,
}
if req.connected_account_id:
params_lookup["connectedAccountId"] = req.connected_account_id
list_url = f"{COMPOSIO_API_BASE}/api/v3/trigger_instances/active"
async with httpx.AsyncClient(timeout=15) as http_client:
lr = await http_client.get(list_url, headers=headers, params=params_lookup)
if lr.status_code == 200:
ldata = lr.json()
items = ldata.get("items") if isinstance(ldata, dict) else (ldata if isinstance(ldata, list) else [])
if items:
composio_trigger_id = _extract_id(items[0] if isinstance(items[0], dict) else getattr(items[0], "__dict__", {}))
try:
logger.debug(
"Composio list_active fallback",
slug=req.slug,
matched=len(items) if isinstance(items, list) else 0,
extracted_id=composio_trigger_id,
)
except Exception:
pass
except Exception:
pass
if not composio_trigger_id:
raise HTTPException(status_code=500, detail="Failed to get Composio trigger id from response")
@ -858,6 +828,17 @@ async def create_composio_trigger(req: CreateComposioTriggerRequest, current_use
async def composio_webhook(request: Request):
"""Shared Composio webhook endpoint. Verifies secret, matches triggers, and enqueues execution."""
try:
# Read raw body first (can only be done once)
try:
body = await request.body()
body_str = body.decode('utf-8') if body else ""
logger.info("Composio webhook raw body", body=body_str, body_length=len(body) if body else 0)
except Exception as e:
logger.info("Composio webhook body read failed", error=str(e))
body_str = ""
# Minimal request diagnostics (no secrets)
try:
client_ip = request.client.host if request.client else None
@ -865,26 +846,28 @@ async def composio_webhook(request: Request):
has_auth = bool(request.headers.get("authorization"))
has_x_secret = bool(request.headers.get("x-composio-secret") or request.headers.get("X-Composio-Secret"))
has_x_trigger = bool(request.headers.get("x-trigger-secret") or request.headers.get("X-Trigger-Secret"))
# Peek payload meta safely
payload_preview = {}
# Parse payload for logging
payload_preview = {"keys": []}
try:
_p = await request.json()
payload_preview = {
"keys": list(_p.keys()) if isinstance(_p, dict) else [],
"id": _p.get("id") if isinstance(_p, dict) else None,
"triggerSlug": _p.get("triggerSlug") if isinstance(_p, dict) else None,
}
if body_str:
_p = json.loads(body_str)
payload_preview = {
"keys": list(_p.keys()) if isinstance(_p, dict) else [],
"id": _p.get("id") if isinstance(_p, dict) else None,
"triggerSlug": _p.get("triggerSlug") if isinstance(_p, dict) else None,
}
except Exception:
payload_preview = {"keys": []}
logger.debug(
"Composio webhook incoming",
client_ip=client_ip,
header_names=header_names,
has_authorization=has_auth,
has_x_composio_secret=has_x_secret,
has_x_trigger_secret=has_x_trigger,
payload_meta=payload_preview,
)
logger.debug(
"Composio webhook incoming",
client_ip=client_ip,
header_names=header_names,
has_authorization=has_auth,
has_x_composio_secret=has_x_secret,
has_x_trigger_secret=has_x_trigger,
payload_meta=payload_preview,
)
except Exception:
pass
@ -896,8 +879,9 @@ async def composio_webhook(request: Request):
# Use robust verifier (tries ASCII/HEX/B64 keys and id.ts.body/ts.body)
await verify_composio(request, "COMPOSIO_WEBHOOK_SECRET")
# Parse payload for processing
try:
payload = await request.json()
payload = json.loads(body_str) if body_str else {}
except Exception:
payload = {}

View File

@ -82,7 +82,6 @@ class ProfileService:
'updated_at': datetime.now(timezone.utc).isoformat()
}).eq('account_id', account_id)\
.eq('mcp_qualified_name', mcp_qualified_name)\
.eq('is_active', True)\
.execute()
result = await client.table('user_mcp_credential_profiles').insert({
@ -106,7 +105,6 @@ class ProfileService:
client = await self._db.client
result = await client.table('user_mcp_credential_profiles').select('*')\
.eq('profile_id', profile_id)\
.eq('is_active', True)\
.execute()
if not result.data:
@ -128,7 +126,6 @@ class ProfileService:
result = await client.table('user_mcp_credential_profiles').select('*')\
.eq('account_id', account_id)\
.eq('mcp_qualified_name', mcp_qualified_name)\
.eq('is_active', True)\
.order('is_default', desc=True)\
.order('created_at', desc=True)\
.execute()
@ -139,7 +136,6 @@ class ProfileService:
client = await self._db.client
result = await client.table('user_mcp_credential_profiles').select('*')\
.eq('account_id', account_id)\
.eq('is_active', True)\
.order('created_at', desc=True)\
.execute()
@ -172,7 +168,6 @@ class ProfileService:
'updated_at': datetime.now(timezone.utc).isoformat()
}).eq('account_id', account_id)\
.eq('mcp_qualified_name', profile.mcp_qualified_name)\
.eq('is_active', True)\
.execute()
result = await client.table('user_mcp_credential_profiles').update({
@ -186,18 +181,15 @@ class ProfileService:
if success:
logger.debug(f"Set profile {profile_id} as default")
return success
return success
async def delete_profile(self, account_id: str, profile_id: str) -> bool:
logger.debug(f"Deleting profile {profile_id}")
client = await self._db.client
result = await client.table('user_mcp_credential_profiles').update({
'is_active': False,
'updated_at': datetime.now(timezone.utc).isoformat()
}).eq('profile_id', profile_id)\
result = await client.table('user_mcp_credential_profiles').delete()\
.eq('profile_id', profile_id)\
.eq('account_id', account_id)\
.eq('is_active', True)\
.execute()
success = len(result.data) > 0

View File

@ -26,15 +26,25 @@ stripe.api_key = config.STRIPE_SECRET_KEY
# Token price multiplier
TOKEN_PRICE_MULTIPLIER = 1.5
# Initialize router
# Minimum credits required to allow a new request when over subscription limit
CREDIT_MIN_START_DOLLARS = 0.20
# Credit packages with Stripe price IDs
CREDIT_PACKAGES = {
'credits_10': {'amount': 10, 'price': 10, 'stripe_price_id': config.STRIPE_CREDITS_10_PRICE_ID},
'credits_25': {'amount': 25, 'price': 25, 'stripe_price_id': config.STRIPE_CREDITS_25_PRICE_ID},
# Uncomment these when you create the additional price IDs in Stripe:
# 'credits_50': {'amount': 50, 'price': 50, 'stripe_price_id': config.STRIPE_CREDITS_50_PRICE_ID},
# 'credits_100': {'amount': 100, 'price': 100, 'stripe_price_id': config.STRIPE_CREDITS_100_PRICE_ID},
# 'credits_250': {'amount': 250, 'price': 250, 'stripe_price_id': config.STRIPE_CREDITS_250_PRICE_ID},
# 'credits_500': {'amount': 500, 'price': 500, 'stripe_price_id': config.STRIPE_CREDITS_500_PRICE_ID},
# 'credits_1000': {'amount': 1000, 'price': 1000, 'stripe_price_id': config.STRIPE_CREDITS_1000_PRICE_ID},
}
router = APIRouter(prefix="/billing", tags=["billing"])
# Plan validation functions
def get_plan_info(price_id: str) -> dict:
"""Get plan information including tier level and type."""
# Production plans mapping
PLAN_TIERS = {
# Monthly plans
config.STRIPE_TIER_2_20_ID: {'tier': 1, 'type': 'monthly', 'name': '2h/$20'},
config.STRIPE_TIER_6_50_ID: {'tier': 2, 'type': 'monthly', 'name': '6h/$50'},
config.STRIPE_TIER_12_100_ID: {'tier': 3, 'type': 'monthly', 'name': '12h/$100'},
@ -163,6 +173,37 @@ class SubscriptionStatus(BaseModel):
# Subscription data for frontend components
subscription_id: Optional[str] = None
subscription: Optional[Dict] = None
# Credit information
credit_balance: Optional[float] = None
can_purchase_credits: bool = False
class PurchaseCreditsRequest(BaseModel):
amount_dollars: float # Amount of credits to purchase in dollars
success_url: str
cancel_url: str
class CreditBalance(BaseModel):
balance_dollars: float
total_purchased: float
total_used: float
last_updated: Optional[datetime] = None
can_purchase_credits: bool = False # True only for highest tier users
class CreditPurchase(BaseModel):
id: str
amount_dollars: float
status: str
created_at: datetime
completed_at: Optional[datetime] = None
stripe_payment_intent_id: Optional[str] = None
class CreditUsage(BaseModel):
id: str
amount_dollars: float
description: Optional[str] = None
created_at: datetime
thread_id: Optional[str] = None
message_id: Optional[str] = None
# Helper functions
async def get_stripe_customer_id(client: SupabaseClient, user_id: str) -> Optional[str]:
@ -355,12 +396,12 @@ async def calculate_monthly_usage(client, user_id: str) -> float:
execution_time = end_time - start_time
logger.debug(f"Calculate monthly usage took {execution_time:.3f} seconds, total cost: {total_cost}")
await Cache.set(f"monthly_usage:{user_id}", total_cost, ttl=2 * 60)
await Cache.set(f"monthly_usage:{user_id}", total_cost, ttl=5)
return total_cost
async def get_usage_logs(client, user_id: str, page: int = 0, items_per_page: int = 1000) -> Dict:
"""Get detailed usage logs for a user with pagination."""
"""Get detailed usage logs for a user with pagination, including credit usage info."""
# Get start of current month in UTC
now = datetime.now(timezone.utc)
start_of_month = datetime(now.year, now.month, 1, tzinfo=timezone.utc)
@ -420,6 +461,37 @@ async def get_usage_logs(client, user_id: str, page: int = 0, items_per_page: in
if not messages_result.data:
return {"logs": [], "has_more": False}
# Get the user's subscription tier info for credit checking
subscription = await get_user_subscription(user_id)
price_id = config.STRIPE_FREE_TIER_ID # Default to free
if subscription and subscription.get('items'):
items = subscription['items'].get('data', [])
if items:
price_id = items[0]['price']['id']
tier_info = SUBSCRIPTION_TIERS.get(price_id, SUBSCRIPTION_TIERS[config.STRIPE_FREE_TIER_ID])
subscription_limit = tier_info['cost']
# Get credit usage records for this month to match with messages
credit_usage_result = await client.table('credit_usage') \
.select('message_id, amount_dollars, created_at') \
.eq('user_id', user_id) \
.gte('created_at', start_of_month.isoformat()) \
.execute()
# Create a map of message_id to credit usage
credit_usage_map = {}
if credit_usage_result.data:
for usage in credit_usage_result.data:
if usage.get('message_id'):
credit_usage_map[usage['message_id']] = {
'amount': float(usage['amount_dollars']),
'created_at': usage['created_at']
}
# Track cumulative usage to determine when credits started being used
cumulative_cost = 0.0
# Process messages into usage log entries
processed_logs = []
@ -444,13 +516,19 @@ async def get_usage_logs(client, user_id: str, page: int = 0, items_per_page: in
model
)
cumulative_cost += estimated_cost
# Safely extract project_id from threads relationship
project_id = 'unknown'
if message.get('threads') and isinstance(message['threads'], list) and len(message['threads']) > 0:
project_id = message['threads'][0].get('project_id', 'unknown')
processed_logs.append({
'message_id': message.get('message_id', 'unknown'),
# Check if credits were used for this message
message_id = message.get('message_id')
credit_used = credit_usage_map.get(message_id, {})
log_entry = {
'message_id': message_id or 'unknown',
'thread_id': message.get('thread_id', 'unknown'),
'created_at': message.get('created_at', None),
'content': {
@ -462,8 +540,14 @@ async def get_usage_logs(client, user_id: str, page: int = 0, items_per_page: in
},
'total_tokens': total_tokens,
'estimated_cost': estimated_cost,
'project_id': project_id
})
'project_id': project_id,
# Add credit usage info
'credit_used': credit_used.get('amount', 0) if credit_used else 0,
'payment_method': 'credits' if credit_used else 'subscription',
'was_over_limit': cumulative_cost > subscription_limit if not credit_used else True
}
processed_logs.append(log_entry)
except Exception as e:
logger.warning(f"Error processing usage log entry for message {message.get('message_id', 'unknown')}: {str(e)}")
continue
@ -473,7 +557,9 @@ async def get_usage_logs(client, user_id: str, page: int = 0, items_per_page: in
return {
"logs": processed_logs,
"has_more": has_more
"has_more": has_more,
"subscription_limit": subscription_limit,
"cumulative_cost": cumulative_cost
}
@ -619,6 +705,7 @@ async def get_subscription_tier(client, user_id: str) -> str:
async def check_billing_status(client, user_id: str) -> Tuple[bool, str, Optional[Dict]]:
"""
Check if a user can run agents based on their subscription and usage.
Now also checks credit balance if subscription limit is exceeded.
Returns:
Tuple[bool, str, Optional[Dict]]: (can_run, message, subscription_info)
@ -658,8 +745,25 @@ async def check_billing_status(client, user_id: str) -> Tuple[bool, str, Optiona
# Calculate current month's usage
current_usage = await calculate_monthly_usage(client, user_id)
# Check if subscription limit is exceeded
if current_usage >= tier_info['cost']:
return False, f"Monthly limit of {tier_info['cost']} dollars reached. Please upgrade your plan or wait until next month.", subscription
# Check if user has credits available
credit_balance = await get_user_credit_balance(client, user_id)
if credit_balance.balance_dollars >= CREDIT_MIN_START_DOLLARS:
# User has enough credits cushion; they can continue
return True, f"Subscription limit reached, using credits. Balance: ${credit_balance.balance_dollars:.2f}", subscription
else:
# Not enough credits to safely start a new request
if credit_balance.can_purchase_credits:
return False, (
f"Monthly limit of ${tier_info['cost']} reached. You need at least ${CREDIT_MIN_START_DOLLARS:.2f} in credits to continue. "
f"Current balance: ${credit_balance.balance_dollars:.2f}."
), subscription
else:
return False, (
f"Monthly limit of ${tier_info['cost']} reached and credits are unavailable. Please upgrade your plan or wait until next month."
), subscription
return True, "OK", subscription
@ -744,6 +848,194 @@ async def check_subscription_commitment(subscription_id: str) -> dict:
'can_cancel': True
}
async def is_user_on_highest_tier(user_id: str) -> bool:
"""Check if user is on the highest subscription tier (200h/$1000)."""
try:
subscription = await get_user_subscription(user_id)
if not subscription:
logger.debug(f"User {user_id} has no subscription")
return False
# Extract price ID from subscription
price_id = None
if subscription.get('items') and subscription['items'].get('data') and len(subscription['items']['data']) > 0:
price_id = subscription['items']['data'][0]['price']['id']
logger.info(f"User {user_id} subscription price_id: {price_id}")
# Check if it's one of the highest tier price IDs (200h/$1000 only)
highest_tier_price_ids = [
config.STRIPE_TIER_200_1000_ID, # Monthly highest tier
config.STRIPE_TIER_200_1000_YEARLY_ID, # Yearly highest tier
config.STRIPE_TIER_25_200_ID_STAGING,
config.STRIPE_TIER_25_200_YEARLY_ID_STAGING,
config.STRIPE_TIER_2_20_ID_STAGING,
config.STRIPE_TIER_2_20_YEARLY_ID_STAGING,
]
is_highest = price_id in highest_tier_price_ids
logger.info(f"User {user_id} is_highest_tier: {is_highest}, price_id: {price_id}, checked against: {highest_tier_price_ids}")
return is_highest
except Exception as e:
logger.error(f"Error checking if user is on highest tier: {str(e)}")
return False
async def get_user_credit_balance(client: SupabaseClient, user_id: str) -> CreditBalance:
"""Get the credit balance for a user."""
try:
# Get balance from database - use execute() instead of single() to handle no records
result = await client.table('credit_balance') \
.select('*') \
.eq('user_id', user_id) \
.execute()
if result.data and len(result.data) > 0:
data = result.data[0]
is_highest_tier = await is_user_on_highest_tier(user_id)
return CreditBalance(
balance_dollars=float(data.get('balance_dollars', 0)),
total_purchased=float(data.get('total_purchased', 0)),
total_used=float(data.get('total_used', 0)),
last_updated=data.get('last_updated'),
can_purchase_credits=is_highest_tier
)
else:
# No balance record exists yet - this is normal for users who haven't purchased credits
is_highest_tier = await is_user_on_highest_tier(user_id)
return CreditBalance(
balance_dollars=0.0,
total_purchased=0.0,
total_used=0.0,
can_purchase_credits=is_highest_tier
)
except Exception as e:
logger.error(f"Error getting credit balance for user {user_id}: {str(e)}")
return CreditBalance(
balance_dollars=0.0,
total_purchased=0.0,
total_used=0.0,
can_purchase_credits=False
)
async def add_credits_to_balance(client: SupabaseClient, user_id: str, amount: float, purchase_id: str = None) -> float:
"""Add credits to a user's balance."""
try:
# Use the database function to add credits
result = await client.rpc('add_credits', {
'p_user_id': user_id,
'p_amount': amount,
'p_purchase_id': purchase_id
}).execute()
if result.data is not None:
return float(result.data)
return 0.0
except Exception as e:
logger.error(f"Error adding credits for user {user_id}: {str(e)}")
raise
async def use_credits_from_balance(
client: SupabaseClient,
user_id: str,
amount: float,
description: str = None,
thread_id: str = None,
message_id: str = None
) -> bool:
"""Deduct credits from a user's balance."""
try:
# Use the database function to use credits
result = await client.rpc('use_credits', {
'p_user_id': user_id,
'p_amount': amount,
'p_description': description,
'p_thread_id': thread_id,
'p_message_id': message_id
}).execute()
if result.data is not None:
return bool(result.data)
return False
except Exception as e:
logger.error(f"Error using credits for user {user_id}: {str(e)}")
return False
async def handle_usage_with_credits(
client: SupabaseClient,
user_id: str,
token_cost: float,
thread_id: str = None,
message_id: str = None,
model: str = None
) -> Tuple[bool, str]:
"""
Handle token usage that may require credits if subscription limit is exceeded.
This should be called after each agent response to track and deduct from credits if needed.
Returns:
Tuple[bool, str]: (success, message)
"""
try:
# Get current subscription tier and limits
subscription = await get_user_subscription(user_id)
# Get tier info
price_id = config.STRIPE_FREE_TIER_ID # Default to free
if subscription and subscription.get('items'):
items = subscription['items'].get('data', [])
if items:
price_id = items[0]['price']['id']
tier_info = SUBSCRIPTION_TIERS.get(price_id, SUBSCRIPTION_TIERS[config.STRIPE_FREE_TIER_ID])
# Get current month's usage
current_usage = await calculate_monthly_usage(client, user_id)
# Check if this usage would exceed the subscription limit
new_total_usage = current_usage + token_cost
if new_total_usage > tier_info['cost']:
# Calculate overage amount
overage_amount = token_cost # The entire cost if already over limit
if current_usage < tier_info['cost']:
# If this is the transaction that pushes over the limit
overage_amount = new_total_usage - tier_info['cost']
# Try to use credits for the overage
credit_balance = await get_user_credit_balance(client, user_id)
if credit_balance.balance_dollars >= overage_amount:
# Deduct from credits
success = await use_credits_from_balance(
client,
user_id,
overage_amount,
description=f"Token overage for model {model or 'unknown'}",
thread_id=thread_id,
message_id=message_id
)
if success:
logger.debug(f"Used ${overage_amount:.4f} credits for user {user_id} overage")
return True, f"Used ${overage_amount:.4f} from credits (Balance: ${credit_balance.balance_dollars - overage_amount:.2f})"
else:
return False, "Failed to deduct credits"
else:
# Insufficient credits
if credit_balance.can_purchase_credits:
return False, f"Insufficient credits. Balance: ${credit_balance.balance_dollars:.2f}, Required: ${overage_amount:.4f}. Purchase more credits to continue."
else:
return False, f"Monthly limit exceeded and no credits available. Upgrade to the highest tier to purchase credits."
# Within subscription limits, no credits needed
return True, "Within subscription limits"
except Exception as e:
logger.error(f"Error handling usage with credits: {str(e)}")
return False, f"Error processing usage: {str(e)}"
# API endpoints
@router.post("/create-checkout-session")
async def create_checkout_session(
@ -1171,7 +1463,7 @@ async def create_portal_session(
async def get_subscription(
current_user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Get the current subscription status for the current user, including scheduled changes."""
"""Get the current subscription status for the current user, including scheduled changes and credit balance."""
try:
# Get subscription from Stripe (this helper already handles filtering/cleanup)
subscription = await get_user_subscription(current_user_id)
@ -1181,6 +1473,9 @@ async def get_subscription(
db = DBConnection()
client = await db.client
current_usage = await calculate_monthly_usage(client, current_user_id)
# Get credit balance
credit_balance_info = await get_user_credit_balance(client, current_user_id)
if not subscription:
# Default to free tier status if no active subscription for our product
@ -1192,7 +1487,9 @@ async def get_subscription(
price_id=free_tier_id,
minutes_limit=free_tier_info.get('minutes') if free_tier_info else 0,
cost_limit=free_tier_info.get('cost') if free_tier_info else 0,
current_usage=current_usage
current_usage=current_usage,
credit_balance=credit_balance_info.balance_dollars,
can_purchase_credits=credit_balance_info.can_purchase_credits
)
# Extract current plan details
@ -1222,7 +1519,9 @@ async def get_subscription(
'cancel_at_period_end': subscription['cancel_at_period_end'],
'cancel_at': subscription.get('cancel_at'),
'current_period_end': current_item['current_period_end']
}
},
credit_balance=credit_balance_info.balance_dollars,
can_purchase_credits=credit_balance_info.can_purchase_credits
)
# Check for an attached schedule (indicates pending downgrade)
@ -1265,7 +1564,7 @@ async def get_subscription(
async def check_status(
current_user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Check if the user can run agents based on their subscription and usage."""
"""Check if the user can run agents based on their subscription, usage, and credit balance."""
try:
# Get Supabase client
db = DBConnection()
@ -1273,10 +1572,15 @@ async def check_status(
can_run, message, subscription = await check_billing_status(client, current_user_id)
# Get credit balance for additional info
credit_balance = await get_user_credit_balance(client, current_user_id)
return {
"can_run": can_run,
"message": message,
"subscription": subscription
"subscription": subscription,
"credit_balance": credit_balance.balance_dollars,
"can_purchase_credits": credit_balance.can_purchase_credits
}
except Exception as e:
@ -1307,7 +1611,77 @@ async def stripe_webhook(request: Request):
logger.error(f"Invalid webhook signature: {str(e)}")
raise HTTPException(status_code=400, detail="Invalid signature")
# Handle the event
# Get database connection
db = DBConnection()
client = await db.client
# Handle credit purchase completion
if event.type == 'checkout.session.completed':
session = event.data.object
# Check if this is a credit purchase
if session.get('metadata', {}).get('type') == 'credit_purchase':
user_id = session['metadata']['user_id']
credit_amount = float(session['metadata']['credit_amount'])
payment_intent_id = session.get('payment_intent')
logger.debug(f"Processing credit purchase for user {user_id}: ${credit_amount}")
try:
# Update the purchase record status
purchase_update = await client.table('credit_purchases') \
.update({
'status': 'completed',
'completed_at': datetime.now(timezone.utc).isoformat(),
'stripe_payment_intent_id': payment_intent_id
}) \
.eq('stripe_payment_intent_id', payment_intent_id) \
.execute()
if not purchase_update.data:
# If no record found by payment_intent_id, try by session_id in metadata (PostgREST JSON operator requires filter)
purchase_update = await client.table('credit_purchases') \
.update({
'status': 'completed',
'completed_at': datetime.now(timezone.utc).isoformat(),
'stripe_payment_intent_id': payment_intent_id
}) \
.filter('metadata->>session_id', 'eq', session['id']) \
.execute()
# Add credits to user's balance
purchase_id = purchase_update.data[0]['id'] if purchase_update.data else None
new_balance = await add_credits_to_balance(client, user_id, credit_amount, purchase_id)
logger.info(f"Successfully added ${credit_amount} credits to user {user_id}. New balance: ${new_balance}")
# Clear cache for this user
await Cache.delete(f"monthly_usage:{user_id}")
await Cache.delete(f"user_subscription:{user_id}")
except Exception as e:
logger.error(f"Error processing credit purchase: {str(e)}")
# Don't fail the webhook, but log the error
return {"status": "success", "message": "Credit purchase processed"}
# Handle payment failed for credit purchases
if event.type == 'payment_intent.payment_failed':
payment_intent = event.data.object
# Check if this is related to a credit purchase
if payment_intent.get('metadata', {}).get('type') == 'credit_purchase':
user_id = payment_intent['metadata']['user_id']
# Update purchase record to failed
await client.table('credit_purchases') \
.update({'status': 'failed'}) \
.eq('stripe_payment_intent_id', payment_intent['id']) \
.execute()
logger.debug(f"Credit purchase failed for user {user_id}")
# Handle the existing subscription events
if event.type in ['customer.subscription.created', 'customer.subscription.updated', 'customer.subscription.deleted']:
# Extract the subscription and customer information
subscription = event.data.object
@ -1317,10 +1691,6 @@ async def stripe_webhook(request: Request):
logger.warning(f"No customer ID found in subscription event: {event.type}")
return {"status": "error", "message": "No customer ID found"}
# Get database connection
db = DBConnection()
client = await db.client
if event.type == 'customer.subscription.created':
# Update customer active status for new subscriptions
if subscription.get('status') in ['active', 'trialing']:
@ -1833,3 +2203,219 @@ async def reactivate_subscription(
except Exception as e:
logger.error(f"Error reactivating subscription: {str(e)}")
raise HTTPException(status_code=500, detail="Error processing reactivation request")
@router.post("/purchase-credits")
async def purchase_credits(
request: PurchaseCreditsRequest,
current_user_id: str = Depends(get_current_user_id_from_jwt)
):
"""
Create a Stripe checkout session for purchasing credits.
Only available for users on the highest subscription tier.
"""
try:
# Check if user is on the highest tier
is_highest_tier = await is_user_on_highest_tier(current_user_id)
if not is_highest_tier:
raise HTTPException(
status_code=403,
detail="Credit purchases are only available for users on the highest subscription tier ($1000/month)."
)
# Validate amount
if request.amount_dollars < 10:
raise HTTPException(status_code=400, detail="Minimum credit purchase is $10")
if request.amount_dollars > 5000:
raise HTTPException(status_code=400, detail="Maximum credit purchase is $5000")
# Get Supabase client
db = DBConnection()
client = await db.client
# Get user email
user_result = await client.auth.admin.get_user_by_id(current_user_id)
if not user_result:
raise HTTPException(status_code=404, detail="User not found")
email = user_result.user.email
# Get or create Stripe customer
customer_id = await get_stripe_customer_id(client, current_user_id)
if not customer_id:
customer_id = await create_stripe_customer(client, current_user_id, email)
# Check if we have a pre-configured price ID for this amount
matching_package = None
for package_key, package_info in CREDIT_PACKAGES.items():
if package_info['amount'] == request.amount_dollars and package_info.get('stripe_price_id'):
matching_package = package_info
break
# Create a checkout session
if matching_package and matching_package['stripe_price_id']:
# Use pre-configured price ID
session = await stripe.checkout.Session.create_async(
customer=customer_id,
payment_method_types=['card'],
line_items=[{
'price': matching_package['stripe_price_id'],
'quantity': 1,
}],
mode='payment',
success_url=request.success_url,
cancel_url=request.cancel_url,
metadata={
'user_id': current_user_id,
'credit_amount': str(request.amount_dollars),
'type': 'credit_purchase'
}
)
else:
session = await stripe.checkout.Session.create_async(
customer=customer_id,
payment_method_types=['card'],
line_items=[{
'price_data': {
'currency': 'usd',
'product_data': {
'name': f'Suna AI Credits',
'description': f'${request.amount_dollars:.2f} in usage credits for Suna AI',
},
'unit_amount': int(request.amount_dollars * 100),
},
'quantity': 1,
}],
mode='payment',
success_url=request.success_url,
cancel_url=request.cancel_url,
metadata={
'user_id': current_user_id,
'credit_amount': str(request.amount_dollars),
'type': 'credit_purchase'
}
)
# Record the pending purchase in database
purchase_record = await client.table('credit_purchases').insert({
'user_id': current_user_id,
'amount_dollars': request.amount_dollars,
'status': 'pending',
'stripe_payment_intent_id': session.payment_intent,
'description': f'Credit purchase via Stripe Checkout',
'metadata': {
'session_id': session.id,
'checkout_url': session.url,
'success_url': request.success_url,
'cancel_url': request.cancel_url
}
}).execute()
return {
"session_id": session.id,
"url": session.url,
"purchase_id": purchase_record.data[0]['id'] if purchase_record.data else None
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error creating credit purchase session: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error creating checkout session: {str(e)}")
@router.get("/credit-balance")
async def get_credit_balance_endpoint(
current_user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Get the current credit balance for the user."""
try:
db = DBConnection()
client = await db.client
balance = await get_user_credit_balance(client, current_user_id)
return balance
except Exception as e:
logger.error(f"Error getting credit balance: {str(e)}")
raise HTTPException(status_code=500, detail="Error retrieving credit balance")
@router.get("/credit-history")
async def get_credit_history(
page: int = 0,
items_per_page: int = 50,
current_user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Get credit purchase and usage history for the user."""
try:
db = DBConnection()
client = await db.client
# Get purchases
purchases_result = await client.table('credit_purchases') \
.select('*') \
.eq('user_id', current_user_id) \
.eq('status', 'completed') \
.order('created_at', desc=True) \
.range(page * items_per_page, (page + 1) * items_per_page - 1) \
.execute()
# Get usage
usage_result = await client.table('credit_usage') \
.select('*') \
.eq('user_id', current_user_id) \
.order('created_at', desc=True) \
.range(page * items_per_page, (page + 1) * items_per_page - 1) \
.execute()
# Format response
purchases = [
CreditPurchase(
id=p['id'],
amount_dollars=float(p['amount_dollars']),
status=p['status'],
created_at=p['created_at'],
completed_at=p.get('completed_at'),
stripe_payment_intent_id=p.get('stripe_payment_intent_id')
)
for p in (purchases_result.data or [])
]
usage = [
CreditUsage(
id=u['id'],
amount_dollars=float(u['amount_dollars']),
description=u.get('description'),
created_at=u['created_at'],
thread_id=u.get('thread_id'),
message_id=u.get('message_id')
)
for u in (usage_result.data or [])
]
return {
"purchases": purchases,
"usage": usage,
"page": page,
"has_more": len(purchases) == items_per_page or len(usage) == items_per_page
}
except Exception as e:
logger.error(f"Error getting credit history: {str(e)}")
raise HTTPException(status_code=500, detail="Error retrieving credit history")
@router.get("/can-purchase-credits")
async def can_purchase_credits(
current_user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Check if the current user can purchase credits (must be on highest tier)."""
try:
is_highest_tier = await is_user_on_highest_tier(current_user_id)
return {
"can_purchase": is_highest_tier,
"reason": "Credit purchases are available" if is_highest_tier else "Must be on the highest subscription tier ($1000/month) to purchase credits"
}
except Exception as e:
logger.error(f"Error checking credit purchase eligibility: {str(e)}")
raise HTTPException(status_code=500, detail="Error checking eligibility")

View File

@ -0,0 +1,172 @@
CREATE TABLE IF NOT EXISTS public.credit_purchases (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL REFERENCES auth.users(id) ON DELETE CASCADE,
amount_dollars DECIMAL(10, 2) NOT NULL CHECK (amount_dollars > 0),
stripe_payment_intent_id TEXT UNIQUE,
stripe_charge_id TEXT,
status TEXT NOT NULL DEFAULT 'pending' CHECK (status IN ('pending', 'completed', 'failed', 'refunded')),
description TEXT,
metadata JSONB DEFAULT '{}',
created_at TIMESTAMPTZ DEFAULT NOW(),
completed_at TIMESTAMPTZ,
expires_at TIMESTAMPTZ,
CONSTRAINT credit_purchases_amount_positive CHECK (amount_dollars > 0)
);
CREATE TABLE IF NOT EXISTS public.credit_balance (
user_id UUID PRIMARY KEY REFERENCES auth.users(id) ON DELETE CASCADE,
balance_dollars DECIMAL(10, 2) NOT NULL DEFAULT 0 CHECK (balance_dollars >= 0),
total_purchased DECIMAL(10, 2) NOT NULL DEFAULT 0 CHECK (total_purchased >= 0),
total_used DECIMAL(10, 2) NOT NULL DEFAULT 0 CHECK (total_used >= 0),
last_updated TIMESTAMPTZ DEFAULT NOW(),
metadata JSONB DEFAULT '{}'
);
CREATE TABLE IF NOT EXISTS public.credit_usage (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL REFERENCES auth.users(id) ON DELETE CASCADE,
amount_dollars DECIMAL(10, 2) NOT NULL CHECK (amount_dollars > 0),
thread_id UUID REFERENCES public.threads(thread_id) ON DELETE SET NULL,
message_id UUID REFERENCES public.messages(message_id) ON DELETE SET NULL,
description TEXT,
usage_type TEXT DEFAULT 'token_overage' CHECK (usage_type IN ('token_overage', 'manual_deduction', 'adjustment')),
created_at TIMESTAMPTZ DEFAULT NOW(),
subscription_tier TEXT,
metadata JSONB DEFAULT '{}'
);
CREATE INDEX IF NOT EXISTS idx_credit_purchases_user_id ON public.credit_purchases(user_id);
CREATE INDEX IF NOT EXISTS idx_credit_purchases_status ON public.credit_purchases(status);
CREATE INDEX IF NOT EXISTS idx_credit_purchases_created_at ON public.credit_purchases(created_at DESC);
CREATE INDEX IF NOT EXISTS idx_credit_purchases_stripe_payment_intent ON public.credit_purchases(stripe_payment_intent_id);
CREATE INDEX IF NOT EXISTS idx_credit_usage_user_id ON public.credit_usage(user_id);
CREATE INDEX IF NOT EXISTS idx_credit_usage_created_at ON public.credit_usage(created_at DESC);
CREATE INDEX IF NOT EXISTS idx_credit_usage_thread_id ON public.credit_usage(thread_id);
ALTER TABLE public.credit_purchases ENABLE ROW LEVEL SECURITY;
ALTER TABLE public.credit_balance ENABLE ROW LEVEL SECURITY;
ALTER TABLE public.credit_usage ENABLE ROW LEVEL SECURITY;
CREATE POLICY "Users can view their own credit purchases" ON public.credit_purchases
FOR SELECT USING (auth.uid() = user_id);
CREATE POLICY "Service role can manage all credit purchases" ON public.credit_purchases
FOR ALL USING (auth.role() = 'service_role');
CREATE POLICY "Users can view their own credit balance" ON public.credit_balance
FOR SELECT USING (auth.uid() = user_id);
CREATE POLICY "Service role can manage all credit balances" ON public.credit_balance
FOR ALL USING (auth.role() = 'service_role');
CREATE POLICY "Users can view their own credit usage" ON public.credit_usage
FOR SELECT USING (auth.uid() = user_id);
CREATE POLICY "Service role can manage all credit usage" ON public.credit_usage
FOR ALL USING (auth.role() = 'service_role');
CREATE OR REPLACE FUNCTION public.add_credits(
p_user_id UUID,
p_amount DECIMAL,
p_purchase_id UUID DEFAULT NULL
)
RETURNS DECIMAL
LANGUAGE plpgsql
SECURITY DEFINER
AS $$
DECLARE
new_balance DECIMAL;
BEGIN
INSERT INTO public.credit_balance (user_id, balance_dollars, total_purchased)
VALUES (p_user_id, p_amount, p_amount)
ON CONFLICT (user_id) DO UPDATE
SET
balance_dollars = credit_balance.balance_dollars + p_amount,
total_purchased = credit_balance.total_purchased + p_amount,
last_updated = NOW()
RETURNING balance_dollars INTO new_balance;
RETURN new_balance;
END;
$$;
CREATE OR REPLACE FUNCTION public.use_credits(
p_user_id UUID,
p_amount DECIMAL,
p_description TEXT DEFAULT NULL,
p_thread_id UUID DEFAULT NULL,
p_message_id UUID DEFAULT NULL
)
RETURNS BOOLEAN
LANGUAGE plpgsql
SECURITY DEFINER
AS $$
DECLARE
current_balance DECIMAL;
success BOOLEAN := FALSE;
BEGIN
SELECT balance_dollars INTO current_balance
FROM public.credit_balance
WHERE user_id = p_user_id
FOR UPDATE;
IF current_balance IS NOT NULL AND current_balance >= p_amount THEN
UPDATE public.credit_balance
SET
balance_dollars = balance_dollars - p_amount,
total_used = total_used + p_amount,
last_updated = NOW()
WHERE user_id = p_user_id;
INSERT INTO public.credit_usage (
user_id,
amount_dollars,
description,
thread_id,
message_id,
usage_type
)
VALUES (
p_user_id,
p_amount,
p_description,
p_thread_id,
p_message_id,
'token_overage'
);
success := TRUE;
END IF;
RETURN success;
END;
$$;
CREATE OR REPLACE FUNCTION public.get_credit_balance(p_user_id UUID)
RETURNS DECIMAL
LANGUAGE plpgsql
SECURITY DEFINER
AS $$
DECLARE
balance DECIMAL;
BEGIN
SELECT balance_dollars INTO balance
FROM public.credit_balance
WHERE user_id = p_user_id;
RETURN COALESCE(balance, 0);
END;
$$;
GRANT SELECT ON public.credit_purchases TO authenticated;
GRANT SELECT ON public.credit_balance TO authenticated;
GRANT SELECT ON public.credit_usage TO authenticated;
GRANT ALL ON public.credit_purchases TO service_role;
GRANT ALL ON public.credit_balance TO service_role;
GRANT ALL ON public.credit_usage TO service_role;
GRANT EXECUTE ON FUNCTION public.add_credits TO service_role;
GRANT EXECUTE ON FUNCTION public.use_credits TO service_role;
GRANT EXECUTE ON FUNCTION public.get_credit_balance TO authenticated, service_role;

View File

@ -84,9 +84,11 @@ class TriggerService:
updated_at=now
)
setup_success = await provider_service.setup_trigger(trigger)
if not setup_success:
raise ValueError(f"Failed to setup trigger with provider: {provider_id}")
# Skip setup_trigger for Composio since triggers are already enabled when created
if provider_id != "composio":
setup_success = await provider_service.setup_trigger(trigger)
if not setup_success:
raise ValueError(f"Failed to setup trigger with provider: {provider_id}")
await self._save_trigger(trigger)
@ -283,12 +285,25 @@ class TriggerService:
async def _log_trigger_event(self, event: TriggerEvent, result: TriggerResult) -> None:
client = await self._db.client
# Ensure raw_data is JSON serializable
try:
if isinstance(event.raw_data, bytes):
event_data = event.raw_data.decode('utf-8', errors='replace')
elif isinstance(event.raw_data, str):
event_data = event.raw_data
else:
event_data = str(event.raw_data)
except Exception as e:
logger.warning(f"Failed to serialize raw_data: {e}")
event_data = str(event.raw_data) if event.raw_data else "{}"
await client.table('trigger_event_logs').insert({
'log_id': str(uuid.uuid4()),
'trigger_id': event.trigger_id,
'agent_id': event.agent_id,
'trigger_type': event.trigger_type.value,
'event_data': event.raw_data,
'event_data': event_data,
'success': result.success,
'should_execute_agent': result.should_execute_agent,
'should_execute_workflow': result.should_execute_workflow,

View File

@ -87,9 +87,25 @@ class Configuration:
STRIPE_TIER_6_42_YEARLY_COMMITMENT_ID_STAGING: str = 'price_1RqYH1G6l1KZGqIrWDKh8xIU' # $42.50/month
STRIPE_TIER_25_170_YEARLY_COMMITMENT_ID_STAGING: str = 'price_1RqYHbG6l1KZGqIrAUVf8KpG' # $170/month
# Credit package price IDs - Production
STRIPE_CREDITS_10_PRICE_ID_PROD: str = 'price_1RxmQUG6l1KZGqIru453O1zW'
STRIPE_CREDITS_25_PRICE_ID_PROD: str = 'price_1RxmQlG6l1KZGqIr3hS5WtGg'
STRIPE_CREDITS_50_PRICE_ID_PROD: str = 'price_1RxmQvG6l1KZGqIrLbMZ3D6r'
STRIPE_CREDITS_100_PRICE_ID_PROD: str = 'price_1RxmR3G6l1KZGqIrpLwFCGac'
STRIPE_CREDITS_250_PRICE_ID_PROD: str = 'price_1RxmRAG6l1KZGqIrtBIMsZAj'
STRIPE_CREDITS_500_PRICE_ID_PROD: str = 'price_1RxmRGG6l1KZGqIrSyvl6w1G'
# Credit package price IDs - Staging
STRIPE_CREDITS_10_PRICE_ID_STAGING: str = 'price_1RxXOvG6l1KZGqIrMqsiYQvk'
STRIPE_CREDITS_25_PRICE_ID_STAGING: str = 'price_1RxXPNG6l1KZGqIrQprPgDme'
STRIPE_CREDITS_50_PRICE_ID_STAGING: str = 'price_1RxmNhG6l1KZGqIrTq2zPtgi'
STRIPE_CREDITS_100_PRICE_ID_STAGING: str = 'price_1RxmNwG6l1KZGqIrnliwPDM6'
STRIPE_CREDITS_250_PRICE_ID_STAGING: str = 'price_1RxmO6G6l1KZGqIrBF8Kx87G'
STRIPE_CREDITS_500_PRICE_ID_STAGING: str = 'price_1RxmOFG6l1KZGqIrn4wgORnH'
# Computed subscription tier IDs based on environment
@property
def STRIPE_FREE_TIER_ID(self) -> str:
def STRIPE_FREE_TIER_ID(self) -> str:
if self.ENV_MODE == EnvMode.STAGING:
return self.STRIPE_FREE_TIER_ID_STAGING
return self.STRIPE_FREE_TIER_ID_PROD
@ -198,6 +214,19 @@ class Configuration:
return self.STRIPE_TIER_25_170_YEARLY_COMMITMENT_ID_STAGING
return self.STRIPE_TIER_25_170_YEARLY_COMMITMENT_ID_PROD
# Credit package price ID properties
@property
def STRIPE_CREDITS_10_PRICE_ID(self) -> str:
if self.ENV_MODE == EnvMode.STAGING:
return self.STRIPE_CREDITS_10_PRICE_ID_STAGING
return self.STRIPE_CREDITS_10_PRICE_ID_PROD
@property
def STRIPE_CREDITS_25_PRICE_ID(self) -> str:
if self.ENV_MODE == EnvMode.STAGING:
return self.STRIPE_CREDITS_25_PRICE_ID_STAGING
return self.STRIPE_CREDITS_25_PRICE_ID_PROD
# LLM API keys
ANTHROPIC_API_KEY: Optional[str] = None
OPENAI_API_KEY: Optional[str] = None

View File

@ -2,6 +2,10 @@
import { useMemo, useState } from 'react';
import { BillingModal } from '@/components/billing/billing-modal';
import {
CreditBalanceDisplay,
CreditPurchaseModal
} from '@/components/billing/credit-purchase';
import { useAccounts } from '@/hooks/use-accounts';
import { Skeleton } from '@/components/ui/skeleton';
import { Alert, AlertTitle, AlertDescription } from '@/components/ui/alert';
@ -15,6 +19,7 @@ const returnUrl = process.env.NEXT_PUBLIC_URL as string;
export default function PersonalAccountBillingPage() {
const { data: accounts, isLoading, error } = useAccounts();
const [showBillingModal, setShowBillingModal] = useState(false);
const [showCreditPurchaseModal, setShowCreditPurchaseModal] = useState(false);
const {
data: subscriptionData,
@ -119,6 +124,17 @@ export default function PersonalAccountBillingPage() {
</div>
)}
{/* Credit Balance Display - Only show for users who can purchase credits */}
{subscriptionData?.can_purchase_credits && (
<div className="mb-6">
<CreditBalanceDisplay
balance={subscriptionData.credit_balance || 0}
canPurchase={subscriptionData.can_purchase_credits}
onPurchaseClick={() => setShowCreditPurchaseModal(true)}
/>
</div>
)}
<div className='flex justify-center items-center gap-4'>
<Button
variant="outline"
@ -139,6 +155,18 @@ export default function PersonalAccountBillingPage() {
</>
)}
</div>
{/* Credit Purchase Modal */}
<CreditPurchaseModal
open={showCreditPurchaseModal}
onOpenChange={setShowCreditPurchaseModal}
currentBalance={subscriptionData?.credit_balance || 0}
canPurchase={subscriptionData?.can_purchase_credits || false}
onPurchaseComplete={() => {
// Optionally refresh subscription data here
window.location.reload();
}}
/>
</div>
);
}

View File

@ -9,6 +9,7 @@ import {
} from '@/components/ui/dialog';
import { Button } from '@/components/ui/button';
import { PricingSection } from '@/components/home/sections/pricing-section';
import { CreditBalanceDisplay, CreditPurchaseModal } from '@/components/billing/credit-purchase';
import { isLocalMode } from '@/lib/config';
import {
getSubscription,
@ -32,6 +33,7 @@ export function BillingModal({ open, onOpenChange, returnUrl = typeof window !==
const [isLoading, setIsLoading] = useState(true);
const [error, setError] = useState<string | null>(null);
const [isManaging, setIsManaging] = useState(false);
const [showCreditPurchaseModal, setShowCreditPurchaseModal] = useState(false);
useEffect(() => {
async function fetchSubscription() {
@ -141,6 +143,17 @@ export function BillingModal({ open, onOpenChange, returnUrl = typeof window !==
</div>
)}
{/* Credit Balance Display - Only show for users who can purchase credits */}
{subscriptionData?.can_purchase_credits && (
<div className="mb-6">
<CreditBalanceDisplay
balance={subscriptionData.credit_balance || 0}
canPurchase={subscriptionData.can_purchase_credits}
onPurchaseClick={() => setShowCreditPurchaseModal(true)}
/>
</div>
)}
<PricingSection returnUrl={returnUrl} showTitleAndTabs={false} />
{subscriptionData && (
@ -155,6 +168,19 @@ export function BillingModal({ open, onOpenChange, returnUrl = typeof window !==
</>
)}
</DialogContent>
{/* Credit Purchase Modal */}
<CreditPurchaseModal
open={showCreditPurchaseModal}
onOpenChange={setShowCreditPurchaseModal}
currentBalance={subscriptionData?.credit_balance || 0}
canPurchase={subscriptionData?.can_purchase_credits || false}
onPurchaseComplete={() => {
// Refresh subscription data
getSubscription().then(setSubscriptionData);
setShowCreditPurchaseModal(false);
}}
/>
</Dialog>
);
}

View File

@ -0,0 +1,257 @@
'use client';
import { useState } from 'react';
import {
Dialog,
DialogContent,
DialogDescription,
DialogHeader,
DialogTitle,
} from '@/components/ui/dialog';
import { Button } from '@/components/ui/button';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { Label } from '@/components/ui/label';
import { Alert, AlertDescription } from '@/components/ui/alert';
import { Badge } from '@/components/ui/badge';
import { Loader2, CreditCard, AlertCircle, Zap, AlertCircleIcon } from 'lucide-react';
import { apiClient, backendApi } from '@/lib/api-client';
import { toast } from 'sonner';
interface CreditPurchaseProps {
open: boolean;
onOpenChange: (open: boolean) => void;
currentBalance?: number;
canPurchase: boolean;
onPurchaseComplete?: () => void;
}
interface CreditPackage {
amount: number;
price: number;
popular?: boolean;
}
const CREDIT_PACKAGES: CreditPackage[] = [
{ amount: 10, price: 10 },
{ amount: 25, price: 25 },
{ amount: 50, price: 50 },
{ amount: 100, price: 100, popular: true },
{ amount: 250, price: 250 },
{ amount: 500, price: 500 },
];
export function CreditPurchaseModal({
open,
onOpenChange,
currentBalance = 0,
canPurchase,
onPurchaseComplete
}: CreditPurchaseProps) {
const [selectedPackage, setSelectedPackage] = useState<CreditPackage | null>(null);
const [customAmount, setCustomAmount] = useState<string>('');
const [isProcessing, setIsProcessing] = useState(false);
const [error, setError] = useState<string | null>(null);
const handlePurchase = async (amount: number) => {
if (amount < 10) {
setError('Minimum purchase amount is $10');
return;
}
if (amount > 5000) {
setError('Maximum purchase amount is $5000');
return;
}
setIsProcessing(true);
setError(null);
try {
const response = await backendApi.post('/billing/purchase-credits', {
amount_dollars: amount,
success_url: `${window.location.origin}/dashboard?credit_purchase=success`,
cancel_url: `${window.location.origin}/dashboard?credit_purchase=cancelled`
});
if (response.data.url) {
window.location.href = response.data.url;
} else {
throw new Error('No checkout URL received');
}
} catch (err: any) {
console.error('Credit purchase error:', err);
const errorMessage = err.response?.data?.detail || err.message || 'Failed to create checkout session';
setError(errorMessage);
toast.error(errorMessage);
} finally {
setIsProcessing(false);
}
};
const handlePackageSelect = (pkg: CreditPackage) => {
setSelectedPackage(pkg);
setCustomAmount('');
setError(null);
};
const handleCustomAmountChange = (value: string) => {
setCustomAmount(value);
setSelectedPackage(null);
setError(null);
};
const handleConfirmPurchase = () => {
const amount = selectedPackage ? selectedPackage.amount : parseFloat(customAmount);
if (!isNaN(amount)) {
handlePurchase(amount);
} else {
setError('Please select a package or enter a valid amount');
}
};
if (!canPurchase) {
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent className="sm:max-w-md">
<DialogHeader>
<DialogTitle>Credits Not Available</DialogTitle>
<DialogDescription>
Credit purchases are only available for users on the highest subscription tier ($1000/month).
</DialogDescription>
</DialogHeader>
<Alert>
<AlertCircle className="h-4 w-4" />
<AlertDescription>
Please upgrade your subscription to the highest tier to unlock credit purchases for unlimited usage.
</AlertDescription>
</Alert>
<div className="flex justify-end">
<Button variant="outline" onClick={() => onOpenChange(false)}>
Close
</Button>
</div>
</DialogContent>
</Dialog>
);
}
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent className="sm:max-w-2xl">
<DialogHeader>
<DialogTitle className="flex items-center gap-2">
<Zap className="h-5 w-5 text-amber-500 dark:text-amber-400" />
Purchase Credits
</DialogTitle>
<DialogDescription>
Add credits to your account for usage beyond your subscription limit.
</DialogDescription>
</DialogHeader>
{currentBalance > 0 && (
<Alert className="text-purple-600 dark:text-purple-400 bg-purple-50 border-purple-200 dark:bg-purple-950 dark:border-purple-800">
<AlertCircleIcon className="h-4 w-4" />
<AlertDescription className="text-purple-600 dark:text-purple-400">
Current balance: <strong>${currentBalance.toFixed(2)}</strong>
</AlertDescription>
</Alert>
)}
<div className="space-y-4">
<div>
<Label className="text-base font-semibold mb-3 block">Select a Package</Label>
<div className="grid grid-cols-2 sm:grid-cols-3 gap-3">
{CREDIT_PACKAGES.map((pkg) => (
<Card
key={pkg.amount}
className={`cursor-pointer transition-all ${
selectedPackage?.amount === pkg.amount
? 'ring-2 ring-primary'
: 'hover:shadow-md'
}`}
onClick={() => handlePackageSelect(pkg)}
>
<CardContent className="p-4 text-center relative">
{pkg.popular && (
<Badge className="absolute -top-2 -right-2" variant="default">
Popular
</Badge>
)}
<div className="text-2xl font-bold">${pkg.amount}</div>
<div className="text-sm text-muted-foreground">credits</div>
</CardContent>
</Card>
))}
</div>
</div>
{error && (
<Alert variant="destructive">
<AlertCircle className="h-4 w-4" />
<AlertDescription>{error}</AlertDescription>
</Alert>
)}
</div>
<div className="flex justify-end gap-3 mt-6">
<Button
variant="outline"
onClick={() => onOpenChange(false)}
disabled={isProcessing}
>
Cancel
</Button>
<Button
onClick={handleConfirmPurchase}
disabled={isProcessing || (!selectedPackage && !customAmount)}
>
{isProcessing ? (
<>
<Loader2 className="h-4 w-4 animate-spin" />
Processing...
</>
) : (
<>
<CreditCard className="h-4 w-4" />
Purchase Credits
</>
)}
</Button>
</div>
</DialogContent>
</Dialog>
);
}
export function CreditBalanceDisplay({ balance, canPurchase, onPurchaseClick }: {
balance: number;
canPurchase: boolean;
onPurchaseClick?: () => void;
}) {
return (
<Card>
<CardHeader className="pb-3">
<CardTitle className="text-sm font-medium flex items-center justify-between">
<span className="flex items-center gap-2">
<Zap className="h-4 w-4 text-yellow-500" />
Credit Balance
</span>
{canPurchase && onPurchaseClick && (
<Button
size="sm"
variant="outline"
onClick={onPurchaseClick}
>
Add Credits
</Button>
)}
</CardTitle>
</CardHeader>
<CardContent>
<div className="text-2xl font-bold">
${balance.toFixed(2)}
</div>
<p className="text-xs text-muted-foreground mt-1">
{canPurchase
? 'Available for usage beyond subscription limits'
: 'Upgrade to highest tier to purchase credits'
}
</p>
</CardContent>
</Card>
);
}

View File

@ -25,12 +25,12 @@ import {
import { Badge } from '@/components/ui/badge';
import { Skeleton } from '@/components/ui/skeleton';
import { Button } from '@/components/ui/button';
import { ExternalLink, Loader2 } from 'lucide-react';
import { ExternalLink, Loader2, AlertCircle } from 'lucide-react';
import Link from 'next/link';
import { OpenInNewWindowIcon } from '@radix-ui/react-icons';
import { useUsageLogs } from '@/hooks/react-query/subscriptions/use-billing';
import { UsageLogEntry } from '@/lib/api';
import { Alert, AlertDescription, AlertTitle } from '@/components/ui/alert';
interface DailyUsage {
@ -86,6 +86,11 @@ export default function UsageLogs({ accountId }: Props) {
return `$${cost.toFixed(4)}`;
};
const formatCreditAmount = (amount: number) => {
if (amount === 0) return null;
return `$${amount.toFixed(4)}`;
};
const formatDateOnly = (dateString: string) => {
return new Date(dateString).toLocaleDateString('en-US', {
weekday: 'long',
@ -200,8 +205,23 @@ export default function UsageLogs({ accountId }: Props) {
0,
);
// Get subscription limit from the first page data
const subscriptionLimit = currentPageData?.subscription_limit || 0;
return (
<div className="space-y-6">
{/* Show credit usage info if user has gone over limit */}
{subscriptionLimit > 0 && totalUsage > subscriptionLimit && (
<Alert>
<AlertCircle className="h-4 w-4" />
<AlertTitle>Credits Being Used</AlertTitle>
<AlertDescription>
You've exceeded your monthly subscription limit of ${subscriptionLimit.toFixed(2)}.
Additional usage is being deducted from your credit balance.
</AlertDescription>
</Alert>
)}
{/* Usage Logs Accordion */}
<Card>
<CardHeader>
@ -253,52 +273,69 @@ export default function UsageLogs({ accountId }: Props) {
<div className="rounded-md border mt-4">
<Table>
<TableHeader>
<TableRow>
<TableHead>Time</TableHead>
<TableHead>Model</TableHead>
<TableHead className="text-right">
Tokens
</TableHead>
<TableHead className="text-right">Cost</TableHead>
<TableHead className="text-center">
Thread
</TableHead>
<TableRow className="hover:bg-transparent">
<TableHead className="w-[180px] text-xs">Time</TableHead>
<TableHead className="text-xs">Model</TableHead>
<TableHead className="text-xs text-right">Prompt</TableHead>
<TableHead className="text-xs text-right">Completion</TableHead>
<TableHead className="text-xs text-right">Total</TableHead>
<TableHead className="text-xs text-right">Cost</TableHead>
<TableHead className="text-xs text-right">Payment</TableHead>
<TableHead className="w-[100px] text-xs text-center">Thread</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{day.logs.map((log) => (
<TableRow key={log.message_id}>
<TableCell className="font-mono text-sm">
{new Date(
log.created_at,
).toLocaleTimeString()}
{day.logs.map((log, index) => (
<TableRow
key={`${log.message_id}_${index}`}
className="hover:bg-muted/50 group"
>
<TableCell className="font-mono text-xs text-muted-foreground">
{new Date(log.created_at).toLocaleTimeString()}
</TableCell>
<TableCell>
<Badge className="font-mono text-xs">
{log.content.model}
<TableCell className="text-xs">
<Badge variant="secondary" className="font-mono text-xs">
{log.content.model.replace('openrouter/', '').replace('anthropic/', '')}
</Badge>
</TableCell>
<TableCell className="text-right font-mono font-medium text-sm">
{log.content.usage.prompt_tokens.toLocaleString()}{' '}
-&gt;{' '}
<TableCell className="text-right font-mono text-xs">
{log.content.usage.prompt_tokens.toLocaleString()}
</TableCell>
<TableCell className="text-right font-mono text-xs">
{log.content.usage.completion_tokens.toLocaleString()}
</TableCell>
<TableCell className="text-right font-mono font-medium text-sm">
<TableCell className="text-right font-mono text-xs">
{log.total_tokens.toLocaleString()}
</TableCell>
<TableCell className="text-right font-mono text-xs">
{formatCost(log.estimated_cost)}
</TableCell>
<TableCell className="text-right text-xs">
{log.payment_method === 'credits' ? (
<div className="flex items-center justify-end gap-2">
<Badge variant="outline" className="text-xs">
Credits
</Badge>
{log.credit_used && log.credit_used > 0 && (
<span className="text-xs text-muted-foreground">
-{formatCreditAmount(log.credit_used)}
</span>
)}
</div>
) : (
<Badge variant="secondary" className="text-xs">
Subscription
</Badge>
)}
</TableCell>
<TableCell className="text-center">
<Button
variant="ghost"
size="sm"
onClick={() =>
handleThreadClick(
log.thread_id,
log.project_id,
)
}
className="h-8 w-8 p-0"
onClick={() => handleThreadClick(log.thread_id, log.project_id)}
className="h-6 px-2 text-xs opacity-0 group-hover:opacity-100 transition-opacity"
>
<ExternalLink className="h-4 w-4" />
<ExternalLink className="h-3 w-3" />
</Button>
</TableCell>
</TableRow>

View File

@ -1601,6 +1601,9 @@ export interface SubscriptionStatus {
cancel_at_period_end: boolean;
current_period_end: number; // timestamp
};
// Credit information
credit_balance?: number;
can_purchase_credits?: boolean;
}
export interface CommitmentInfo {
@ -1668,8 +1671,12 @@ export interface UsageLogEntry {
model: string;
};
total_tokens: number;
estimated_cost: number;
estimated_cost: number | string;
project_id: string;
// Credit usage fields
credit_used?: number;
payment_method?: 'credits' | 'subscription';
was_over_limit?: boolean;
}
// Usage logs response interface
@ -1677,6 +1684,8 @@ export interface UsageLogsResponse {
logs: UsageLogEntry[];
has_more: boolean;
message?: string;
subscription_limit?: number;
cumulative_cost?: number;
}
export interface BillingStatusResponse {