mirror of https://github.com/kortix-ai/suna.git
Compare commits
13 Commits
b784007ab7
...
6623e87ea9
Author | SHA1 | Date |
---|---|---|
|
6623e87ea9 | |
|
b3666e8aad | |
|
a14c2a1a2c | |
|
5d28b65111 | |
|
7628ced002 | |
|
b7b7eeb705 | |
|
a7d38c0096 | |
|
23574e37cf | |
|
09c4099ca5 | |
|
28da425ce8 | |
|
865b2f3633 | |
|
b3f1398c3d | |
|
d6706ead43 |
|
@ -162,7 +162,6 @@ cython_debug/
|
|||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# AgentPress
|
||||
/threads
|
||||
state.json
|
||||
/workspace/
|
||||
|
|
|
@ -163,7 +163,6 @@ cython_debug/
|
|||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# AgentPress
|
||||
/threads
|
||||
state.json
|
||||
/workspace/
|
||||
|
|
|
@ -15,9 +15,9 @@ from agentpress.thread_manager import ThreadManager
|
|||
from services.supabase import DBConnection
|
||||
from services import redis
|
||||
from agent.run import run_agent
|
||||
from utils.auth_utils import get_current_user_id, get_user_id_from_stream_auth, verify_thread_access
|
||||
from utils.auth_utils import get_current_user_id_from_jwt, get_user_id_from_stream_auth, verify_thread_access
|
||||
from utils.logger import logger
|
||||
from utils.billing import check_billing_status, get_account_id_from_thread
|
||||
from services.billing import check_billing_status
|
||||
from sandbox.sandbox import create_sandbox, get_or_start_sandbox
|
||||
from services.llm import make_llm_api_call
|
||||
|
||||
|
@ -348,7 +348,7 @@ async def get_or_create_project_sandbox(client, project_id: str):
|
|||
async def start_agent(
|
||||
thread_id: str,
|
||||
body: AgentStartRequest = Body(...),
|
||||
user_id: str = Depends(get_current_user_id)
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
):
|
||||
"""Start an agent for a specific thread in the background."""
|
||||
global instance_id # Ensure instance_id is accessible
|
||||
|
@ -412,7 +412,7 @@ async def start_agent(
|
|||
return {"agent_run_id": agent_run_id, "status": "running"}
|
||||
|
||||
@router.post("/agent-run/{agent_run_id}/stop")
|
||||
async def stop_agent(agent_run_id: str, user_id: str = Depends(get_current_user_id)):
|
||||
async def stop_agent(agent_run_id: str, user_id: str = Depends(get_current_user_id_from_jwt)):
|
||||
"""Stop a running agent."""
|
||||
logger.info(f"Received request to stop agent run: {agent_run_id}")
|
||||
client = await db.client
|
||||
|
@ -421,7 +421,7 @@ async def stop_agent(agent_run_id: str, user_id: str = Depends(get_current_user_
|
|||
return {"status": "stopped"}
|
||||
|
||||
@router.get("/thread/{thread_id}/agent-runs")
|
||||
async def get_agent_runs(thread_id: str, user_id: str = Depends(get_current_user_id)):
|
||||
async def get_agent_runs(thread_id: str, user_id: str = Depends(get_current_user_id_from_jwt)):
|
||||
"""Get all agent runs for a thread."""
|
||||
logger.info(f"Fetching agent runs for thread: {thread_id}")
|
||||
client = await db.client
|
||||
|
@ -431,7 +431,7 @@ async def get_agent_runs(thread_id: str, user_id: str = Depends(get_current_user
|
|||
return {"agent_runs": agent_runs.data}
|
||||
|
||||
@router.get("/agent-run/{agent_run_id}")
|
||||
async def get_agent_run(agent_run_id: str, user_id: str = Depends(get_current_user_id)):
|
||||
async def get_agent_run(agent_run_id: str, user_id: str = Depends(get_current_user_id_from_jwt)):
|
||||
"""Get agent run status and responses."""
|
||||
logger.info(f"Fetching agent run details: {agent_run_id}")
|
||||
client = await db.client
|
||||
|
@ -859,7 +859,7 @@ async def initiate_agent_with_files(
|
|||
stream: Optional[bool] = Form(True),
|
||||
enable_context_manager: Optional[bool] = Form(False),
|
||||
files: List[UploadFile] = File(default=[]),
|
||||
user_id: str = Depends(get_current_user_id)
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
):
|
||||
"""Initiate a new agent session with optional file attachments."""
|
||||
global instance_id # Ensure instance_id is accessible
|
||||
|
|
|
@ -20,7 +20,8 @@ from agent.tools.sb_browser_tool import SandboxBrowserTool
|
|||
from agent.tools.data_providers_tool import DataProvidersTool
|
||||
from agent.prompt import get_system_prompt
|
||||
from utils import logger
|
||||
from utils.billing import check_billing_status, get_account_id_from_thread
|
||||
from utils.auth_utils import get_account_id_from_thread
|
||||
from services.billing import check_billing_status
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ from collections import OrderedDict
|
|||
# Import the agent API module
|
||||
from agent import api as agent_api
|
||||
from sandbox import api as sandbox_api
|
||||
from services import billing as billing_api
|
||||
|
||||
# Load environment variables (these will be available through config)
|
||||
load_dotenv()
|
||||
|
@ -132,6 +133,9 @@ app.include_router(agent_api.router, prefix="/api")
|
|||
# Include the sandbox router with a prefix
|
||||
app.include_router(sandbox_api.router, prefix="/api")
|
||||
|
||||
# Include the billing router with a prefix
|
||||
app.include_router(billing_api.router, prefix="/api")
|
||||
|
||||
@app.get("/api/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint to verify API is working."""
|
||||
|
@ -152,5 +156,6 @@ if __name__ == "__main__":
|
|||
"api:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
workers=workers
|
||||
workers=workers,
|
||||
reload=True
|
||||
)
|
|
@ -2873,6 +2873,21 @@ docs = ["myst-parser[linkify]", "sphinx", "sphinx-rtd-theme"]
|
|||
release = ["twine"]
|
||||
test = ["pylint", "pytest", "pytest-black", "pytest-cov", "pytest-pylint"]
|
||||
|
||||
[[package]]
|
||||
name = "stripe"
|
||||
version = "12.0.1"
|
||||
description = "Python bindings for the Stripe API"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "stripe-12.0.1-py2.py3-none-any.whl", hash = "sha256:b10b19dbd0622868b98a7c6e879ebde704be96ad75c780944bca4069bb427988"},
|
||||
{file = "stripe-12.0.1.tar.gz", hash = "sha256:3fc7cc190946d8ebcc5b637e7e04f387d61b9c5156a89619a3ba90704ac09d4a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
requests = {version = ">=2.20", markers = "python_version >= \"3.0\""}
|
||||
typing_extensions = {version = ">=4.5.0", markers = "python_version >= \"3.7\""}
|
||||
|
||||
[[package]]
|
||||
name = "supabase"
|
||||
version = "2.15.0"
|
||||
|
@ -3622,4 +3637,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "622a06feff14fc27c612f15e50be3375531175462c46fa57c3bcf33851e2a9c3"
|
||||
content-hash = "160a62f76af02d841f0f1b60e2d96c5c8e91310182d728dfa82bb792fa098e95"
|
||||
|
|
|
@ -20,7 +20,7 @@ classifiers = [
|
|||
python = "^3.11"
|
||||
streamlit-quill = "0.0.3"
|
||||
python-dotenv = "1.0.1"
|
||||
litellm = "^1.44.0"
|
||||
litellm = "1.66.1"
|
||||
click = "8.1.7"
|
||||
questionary = "2.0.1"
|
||||
requests = "^2.31.0"
|
||||
|
@ -50,6 +50,7 @@ nest-asyncio = "^1.6.0"
|
|||
vncdotool = "^1.2.0"
|
||||
tavily-python = "^0.5.4"
|
||||
pytesseract = "^0.3.13"
|
||||
stripe = "^12.0.1"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
agentpress = "agentpress.cli:main"
|
||||
|
@ -57,7 +58,6 @@ agentpress = "agentpress.cli:main"
|
|||
[[tool.poetry.packages]]
|
||||
include = "agentpress"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
daytona-sdk = "^0.14.0"
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
streamlit-quill==0.0.3
|
||||
python-dotenv==1.0.1
|
||||
litellm>=1.66.2
|
||||
litellm==1.66.2
|
||||
click==8.1.7
|
||||
questionary==2.0.1
|
||||
requests>=2.31.0
|
||||
|
@ -30,4 +30,5 @@ nest-asyncio>=1.6.0
|
|||
vncdotool>=1.2.0
|
||||
pydantic
|
||||
tavily-python>=0.5.4
|
||||
pytesseract==0.3.13
|
||||
pytesseract==0.3.13
|
||||
stripe>=7.0.0
|
|
@ -6,7 +6,7 @@ from fastapi.responses import Response, JSONResponse
|
|||
from pydantic import BaseModel
|
||||
|
||||
from utils.logger import logger
|
||||
from utils.auth_utils import get_current_user_id, get_user_id_from_stream_auth, get_optional_user_id
|
||||
from utils.auth_utils import get_current_user_id_from_jwt, get_user_id_from_stream_auth, get_optional_user_id
|
||||
from sandbox.sandbox import get_or_start_sandbox
|
||||
from services.supabase import DBConnection
|
||||
from agent.api import get_or_create_project_sandbox
|
||||
|
|
|
@ -0,0 +1,748 @@
|
|||
"""
|
||||
Stripe Billing API implementation for Suna on top of Basejump. ONLY HAS SUPPOT FOR USER ACCOUNTS – no team accounts. As we are using the user_id as account_id as is the case with personal accounts. In personal accounts, the account_id equals the user_id. In team accounts, the account_id is unique.
|
||||
|
||||
stripe listen --forward-to localhost:8000/api/billing/webhook
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request
|
||||
from typing import Optional, Dict, Any, List, Tuple
|
||||
import stripe
|
||||
from datetime import datetime, timezone
|
||||
from utils.logger import logger
|
||||
from utils.config import config, EnvMode
|
||||
from services.supabase import DBConnection
|
||||
from utils.auth_utils import get_current_user_id_from_jwt
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Initialize Stripe
|
||||
stripe.api_key = config.STRIPE_SECRET_KEY
|
||||
|
||||
# Initialize router
|
||||
router = APIRouter(prefix="/billing", tags=["billing"])
|
||||
|
||||
SUBSCRIPTION_TIERS = {
|
||||
config.STRIPE_FREE_TIER_ID: {'name': 'free', 'minutes': 10},
|
||||
config.STRIPE_TIER_2_20_ID: {'name': 'tier_2_20', 'minutes': 120}, # 2 hours
|
||||
config.STRIPE_TIER_6_50_ID: {'name': 'tier_6_50', 'minutes': 360}, # 6 hours
|
||||
config.STRIPE_TIER_12_100_ID: {'name': 'tier_12_100', 'minutes': 720}, # 12 hours
|
||||
config.STRIPE_TIER_25_200_ID: {'name': 'tier_25_200', 'minutes': 1500}, # 25 hours
|
||||
config.STRIPE_TIER_50_400_ID: {'name': 'tier_50_400', 'minutes': 3000}, # 50 hours
|
||||
config.STRIPE_TIER_125_800_ID: {'name': 'tier_125_800', 'minutes': 7500}, # 125 hours
|
||||
config.STRIPE_TIER_200_1000_ID: {'name': 'tier_200_1000', 'minutes': 12000}, # 200 hours
|
||||
}
|
||||
|
||||
# Pydantic models for request/response validation
|
||||
class CreateCheckoutSessionRequest(BaseModel):
|
||||
price_id: str
|
||||
success_url: str
|
||||
cancel_url: str
|
||||
|
||||
class CreatePortalSessionRequest(BaseModel):
|
||||
return_url: str
|
||||
|
||||
class SubscriptionStatus(BaseModel):
|
||||
status: str # e.g., 'active', 'trialing', 'past_due', 'scheduled_downgrade', 'no_subscription'
|
||||
plan_name: Optional[str] = None
|
||||
price_id: Optional[str] = None # Added price ID
|
||||
current_period_end: Optional[datetime] = None
|
||||
cancel_at_period_end: bool = False
|
||||
trial_end: Optional[datetime] = None
|
||||
minutes_limit: Optional[int] = None
|
||||
current_usage: Optional[float] = None
|
||||
# Fields for scheduled changes
|
||||
has_schedule: bool = False
|
||||
scheduled_plan_name: Optional[str] = None
|
||||
scheduled_price_id: Optional[str] = None # Added scheduled price ID
|
||||
scheduled_change_date: Optional[datetime] = None
|
||||
|
||||
# Helper functions
|
||||
async def get_stripe_customer_id(client, user_id: str) -> Optional[str]:
|
||||
"""Get the Stripe customer ID for a user."""
|
||||
result = await client.schema('basejump').from_('billing_customers') \
|
||||
.select('id') \
|
||||
.eq('account_id', user_id) \
|
||||
.execute()
|
||||
|
||||
if result.data and len(result.data) > 0:
|
||||
return result.data[0]['id']
|
||||
return None
|
||||
|
||||
async def create_stripe_customer(client, user_id: str, email: str) -> str:
|
||||
"""Create a new Stripe customer for a user."""
|
||||
# Create customer in Stripe
|
||||
customer = stripe.Customer.create(
|
||||
email=email,
|
||||
metadata={"user_id": user_id}
|
||||
)
|
||||
|
||||
# Store customer ID in Supabase
|
||||
await client.schema('basejump').from_('billing_customers').insert({
|
||||
'id': customer.id,
|
||||
'account_id': user_id,
|
||||
'email': email,
|
||||
'provider': 'stripe'
|
||||
}).execute()
|
||||
|
||||
return customer.id
|
||||
|
||||
async def get_user_subscription(user_id: str) -> Optional[Dict]:
|
||||
"""Get the current subscription for a user from Stripe."""
|
||||
try:
|
||||
# Get customer ID
|
||||
db = DBConnection()
|
||||
client = await db.client
|
||||
customer_id = await get_stripe_customer_id(client, user_id)
|
||||
|
||||
if not customer_id:
|
||||
return None
|
||||
|
||||
# Get all active subscriptions for the customer
|
||||
subscriptions = stripe.Subscription.list(
|
||||
customer=customer_id,
|
||||
status='active'
|
||||
)
|
||||
# print("Found subscriptions:", subscriptions)
|
||||
|
||||
# Check if we have any subscriptions
|
||||
if not subscriptions or not subscriptions.get('data'):
|
||||
return None
|
||||
|
||||
# Filter subscriptions to only include our product's subscriptions
|
||||
our_subscriptions = []
|
||||
for sub in subscriptions['data']:
|
||||
# Get the first subscription item
|
||||
if sub.get('items') and sub['items'].get('data') and len(sub['items']['data']) > 0:
|
||||
item = sub['items']['data'][0]
|
||||
if item.get('price') and item['price'].get('id') in [
|
||||
config.STRIPE_FREE_TIER_ID,
|
||||
config.STRIPE_TIER_2_20_ID,
|
||||
config.STRIPE_TIER_6_50_ID,
|
||||
config.STRIPE_TIER_12_100_ID,
|
||||
config.STRIPE_TIER_25_200_ID,
|
||||
config.STRIPE_TIER_50_400_ID,
|
||||
config.STRIPE_TIER_125_800_ID,
|
||||
config.STRIPE_TIER_200_1000_ID
|
||||
]:
|
||||
our_subscriptions.append(sub)
|
||||
|
||||
if not our_subscriptions:
|
||||
return None
|
||||
|
||||
# If there are multiple active subscriptions, we need to handle this
|
||||
if len(our_subscriptions) > 1:
|
||||
logger.warning(f"User {user_id} has multiple active subscriptions: {[sub['id'] for sub in our_subscriptions]}")
|
||||
|
||||
# Get the most recent subscription
|
||||
most_recent = max(our_subscriptions, key=lambda x: x['created'])
|
||||
|
||||
# Cancel all other subscriptions
|
||||
for sub in our_subscriptions:
|
||||
if sub['id'] != most_recent['id']:
|
||||
try:
|
||||
stripe.Subscription.modify(
|
||||
sub['id'],
|
||||
cancel_at_period_end=True
|
||||
)
|
||||
logger.info(f"Cancelled subscription {sub['id']} for user {user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cancelling subscription {sub['id']}: {str(e)}")
|
||||
|
||||
return most_recent
|
||||
|
||||
return our_subscriptions[0]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting subscription from Stripe: {str(e)}")
|
||||
return None
|
||||
|
||||
async def calculate_monthly_usage(client, user_id: str) -> float:
|
||||
"""Calculate total agent run minutes for the current month for a user."""
|
||||
# Get start of current month in UTC
|
||||
now = datetime.now(timezone.utc)
|
||||
start_of_month = datetime(now.year, now.month, 1, tzinfo=timezone.utc)
|
||||
|
||||
# First get all threads for this user
|
||||
threads_result = await client.table('threads') \
|
||||
.select('thread_id') \
|
||||
.eq('account_id', user_id) \
|
||||
.execute()
|
||||
|
||||
if not threads_result.data:
|
||||
return 0.0
|
||||
|
||||
thread_ids = [t['thread_id'] for t in threads_result.data]
|
||||
|
||||
# Then get all agent runs for these threads in current month
|
||||
runs_result = await client.table('agent_runs') \
|
||||
.select('started_at, completed_at') \
|
||||
.in_('thread_id', thread_ids) \
|
||||
.gte('started_at', start_of_month.isoformat()) \
|
||||
.execute()
|
||||
|
||||
if not runs_result.data:
|
||||
return 0.0
|
||||
|
||||
# Calculate total minutes
|
||||
total_seconds = 0
|
||||
now_ts = now.timestamp()
|
||||
|
||||
for run in runs_result.data:
|
||||
start_time = datetime.fromisoformat(run['started_at'].replace('Z', '+00:00')).timestamp()
|
||||
if run['completed_at']:
|
||||
end_time = datetime.fromisoformat(run['completed_at'].replace('Z', '+00:00')).timestamp()
|
||||
else:
|
||||
# For running jobs, use current time
|
||||
end_time = now_ts
|
||||
|
||||
total_seconds += (end_time - start_time)
|
||||
|
||||
return total_seconds / 60 # Convert to minutes
|
||||
|
||||
async def check_billing_status(client, user_id: str) -> Tuple[bool, str, Optional[Dict]]:
|
||||
"""
|
||||
Check if a user can run agents based on their subscription and usage.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, Optional[Dict]]: (can_run, message, subscription_info)
|
||||
"""
|
||||
if config.ENV_MODE == EnvMode.LOCAL:
|
||||
logger.info("Running in local development mode - billing checks are disabled")
|
||||
return True, "Local development mode - billing disabled", {
|
||||
"price_id": "local_dev",
|
||||
"plan_name": "Local Development",
|
||||
"minutes_limit": "no limit"
|
||||
}
|
||||
|
||||
# Get current subscription
|
||||
subscription = await get_user_subscription(user_id)
|
||||
# print("Current subscription:", subscription)
|
||||
|
||||
# If no subscription, they can use free tier
|
||||
if not subscription:
|
||||
subscription = {
|
||||
'price_id': config.STRIPE_FREE_TIER_ID, # Free tier
|
||||
'plan_name': 'free'
|
||||
}
|
||||
|
||||
# Get tier info - default to free tier if not found
|
||||
tier_info = SUBSCRIPTION_TIERS.get(subscription.get('price_id', config.STRIPE_FREE_TIER_ID))
|
||||
if not tier_info:
|
||||
logger.warning(f"Unknown subscription tier: {subscription.get('price_id')}, defaulting to free tier")
|
||||
tier_info = SUBSCRIPTION_TIERS[config.STRIPE_FREE_TIER_ID]
|
||||
|
||||
# Calculate current month's usage
|
||||
current_usage = await calculate_monthly_usage(client, user_id)
|
||||
|
||||
# Check if within limits
|
||||
if current_usage >= tier_info['minutes']:
|
||||
return False, f"Monthly limit of {tier_info['minutes']} minutes reached. Please upgrade your plan or wait until next month.", subscription
|
||||
|
||||
return True, "OK", subscription
|
||||
|
||||
# API endpoints
|
||||
@router.post("/create-checkout-session")
|
||||
async def create_checkout_session(
|
||||
request: CreateCheckoutSessionRequest,
|
||||
current_user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
):
|
||||
"""Create a Stripe Checkout session or modify an existing subscription."""
|
||||
try:
|
||||
# Get Supabase client
|
||||
db = DBConnection()
|
||||
client = await db.client
|
||||
|
||||
# Get user email from auth.users
|
||||
user_result = await client.auth.admin.get_user_by_id(current_user_id)
|
||||
if not user_result: raise HTTPException(status_code=404, detail="User not found")
|
||||
email = user_result.user.email
|
||||
|
||||
# Get or create Stripe customer
|
||||
customer_id = await get_stripe_customer_id(client, current_user_id)
|
||||
if not customer_id: customer_id = await create_stripe_customer(client, current_user_id, email)
|
||||
|
||||
# Get the target price and product ID
|
||||
try:
|
||||
price = stripe.Price.retrieve(request.price_id, expand=['product'])
|
||||
product_id = price['product']['id']
|
||||
except stripe.error.InvalidRequestError:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid price ID: {request.price_id}")
|
||||
|
||||
# Verify the price belongs to our product
|
||||
if product_id != config.STRIPE_PRODUCT_ID:
|
||||
raise HTTPException(status_code=400, detail="Price ID does not belong to the correct product.")
|
||||
|
||||
# Check for existing subscription for our product
|
||||
existing_subscription = await get_user_subscription(current_user_id)
|
||||
# print("Existing subscription for product:", existing_subscription)
|
||||
|
||||
if existing_subscription:
|
||||
# --- Handle Subscription Change (Upgrade or Downgrade) ---
|
||||
try:
|
||||
subscription_id = existing_subscription['id']
|
||||
subscription_item = existing_subscription['items']['data'][0]
|
||||
current_price_id = subscription_item['price']['id']
|
||||
|
||||
# Skip if already on this plan
|
||||
if current_price_id == request.price_id:
|
||||
return {
|
||||
"subscription_id": subscription_id,
|
||||
"status": "no_change",
|
||||
"message": "Already subscribed to this plan.",
|
||||
"details": {
|
||||
"is_upgrade": None,
|
||||
"effective_date": None,
|
||||
"current_price": round(price['unit_amount'] / 100, 2) if price.get('unit_amount') else 0,
|
||||
"new_price": round(price['unit_amount'] / 100, 2) if price.get('unit_amount') else 0,
|
||||
}
|
||||
}
|
||||
|
||||
# Get current and new price details
|
||||
current_price = stripe.Price.retrieve(current_price_id)
|
||||
new_price = price # Already retrieved
|
||||
is_upgrade = new_price['unit_amount'] > current_price['unit_amount']
|
||||
|
||||
if is_upgrade:
|
||||
# --- Handle Upgrade --- Immediate modification
|
||||
updated_subscription = stripe.Subscription.modify(
|
||||
subscription_id,
|
||||
items=[{
|
||||
'id': subscription_item['id'],
|
||||
'price': request.price_id,
|
||||
}],
|
||||
proration_behavior='always_invoice', # Prorate and charge immediately
|
||||
billing_cycle_anchor='now' # Reset billing cycle
|
||||
)
|
||||
|
||||
latest_invoice = None
|
||||
if updated_subscription.get('latest_invoice'):
|
||||
latest_invoice = stripe.Invoice.retrieve(updated_subscription['latest_invoice'])
|
||||
|
||||
return {
|
||||
"subscription_id": updated_subscription['id'],
|
||||
"status": "updated",
|
||||
"message": "Subscription upgraded successfully",
|
||||
"details": {
|
||||
"is_upgrade": True,
|
||||
"effective_date": "immediate",
|
||||
"current_price": round(current_price['unit_amount'] / 100, 2) if current_price.get('unit_amount') else 0,
|
||||
"new_price": round(new_price['unit_amount'] / 100, 2) if new_price.get('unit_amount') else 0,
|
||||
"invoice": {
|
||||
"id": latest_invoice['id'] if latest_invoice else None,
|
||||
"status": latest_invoice['status'] if latest_invoice else None,
|
||||
"amount_due": round(latest_invoice['amount_due'] / 100, 2) if latest_invoice else 0,
|
||||
"amount_paid": round(latest_invoice['amount_paid'] / 100, 2) if latest_invoice else 0
|
||||
} if latest_invoice else None
|
||||
}
|
||||
}
|
||||
else:
|
||||
# --- Handle Downgrade --- Use Subscription Schedule
|
||||
try:
|
||||
current_period_end_ts = subscription_item['current_period_end']
|
||||
|
||||
# Retrieve the subscription again to get the schedule ID if it exists
|
||||
# This ensures we have the latest state before creating/modifying schedule
|
||||
sub_with_schedule = stripe.Subscription.retrieve(subscription_id)
|
||||
schedule_id = sub_with_schedule.get('schedule')
|
||||
|
||||
# Get the current phase configuration from the schedule or subscription
|
||||
if schedule_id:
|
||||
schedule = stripe.SubscriptionSchedule.retrieve(schedule_id)
|
||||
# Find the current phase in the schedule
|
||||
# This logic assumes simple schedules; might need refinement for complex ones
|
||||
current_phase = None
|
||||
for phase in reversed(schedule['phases']):
|
||||
if phase['start_date'] <= datetime.now(timezone.utc).timestamp():
|
||||
current_phase = phase
|
||||
break
|
||||
if not current_phase: # Fallback if logic fails
|
||||
current_phase = schedule['phases'][-1]
|
||||
else:
|
||||
# If no schedule, the current subscription state defines the current phase
|
||||
current_phase = {
|
||||
'items': existing_subscription['items']['data'], # Use original items data
|
||||
'start_date': existing_subscription['current_period_start'], # Use sub start if no schedule
|
||||
# Add other relevant fields if needed for create/modify
|
||||
}
|
||||
|
||||
# Prepare the current phase data for the update/create
|
||||
# Ensure items is formatted correctly for the API
|
||||
current_phase_items_for_api = []
|
||||
for item in current_phase.get('items', []):
|
||||
price_data = item.get('price')
|
||||
quantity = item.get('quantity')
|
||||
price_id = None
|
||||
|
||||
# Safely extract price ID whether it's an object or just the ID string
|
||||
if isinstance(price_data, dict):
|
||||
price_id = price_data.get('id')
|
||||
elif isinstance(price_data, str):
|
||||
price_id = price_data
|
||||
|
||||
if price_id and quantity is not None:
|
||||
current_phase_items_for_api.append({'price': price_id, 'quantity': quantity})
|
||||
else:
|
||||
logger.warning(f"Skipping item in current phase due to missing price ID or quantity: {item}")
|
||||
|
||||
if not current_phase_items_for_api:
|
||||
raise ValueError("Could not determine valid items for the current phase.")
|
||||
|
||||
current_phase_update_data = {
|
||||
'items': current_phase_items_for_api,
|
||||
'start_date': current_phase['start_date'], # Preserve original start date
|
||||
'end_date': current_period_end_ts, # End this phase at period end
|
||||
'proration_behavior': 'none'
|
||||
# Include other necessary fields from current_phase if modifying?
|
||||
# e.g., 'billing_cycle_anchor', 'collection_method'? Usually inherited.
|
||||
}
|
||||
|
||||
# Define the new (downgrade) phase
|
||||
new_downgrade_phase_data = {
|
||||
'items': [{'price': request.price_id, 'quantity': 1}],
|
||||
'start_date': current_period_end_ts, # Start immediately after current phase ends
|
||||
'proration_behavior': 'none'
|
||||
# iterations defaults to 1, meaning it runs for one billing cycle
|
||||
# then schedule ends based on end_behavior
|
||||
}
|
||||
|
||||
# Update or Create Schedule
|
||||
if schedule_id:
|
||||
# Update existing schedule, replacing all future phases
|
||||
# print(f"Updating existing schedule {schedule_id}")
|
||||
logger.info(f"Updating existing schedule {schedule_id} for subscription {subscription_id}")
|
||||
logger.debug(f"Current phase data: {current_phase_update_data}")
|
||||
logger.debug(f"New phase data: {new_downgrade_phase_data}")
|
||||
updated_schedule = stripe.SubscriptionSchedule.modify(
|
||||
schedule_id,
|
||||
phases=[current_phase_update_data, new_downgrade_phase_data],
|
||||
end_behavior='release'
|
||||
)
|
||||
logger.info(f"Successfully updated schedule {updated_schedule['id']}")
|
||||
else:
|
||||
# Create a new schedule using the defined phases
|
||||
print(f"Creating new schedule for subscription {subscription_id}")
|
||||
logger.info(f"Creating new schedule for subscription {subscription_id}")
|
||||
# Deep debug logging - write subscription details to help diagnose issues
|
||||
logger.debug(f"Subscription details: {subscription_id}, current_period_end_ts: {current_period_end_ts}")
|
||||
logger.debug(f"Current price: {current_price_id}, New price: {request.price_id}")
|
||||
|
||||
try:
|
||||
updated_schedule = stripe.SubscriptionSchedule.create(
|
||||
from_subscription=subscription_id,
|
||||
phases=[
|
||||
{
|
||||
'start_date': current_phase['start_date'],
|
||||
'end_date': current_period_end_ts,
|
||||
'proration_behavior': 'none',
|
||||
'items': [
|
||||
{
|
||||
'price': current_price_id,
|
||||
'quantity': 1
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
'start_date': current_period_end_ts,
|
||||
'proration_behavior': 'none',
|
||||
'items': [
|
||||
{
|
||||
'price': request.price_id,
|
||||
'quantity': 1
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
end_behavior='release'
|
||||
)
|
||||
# Don't try to link the schedule - that's handled by from_subscription
|
||||
logger.info(f"Created new schedule {updated_schedule['id']} from subscription {subscription_id}")
|
||||
# print(f"Created new schedule {updated_schedule['id']} from subscription {subscription_id}")
|
||||
|
||||
# Verify the schedule was created correctly
|
||||
fetched_schedule = stripe.SubscriptionSchedule.retrieve(updated_schedule['id'])
|
||||
logger.info(f"Schedule verification - Status: {fetched_schedule.get('status')}, Phase Count: {len(fetched_schedule.get('phases', []))}")
|
||||
logger.debug(f"Schedule details: {fetched_schedule}")
|
||||
except Exception as schedule_error:
|
||||
logger.exception(f"Failed to create schedule: {str(schedule_error)}")
|
||||
raise schedule_error # Re-raise to be caught by the outer try-except
|
||||
|
||||
return {
|
||||
"subscription_id": subscription_id,
|
||||
"schedule_id": updated_schedule['id'],
|
||||
"status": "scheduled",
|
||||
"message": "Subscription downgrade scheduled",
|
||||
"details": {
|
||||
"is_upgrade": False,
|
||||
"effective_date": "end_of_period",
|
||||
"current_price": round(current_price['unit_amount'] / 100, 2) if current_price.get('unit_amount') else 0,
|
||||
"new_price": round(new_price['unit_amount'] / 100, 2) if new_price.get('unit_amount') else 0,
|
||||
"effective_at": datetime.fromtimestamp(current_period_end_ts, tz=timezone.utc).isoformat()
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception(f"Error handling subscription schedule for sub {subscription_id}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Error handling subscription schedule: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error updating subscription {existing_subscription.get('id') if existing_subscription else 'N/A'}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Error updating subscription: {str(e)}")
|
||||
else:
|
||||
# --- Create New Subscription via Checkout Session ---
|
||||
session = stripe.checkout.Session.create(
|
||||
customer=customer_id,
|
||||
payment_method_types=['card'],
|
||||
line_items=[{'price': request.price_id, 'quantity': 1}],
|
||||
mode='subscription',
|
||||
success_url=request.success_url,
|
||||
cancel_url=request.cancel_url,
|
||||
metadata={
|
||||
'user_id': current_user_id,
|
||||
'product_id': product_id
|
||||
}
|
||||
)
|
||||
return {"session_id": session['id'], "url": session['url'], "status": "new"}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error creating checkout session: {str(e)}")
|
||||
# Check if it's a Stripe error with more details
|
||||
if hasattr(e, 'json_body') and e.json_body and 'error' in e.json_body:
|
||||
error_detail = e.json_body['error'].get('message', str(e))
|
||||
else:
|
||||
error_detail = str(e)
|
||||
raise HTTPException(status_code=500, detail=f"Error creating checkout session: {error_detail}")
|
||||
|
||||
@router.post("/create-portal-session")
|
||||
async def create_portal_session(
|
||||
request: CreatePortalSessionRequest,
|
||||
current_user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
):
|
||||
"""Create a Stripe Customer Portal session for subscription management."""
|
||||
try:
|
||||
# Get Supabase client
|
||||
db = DBConnection()
|
||||
client = await db.client
|
||||
|
||||
# Get customer ID
|
||||
customer_id = await get_stripe_customer_id(client, current_user_id)
|
||||
if not customer_id:
|
||||
raise HTTPException(status_code=404, detail="No billing customer found")
|
||||
|
||||
# Ensure the portal configuration has subscription_update enabled
|
||||
try:
|
||||
# First, check if we have a configuration that already enables subscription update
|
||||
configurations = stripe.billing_portal.Configuration.list(limit=100)
|
||||
active_config = None
|
||||
|
||||
# Look for a configuration with subscription_update enabled
|
||||
for config in configurations.get('data', []):
|
||||
features = config.get('features', {})
|
||||
subscription_update = features.get('subscription_update', {})
|
||||
if subscription_update.get('enabled', False):
|
||||
active_config = config
|
||||
logger.info(f"Found existing portal configuration with subscription_update enabled: {config['id']}")
|
||||
break
|
||||
|
||||
# If no config with subscription_update found, create one or update the active one
|
||||
if not active_config:
|
||||
# Find the active configuration or create a new one
|
||||
if configurations.get('data', []):
|
||||
default_config = configurations['data'][0]
|
||||
logger.info(f"Updating default portal configuration: {default_config['id']} to enable subscription_update")
|
||||
|
||||
active_config = stripe.billing_portal.Configuration.update(
|
||||
default_config['id'],
|
||||
features={
|
||||
'subscription_update': {
|
||||
'enabled': True,
|
||||
'proration_behavior': 'create_prorations',
|
||||
'default_allowed_updates': ['price']
|
||||
},
|
||||
# Preserve other features that may already be enabled
|
||||
'customer_update': default_config.get('features', {}).get('customer_update', {'enabled': True, 'allowed_updates': ['email', 'address']}),
|
||||
'invoice_history': {'enabled': True},
|
||||
'payment_method_update': {'enabled': True}
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Create a new configuration with subscription_update enabled
|
||||
logger.info("Creating new portal configuration with subscription_update enabled")
|
||||
active_config = stripe.billing_portal.Configuration.create(
|
||||
business_profile={
|
||||
'headline': 'Subscription Management',
|
||||
'privacy_policy_url': config.FRONTEND_URL + '/privacy',
|
||||
'terms_of_service_url': config.FRONTEND_URL + '/terms'
|
||||
},
|
||||
features={
|
||||
'subscription_update': {
|
||||
'enabled': True,
|
||||
'proration_behavior': 'create_prorations',
|
||||
'default_allowed_updates': ['price']
|
||||
},
|
||||
'customer_update': {
|
||||
'enabled': True,
|
||||
'allowed_updates': ['email', 'address']
|
||||
},
|
||||
'invoice_history': {'enabled': True},
|
||||
'payment_method_update': {'enabled': True}
|
||||
}
|
||||
)
|
||||
|
||||
# Log the active configuration for debugging
|
||||
logger.info(f"Using portal configuration: {active_config['id']} with subscription_update: {active_config.get('features', {}).get('subscription_update', {}).get('enabled', False)}")
|
||||
|
||||
except Exception as config_error:
|
||||
logger.warning(f"Error configuring portal: {config_error}. Continuing with default configuration.")
|
||||
|
||||
# Create portal session using the proper configuration if available
|
||||
portal_params = {
|
||||
"customer": customer_id,
|
||||
"return_url": request.return_url
|
||||
}
|
||||
|
||||
# Add configuration_id if we found or created one with subscription_update enabled
|
||||
if active_config:
|
||||
portal_params["configuration"] = active_config['id']
|
||||
|
||||
# Create the session
|
||||
session = stripe.billing_portal.Session.create(**portal_params)
|
||||
|
||||
return {"url": session.url}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating portal session: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get("/subscription")
|
||||
async def get_subscription(
|
||||
current_user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
):
|
||||
"""Get the current subscription status for the current user, including scheduled changes."""
|
||||
try:
|
||||
# Get subscription from Stripe (this helper already handles filtering/cleanup)
|
||||
subscription = await get_user_subscription(current_user_id)
|
||||
# print("Subscription data for status:", subscription)
|
||||
|
||||
if not subscription:
|
||||
# Default to free tier status if no active subscription for our product
|
||||
free_tier_id = config.STRIPE_FREE_TIER_ID
|
||||
free_tier_info = SUBSCRIPTION_TIERS.get(free_tier_id)
|
||||
return SubscriptionStatus(
|
||||
status="no_subscription",
|
||||
plan_name=free_tier_info.get('name', 'free') if free_tier_info else 'free',
|
||||
price_id=free_tier_id,
|
||||
minutes_limit=free_tier_info.get('minutes') if free_tier_info else 0
|
||||
)
|
||||
|
||||
# Extract current plan details
|
||||
current_item = subscription['items']['data'][0]
|
||||
current_price_id = current_item['price']['id']
|
||||
current_tier_info = SUBSCRIPTION_TIERS.get(current_price_id)
|
||||
if not current_tier_info:
|
||||
# Fallback if somehow subscribed to an unknown price within our product
|
||||
logger.warning(f"User {current_user_id} subscribed to unknown price {current_price_id}. Defaulting info.")
|
||||
current_tier_info = {'name': 'unknown', 'minutes': 0}
|
||||
|
||||
# Calculate current usage
|
||||
db = DBConnection()
|
||||
client = await db.client
|
||||
current_usage = await calculate_monthly_usage(client, current_user_id)
|
||||
|
||||
status_response = SubscriptionStatus(
|
||||
status=subscription['status'], # 'active', 'trialing', etc.
|
||||
plan_name=subscription['plan'].get('nickname') or current_tier_info['name'],
|
||||
price_id=current_price_id,
|
||||
current_period_end=datetime.fromtimestamp(current_item['current_period_end'], tz=timezone.utc),
|
||||
cancel_at_period_end=subscription['cancel_at_period_end'],
|
||||
trial_end=datetime.fromtimestamp(subscription['trial_end'], tz=timezone.utc) if subscription.get('trial_end') else None,
|
||||
minutes_limit=current_tier_info['minutes'],
|
||||
current_usage=round(current_usage, 2),
|
||||
has_schedule=False # Default
|
||||
)
|
||||
|
||||
# Check for an attached schedule (indicates pending downgrade)
|
||||
schedule_id = subscription.get('schedule')
|
||||
if schedule_id:
|
||||
try:
|
||||
schedule = stripe.SubscriptionSchedule.retrieve(schedule_id)
|
||||
# Find the *next* phase after the current one
|
||||
next_phase = None
|
||||
current_phase_end = current_item['current_period_end']
|
||||
|
||||
for phase in schedule.get('phases', []):
|
||||
# Check if this phase starts exactly when the current one ends
|
||||
if phase.get('start_date') == current_phase_end:
|
||||
next_phase = phase
|
||||
break # Found the immediate next phase
|
||||
|
||||
if next_phase:
|
||||
scheduled_item = next_phase['items'][0] # Assuming single item
|
||||
scheduled_price_id = scheduled_item['price'] # Price ID might be string here
|
||||
scheduled_tier_info = SUBSCRIPTION_TIERS.get(scheduled_price_id)
|
||||
|
||||
status_response.has_schedule = True
|
||||
status_response.status = 'scheduled_downgrade' # Override status
|
||||
status_response.scheduled_plan_name = scheduled_tier_info.get('name', 'unknown') if scheduled_tier_info else 'unknown'
|
||||
status_response.scheduled_price_id = scheduled_price_id
|
||||
status_response.scheduled_change_date = datetime.fromtimestamp(next_phase['start_date'], tz=timezone.utc)
|
||||
|
||||
except Exception as schedule_error:
|
||||
logger.error(f"Error retrieving or parsing schedule {schedule_id} for sub {subscription['id']}: {schedule_error}")
|
||||
# Proceed without schedule info if retrieval fails
|
||||
|
||||
return status_response
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error getting subscription status for user {current_user_id}: {str(e)}") # Use logger.exception
|
||||
raise HTTPException(status_code=500, detail="Error retrieving subscription status.")
|
||||
|
||||
@router.get("/check-status")
|
||||
async def check_status(
|
||||
current_user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
):
|
||||
"""Check if the user can run agents based on their subscription and usage."""
|
||||
try:
|
||||
# Get Supabase client
|
||||
db = DBConnection()
|
||||
client = await db.client
|
||||
|
||||
can_run, message, subscription = await check_billing_status(client, current_user_id)
|
||||
|
||||
return {
|
||||
"can_run": can_run,
|
||||
"message": message,
|
||||
"subscription": subscription
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking billing status: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/webhook")
|
||||
async def stripe_webhook(request: Request):
|
||||
"""Handle Stripe webhook events."""
|
||||
try:
|
||||
# Get the webhook secret from config
|
||||
webhook_secret = config.STRIPE_WEBHOOK_SECRET
|
||||
|
||||
# Get the webhook payload
|
||||
payload = await request.body()
|
||||
sig_header = request.headers.get('stripe-signature')
|
||||
|
||||
# Verify webhook signature
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, webhook_secret
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail="Invalid payload")
|
||||
except stripe.error.SignatureVerificationError as e:
|
||||
raise HTTPException(status_code=400, detail="Invalid signature")
|
||||
|
||||
# Handle the event
|
||||
if event.type in ['customer.subscription.created', 'customer.subscription.updated', 'customer.subscription.deleted']:
|
||||
# We don't need to do anything here as we'll query Stripe directly
|
||||
pass
|
||||
|
||||
return {"status": "success"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing webhook: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
|
@ -1,5 +0,0 @@
|
|||
{
|
||||
"imports": {
|
||||
"@supabase/supabase-js": "https://esm.sh/@supabase/supabase-js"
|
||||
}
|
||||
}
|
|
@ -1,284 +0,0 @@
|
|||
import {serve} from "https://deno.land/std@0.168.0/http/server.ts";
|
||||
import {stripeFunctionHandler} from "https://deno.land/x/basejump@v2.0.3/billing-functions/mod.ts";
|
||||
import { requireAuthorizedBillingUser } from "https://deno.land/x/basejump@v2.0.3/billing-functions/src/require-authorized-billing-user.ts";
|
||||
import getBillingStatus from "https://deno.land/x/basejump@v2.0.3/billing-functions/src/wrappers/get-billing-status.ts";
|
||||
import createSupabaseServiceClient from "https://deno.land/x/basejump@v2.0.3/billing-functions/lib/create-supabase-service-client.ts";
|
||||
import validateUrl from "https://deno.land/x/basejump@v2.0.3/billing-functions/lib/validate-url.ts";
|
||||
|
||||
import Stripe from "https://esm.sh/stripe@11.1.0?target=deno";
|
||||
|
||||
console.log("Starting billing functions...");
|
||||
|
||||
const defaultAllowedHost = Deno.env.get("ALLOWED_HOST") || "http://localhost:3000";
|
||||
const allowedHosts = [defaultAllowedHost, "https://www.suna.so", "https://suna.so", "https://staging.suna.so"];
|
||||
console.log("Default allowed host:", defaultAllowedHost);
|
||||
|
||||
export const corsHeaders = {
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Headers":
|
||||
"authorization, x-client-info, apikey, content-type",
|
||||
};
|
||||
|
||||
console.log("Initializing Stripe client...");
|
||||
const stripeClient = new Stripe(Deno.env.get("STRIPE_API_KEY") as string, {
|
||||
// This is needed to use the Fetch API rather than relying on the Node http
|
||||
// package.
|
||||
apiVersion: "2022-11-15",
|
||||
httpClient: Stripe.createFetchHttpClient(),
|
||||
});
|
||||
console.log("Stripe client initialized");
|
||||
|
||||
console.log("Setting up stripe handler...");
|
||||
const stripeHandler = stripeFunctionHandler({
|
||||
stripeClient,
|
||||
defaultPlanId: Deno.env.get("STRIPE_DEFAULT_PLAN_ID") as string,
|
||||
defaultTrialDays: Deno.env.get("STRIPE_DEFAULT_TRIAL_DAYS") ? Number(Deno.env.get("STRIPE_DEFAULT_TRIAL_DAYS")) : undefined
|
||||
});
|
||||
console.log("Stripe handler configured");
|
||||
|
||||
serve(async (req) => {
|
||||
console.log("Received request:", req.method, req.url);
|
||||
|
||||
if (req.method === "OPTIONS") {
|
||||
console.log("Handling OPTIONS request");
|
||||
return new Response("ok", {headers: corsHeaders});
|
||||
}
|
||||
|
||||
try {
|
||||
const body = await req.json();
|
||||
console.log("Request body:", body);
|
||||
|
||||
if (!body.args?.account_id) {
|
||||
console.log("Missing account_id in request");
|
||||
return new Response(
|
||||
JSON.stringify({ error: "Account id is required" }),
|
||||
{
|
||||
status: 400,
|
||||
headers: {
|
||||
...corsHeaders,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
switch (body.action) {
|
||||
case "get_plans":
|
||||
console.log("Getting plans");
|
||||
try {
|
||||
const plans = await stripeHandler.getPlans(body.args);
|
||||
console.log("Plans retrieved:", plans.length);
|
||||
return new Response(
|
||||
JSON.stringify(plans),
|
||||
{
|
||||
headers: {
|
||||
...corsHeaders,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
}
|
||||
);
|
||||
} catch (e) {
|
||||
console.log("Error getting plans:", e);
|
||||
return new Response(
|
||||
JSON.stringify({ error: "Failed to get plans" }),
|
||||
{
|
||||
status: 500,
|
||||
headers: {
|
||||
...corsHeaders,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
case "get_billing_portal_url":
|
||||
console.log("Getting billing portal URL for account:", body.args.account_id);
|
||||
if (!validateUrl(body.args.return_url, allowedHosts)) {
|
||||
console.log("Invalid return URL:", body.args.return_url);
|
||||
return new Response(
|
||||
JSON.stringify({ error: "Return url is not allowed" }),
|
||||
{
|
||||
status: 400,
|
||||
headers: {
|
||||
...corsHeaders,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
return await requireAuthorizedBillingUser(req, {
|
||||
accountId: body.args.account_id,
|
||||
authorizedRoles: ["owner"],
|
||||
async onBillableAndAuthorized(roleInfo) {
|
||||
console.log("User authorized for billing portal, role info:", roleInfo);
|
||||
try {
|
||||
const response = await stripeHandler.getBillingPortalUrl({
|
||||
accountId: roleInfo.account_id,
|
||||
subscriptionId: roleInfo.billing_subscription_id,
|
||||
customerId: roleInfo.billing_customer_id,
|
||||
returnUrl: body.args.return_url,
|
||||
});
|
||||
console.log("Billing portal URL generated");
|
||||
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
billing_enabled: roleInfo.billing_enabled,
|
||||
...response,
|
||||
}),
|
||||
{
|
||||
headers: {
|
||||
...corsHeaders,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
}
|
||||
);
|
||||
} catch (e) {
|
||||
console.log("Error getting billing portal URL:", e);
|
||||
return new Response(
|
||||
JSON.stringify({ error: "Failed to generate billing portal URL" }),
|
||||
{
|
||||
status: 500,
|
||||
headers: {
|
||||
...corsHeaders,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
case "get_new_subscription_url":
|
||||
console.log("Getting new subscription URL for account:", body.args.account_id);
|
||||
if (!validateUrl(body.args.success_url, allowedHosts) || !validateUrl(body.args.cancel_url, allowedHosts)) {
|
||||
console.log("Invalid success or cancel URL:", body.args.success_url, body.args.cancel_url);
|
||||
return new Response(
|
||||
JSON.stringify({ error: "Success or cancel url is not allowed" }),
|
||||
{
|
||||
status: 400,
|
||||
headers: {
|
||||
...corsHeaders,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
return await requireAuthorizedBillingUser(req, {
|
||||
accountId: body.args.account_id,
|
||||
authorizedRoles: ["owner"],
|
||||
async onBillableAndAuthorized(roleInfo) {
|
||||
console.log("User authorized for new subscription, role info:", roleInfo);
|
||||
try {
|
||||
const response = await stripeHandler.getNewSubscriptionUrl({
|
||||
accountId: roleInfo.account_id,
|
||||
planId: body.args.plan_id,
|
||||
successUrl: body.args.success_url,
|
||||
cancelUrl: body.args.cancel_url,
|
||||
billingEmail: roleInfo.billing_email,
|
||||
customerId: roleInfo.billing_customer_id,
|
||||
});
|
||||
console.log("New subscription URL generated");
|
||||
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
billing_enabled: roleInfo.billing_enabled,
|
||||
...response,
|
||||
}),
|
||||
{
|
||||
headers: {
|
||||
...corsHeaders,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
}
|
||||
);
|
||||
} catch (e) {
|
||||
console.log("Error getting new subscription URL:", e);
|
||||
return new Response(
|
||||
JSON.stringify({ error: "Failed to generate new subscription URL" }),
|
||||
{
|
||||
status: 500,
|
||||
headers: {
|
||||
...corsHeaders,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
case "get_billing_status":
|
||||
console.log("Getting billing status for account:", body.args.account_id);
|
||||
return await requireAuthorizedBillingUser(req, {
|
||||
accountId: body.args.account_id,
|
||||
authorizedRoles: ["owner"],
|
||||
async onBillableAndAuthorized(roleInfo) {
|
||||
console.log("User authorized, role info:", roleInfo);
|
||||
const supabaseClient = createSupabaseServiceClient();
|
||||
console.log("Getting billing status...");
|
||||
try {
|
||||
const response = await getBillingStatus(
|
||||
supabaseClient,
|
||||
roleInfo,
|
||||
stripeHandler
|
||||
);
|
||||
console.log("Billing status response:", response);
|
||||
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
...response,
|
||||
status: response.status || "not_setup",
|
||||
billing_enabled: roleInfo.billing_enabled,
|
||||
}),
|
||||
{
|
||||
headers: {
|
||||
...corsHeaders,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
}
|
||||
);
|
||||
} catch (e) {
|
||||
console.log("Error getting billing status:", e);
|
||||
return new Response(
|
||||
JSON.stringify({ error: "Internal server error" }),
|
||||
{
|
||||
status: 500,
|
||||
headers: {
|
||||
...corsHeaders,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
default:
|
||||
console.log("Invalid action requested:", body.action);
|
||||
return new Response(
|
||||
JSON.stringify({ error: "Invalid action" }),
|
||||
{
|
||||
status: 400,
|
||||
headers: {
|
||||
...corsHeaders,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
console.log("Error processing request:", e);
|
||||
return new Response(
|
||||
JSON.stringify({ error: "Internal server error" }),
|
||||
{
|
||||
status: 500,
|
||||
headers: {
|
||||
...corsHeaders,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
});
|
|
@ -1,5 +0,0 @@
|
|||
{
|
||||
"imports": {
|
||||
"@supabase/supabase-js": "https://esm.sh/@supabase/supabase-js"
|
||||
}
|
||||
}
|
|
@ -1,24 +0,0 @@
|
|||
import {serve} from "https://deno.land/std@0.168.0/http/server.ts";
|
||||
import {billingWebhooksWrapper, stripeWebhookHandler} from "https://deno.land/x/basejump@v2.0.3/billing-functions/mod.ts";
|
||||
|
||||
|
||||
import Stripe from "https://esm.sh/stripe@11.1.0?target=deno";
|
||||
|
||||
const stripeClient = new Stripe(Deno.env.get("STRIPE_API_KEY") as string, {
|
||||
// This is needed to use the Fetch API rather than relying on the Node http
|
||||
// package.
|
||||
apiVersion: "2022-11-15",
|
||||
httpClient: Stripe.createFetchHttpClient(),
|
||||
});
|
||||
|
||||
const stripeResponse = stripeWebhookHandler({
|
||||
stripeClient,
|
||||
stripeWebhookSigningSecret: Deno.env.get("STRIPE_WEBHOOK_SIGNING_SECRET") as string,
|
||||
});
|
||||
|
||||
const webhookEndpoint = billingWebhooksWrapper(stripeResponse);
|
||||
|
||||
serve(async (req) => {
|
||||
const response = await webhookEndpoint(req);
|
||||
return response;
|
||||
});
|
|
@ -1,133 +0,0 @@
|
|||
"""
|
||||
Tests for direct tool execution in AgentPress.
|
||||
|
||||
This module tests the performance difference between sequential and parallel
|
||||
tool execution strategies by directly calling the execution methods without thread overhead.
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import sys
|
||||
from dotenv import load_dotenv
|
||||
from agentpress.thread_manager import ThreadManager
|
||||
from agentpress.response_processor import ProcessorConfig
|
||||
from agent.tools.wait_tool import WaitTool
|
||||
from agentpress.tool import ToolResult
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
async def test_direct_execution():
|
||||
"""Directly test sequential vs parallel execution without thread overhead."""
|
||||
print("\n" + "="*80)
|
||||
print("🧪 TESTING DIRECT TOOL EXECUTION: PARALLEL VS SEQUENTIAL")
|
||||
print("="*80 + "\n")
|
||||
|
||||
# Initialize ThreadManager and register tools
|
||||
thread_manager = ThreadManager()
|
||||
thread_manager.add_tool(WaitTool)
|
||||
|
||||
# Create wait tool calls
|
||||
wait_tool_calls = [
|
||||
{"name": "wait", "arguments": {"seconds": 2, "message": "Wait tool 1"}},
|
||||
{"name": "wait", "arguments": {"seconds": 2, "message": "Wait tool 2"}},
|
||||
{"name": "wait", "arguments": {"seconds": 2, "message": "Wait tool 3"}}
|
||||
]
|
||||
|
||||
# Expected values for validation
|
||||
expected_tool_count = len(wait_tool_calls)
|
||||
|
||||
# Test sequential execution
|
||||
print("🔄 Testing Sequential Execution")
|
||||
print("-"*60)
|
||||
sequential_start = asyncio.get_event_loop().time()
|
||||
sequential_results = await thread_manager.response_processor._execute_tools(
|
||||
wait_tool_calls,
|
||||
execution_strategy="sequential"
|
||||
)
|
||||
sequential_end = asyncio.get_event_loop().time()
|
||||
sequential_time = sequential_end - sequential_start
|
||||
|
||||
print(f"Sequential execution completed in {sequential_time:.2f} seconds")
|
||||
|
||||
# Validate sequential results - results are a list of tuples (tool_call, tool_result)
|
||||
assert len(sequential_results) == expected_tool_count, f"❌ Expected {expected_tool_count} tool results, got {len(sequential_results)} in sequential execution"
|
||||
assert all(isinstance(result_tuple, tuple) and len(result_tuple) == 2 for result_tuple in sequential_results), "❌ Not all sequential results are tuples of (tool_call, result)"
|
||||
assert all(isinstance(result_tuple[1], ToolResult) for result_tuple in sequential_results), "❌ Not all sequential result values are ToolResult instances"
|
||||
assert all(result_tuple[1].success for result_tuple in sequential_results), "❌ Not all sequential tool executions were successful"
|
||||
print("✅ PASS: Sequential execution completed all tool calls successfully")
|
||||
print()
|
||||
|
||||
# Test parallel execution
|
||||
print("⚡ Testing Parallel Execution")
|
||||
print("-"*60)
|
||||
parallel_start = asyncio.get_event_loop().time()
|
||||
parallel_results = await thread_manager.response_processor._execute_tools(
|
||||
wait_tool_calls,
|
||||
execution_strategy="parallel"
|
||||
)
|
||||
parallel_end = asyncio.get_event_loop().time()
|
||||
parallel_time = parallel_end - parallel_start
|
||||
|
||||
print(f"Parallel execution completed in {parallel_time:.2f} seconds")
|
||||
|
||||
# Validate parallel results - results are a list of tuples (tool_call, tool_result)
|
||||
assert len(parallel_results) == expected_tool_count, f"❌ Expected {expected_tool_count} tool results, got {len(parallel_results)} in parallel execution"
|
||||
assert all(isinstance(result_tuple, tuple) and len(result_tuple) == 2 for result_tuple in parallel_results), "❌ Not all parallel results are tuples of (tool_call, result)"
|
||||
assert all(isinstance(result_tuple[1], ToolResult) for result_tuple in parallel_results), "❌ Not all parallel result values are ToolResult instances"
|
||||
assert all(result_tuple[1].success for result_tuple in parallel_results), "❌ Not all parallel tool executions were successful"
|
||||
print("✅ PASS: Parallel execution completed all tool calls successfully")
|
||||
print()
|
||||
|
||||
# Report results
|
||||
print("\n" + "="*80)
|
||||
print(f"🧮 RESULTS SUMMARY")
|
||||
print("="*80)
|
||||
print(f"Sequential: {sequential_time:.2f} seconds")
|
||||
print(f"Parallel: {parallel_time:.2f} seconds")
|
||||
|
||||
# Calculate and validate speedup
|
||||
speedup = sequential_time / parallel_time if parallel_time > 0 else 0
|
||||
print(f"Speedup: {speedup:.2f}x faster")
|
||||
|
||||
# Validate speedup is significant (at least 1.5x faster)
|
||||
min_expected_speedup = 1.5
|
||||
assert speedup >= min_expected_speedup, f"❌ Expected parallel execution to be at least {min_expected_speedup}x faster than sequential, but got {speedup:.2f}x"
|
||||
print(f"✅ PASS: Parallel execution is {speedup:.2f}x faster than sequential as expected")
|
||||
|
||||
# Ideal speedup should be close to the number of tools (3x)
|
||||
# But allow for some overhead (at least 1.5x)
|
||||
theoretical_max_speedup = len(wait_tool_calls)
|
||||
print(f"Note: Theoretical max speedup: {theoretical_max_speedup:.1f}x")
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("✅ ALL TESTS PASSED")
|
||||
print("="*80)
|
||||
|
||||
# Return results for potential further analysis
|
||||
return {
|
||||
"sequential": {
|
||||
"time": sequential_time,
|
||||
"results": sequential_results
|
||||
},
|
||||
"parallel": {
|
||||
"time": parallel_time,
|
||||
"results": parallel_results
|
||||
},
|
||||
"speedup": speedup
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(test_direct_execution())
|
||||
print("\n✅ Test completed successfully")
|
||||
sys.exit(0)
|
||||
except AssertionError as e:
|
||||
print(f"\n\n❌ Test failed: {str(e)}")
|
||||
sys.exit(1)
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n❌ Test interrupted by user")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n\n❌ Error during test: {str(e)}")
|
||||
sys.exit(1)
|
|
@ -1,193 +0,0 @@
|
|||
"""
|
||||
Raw streaming test to analyze tool call streaming behavior.
|
||||
|
||||
This script specifically tests how raw streaming chunks are delivered from the Anthropic API
|
||||
with tool calls containing large JSON payloads.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
|
||||
from anthropic import AsyncAnthropic
|
||||
from utils.logger import logger
|
||||
|
||||
# Example tool schema for Anthropic format
|
||||
CREATE_FILE_TOOL = {
|
||||
"name": "create_file",
|
||||
"description": "Create a new file with the provided contents at a given path in the workspace",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to be created"
|
||||
},
|
||||
"file_contents": {
|
||||
"type": "string",
|
||||
"description": "The content to write to the file"
|
||||
}
|
||||
},
|
||||
"required": ["file_path", "file_contents"]
|
||||
}
|
||||
}
|
||||
|
||||
async def test_raw_streaming():
|
||||
"""Test tool calling with streaming to observe raw chunk behavior using Anthropic SDK directly."""
|
||||
# Setup conversation with a prompt likely to generate large file payloads
|
||||
messages = [
|
||||
{"role": "user", "content": "Create a CSS file with a comprehensive set of styles for a modern responsive website."}
|
||||
]
|
||||
|
||||
print("\n=== Testing Raw Streaming Tool Call Behavior ===\n")
|
||||
|
||||
try:
|
||||
# Get API key from environment
|
||||
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
logger.error("ANTHROPIC_API_KEY environment variable not set")
|
||||
return
|
||||
|
||||
# Initialize Anthropic client
|
||||
client = AsyncAnthropic(api_key=api_key)
|
||||
|
||||
# Make API call with tool in streaming mode
|
||||
print("Sending streaming request...")
|
||||
stream = await client.messages.create(
|
||||
model="claude-3-5-sonnet-latest",
|
||||
max_tokens=4096,
|
||||
temperature=0.0,
|
||||
system="You are a helpful assistant with access to file management tools.",
|
||||
messages=messages,
|
||||
tools=[CREATE_FILE_TOOL],
|
||||
tool_choice={"type": "tool", "name": "create_file"},
|
||||
stream=True
|
||||
)
|
||||
|
||||
# Process streaming response
|
||||
print("\nResponse stream started. Processing raw chunks:\n")
|
||||
|
||||
# Stream statistics
|
||||
chunk_count = 0
|
||||
tool_call_chunks = 0
|
||||
accumulated_tool_input = ""
|
||||
current_tool_name = None
|
||||
accumulated_content = ""
|
||||
|
||||
# Process each chunk with ZERO buffering
|
||||
print("\n--- BEGINNING STREAM OUTPUT ---\n", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
# Process each event in the stream
|
||||
async for event in stream:
|
||||
chunk_count += 1
|
||||
|
||||
# Immediate debug output for every chunk
|
||||
print(f"\n[CHUNK {chunk_count}] Type: {event.type}", end="", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
# Process based on event type
|
||||
if event.type == "message_start":
|
||||
print(f" Message ID: {event.message.id}", end="", flush=True)
|
||||
|
||||
elif event.type == "content_block_start":
|
||||
print(f" Content block start: {event.content_block.type}", end="", flush=True)
|
||||
|
||||
elif event.type == "content_block_delta":
|
||||
if hasattr(event.delta, "text") and event.delta.text:
|
||||
text = event.delta.text
|
||||
accumulated_content += text
|
||||
print(f" Content: {repr(text)}", end="", flush=True)
|
||||
|
||||
elif event.type == "tool_use":
|
||||
current_tool_name = event.tool_use.name
|
||||
print(f" Tool use: {current_tool_name}", end="", flush=True)
|
||||
|
||||
# If input is available immediately
|
||||
if hasattr(event.tool_use, "input") and event.tool_use.input:
|
||||
tool_call_chunks += 1
|
||||
input_json = json.dumps(event.tool_use.input)
|
||||
input_len = len(input_json)
|
||||
print(f" Input[{input_len}]: {input_json[:50]}...", end="", flush=True)
|
||||
accumulated_tool_input = input_json
|
||||
|
||||
elif event.type == "tool_use_delta":
|
||||
if hasattr(event.delta, "input") and event.delta.input:
|
||||
tool_call_chunks += 1
|
||||
# For streaming tool inputs, we get partial updates
|
||||
# The delta.input is a dictionary with partial updates to specific fields
|
||||
input_json = json.dumps(event.delta.input)
|
||||
input_len = len(input_json)
|
||||
print(f" Input delta[{input_len}]: {input_json[:50]}...", end="", flush=True)
|
||||
|
||||
# Try to merge the deltas
|
||||
try:
|
||||
if accumulated_tool_input:
|
||||
# Parse existing accumulated JSON
|
||||
existing_input = json.loads(accumulated_tool_input)
|
||||
# Update with new delta
|
||||
existing_input.update(event.delta.input)
|
||||
accumulated_tool_input = json.dumps(existing_input)
|
||||
else:
|
||||
accumulated_tool_input = input_json
|
||||
except json.JSONDecodeError:
|
||||
# If we can't parse JSON yet, just append the raw delta
|
||||
accumulated_tool_input += input_json
|
||||
|
||||
elif event.type == "message_delta":
|
||||
if hasattr(event.delta, "stop_reason") and event.delta.stop_reason:
|
||||
print(f"\n--- FINISH REASON: {event.delta.stop_reason} ---", flush=True)
|
||||
|
||||
elif event.type == "message_stop":
|
||||
# Access stop_reason directly from the event
|
||||
if hasattr(event, "stop_reason"):
|
||||
print(f"\n--- MESSAGE STOP: {event.stop_reason} ---", flush=True)
|
||||
else:
|
||||
print("\n--- MESSAGE STOP ---", flush=True)
|
||||
|
||||
# Force flush after every chunk
|
||||
sys.stdout.flush()
|
||||
|
||||
print("\n\n--- END STREAM OUTPUT ---\n", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
# Summary after all chunks processed
|
||||
print("\n=== Streaming Summary ===")
|
||||
print(f"Total chunks: {chunk_count}")
|
||||
print(f"Tool call chunks: {tool_call_chunks}")
|
||||
|
||||
if current_tool_name:
|
||||
print(f"\nTool name: {current_tool_name}")
|
||||
|
||||
if accumulated_content:
|
||||
print(f"\nAccumulated content:")
|
||||
print(accumulated_content)
|
||||
|
||||
# Try to parse accumulated arguments as JSON
|
||||
try:
|
||||
if accumulated_tool_input:
|
||||
print(f"\nTotal accumulated tool input length: {len(accumulated_tool_input)}")
|
||||
input_obj = json.loads(accumulated_tool_input)
|
||||
print(f"\nSuccessfully parsed accumulated tool input as JSON")
|
||||
if 'file_path' in input_obj:
|
||||
print(f"file_path: {input_obj['file_path']}")
|
||||
if 'file_contents' in input_obj:
|
||||
contents = input_obj['file_contents']
|
||||
print(f"file_contents length: {len(contents)}")
|
||||
print(f"file_contents preview: {contents[:100]}...")
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"\nError parsing accumulated tool input: {e}")
|
||||
print(f"Tool input start: {accumulated_tool_input[:100]}...")
|
||||
print(f"Tool input end: {accumulated_tool_input[-100:]}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in streaming test: {str(e)}", exc_info=True)
|
||||
|
||||
async def main():
|
||||
"""Run the raw streaming test."""
|
||||
await test_raw_streaming()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
|
@ -1,236 +0,0 @@
|
|||
"""
|
||||
Simple test script for LLM API with tool calling functionality.
|
||||
|
||||
This script tests basic tool calling with both streaming and non-streaming to verify functionality.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
|
||||
from services.llm import make_llm_api_call
|
||||
from utils.logger import logger
|
||||
|
||||
# Example tool schema from files_tool.py
|
||||
CREATE_FILE_SCHEMA = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "create_file",
|
||||
"description": "Create a new file with the provided contents at a given path in the workspace",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to be created"
|
||||
},
|
||||
"file_contents": {
|
||||
"type": "string",
|
||||
"description": "The content to write to the file"
|
||||
}
|
||||
},
|
||||
"required": ["file_path", "file_contents"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async def test_simple_tool_call():
|
||||
"""Test a simple non-streaming tool call to verify functionality."""
|
||||
# Setup conversation
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant with access to file management tools."},
|
||||
{"role": "user", "content": "Create an HTML file named hello.html with a simple Hello World message."}
|
||||
]
|
||||
|
||||
print("\n=== Testing non-streaming tool call ===\n")
|
||||
|
||||
try:
|
||||
# Make API call with tool
|
||||
response = await make_llm_api_call(
|
||||
messages=messages,
|
||||
model_name="gpt-4o",
|
||||
temperature=0.0,
|
||||
tools=[CREATE_FILE_SCHEMA],
|
||||
tool_choice={"type": "function", "function": {"name": "create_file"}}
|
||||
)
|
||||
|
||||
# Print basic response info
|
||||
print(f"Response model: {response.model}")
|
||||
print(f"Response type: {type(response)}")
|
||||
|
||||
# Check if the response has tool calls
|
||||
assistant_message = response.choices[0].message
|
||||
print(f"\nAssistant message content: {assistant_message.content}")
|
||||
|
||||
if hasattr(assistant_message, 'tool_calls') and assistant_message.tool_calls:
|
||||
print("\nTool calls detected:")
|
||||
|
||||
for i, tool_call in enumerate(assistant_message.tool_calls):
|
||||
print(f"\nTool call {i+1}:")
|
||||
print(f" ID: {tool_call.id}")
|
||||
print(f" Type: {tool_call.type}")
|
||||
print(f" Function: {tool_call.function.name}")
|
||||
print(f" Arguments:")
|
||||
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
print(json.dumps(args, indent=4))
|
||||
|
||||
# Access and print specific arguments
|
||||
if tool_call.function.name == "create_file":
|
||||
print(f"\nFile path: {args.get('file_path')}")
|
||||
print(f"File contents length: {len(args.get('file_contents', ''))}")
|
||||
print(f"File contents preview: {args.get('file_contents', '')[:100]}...")
|
||||
except Exception as e:
|
||||
print(f"Error parsing arguments: {e}")
|
||||
else:
|
||||
print("\nNo tool calls found in the response.")
|
||||
print(f"Full response: {response}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in test: {str(e)}", exc_info=True)
|
||||
|
||||
async def test_streaming_tool_call():
|
||||
"""Test tool calling with streaming to observe behavior."""
|
||||
# Setup conversation
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant with access to file management tools. YOU ALWAYS USE MULTIPLE TOOL FUNCTION CALLS AT ONCE. YOU NEVER USE ONE TOOL FUNCTION CALL AT A TIME."},
|
||||
{"role": "user", "content": "Create 10 random files with different extensions and content."}
|
||||
]
|
||||
|
||||
print("\n=== Testing streaming tool call ===\n")
|
||||
|
||||
try:
|
||||
# Make API call with tool in streaming mode
|
||||
print("Sending streaming request...")
|
||||
stream_response = await make_llm_api_call(
|
||||
messages=messages,
|
||||
model_name="anthropic/claude-3-5-sonnet-latest",
|
||||
temperature=0.0,
|
||||
tools=[CREATE_FILE_SCHEMA],
|
||||
tool_choice="auto",
|
||||
stream=True
|
||||
)
|
||||
|
||||
# Process streaming response
|
||||
print("\nResponse stream started. Processing chunks:\n")
|
||||
|
||||
# Stream statistics
|
||||
chunk_count = 0
|
||||
content_chunks = 0
|
||||
tool_call_chunks = 0
|
||||
accumulated_content = ""
|
||||
|
||||
# Storage for accumulated tool calls
|
||||
tool_calls = []
|
||||
last_chunk = None # Variable to store the last chunk
|
||||
|
||||
# Process each chunk
|
||||
async for chunk in stream_response:
|
||||
chunk_count += 1
|
||||
last_chunk = chunk # Keep track of the last chunk
|
||||
|
||||
# Print chunk number and type
|
||||
print(f"\n--- Chunk {chunk_count} ---")
|
||||
print(f"Chunk type: {type(chunk)}")
|
||||
|
||||
if not hasattr(chunk, 'choices') or not chunk.choices:
|
||||
print("No choices in chunk")
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
# Process content if present
|
||||
if hasattr(delta, 'content') and delta.content is not None:
|
||||
content_chunks += 1
|
||||
accumulated_content += delta.content
|
||||
print(f"Content: {delta.content}")
|
||||
|
||||
# Look for tool calls
|
||||
if hasattr(delta, 'tool_calls') and delta.tool_calls:
|
||||
tool_call_chunks += 1
|
||||
print("Tool call detected in chunk!")
|
||||
|
||||
for tool_call in delta.tool_calls:
|
||||
print(f"Tool call: {tool_call.model_dump()}")
|
||||
|
||||
# Track tool call parts
|
||||
tool_call_index = tool_call.index if hasattr(tool_call, 'index') else 0
|
||||
|
||||
# Initialize tool call if new
|
||||
while len(tool_calls) <= tool_call_index:
|
||||
tool_calls.append({
|
||||
"id": "",
|
||||
"type": "function",
|
||||
"function": {"name": "", "arguments": ""}
|
||||
})
|
||||
|
||||
# Update tool call ID if present
|
||||
if hasattr(tool_call, 'id') and tool_call.id:
|
||||
tool_calls[tool_call_index]["id"] = tool_call.id
|
||||
|
||||
# Update function name if present
|
||||
if hasattr(tool_call, 'function'):
|
||||
if hasattr(tool_call.function, 'name') and tool_call.function.name:
|
||||
tool_calls[tool_call_index]["function"]["name"] = tool_call.function.name
|
||||
|
||||
# Update function arguments if present
|
||||
if hasattr(tool_call.function, 'arguments') and tool_call.function.arguments:
|
||||
tool_calls[tool_call_index]["function"]["arguments"] += tool_call.function.arguments
|
||||
|
||||
# Summary after all chunks processed
|
||||
print("\n=== Streaming Summary ===")
|
||||
print(f"Total chunks: {chunk_count}")
|
||||
print(f"Content chunks: {content_chunks}")
|
||||
print(f"Tool call chunks: {tool_call_chunks}")
|
||||
|
||||
if accumulated_content:
|
||||
print(f"\nAccumulated content: {accumulated_content}")
|
||||
|
||||
if tool_calls:
|
||||
print("\nAccumulated tool calls:")
|
||||
for i, tool_call in enumerate(tool_calls):
|
||||
print(f"\nTool call {i+1}:")
|
||||
print(f" ID: {tool_call['id']}")
|
||||
print(f" Type: {tool_call['type']}")
|
||||
print(f" Function: {tool_call['function']['name']}")
|
||||
print(f" Arguments: {tool_call['function']['arguments']}")
|
||||
|
||||
# Try to parse arguments
|
||||
try:
|
||||
args = json.loads(tool_call['function']['arguments'])
|
||||
print("\nParsed arguments:")
|
||||
print(json.dumps(args, indent=4))
|
||||
except Exception as e:
|
||||
print(f"Error parsing arguments: {str(e)}")
|
||||
else:
|
||||
print("\nNo tool calls accumulated from streaming response.")
|
||||
|
||||
# --- Added logging for last chunk and finish reason ---
|
||||
finish_reason = None
|
||||
if last_chunk:
|
||||
try:
|
||||
if hasattr(last_chunk, 'choices') and last_chunk.choices:
|
||||
finish_reason = last_chunk.choices[0].finish_reason
|
||||
last_chunk_data = last_chunk.model_dump() if hasattr(last_chunk, 'model_dump') else vars(last_chunk)
|
||||
print("\n--- Last Chunk Received ---")
|
||||
print(f"Finish Reason: {finish_reason}")
|
||||
print(f"Raw Last Chunk Data: {json.dumps(last_chunk_data, indent=2)}")
|
||||
except Exception as log_ex:
|
||||
print("\n--- Error logging last chunk ---")
|
||||
print(f"Error: {log_ex}")
|
||||
print(f"Last Chunk (repr): {repr(last_chunk)}")
|
||||
else:
|
||||
print("\n--- No last chunk recorded ---")
|
||||
# --- End added logging ---
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in streaming test: {str(e)}", exc_info=True)
|
||||
|
||||
async def main():
|
||||
"""Run both tests for comparison."""
|
||||
# await test_simple_tool_call()
|
||||
await test_streaming_tool_call()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
|
@ -1,237 +0,0 @@
|
|||
"""
|
||||
Tests for tool execution strategies in AgentPress.
|
||||
|
||||
This module tests both sequential and parallel execution strategies using the WaitTool
|
||||
in a realistic thread with XML tool calls.
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from dotenv import load_dotenv
|
||||
from agentpress.thread_manager import ThreadManager
|
||||
from agentpress.response_processor import ProcessorConfig
|
||||
from agent.tools.wait_tool import WaitTool
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
TOOL_XML_SEQUENTIAL = """
|
||||
Here are some examples of using the wait tool:
|
||||
|
||||
<wait seconds="2">This is sequential wait 1</wait>
|
||||
<wait seconds="2">This is sequential wait 2</wait>
|
||||
<wait seconds="2">This is sequential wait 3</wait>
|
||||
|
||||
Now wait sequence:
|
||||
<wait-sequence count="3" seconds="1" label="Sequential Test" />
|
||||
"""
|
||||
|
||||
TOOL_XML_PARALLEL = """
|
||||
Here are some examples of using the wait tool:
|
||||
|
||||
<wait seconds="2">This is parallel wait 1</wait>
|
||||
<wait seconds="2">This is parallel wait 2</wait>
|
||||
<wait seconds="2">This is parallel wait 3</wait>
|
||||
|
||||
Now wait sequence:
|
||||
<wait-sequence count="3" seconds="1" label="Parallel Test" />
|
||||
"""
|
||||
|
||||
# Create a simple mock function that logs instead of accessing the database
|
||||
async def mock_add_message(thread_id, message):
|
||||
print(f"MOCK: Adding message to thread {thread_id}")
|
||||
print(f"MOCK: Message role: {message.get('role')}")
|
||||
print(f"MOCK: Content length: {len(message.get('content', ''))}")
|
||||
return {"id": "mock-message-id", "thread_id": thread_id}
|
||||
|
||||
async def test_execution_strategies():
|
||||
"""Test both sequential and parallel execution strategies in a thread."""
|
||||
print("\n" + "="*80)
|
||||
print("🧪 TESTING TOOL EXECUTION STRATEGIES")
|
||||
print("="*80 + "\n")
|
||||
|
||||
# Initialize ThreadManager and register tools
|
||||
thread_manager = ThreadManager()
|
||||
thread_manager.add_tool(WaitTool)
|
||||
|
||||
# Mock both ThreadManager's and ResponseProcessor's add_message method
|
||||
thread_manager.add_message = AsyncMock(side_effect=mock_add_message)
|
||||
# This is crucial - the ResponseProcessor receives add_message as a callback
|
||||
thread_manager.response_processor.add_message = AsyncMock(side_effect=mock_add_message)
|
||||
|
||||
# Create a test thread - we'll use a dummy ID since we're mocking the database
|
||||
thread_id = "test-thread-id"
|
||||
print(f"🧵 Using test thread: {thread_id}\n")
|
||||
|
||||
# Set up the get_llm_messages mock
|
||||
original_get_llm_messages = thread_manager.get_llm_messages
|
||||
thread_manager.get_llm_messages = AsyncMock()
|
||||
|
||||
# Test both strategies
|
||||
test_cases = [
|
||||
{"name": "Sequential", "strategy": "sequential", "content": TOOL_XML_SEQUENTIAL},
|
||||
{"name": "Parallel", "strategy": "parallel", "content": TOOL_XML_PARALLEL}
|
||||
]
|
||||
|
||||
# Expected values for validation - this varies based on XML parsing
|
||||
# For reliable testing, we look at <wait> tags which we know are being parsed
|
||||
expected_wait_count = 3 # 3 wait tags per test
|
||||
test_results = {}
|
||||
|
||||
for test in test_cases:
|
||||
print("\n" + "-"*60)
|
||||
print(f"🔍 Testing {test['name']} Execution Strategy")
|
||||
print("-"*60 + "\n")
|
||||
|
||||
# Setup mock for get_llm_messages to return our test content
|
||||
thread_manager.get_llm_messages.return_value = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a testing assistant that will execute wait commands."
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": test["content"]
|
||||
}
|
||||
]
|
||||
|
||||
# Simulate adding message (mocked)
|
||||
print(f"MOCK: Adding test message with {test['name']} execution strategy content")
|
||||
await thread_manager.add_message(
|
||||
thread_id=thread_id,
|
||||
type="assistant",
|
||||
content={
|
||||
"role": "assistant",
|
||||
"content": test["content"]
|
||||
},
|
||||
is_llm_message=True
|
||||
)
|
||||
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
print(f"⏱️ Starting execution with {test['strategy']} strategy at {start_time:.2f}s")
|
||||
|
||||
# Process the response with appropriate strategy
|
||||
config = ProcessorConfig(
|
||||
xml_tool_calling=True,
|
||||
native_tool_calling=False,
|
||||
execute_tools=True,
|
||||
execute_on_stream=False,
|
||||
tool_execution_strategy=test["strategy"]
|
||||
)
|
||||
|
||||
# Get the last message to process (mocked)
|
||||
messages = await thread_manager.get_llm_messages(thread_id)
|
||||
last_message = messages[-1]
|
||||
|
||||
# Create a simple non-streaming response object
|
||||
class MockResponse:
|
||||
def __init__(self, content):
|
||||
self.choices = [type('obj', (object,), {
|
||||
'message': type('obj', (object,), {
|
||||
'content': content
|
||||
})
|
||||
})]
|
||||
|
||||
mock_response = MockResponse(last_message["content"])
|
||||
|
||||
# Process using the response processor
|
||||
tool_execution_count = 0
|
||||
wait_tool_count = 0
|
||||
tool_results = []
|
||||
|
||||
async for chunk in thread_manager.response_processor.process_non_streaming_response(
|
||||
llm_response=mock_response,
|
||||
thread_id=thread_id,
|
||||
config=config
|
||||
):
|
||||
if chunk.get('type') == 'tool_result':
|
||||
tool_name = chunk.get('name', '')
|
||||
tool_execution_count += 1
|
||||
if tool_name == 'wait':
|
||||
wait_tool_count += 1
|
||||
|
||||
elapsed = asyncio.get_event_loop().time() - start_time
|
||||
print(f"⏱️ [{elapsed:.2f}s] Tool result: {chunk['name']}")
|
||||
print(f" {chunk['result']}")
|
||||
print()
|
||||
tool_results.append(chunk)
|
||||
|
||||
end_time = asyncio.get_event_loop().time()
|
||||
elapsed = end_time - start_time
|
||||
print(f"\n⏱️ {test['name']} execution completed in {elapsed:.2f} seconds")
|
||||
print(f"🔢 Total tool executions: {tool_execution_count}")
|
||||
print(f"🔢 Wait tool executions: {wait_tool_count}")
|
||||
|
||||
# Store results for validation
|
||||
test_results[test['name']] = {
|
||||
'execution_time': elapsed,
|
||||
'tool_count': tool_execution_count,
|
||||
'wait_count': wait_tool_count,
|
||||
'tool_results': tool_results
|
||||
}
|
||||
|
||||
# Assert correct number of wait tools executions (this is more reliable than total count)
|
||||
assert wait_tool_count == expected_wait_count, f"❌ Expected {expected_wait_count} wait tool executions, got {wait_tool_count} in {test['name']} strategy"
|
||||
print(f"✅ PASS: {test['name']} executed {wait_tool_count} wait tools as expected")
|
||||
|
||||
# Restore original get_llm_messages method
|
||||
thread_manager.get_llm_messages = original_get_llm_messages
|
||||
|
||||
# Additional assertions for both test cases
|
||||
assert 'Sequential' in test_results, "❌ Sequential test not completed"
|
||||
assert 'Parallel' in test_results, "❌ Parallel test not completed"
|
||||
|
||||
# Validate parallel is faster than sequential for multiple wait tools
|
||||
sequential_time = test_results['Sequential']['execution_time']
|
||||
parallel_time = test_results['Parallel']['execution_time']
|
||||
speedup = sequential_time / parallel_time if parallel_time > 0 else 0
|
||||
|
||||
# Parallel should be faster than sequential (at least 1.5x speedup expected)
|
||||
print(f"\n⏱️ Execution time comparison:")
|
||||
print(f" Sequential: {sequential_time:.2f}s")
|
||||
print(f" Parallel: {parallel_time:.2f}s")
|
||||
print(f" Speedup: {speedup:.2f}x")
|
||||
|
||||
min_expected_speedup = 1.5
|
||||
assert speedup >= min_expected_speedup, f"❌ Expected parallel execution to be at least {min_expected_speedup}x faster than sequential, but got {speedup:.2f}x"
|
||||
print(f"✅ PASS: Parallel execution is {speedup:.2f}x faster than sequential")
|
||||
|
||||
# Check if all results have a status field
|
||||
all_have_status = all(
|
||||
'status' in result
|
||||
for test_data in test_results.values()
|
||||
for result in test_data['tool_results']
|
||||
)
|
||||
|
||||
# If results have a status field, check if they're all successful
|
||||
if all_have_status:
|
||||
all_successful = all(
|
||||
result.get('status') == 'success'
|
||||
for test_data in test_results.values()
|
||||
for result in test_data['tool_results']
|
||||
)
|
||||
assert all_successful, "❌ Not all tool executions were successful"
|
||||
print("✅ PASS: All tool executions completed successfully")
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("✅ ALL TESTS PASSED")
|
||||
print("="*80 + "\n")
|
||||
|
||||
return test_results
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(test_execution_strategies())
|
||||
print("\n✅ Test completed successfully")
|
||||
sys.exit(0)
|
||||
except AssertionError as e:
|
||||
print(f"\n\n❌ Test failed: {str(e)}")
|
||||
sys.exit(1)
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n❌ Test interrupted by user")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n\n❌ Error during test: {str(e)}")
|
||||
sys.exit(1)
|
|
@ -1,282 +0,0 @@
|
|||
"""
|
||||
Tests for XML tool execution in streaming and non-streaming modes.
|
||||
|
||||
This module tests XML tool execution with execute_on_stream set to TRUE and FALSE,
|
||||
to ensure both modes work correctly with the WaitTool.
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from dotenv import load_dotenv
|
||||
from agentpress.thread_manager import ThreadManager
|
||||
from agentpress.response_processor import ProcessorConfig
|
||||
from agent.tools.wait_tool import WaitTool
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# XML content with wait tool calls
|
||||
XML_CONTENT = """
|
||||
Here are some examples of using the wait tool:
|
||||
|
||||
<wait seconds="1">This is wait 1</wait>
|
||||
<wait seconds="1">This is wait 2</wait>
|
||||
<wait seconds="1">This is wait 3</wait>
|
||||
|
||||
Now wait sequence:
|
||||
<wait-sequence count="2" seconds="1" label="Test" />
|
||||
"""
|
||||
|
||||
class MockStreamingResponse:
|
||||
"""Mock streaming response from an LLM."""
|
||||
|
||||
def __init__(self, content):
|
||||
self.content = content
|
||||
self.chunk_size = 20 # Small chunks to simulate streaming
|
||||
|
||||
async def __aiter__(self):
|
||||
# Split content into chunks to simulate streaming
|
||||
for i in range(0, len(self.content), self.chunk_size):
|
||||
chunk = self.content[i:i+self.chunk_size]
|
||||
yield type('obj', (object,), {
|
||||
'choices': [type('obj', (object,), {
|
||||
'delta': type('obj', (object,), {
|
||||
'content': chunk
|
||||
})
|
||||
})]
|
||||
})
|
||||
# Simulate some network delay
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
class MockNonStreamingResponse:
|
||||
"""Mock non-streaming response from an LLM."""
|
||||
|
||||
def __init__(self, content):
|
||||
self.choices = [type('obj', (object,), {
|
||||
'message': type('obj', (object,), {
|
||||
'content': content
|
||||
})
|
||||
})]
|
||||
|
||||
# Create a simple mock function that logs instead of accessing the database
|
||||
async def mock_add_message(thread_id, message):
|
||||
print(f"MOCK: Adding message to thread {thread_id}")
|
||||
print(f"MOCK: Message role: {message.get('role')}")
|
||||
print(f"MOCK: Content length: {len(message.get('content', ''))}")
|
||||
return {"id": "mock-message-id", "thread_id": thread_id}
|
||||
|
||||
async def test_xml_streaming_execution():
|
||||
"""Test XML tool execution in both streaming and non-streaming modes."""
|
||||
print("\n" + "="*80)
|
||||
print("🧪 TESTING XML TOOL EXECUTION: STREAMING VS NON-STREAMING")
|
||||
print("="*80 + "\n")
|
||||
|
||||
# Initialize ThreadManager and register tools
|
||||
thread_manager = ThreadManager()
|
||||
thread_manager.add_tool(WaitTool)
|
||||
|
||||
# Mock both ThreadManager's and ResponseProcessor's add_message method
|
||||
thread_manager.add_message = AsyncMock(side_effect=mock_add_message)
|
||||
thread_manager.response_processor.add_message = AsyncMock(side_effect=mock_add_message)
|
||||
|
||||
# Set up the get_llm_messages mock
|
||||
original_get_llm_messages = thread_manager.get_llm_messages
|
||||
thread_manager.get_llm_messages = AsyncMock()
|
||||
|
||||
# Test cases for streaming and non-streaming
|
||||
test_cases = [
|
||||
{"name": "Non-Streaming", "execute_on_stream": False},
|
||||
{"name": "Streaming", "execute_on_stream": True}
|
||||
]
|
||||
|
||||
# Expected values for validation - focus specifically on wait tools
|
||||
expected_wait_count = 3 # 3 wait tags in the XML content
|
||||
test_results = {}
|
||||
|
||||
for test in test_cases:
|
||||
# Create a test thread ID - we're mocking so no actual creation
|
||||
thread_id = f"test-thread-{test['name'].lower()}"
|
||||
|
||||
print("\n" + "-"*60)
|
||||
print(f"🔍 Testing XML Tool Execution - {test['name']} Mode")
|
||||
print("-"*60 + "\n")
|
||||
|
||||
# Setup mock for get_llm_messages to return test content
|
||||
thread_manager.get_llm_messages.return_value = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a testing assistant that will execute wait commands."
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": XML_CONTENT
|
||||
}
|
||||
]
|
||||
|
||||
# Simulate adding system message (mocked)
|
||||
print(f"MOCK: Adding system message to thread {thread_id}")
|
||||
await thread_manager.add_message(
|
||||
thread_id=thread_id,
|
||||
type="system",
|
||||
content={
|
||||
"role": "system",
|
||||
"content": "You are a testing assistant that will execute wait commands."
|
||||
},
|
||||
is_llm_message=False
|
||||
)
|
||||
|
||||
# Simulate adding message with XML content (mocked)
|
||||
print(f"MOCK: Adding message with XML content to thread {thread_id}")
|
||||
await thread_manager.add_message(
|
||||
thread_id=thread_id,
|
||||
type="assistant",
|
||||
content={
|
||||
"role": "assistant",
|
||||
"content": XML_CONTENT
|
||||
},
|
||||
is_llm_message=True
|
||||
)
|
||||
|
||||
print(f"🧵 Using test thread: {thread_id}")
|
||||
print(f"⚙️ execute_on_stream: {test['execute_on_stream']}")
|
||||
|
||||
# Prepare the response processor config
|
||||
config = ProcessorConfig(
|
||||
xml_tool_calling=True,
|
||||
native_tool_calling=False,
|
||||
execute_tools=True,
|
||||
execute_on_stream=test['execute_on_stream'],
|
||||
tool_execution_strategy="sequential"
|
||||
)
|
||||
|
||||
# Get the last message to process (using mock)
|
||||
messages = await thread_manager.get_llm_messages(thread_id)
|
||||
last_message = messages[-1]
|
||||
|
||||
# Process response based on mode
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
print(f"⏱️ Starting execution at {start_time:.2f}s")
|
||||
tool_execution_count = 0
|
||||
wait_tool_count = 0
|
||||
tool_results = []
|
||||
|
||||
if test['execute_on_stream']:
|
||||
# Create streaming response
|
||||
streaming_response = MockStreamingResponse(last_message["content"])
|
||||
|
||||
# Process streaming response
|
||||
async for chunk in thread_manager.response_processor.process_streaming_response(
|
||||
llm_response=streaming_response,
|
||||
thread_id=thread_id,
|
||||
config=config
|
||||
):
|
||||
if chunk.get('type') == 'tool_result':
|
||||
elapsed = asyncio.get_event_loop().time() - start_time
|
||||
tool_name = chunk.get('name', '')
|
||||
tool_execution_count += 1
|
||||
if tool_name == 'wait':
|
||||
wait_tool_count += 1
|
||||
|
||||
print(f"⏱️ [{elapsed:.2f}s] Tool result: {chunk['name']}")
|
||||
print(f" {chunk['result']}")
|
||||
print()
|
||||
tool_results.append(chunk)
|
||||
else:
|
||||
# Create non-streaming response
|
||||
non_streaming_response = MockNonStreamingResponse(last_message["content"])
|
||||
|
||||
# Process non-streaming response
|
||||
async for chunk in thread_manager.response_processor.process_non_streaming_response(
|
||||
llm_response=non_streaming_response,
|
||||
thread_id=thread_id,
|
||||
config=config
|
||||
):
|
||||
if chunk.get('type') == 'tool_result':
|
||||
elapsed = asyncio.get_event_loop().time() - start_time
|
||||
tool_name = chunk.get('name', '')
|
||||
tool_execution_count += 1
|
||||
if tool_name == 'wait':
|
||||
wait_tool_count += 1
|
||||
|
||||
print(f"⏱️ [{elapsed:.2f}s] Tool result: {chunk['name']}")
|
||||
print(f" {chunk['result']}")
|
||||
print()
|
||||
tool_results.append(chunk)
|
||||
|
||||
end_time = asyncio.get_event_loop().time()
|
||||
elapsed = end_time - start_time
|
||||
|
||||
print(f"\n⏱️ {test['name']} execution completed in {elapsed:.2f} seconds")
|
||||
print(f"🔢 Total tool executions: {tool_execution_count}")
|
||||
print(f"🔢 Wait tool executions: {wait_tool_count}")
|
||||
|
||||
# Store results for validation
|
||||
test_results[test['name']] = {
|
||||
'execution_time': elapsed,
|
||||
'tool_count': tool_execution_count,
|
||||
'wait_count': wait_tool_count,
|
||||
'tool_results': tool_results
|
||||
}
|
||||
|
||||
# Assert correct number of wait tool executions
|
||||
assert wait_tool_count == expected_wait_count, f"❌ Expected {expected_wait_count} wait tool executions, got {wait_tool_count} in {test['name']} mode"
|
||||
print(f"✅ PASS: {test['name']} executed {wait_tool_count} wait tools as expected")
|
||||
|
||||
# Restore original get_llm_messages method
|
||||
thread_manager.get_llm_messages = original_get_llm_messages
|
||||
|
||||
# Additional assertions for both test cases
|
||||
assert 'Non-Streaming' in test_results, "❌ Non-streaming test not completed"
|
||||
assert 'Streaming' in test_results, "❌ Streaming test not completed"
|
||||
|
||||
# Validate streaming has different timing characteristics than non-streaming
|
||||
non_streaming_time = test_results['Non-Streaming']['execution_time']
|
||||
streaming_time = test_results['Streaming']['execution_time']
|
||||
|
||||
# Streaming should have different timing due to the nature of execution
|
||||
# We don't assert strict timing as it can vary, but we validate the tests ran successfully
|
||||
print(f"\n⏱️ Execution time comparison:")
|
||||
print(f" Non-Streaming: {non_streaming_time:.2f}s")
|
||||
print(f" Streaming: {streaming_time:.2f}s")
|
||||
print(f" Time difference: {abs(non_streaming_time - streaming_time):.2f}s")
|
||||
|
||||
# Check if all results have a status field
|
||||
all_have_status = all(
|
||||
'status' in result
|
||||
for test_data in test_results.values()
|
||||
for result in test_data['tool_results']
|
||||
)
|
||||
|
||||
# If results have a status field, check if they're all successful
|
||||
if all_have_status:
|
||||
all_successful = all(
|
||||
result.get('status') == 'success'
|
||||
for test_data in test_results.values()
|
||||
for result in test_data['tool_results']
|
||||
)
|
||||
assert all_successful, "❌ Not all tool executions were successful"
|
||||
print("✅ PASS: All tool executions completed successfully")
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("✅ ALL TESTS PASSED")
|
||||
print("="*80 + "\n")
|
||||
|
||||
return test_results
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(test_xml_streaming_execution())
|
||||
print("\n✅ Test completed successfully")
|
||||
sys.exit(0)
|
||||
except AssertionError as e:
|
||||
print(f"\n\n❌ Test failed: {str(e)}")
|
||||
sys.exit(1)
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n❌ Test interrupted by user")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n\n❌ Error during test: {str(e)}")
|
||||
sys.exit(1)
|
|
@ -5,7 +5,7 @@ from jwt.exceptions import PyJWTError
|
|||
from utils.logger import logger
|
||||
|
||||
# This function extracts the user ID from Supabase JWT
|
||||
async def get_current_user_id(request: Request) -> str:
|
||||
async def get_current_user_id_from_jwt(request: Request) -> str:
|
||||
"""
|
||||
Extract and verify the user ID from the JWT in the Authorization header.
|
||||
|
||||
|
@ -56,6 +56,45 @@ async def get_current_user_id(request: Request) -> str:
|
|||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
async def get_account_id_from_thread(client, thread_id: str) -> str:
|
||||
"""
|
||||
Extract and verify the account ID from the thread.
|
||||
|
||||
Args:
|
||||
client: The Supabase client
|
||||
thread_id: The ID of the thread
|
||||
|
||||
Returns:
|
||||
str: The account ID associated with the thread
|
||||
|
||||
Raises:
|
||||
HTTPException: If the thread is not found or if there's an error
|
||||
"""
|
||||
try:
|
||||
response = await client.table('threads').select('account_id').eq('thread_id', thread_id).execute()
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Thread not found"
|
||||
)
|
||||
|
||||
account_id = response.data[0].get('account_id')
|
||||
|
||||
if not account_id:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Thread has no associated account"
|
||||
)
|
||||
|
||||
return account_id
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error retrieving thread information: {str(e)}"
|
||||
)
|
||||
|
||||
async def get_user_id_from_stream_auth(
|
||||
request: Request,
|
||||
token: Optional[str] = None
|
||||
|
@ -105,6 +144,7 @@ async def get_user_id_from_stream_auth(
|
|||
detail="No valid authentication credentials found",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
async def verify_thread_access(client, thread_id: str, user_id: str):
|
||||
"""
|
||||
Verify that a user has access to a specific thread based on account membership.
|
||||
|
|
|
@ -1,125 +0,0 @@
|
|||
from datetime import datetime, timezone
|
||||
from typing import Dict, Optional, Tuple
|
||||
from utils.logger import logger
|
||||
from utils.config import config, EnvMode
|
||||
|
||||
# Define subscription tiers and their monthly limits (in minutes)
|
||||
SUBSCRIPTION_TIERS = {
|
||||
'price_1RGJ9GG6l1KZGqIroxSqgphC': {'name': 'free', 'minutes': 8},
|
||||
'price_1RGJ9LG6l1KZGqIrd9pwzeNW': {'name': 'base', 'minutes': 300},
|
||||
'price_1RGJ9JG6l1KZGqIrVUU4ZRv6': {'name': 'extra', 'minutes': 2400}
|
||||
}
|
||||
|
||||
async def get_account_subscription(client, account_id: str) -> Optional[Dict]:
|
||||
"""Get the current subscription for an account."""
|
||||
result = await client.schema('basejump').from_('billing_subscriptions') \
|
||||
.select('*') \
|
||||
.eq('account_id', account_id) \
|
||||
.eq('status', 'active') \
|
||||
.order('created', desc=True) \
|
||||
.limit(1) \
|
||||
.execute()
|
||||
|
||||
if result.data and len(result.data) > 0:
|
||||
return result.data[0]
|
||||
return None
|
||||
|
||||
async def calculate_monthly_usage(client, account_id: str) -> float:
|
||||
"""Calculate total agent run minutes for the current month for an account."""
|
||||
# Get start of current month in UTC
|
||||
now = datetime.now(timezone.utc)
|
||||
start_of_month = datetime(now.year, now.month, 1, tzinfo=timezone.utc)
|
||||
|
||||
# First get all threads for this account
|
||||
threads_result = await client.table('threads') \
|
||||
.select('thread_id') \
|
||||
.eq('account_id', account_id) \
|
||||
.execute()
|
||||
|
||||
if not threads_result.data:
|
||||
return 0.0
|
||||
|
||||
thread_ids = [t['thread_id'] for t in threads_result.data]
|
||||
|
||||
# Then get all agent runs for these threads in current month
|
||||
runs_result = await client.table('agent_runs') \
|
||||
.select('started_at, completed_at') \
|
||||
.in_('thread_id', thread_ids) \
|
||||
.gte('started_at', start_of_month.isoformat()) \
|
||||
.execute()
|
||||
|
||||
if not runs_result.data:
|
||||
return 0.0
|
||||
|
||||
# Calculate total minutes
|
||||
total_seconds = 0
|
||||
now_ts = now.timestamp()
|
||||
|
||||
for run in runs_result.data:
|
||||
start_time = datetime.fromisoformat(run['started_at'].replace('Z', '+00:00')).timestamp()
|
||||
if run['completed_at']:
|
||||
end_time = datetime.fromisoformat(run['completed_at'].replace('Z', '+00:00')).timestamp()
|
||||
else:
|
||||
# For running jobs, use current time
|
||||
end_time = now_ts
|
||||
|
||||
total_seconds += (end_time - start_time)
|
||||
|
||||
return total_seconds / 60 # Convert to minutes
|
||||
|
||||
async def check_billing_status(client, account_id: str) -> Tuple[bool, str, Optional[Dict]]:
|
||||
"""
|
||||
Check if an account can run agents based on their subscription and usage.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, Optional[Dict]]: (can_run, message, subscription_info)
|
||||
"""
|
||||
if config.ENV_MODE == EnvMode.LOCAL:
|
||||
logger.info("Running in local development mode - billing checks are disabled")
|
||||
return True, "Local development mode - billing disabled", {
|
||||
"price_id": "local_dev",
|
||||
"plan_name": "Local Development",
|
||||
"minutes_limit": "no limit"
|
||||
}
|
||||
|
||||
# For staging/production, check subscription status
|
||||
|
||||
# Get current subscription
|
||||
subscription = await get_account_subscription(client, account_id)
|
||||
|
||||
# If no subscription, they can use free tier
|
||||
if not subscription:
|
||||
subscription = {
|
||||
'price_id': 'price_1RGJ9GG6l1KZGqIroxSqgphC', # Free tier
|
||||
'plan_name': 'free'
|
||||
}
|
||||
|
||||
# if not subscription or subscription.get('price_id') is None or subscription.get('price_id') == 'price_1RGJ9GG6l1KZGqIroxSqgphC':
|
||||
# return False, "You are not subscribed to any plan. Please upgrade your plan to continue.", subscription
|
||||
|
||||
# Get tier info
|
||||
tier_info = SUBSCRIPTION_TIERS.get(subscription['price_id'])
|
||||
if not tier_info:
|
||||
return False, "Invalid subscription tier", subscription
|
||||
|
||||
# Calculate current month's usage
|
||||
current_usage = await calculate_monthly_usage(client, account_id)
|
||||
|
||||
# Check if within limits
|
||||
if current_usage >= tier_info['minutes']:
|
||||
return False, f"Monthly limit of {tier_info['minutes']} minutes reached. Please upgrade your plan or wait until next month.", subscription
|
||||
|
||||
return True, "OK", subscription
|
||||
|
||||
# Helper function to get account ID from thread
|
||||
async def get_account_id_from_thread(client, thread_id: str) -> Optional[str]:
|
||||
"""Get the account ID associated with a thread."""
|
||||
result = await client.table('threads') \
|
||||
.select('account_id') \
|
||||
.eq('thread_id', thread_id) \
|
||||
.limit(1) \
|
||||
.execute()
|
||||
|
||||
if result.data and len(result.data) > 0:
|
||||
return result.data[0]['account_id']
|
||||
return None
|
|
@ -39,6 +39,75 @@ class Configuration:
|
|||
# Environment mode
|
||||
ENV_MODE: EnvMode = EnvMode.LOCAL
|
||||
|
||||
# Subscription tier IDs - Production
|
||||
STRIPE_FREE_TIER_ID_PROD: str = 'price_1RILb4G6l1KZGqIrK4QLrx9i'
|
||||
STRIPE_TIER_2_20_ID_PROD: str = 'price_1RILb4G6l1KZGqIrhomjgDnO'
|
||||
STRIPE_TIER_6_50_ID_PROD: str = 'price_1RILb4G6l1KZGqIr5q0sybWn'
|
||||
STRIPE_TIER_12_100_ID_PROD: str = 'price_1RILb4G6l1KZGqIr5Y20ZLHm'
|
||||
STRIPE_TIER_25_200_ID_PROD: str = 'price_1RILb4G6l1KZGqIrGAD8rNjb'
|
||||
STRIPE_TIER_50_400_ID_PROD: str = 'price_1RILb4G6l1KZGqIruNBUMTF1'
|
||||
STRIPE_TIER_125_800_ID_PROD: str = 'price_1RILb3G6l1KZGqIrbJA766tN'
|
||||
STRIPE_TIER_200_1000_ID_PROD: str = 'price_1RILb3G6l1KZGqIrmauYPOiN'
|
||||
|
||||
# Subscription tier IDs - Staging
|
||||
STRIPE_FREE_TIER_ID_STAGING: str = 'price_1RIGvuG6l1KZGqIrw14abxeL'
|
||||
STRIPE_TIER_2_20_ID_STAGING: str = 'price_1RIGvuG6l1KZGqIrCRu0E4Gi'
|
||||
STRIPE_TIER_6_50_ID_STAGING: str = 'price_1RIGvuG6l1KZGqIrvjlz5p5V'
|
||||
STRIPE_TIER_12_100_ID_STAGING: str = 'price_1RIGvuG6l1KZGqIrT6UfgblC'
|
||||
STRIPE_TIER_25_200_ID_STAGING: str = 'price_1RIGvuG6l1KZGqIrOVLKlOMj'
|
||||
STRIPE_TIER_50_400_ID_STAGING: str = 'price_1RIKNgG6l1KZGqIrvsat5PW7'
|
||||
STRIPE_TIER_125_800_ID_STAGING: str = 'price_1RIKNrG6l1KZGqIrjKT0yGvI'
|
||||
STRIPE_TIER_200_1000_ID_STAGING: str = 'price_1RIKQ2G6l1KZGqIrum9n8SI7'
|
||||
|
||||
# Computed subscription tier IDs based on environment
|
||||
@property
|
||||
def STRIPE_FREE_TIER_ID(self) -> str:
|
||||
if self.ENV_MODE == EnvMode.STAGING:
|
||||
return self.STRIPE_FREE_TIER_ID_STAGING
|
||||
return self.STRIPE_FREE_TIER_ID_PROD
|
||||
|
||||
@property
|
||||
def STRIPE_TIER_2_20_ID(self) -> str:
|
||||
if self.ENV_MODE == EnvMode.STAGING:
|
||||
return self.STRIPE_TIER_2_20_ID_STAGING
|
||||
return self.STRIPE_TIER_2_20_ID_PROD
|
||||
|
||||
@property
|
||||
def STRIPE_TIER_6_50_ID(self) -> str:
|
||||
if self.ENV_MODE == EnvMode.STAGING:
|
||||
return self.STRIPE_TIER_6_50_ID_STAGING
|
||||
return self.STRIPE_TIER_6_50_ID_PROD
|
||||
|
||||
@property
|
||||
def STRIPE_TIER_12_100_ID(self) -> str:
|
||||
if self.ENV_MODE == EnvMode.STAGING:
|
||||
return self.STRIPE_TIER_12_100_ID_STAGING
|
||||
return self.STRIPE_TIER_12_100_ID_PROD
|
||||
|
||||
@property
|
||||
def STRIPE_TIER_25_200_ID(self) -> str:
|
||||
if self.ENV_MODE == EnvMode.STAGING:
|
||||
return self.STRIPE_TIER_25_200_ID_STAGING
|
||||
return self.STRIPE_TIER_25_200_ID_PROD
|
||||
|
||||
@property
|
||||
def STRIPE_TIER_50_400_ID(self) -> str:
|
||||
if self.ENV_MODE == EnvMode.STAGING:
|
||||
return self.STRIPE_TIER_50_400_ID_STAGING
|
||||
return self.STRIPE_TIER_50_400_ID_PROD
|
||||
|
||||
@property
|
||||
def STRIPE_TIER_125_800_ID(self) -> str:
|
||||
if self.ENV_MODE == EnvMode.STAGING:
|
||||
return self.STRIPE_TIER_125_800_ID_STAGING
|
||||
return self.STRIPE_TIER_125_800_ID_PROD
|
||||
|
||||
@property
|
||||
def STRIPE_TIER_200_1000_ID(self) -> str:
|
||||
if self.ENV_MODE == EnvMode.STAGING:
|
||||
return self.STRIPE_TIER_200_1000_ID_STAGING
|
||||
return self.STRIPE_TIER_200_1000_ID_PROD
|
||||
|
||||
# LLM API keys
|
||||
ANTHROPIC_API_KEY: str = None
|
||||
OPENAI_API_KEY: Optional[str] = None
|
||||
|
@ -80,9 +149,19 @@ class Configuration:
|
|||
|
||||
# Stripe configuration
|
||||
STRIPE_SECRET_KEY: Optional[str] = None
|
||||
STRIPE_WEBHOOK_SECRET: Optional[str] = None
|
||||
STRIPE_DEFAULT_PLAN_ID: Optional[str] = None
|
||||
STRIPE_DEFAULT_TRIAL_DAYS: int = 14
|
||||
|
||||
# Stripe Product IDs
|
||||
STRIPE_PRODUCT_ID_PROD: str = 'prod_SCl7AQ2C8kK1CD' # Production product ID
|
||||
STRIPE_PRODUCT_ID_STAGING: str = 'prod_SCgIj3G7yPOAWY' # Staging product ID
|
||||
|
||||
@property
|
||||
def STRIPE_PRODUCT_ID(self) -> str:
|
||||
if self.ENV_MODE == EnvMode.STAGING:
|
||||
return self.STRIPE_PRODUCT_ID_STAGING
|
||||
return self.STRIPE_PRODUCT_ID_PROD
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize configuration by loading from environment variables."""
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
version: '3.8'
|
||||
|
||||
services:
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import {createClient} from "@/lib/supabase/server";
|
||||
import AccountBillingStatus from "@/components/basejump/account-billing-status";
|
||||
import AccountBillingStatus from "@/components/billing/account-billing-status";
|
||||
|
||||
const returnUrl = process.env.NEXT_PUBLIC_URL as string;
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import React from 'react';
|
||||
import {createClient} from "@/lib/supabase/server";
|
||||
import AccountBillingStatus from "@/components/basejump/account-billing-status";
|
||||
import AccountBillingStatus from "@/components/billing/account-billing-status";
|
||||
import { Alert, AlertTitle, AlertDescription } from "@/components/ui/alert";
|
||||
|
||||
const returnUrl = process.env.NEXT_PUBLIC_URL as string;
|
||||
|
|
|
@ -7,7 +7,7 @@ import { Button } from '@/components/ui/button';
|
|||
import {
|
||||
ArrowDown, CheckCircle, CircleDashed, AlertTriangle, Info, File, ChevronRight
|
||||
} from 'lucide-react';
|
||||
import { addUserMessage, getMessages, startAgent, stopAgent, getAgentRuns, getProject, getThread, updateProject, Project, Message as BaseApiMessageType, BillingError } from '@/lib/api';
|
||||
import { addUserMessage, getMessages, startAgent, stopAgent, getAgentRuns, getProject, getThread, updateProject, Project, Message as BaseApiMessageType, BillingError, checkBillingStatus } from '@/lib/api';
|
||||
import { toast } from 'sonner';
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import { ChatInput } from '@/components/thread/chat-input';
|
||||
|
@ -20,10 +20,9 @@ import { Markdown } from '@/components/ui/markdown';
|
|||
import { cn } from "@/lib/utils";
|
||||
import { useIsMobile } from "@/hooks/use-mobile";
|
||||
import { BillingErrorAlert } from '@/components/billing/usage-limit-alert';
|
||||
import { SUBSCRIPTION_PLANS } from '@/components/billing/plan-comparison';
|
||||
import { createClient } from '@/lib/supabase/client';
|
||||
import { isLocalMode } from "@/lib/config";
|
||||
|
||||
|
||||
import { UnifiedMessage, ParsedContent, ParsedMetadata, ThreadParams } from '@/components/thread/types';
|
||||
import { getToolIcon, extractPrimaryParam, safeJsonParse } from '@/components/thread/utils';
|
||||
|
||||
|
@ -1040,122 +1039,62 @@ export default function ThreadPage({ params }: { params: Promise<ThreadParams> }
|
|||
}
|
||||
}, [agentStatus, threadId, isLoading, streamHookStatus]);
|
||||
|
||||
// Check billing status when agent completes
|
||||
const checkBillingStatus = useCallback(async () => {
|
||||
// Update the checkBillingStatus function
|
||||
const checkBillingLimits = useCallback(async () => {
|
||||
// Skip billing checks in local development mode
|
||||
if (isLocalMode()) {
|
||||
console.log("Running in local development mode - billing checks are disabled");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!project?.account_id) return;
|
||||
|
||||
const supabase = createClient();
|
||||
|
||||
try {
|
||||
// Check subscription status
|
||||
const { data: subscriptionData } = await supabase
|
||||
.schema('basejump')
|
||||
.from('billing_subscriptions')
|
||||
.select('price_id')
|
||||
.eq('account_id', project.account_id)
|
||||
.eq('status', 'active')
|
||||
.single();
|
||||
const result = await checkBillingStatus();
|
||||
|
||||
const currentPlanId = subscriptionData?.price_id || SUBSCRIPTION_PLANS.FREE;
|
||||
|
||||
// Only check usage limits for free tier users
|
||||
if (currentPlanId === SUBSCRIPTION_PLANS.FREE) {
|
||||
// Calculate usage
|
||||
const startOfMonth = new Date();
|
||||
startOfMonth.setDate(1);
|
||||
startOfMonth.setHours(0, 0, 0, 0);
|
||||
|
||||
// Get threads for this account
|
||||
const { data: threadsData } = await supabase
|
||||
.from('threads')
|
||||
.select('thread_id')
|
||||
.eq('account_id', project.account_id);
|
||||
|
||||
const threadIds = threadsData?.map(t => t.thread_id) || [];
|
||||
|
||||
// Get agent runs for those threads
|
||||
const { data: agentRunData } = await supabase
|
||||
.from('agent_runs')
|
||||
.select('started_at, completed_at')
|
||||
.in('thread_id', threadIds)
|
||||
.gte('started_at', startOfMonth.toISOString());
|
||||
|
||||
let totalSeconds = 0;
|
||||
if (agentRunData) {
|
||||
totalSeconds = agentRunData.reduce((acc, run) => {
|
||||
const start = new Date(run.started_at);
|
||||
const end = run.completed_at ? new Date(run.completed_at) : new Date();
|
||||
const seconds = (end.getTime() - start.getTime()) / 1000;
|
||||
return acc + seconds;
|
||||
}, 0);
|
||||
}
|
||||
|
||||
// Convert to hours for display
|
||||
const hours = totalSeconds / 3600;
|
||||
const minutesUsed = totalSeconds / 60;
|
||||
|
||||
// The free plan has a 10 minute limit as defined in backend/utils/billing.py
|
||||
const FREE_PLAN_LIMIT_MINUTES = 10;
|
||||
const FREE_PLAN_LIMIT_HOURS = FREE_PLAN_LIMIT_MINUTES / 60;
|
||||
|
||||
// Show alert if over limit
|
||||
if (minutesUsed > FREE_PLAN_LIMIT_MINUTES) {
|
||||
console.log("Usage limit exceeded:", {
|
||||
minutesUsed,
|
||||
hoursUsed: hours,
|
||||
limit: FREE_PLAN_LIMIT_MINUTES
|
||||
});
|
||||
setBillingData({
|
||||
currentUsage: Number(hours.toFixed(2)),
|
||||
limit: FREE_PLAN_LIMIT_HOURS,
|
||||
message: `You've used ${Math.floor(minutesUsed)} minutes on the Free plan. The limit is ${FREE_PLAN_LIMIT_MINUTES} minutes per month.`,
|
||||
accountId: project.account_id || null
|
||||
});
|
||||
setShowBillingAlert(true);
|
||||
return true; // Return true if over limit
|
||||
}
|
||||
if (!result.can_run) {
|
||||
setBillingData({
|
||||
currentUsage: result.subscription?.minutes_limit || 0,
|
||||
limit: result.subscription?.minutes_limit || 0,
|
||||
message: result.message || 'Usage limit reached',
|
||||
accountId: project?.account_id || null
|
||||
});
|
||||
setShowBillingAlert(true);
|
||||
return true;
|
||||
}
|
||||
return false; // Return false if not over limit
|
||||
return false;
|
||||
} catch (err) {
|
||||
console.error('Error checking billing status:', err);
|
||||
return false;
|
||||
}
|
||||
}, [project?.account_id]);
|
||||
|
||||
// Update useEffect to check billing when agent completes
|
||||
// Update useEffect to use the renamed function
|
||||
useEffect(() => {
|
||||
const previousStatus = previousAgentStatus.current;
|
||||
|
||||
// Check if agent just completed (status changed from running to idle)
|
||||
if (previousStatus === 'running' && agentStatus === 'idle') {
|
||||
checkBillingStatus();
|
||||
checkBillingLimits();
|
||||
}
|
||||
|
||||
// Store current status for next comparison
|
||||
previousAgentStatus.current = agentStatus;
|
||||
}, [agentStatus, checkBillingStatus]);
|
||||
}, [agentStatus, checkBillingLimits]);
|
||||
|
||||
// Add new useEffect to check billing limits when page first loads or project changes
|
||||
// Update other useEffect to use the renamed function
|
||||
useEffect(() => {
|
||||
if (project?.account_id && initialLoadCompleted.current) {
|
||||
console.log("Checking billing status on page load");
|
||||
checkBillingStatus();
|
||||
checkBillingLimits();
|
||||
}
|
||||
}, [project?.account_id, checkBillingStatus, initialLoadCompleted]);
|
||||
|
||||
// Also check after messages are loaded to ensure we have the complete state
|
||||
}, [project?.account_id, checkBillingLimits, initialLoadCompleted]);
|
||||
|
||||
// Update the last useEffect to use the renamed function
|
||||
useEffect(() => {
|
||||
if (messagesLoadedRef.current && project?.account_id && !isLoading) {
|
||||
console.log("Checking billing status after messages loaded");
|
||||
checkBillingStatus();
|
||||
checkBillingLimits();
|
||||
}
|
||||
}, [messagesLoadedRef.current, checkBillingStatus, project?.account_id, isLoading]);
|
||||
}, [messagesLoadedRef.current, checkBillingLimits, project?.account_id, isLoading]);
|
||||
|
||||
if (isLoading && !initialLoadCompleted.current) {
|
||||
return (
|
||||
|
|
|
@ -14,7 +14,7 @@ import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip
|
|||
import { useBillingError } from "@/hooks/useBillingError";
|
||||
import { BillingErrorAlert } from "@/components/billing/usage-limit-alert";
|
||||
import { useAccounts } from "@/hooks/use-accounts";
|
||||
import { isLocalMode } from "@/lib/config";
|
||||
import { isLocalMode, config } from "@/lib/config";
|
||||
import { toast } from "sonner";
|
||||
|
||||
// Constant for localStorage key to ensure consistency
|
||||
|
@ -103,7 +103,7 @@ function DashboardContent() {
|
|||
limit: error.detail.limit as number | undefined,
|
||||
// Include subscription details if available in the error, otherwise provide defaults
|
||||
subscription: error.detail.subscription || {
|
||||
price_id: "price_1RGJ9GG6l1KZGqIroxSqgphC", // Default to Free tier
|
||||
price_id: config.SUBSCRIPTION_TIERS.FREE.priceId, // Default to Free tier
|
||||
plan_name: "Free"
|
||||
}
|
||||
});
|
||||
|
|
|
@ -6,7 +6,7 @@ import {
|
|||
SidebarInset,
|
||||
SidebarProvider,
|
||||
} from "@/components/ui/sidebar"
|
||||
import { PricingAlert } from "@/components/billing/pricing-alert"
|
||||
// import { PricingAlert } from "@/components/billing/pricing-alert"
|
||||
import { MaintenanceAlert } from "@/components/maintenance-alert"
|
||||
import { useAccounts } from "@/hooks/use-accounts"
|
||||
import { useAuth } from "@/components/AuthProvider"
|
||||
|
@ -22,7 +22,7 @@ interface DashboardLayoutProps {
|
|||
export default function DashboardLayout({
|
||||
children,
|
||||
}: DashboardLayoutProps) {
|
||||
const [showPricingAlert, setShowPricingAlert] = useState(false)
|
||||
// const [showPricingAlert, setShowPricingAlert] = useState(false)
|
||||
const [showMaintenanceAlert, setShowMaintenanceAlert] = useState(false)
|
||||
const [isApiHealthy, setIsApiHealthy] = useState(true)
|
||||
const [isCheckingHealth, setIsCheckingHealth] = useState(true)
|
||||
|
@ -32,7 +32,7 @@ export default function DashboardLayout({
|
|||
const router = useRouter()
|
||||
|
||||
useEffect(() => {
|
||||
setShowPricingAlert(false)
|
||||
// setShowPricingAlert(false)
|
||||
setShowMaintenanceAlert(false)
|
||||
}, [])
|
||||
|
||||
|
@ -91,12 +91,12 @@ export default function DashboardLayout({
|
|||
</div>
|
||||
</SidebarInset>
|
||||
|
||||
<PricingAlert
|
||||
{/* <PricingAlert
|
||||
open={showPricingAlert}
|
||||
onOpenChange={setShowPricingAlert}
|
||||
closeable={false}
|
||||
accountId={personalAccount?.account_id}
|
||||
/>
|
||||
/> */}
|
||||
|
||||
<MaintenanceAlert
|
||||
open={showMaintenanceAlert}
|
||||
|
|
|
@ -27,7 +27,8 @@ export async function signIn(prevState: any, formData: FormData) {
|
|||
return { message: error.message || "Could not authenticate user" };
|
||||
}
|
||||
|
||||
return redirect(returnUrl || "/dashboard");
|
||||
// Use client-side navigation instead of server-side redirect
|
||||
return { success: true, redirectTo: returnUrl || "/dashboard" };
|
||||
}
|
||||
|
||||
export async function signUp(prevState: any, formData: FormData) {
|
||||
|
@ -73,7 +74,8 @@ export async function signUp(prevState: any, formData: FormData) {
|
|||
return { message: "Account created! Check your email to confirm your registration." };
|
||||
}
|
||||
|
||||
return redirect(returnUrl || "/dashboard");
|
||||
// Use client-side navigation instead of server-side redirect
|
||||
return { success: true, redirectTo: returnUrl || "/dashboard" };
|
||||
}
|
||||
|
||||
export async function forgotPassword(prevState: any, formData: FormData) {
|
||||
|
|
|
@ -101,8 +101,19 @@ function LoginContent() {
|
|||
const handleSignIn = async (prevState: any, formData: FormData) => {
|
||||
if (returnUrl) {
|
||||
formData.append("returnUrl", returnUrl);
|
||||
} else {
|
||||
formData.append("returnUrl", "/dashboard");
|
||||
}
|
||||
return signIn(prevState, formData);
|
||||
const result = await signIn(prevState, formData);
|
||||
|
||||
// Check for success and redirectTo properties
|
||||
if (result && typeof result === 'object' && 'success' in result && result.success && 'redirectTo' in result) {
|
||||
// Use window.location for hard navigation to avoid stale state
|
||||
window.location.href = result.redirectTo as string;
|
||||
return null; // Return null to prevent normal form action completion
|
||||
}
|
||||
|
||||
return result;
|
||||
};
|
||||
|
||||
const handleSignUp = async (prevState: any, formData: FormData) => {
|
||||
|
@ -119,6 +130,13 @@ function LoginContent() {
|
|||
|
||||
const result = await signUp(prevState, formData);
|
||||
|
||||
// Check for success and redirectTo properties (direct login case)
|
||||
if (result && typeof result === 'object' && 'success' in result && result.success && 'redirectTo' in result) {
|
||||
// Use window.location for hard navigation to avoid stale state
|
||||
window.location.href = result.redirectTo as string;
|
||||
return null; // Return null to prevent normal form action completion
|
||||
}
|
||||
|
||||
// Check if registration was successful but needs email verification
|
||||
if (result && typeof result === 'object' && 'message' in result) {
|
||||
const resultMessage = result.message as string;
|
||||
|
@ -166,9 +184,10 @@ function LoginContent() {
|
|||
|
||||
const resetRegistrationSuccess = () => {
|
||||
setRegistrationSuccess(false);
|
||||
// Remove message from URL
|
||||
// Remove message from URL and set mode to signin
|
||||
const params = new URLSearchParams(window.location.search);
|
||||
params.delete('message');
|
||||
params.set('mode', 'signin');
|
||||
|
||||
const newUrl =
|
||||
window.location.pathname +
|
||||
|
|
|
@ -68,6 +68,9 @@ export default function GoogleSignIn({ returnUrl }: GoogleSignInProps) {
|
|||
try {
|
||||
setIsLoading(true);
|
||||
const supabase = createClient();
|
||||
|
||||
console.log('Starting Google sign in process');
|
||||
|
||||
const { error } = await supabase.auth.signInWithIdToken({
|
||||
provider: 'google',
|
||||
token: response.credential,
|
||||
|
@ -75,10 +78,13 @@ export default function GoogleSignIn({ returnUrl }: GoogleSignInProps) {
|
|||
|
||||
if (error) throw error;
|
||||
|
||||
// Add a small delay before redirecting to ensure localStorage is properly saved
|
||||
console.log('Google sign in successful, preparing redirect to:', returnUrl || "/dashboard");
|
||||
|
||||
// Add a longer delay before redirecting to ensure localStorage is properly saved
|
||||
setTimeout(() => {
|
||||
console.log('Executing redirect now to:', returnUrl || "/dashboard");
|
||||
window.location.href = returnUrl || "/dashboard";
|
||||
}, 100);
|
||||
}, 500); // Increased from 100ms to 500ms
|
||||
} catch (error) {
|
||||
console.error('Error signing in with Google:', error);
|
||||
setIsLoading(false);
|
||||
|
|
|
@ -1,178 +0,0 @@
|
|||
import { createClient } from "@/lib/supabase/server";
|
||||
import { SubmitButton } from "../ui/submit-button";
|
||||
import { manageSubscription } from "@/lib/actions/billing";
|
||||
import { PlanComparison, SUBSCRIPTION_PLANS } from "../billing/plan-comparison";
|
||||
import { isLocalMode } from "@/lib/config";
|
||||
|
||||
type Props = {
|
||||
accountId: string;
|
||||
returnUrl: string;
|
||||
}
|
||||
|
||||
export default async function AccountBillingStatus({ accountId, returnUrl }: Props) {
|
||||
// In local development mode, show a simplified component
|
||||
if (isLocalMode()) {
|
||||
return (
|
||||
<div className="rounded-xl border shadow-sm bg-card p-6">
|
||||
<h2 className="text-xl font-semibold mb-4">Billing Status</h2>
|
||||
<div className="p-4 mb-4 bg-muted/30 border border-border rounded-lg text-center">
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Running in local development mode - billing features are disabled
|
||||
</p>
|
||||
<p className="text-xs text-muted-foreground mt-2">
|
||||
Agent usage limits are not enforced in this environment
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const supabaseClient = await createClient();
|
||||
|
||||
// Get account subscription and usage data
|
||||
const { data: subscriptionData } = await supabaseClient
|
||||
.schema('basejump')
|
||||
.from('billing_subscriptions')
|
||||
.select('*')
|
||||
.eq('account_id', accountId)
|
||||
.eq('status', 'active')
|
||||
.limit(1)
|
||||
.order('created_at', { ascending: false })
|
||||
.single();
|
||||
|
||||
// Get agent runs for this account
|
||||
// Get the account's threads
|
||||
const { data: threads } = await supabaseClient
|
||||
.from('threads')
|
||||
.select('thread_id')
|
||||
.eq('account_id', accountId);
|
||||
|
||||
const threadIds = threads?.map(t => t.thread_id) || [];
|
||||
|
||||
// Get current month usage
|
||||
const now = new Date();
|
||||
const startOfMonth = new Date(now.getFullYear(), now.getMonth(), 1);
|
||||
const isoStartOfMonth = startOfMonth.toISOString();
|
||||
|
||||
let totalAgentTime = 0;
|
||||
let usageDisplay = "No usage this month";
|
||||
|
||||
if (threadIds.length > 0) {
|
||||
const { data: agentRuns } = await supabaseClient
|
||||
.from('agent_runs')
|
||||
.select('started_at, completed_at')
|
||||
.in('thread_id', threadIds)
|
||||
.gte('started_at', isoStartOfMonth);
|
||||
|
||||
if (agentRuns && agentRuns.length > 0) {
|
||||
const nowTimestamp = now.getTime();
|
||||
|
||||
totalAgentTime = agentRuns.reduce((total, run) => {
|
||||
const startTime = new Date(run.started_at).getTime();
|
||||
const endTime = run.completed_at
|
||||
? new Date(run.completed_at).getTime()
|
||||
: nowTimestamp;
|
||||
|
||||
return total + (endTime - startTime) / 1000; // In seconds
|
||||
}, 0);
|
||||
|
||||
// Convert to minutes
|
||||
const totalMinutes = Math.round(totalAgentTime / 60);
|
||||
usageDisplay = `${totalMinutes} minutes`;
|
||||
}
|
||||
}
|
||||
|
||||
const isPlan = (planId?: string) => {
|
||||
return subscriptionData?.price_id === planId;
|
||||
};
|
||||
|
||||
const planName = isPlan(SUBSCRIPTION_PLANS.FREE)
|
||||
? "Free"
|
||||
: isPlan(SUBSCRIPTION_PLANS.PRO)
|
||||
? "Pro"
|
||||
: isPlan(SUBSCRIPTION_PLANS.ENTERPRISE)
|
||||
? "Enterprise"
|
||||
: "Unknown";
|
||||
|
||||
return (
|
||||
<div className="rounded-xl border shadow-sm bg-card p-6">
|
||||
<h2 className="text-xl font-semibold mb-4">Billing Status</h2>
|
||||
|
||||
{subscriptionData ? (
|
||||
<>
|
||||
<div className="mb-6">
|
||||
<div className="rounded-lg border bg-background p-4 grid grid-cols-1 md:grid-cols-2 gap-4">
|
||||
<div>
|
||||
<div className="flex justify-between items-center">
|
||||
<span className="text-sm font-medium text-foreground/90">Current Plan</span>
|
||||
<span className="text-sm font-medium text-card-title">{planName}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex justify-between items-center">
|
||||
<span className="text-sm font-medium text-foreground/90">Agent Usage This Month</span>
|
||||
<span className="text-sm font-medium text-card-title">{usageDisplay}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Plans Comparison */}
|
||||
<PlanComparison
|
||||
accountId={accountId}
|
||||
returnUrl={returnUrl}
|
||||
className="mb-6"
|
||||
/>
|
||||
|
||||
{/* Manage Subscription Button */}
|
||||
<form>
|
||||
<input type="hidden" name="accountId" value={accountId} />
|
||||
<input type="hidden" name="returnUrl" value={returnUrl} />
|
||||
<SubmitButton
|
||||
pendingText="Loading..."
|
||||
formAction={manageSubscription}
|
||||
className="w-full bg-primary text-white hover:bg-primary/90 shadow-md hover:shadow-lg transition-all"
|
||||
>
|
||||
Manage Subscription
|
||||
</SubmitButton>
|
||||
</form>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<div className="mb-6">
|
||||
<div className="rounded-lg border bg-background p-4 gap-4">
|
||||
<div className="flex justify-between items-center">
|
||||
<span className="text-sm font-medium text-foreground/90">Current Plan</span>
|
||||
<span className="text-sm font-medium text-card-title">Free</span>
|
||||
</div>
|
||||
|
||||
<div className="flex justify-between items-center">
|
||||
<span className="text-sm font-medium text-foreground/90">Agent Usage This Month</span>
|
||||
<span className="text-sm font-medium text-card-title">{usageDisplay}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Plans Comparison */}
|
||||
<PlanComparison
|
||||
accountId={accountId}
|
||||
returnUrl={returnUrl}
|
||||
className="mb-6"
|
||||
/>
|
||||
|
||||
{/* Manage Subscription Button */}
|
||||
<form>
|
||||
<input type="hidden" name="accountId" value={accountId} />
|
||||
<input type="hidden" name="returnUrl" value={returnUrl} />
|
||||
<SubmitButton
|
||||
pendingText="Loading..."
|
||||
formAction={manageSubscription}
|
||||
className="w-full bg-primary text-white hover:bg-primary/90 shadow-md hover:shadow-lg transition-all"
|
||||
>
|
||||
Manage Subscription
|
||||
</SubmitButton>
|
||||
</form>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
|
@ -0,0 +1,180 @@
|
|||
'use client';
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { PricingSection } from "@/components/home/sections/pricing-section";
|
||||
import { isLocalMode } from "@/lib/config";
|
||||
import { getSubscription, createPortalSession, SubscriptionStatus } from "@/lib/api";
|
||||
import { useAuth } from "@/components/AuthProvider";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
|
||||
type Props = {
|
||||
accountId: string;
|
||||
returnUrl: string;
|
||||
}
|
||||
|
||||
export default function AccountBillingStatus({ accountId, returnUrl }: Props) {
|
||||
const { session, isLoading: authLoading } = useAuth();
|
||||
const [subscriptionData, setSubscriptionData] = useState<SubscriptionStatus | null>(null);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [isManaging, setIsManaging] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
async function fetchSubscription() {
|
||||
if (authLoading || !session) return;
|
||||
|
||||
try {
|
||||
const data = await getSubscription();
|
||||
setSubscriptionData(data);
|
||||
setError(null);
|
||||
} catch (err) {
|
||||
console.error('Failed to get subscription:', err);
|
||||
setError(err instanceof Error ? err.message : 'Failed to load subscription data');
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}
|
||||
|
||||
fetchSubscription();
|
||||
}, [session, authLoading]);
|
||||
|
||||
const handleManageSubscription = async () => {
|
||||
try {
|
||||
setIsManaging(true);
|
||||
const { url } = await createPortalSession({ return_url: returnUrl });
|
||||
window.location.href = url;
|
||||
} catch (err) {
|
||||
console.error('Failed to create portal session:', err);
|
||||
setError(err instanceof Error ? err.message : 'Failed to create portal session');
|
||||
} finally {
|
||||
setIsManaging(false);
|
||||
}
|
||||
};
|
||||
|
||||
// In local development mode, show a simplified component
|
||||
if (isLocalMode()) {
|
||||
return (
|
||||
<div className="rounded-xl border shadow-sm bg-card p-6">
|
||||
<h2 className="text-xl font-semibold mb-4">Billing Status</h2>
|
||||
<div className="p-4 mb-4 bg-muted/30 border border-border rounded-lg text-center">
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Running in local development mode - billing features are disabled
|
||||
</p>
|
||||
<p className="text-xs text-muted-foreground mt-2">
|
||||
Agent usage limits are not enforced in this environment
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Show loading state
|
||||
if (isLoading || authLoading) {
|
||||
return (
|
||||
<div className="rounded-xl border shadow-sm bg-card p-6">
|
||||
<h2 className="text-xl font-semibold mb-4">Billing Status</h2>
|
||||
<div className="space-y-4">
|
||||
<Skeleton className="h-20 w-full" />
|
||||
<Skeleton className="h-40 w-full" />
|
||||
<Skeleton className="h-10 w-full" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Show error state
|
||||
if (error) {
|
||||
return (
|
||||
<div className="rounded-xl border shadow-sm bg-card p-6">
|
||||
<h2 className="text-xl font-semibold mb-4">Billing Status</h2>
|
||||
<div className="p-4 mb-4 bg-destructive/10 border border-destructive/20 rounded-lg text-center">
|
||||
<p className="text-sm text-destructive">
|
||||
Error loading billing status: {error}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const isPlan = (planId?: string) => {
|
||||
return subscriptionData?.plan_name === planId;
|
||||
};
|
||||
|
||||
const planName = isPlan('free')
|
||||
? "Free"
|
||||
: isPlan('base')
|
||||
? "Pro"
|
||||
: isPlan('extra')
|
||||
? "Enterprise"
|
||||
: "Unknown";
|
||||
|
||||
return (
|
||||
<div className="rounded-xl border shadow-sm bg-card p-6">
|
||||
<h2 className="text-xl font-semibold mb-4">Billing Status</h2>
|
||||
|
||||
{subscriptionData ? (
|
||||
<>
|
||||
<div className="mb-6">
|
||||
<div className="rounded-lg border bg-background p-4 grid grid-cols-1 md:grid-cols-2 gap-4">
|
||||
<div className="flex justify-between items-center">
|
||||
<span className="text-sm font-medium text-foreground/90">Agent Usage This Month</span>
|
||||
<span className="text-sm font-medium text-card-title">
|
||||
{subscriptionData.current_usage?.toFixed(2) || '0'} / {subscriptionData.minutes_limit || '0'} minutes
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Plans Comparison */}
|
||||
<PricingSection
|
||||
returnUrl={returnUrl}
|
||||
showTitleAndTabs={false}
|
||||
/>
|
||||
|
||||
{/* Manage Subscription Button */}
|
||||
<Button
|
||||
onClick={handleManageSubscription}
|
||||
disabled={isManaging}
|
||||
className="w-full bg-primary text-white hover:bg-primary/90 shadow-md hover:shadow-lg transition-all"
|
||||
>
|
||||
{isManaging ? "Loading..." : "Manage Subscription"}
|
||||
</Button>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<div className="mb-6">
|
||||
<div className="rounded-lg border bg-background p-4 gap-4">
|
||||
<div className="flex justify-between items-center">
|
||||
<span className="text-sm font-medium text-foreground/90">Current Plan</span>
|
||||
<span className="text-sm font-medium text-card-title">Free</span>
|
||||
</div>
|
||||
|
||||
<div className="flex justify-between items-center">
|
||||
<span className="text-sm font-medium text-foreground/90">Agent Usage This Month</span>
|
||||
<span className="text-sm font-medium text-card-title">
|
||||
{subscriptionData?.current_usage?.toFixed(2) || '0'} / {subscriptionData?.minutes_limit || '0'} minutes
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Plans Comparison */}
|
||||
<PricingSection
|
||||
returnUrl={returnUrl}
|
||||
showTitleAndTabs={false}
|
||||
/>
|
||||
|
||||
{/* Manage Subscription Button */}
|
||||
<Button
|
||||
onClick={handleManageSubscription}
|
||||
disabled={isManaging}
|
||||
className="w-full bg-primary text-white hover:bg-primary/90 shadow-md hover:shadow-lg transition-all"
|
||||
>
|
||||
{isManaging ? "Loading..." : "Manage Subscription"}
|
||||
</Button>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
|
@ -1,263 +0,0 @@
|
|||
'use client';
|
||||
|
||||
import { createClient } from "@/lib/supabase/client";
|
||||
import { useEffect, useState } from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { motion } from "motion/react";
|
||||
import { setupNewSubscription } from "@/lib/actions/billing";
|
||||
import { SubmitButton } from "@/components/ui/submit-button";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { siteConfig } from "@/lib/home";
|
||||
import { isLocalMode } from "@/lib/config";
|
||||
|
||||
// Create SUBSCRIPTION_PLANS using stripePriceId from siteConfig
|
||||
export const SUBSCRIPTION_PLANS = {
|
||||
FREE: siteConfig.cloudPricingItems.find(item => item.name === 'Free')?.stripePriceId || '',
|
||||
PRO: siteConfig.cloudPricingItems.find(item => item.name === 'Pro')?.stripePriceId || '',
|
||||
ENTERPRISE: siteConfig.cloudPricingItems.find(item => item.name === 'Enterprise')?.stripePriceId || '',
|
||||
};
|
||||
|
||||
// Price display animation component
|
||||
const PriceDisplay = ({ tier, isCompact }: { tier: typeof siteConfig.cloudPricingItems[number]; isCompact?: boolean }) => {
|
||||
return (
|
||||
<motion.span
|
||||
key={tier.price}
|
||||
className={isCompact ? "text-xl font-semibold" : "text-3xl font-semibold"}
|
||||
initial={{
|
||||
opacity: 0,
|
||||
x: 10,
|
||||
filter: "blur(5px)",
|
||||
}}
|
||||
animate={{ opacity: 1, x: 0, filter: "blur(0px)" }}
|
||||
transition={{ duration: 0.25, ease: [0.4, 0, 0.2, 1] }}
|
||||
>
|
||||
{tier.price}
|
||||
</motion.span>
|
||||
);
|
||||
};
|
||||
|
||||
interface PlanComparisonProps {
|
||||
accountId?: string | null;
|
||||
returnUrl?: string;
|
||||
isManaged?: boolean;
|
||||
onPlanSelect?: (planId: string) => void;
|
||||
className?: string;
|
||||
isCompact?: boolean; // When true, uses vertical stacked layout for modals
|
||||
}
|
||||
|
||||
export function PlanComparison({
|
||||
accountId,
|
||||
returnUrl = typeof window !== 'undefined' ? window.location.href : '',
|
||||
isManaged = true,
|
||||
onPlanSelect,
|
||||
className = "",
|
||||
isCompact = false
|
||||
}: PlanComparisonProps) {
|
||||
const [currentPlanId, setCurrentPlanId] = useState<string | undefined>();
|
||||
|
||||
useEffect(() => {
|
||||
async function fetchCurrentPlan() {
|
||||
if (accountId) {
|
||||
const supabase = createClient();
|
||||
const { data } = await supabase
|
||||
.schema('basejump')
|
||||
.from('billing_subscriptions')
|
||||
.select('price_id')
|
||||
.eq('account_id', accountId)
|
||||
.eq('status', 'active')
|
||||
.single();
|
||||
|
||||
setCurrentPlanId(data?.price_id || SUBSCRIPTION_PLANS.FREE);
|
||||
} else {
|
||||
setCurrentPlanId(SUBSCRIPTION_PLANS.FREE);
|
||||
}
|
||||
}
|
||||
|
||||
fetchCurrentPlan();
|
||||
}, [accountId]);
|
||||
|
||||
// For local development mode, show a message instead
|
||||
if (isLocalMode()) {
|
||||
return (
|
||||
<div className={cn("p-4 bg-muted/30 border border-border rounded-lg text-center", className)}>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Running in local development mode - billing features are disabled
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"grid gap-3 w-full mx-auto",
|
||||
isCompact
|
||||
? "grid-cols-1 max-w-md"
|
||||
: "grid-cols-1 md:grid-cols-3 max-w-6xl",
|
||||
className
|
||||
)}
|
||||
>
|
||||
{siteConfig.cloudPricingItems.map((tier) => {
|
||||
const isCurrentPlan = currentPlanId === SUBSCRIPTION_PLANS[tier.name.toUpperCase() as keyof typeof SUBSCRIPTION_PLANS];
|
||||
|
||||
return (
|
||||
<div
|
||||
key={tier.name}
|
||||
className={cn(
|
||||
"rounded-lg bg-background border border-border",
|
||||
isCompact ? "p-3 text-sm" : "p-5",
|
||||
isCurrentPlan && (isCompact ? "ring-1 ring-primary" : "ring-2 ring-primary")
|
||||
)}
|
||||
>
|
||||
{isCompact ? (
|
||||
// Compact layout for modal
|
||||
<>
|
||||
<div className="flex justify-between mb-2">
|
||||
<div>
|
||||
<div className="flex items-center gap-1">
|
||||
<h3 className="font-medium">{tier.name}</h3>
|
||||
{tier.isPopular && (
|
||||
<span className="bg-primary/10 text-primary text-[10px] font-medium px-1.5 py-0.5 rounded-full">
|
||||
Popular
|
||||
</span>
|
||||
)}
|
||||
{isCurrentPlan && (
|
||||
<span className="bg-secondary/10 text-secondary text-[10px] font-medium px-1.5 py-0.5 rounded-full">
|
||||
Current
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground mt-0.5">{tier.description}</div>
|
||||
</div>
|
||||
|
||||
<div className="text-right">
|
||||
<div className="flex items-baseline">
|
||||
<PriceDisplay tier={tier} isCompact={true} />
|
||||
<span className="text-xs text-muted-foreground ml-1">
|
||||
{tier.price !== "$0" ? "/mo" : ""}
|
||||
</span>
|
||||
</div>
|
||||
<div className="text-[10px] text-muted-foreground mt-0.5">
|
||||
{tier.hours}/month
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="mb-2.5">
|
||||
<div className="text-[10px] text-muted-foreground leading-tight max-h-[40px] overflow-y-auto pr-1">
|
||||
{tier.features.map((feature, index) => (
|
||||
<span key={index} className="whitespace-normal">
|
||||
{index > 0 && ' • '}
|
||||
{feature}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
// Standard layout for normal view
|
||||
<>
|
||||
<div className="flex items-center justify-between mb-4">
|
||||
<h3 className="text-lg font-medium">{tier.name}</h3>
|
||||
<div className="flex gap-1">
|
||||
{tier.isPopular && (
|
||||
<span className="bg-primary/10 text-primary text-xs font-medium px-2 py-0.5 rounded-full">
|
||||
Popular
|
||||
</span>
|
||||
)}
|
||||
{isCurrentPlan && (
|
||||
<span className="bg-secondary/10 text-secondary text-xs font-medium px-2 py-0.5 rounded-full">
|
||||
Current
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex items-baseline mb-1">
|
||||
<PriceDisplay tier={tier} />
|
||||
<span className="text-muted-foreground ml-2">
|
||||
{tier.price !== "$0" ? "/month" : ""}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div className="inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-medium bg-secondary/10 text-secondary mb-4">
|
||||
{tier.hours}/month
|
||||
</div>
|
||||
|
||||
<p className="text-muted-foreground mb-6">{tier.description}</p>
|
||||
|
||||
<div className="mb-6">
|
||||
<div className="text-sm text-muted-foreground space-y-2">
|
||||
{tier.features.map((feature, index) => (
|
||||
<div key={index} className="flex items-start gap-2">
|
||||
<div className="size-5 rounded-full bg-primary/10 flex items-center justify-center mt-0.5">
|
||||
<svg
|
||||
width="12"
|
||||
height="12"
|
||||
viewBox="0 0 12 12"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className="text-primary"
|
||||
>
|
||||
<path
|
||||
d="M2.5 6L5 8.5L9.5 4"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
<span>{feature}</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
||||
<form>
|
||||
<input type="hidden" name="accountId" value={accountId} />
|
||||
<input type="hidden" name="returnUrl" value={returnUrl} />
|
||||
<input type="hidden" name="planId" value={SUBSCRIPTION_PLANS[tier.name.toUpperCase() as keyof typeof SUBSCRIPTION_PLANS]} />
|
||||
{isManaged ? (
|
||||
<SubmitButton
|
||||
pendingText="..."
|
||||
formAction={setupNewSubscription}
|
||||
// disabled={isCurrentPlan}
|
||||
className={cn(
|
||||
"w-full font-medium transition-colors",
|
||||
isCompact
|
||||
? "h-7 rounded-md text-xs"
|
||||
: "h-10 rounded-full text-sm",
|
||||
isCurrentPlan
|
||||
? "bg-muted text-muted-foreground hover:bg-muted"
|
||||
: tier.buttonColor
|
||||
)}
|
||||
>
|
||||
{isCurrentPlan ? "Current Plan" : (tier.name === "Free" ? tier.buttonText : "Upgrade")}
|
||||
</SubmitButton>
|
||||
) : (
|
||||
<Button
|
||||
className={cn(
|
||||
"w-full font-medium transition-colors",
|
||||
isCompact
|
||||
? "h-7 rounded-md text-xs"
|
||||
: "h-10 rounded-full text-sm",
|
||||
isCurrentPlan
|
||||
? "bg-muted text-muted-foreground hover:bg-muted"
|
||||
: tier.buttonColor
|
||||
)}
|
||||
disabled={isCurrentPlan}
|
||||
onClick={() => onPlanSelect?.(SUBSCRIPTION_PLANS[tier.name.toUpperCase() as keyof typeof SUBSCRIPTION_PLANS])}
|
||||
>
|
||||
{isCurrentPlan ? "Current Plan" : (tier.name === "Free" ? tier.buttonText : "Upgrade")}
|
||||
</Button>
|
||||
)}
|
||||
</form>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
|
@ -1,251 +0,0 @@
|
|||
"use client"
|
||||
|
||||
import { X, Zap, Github, Check } from "lucide-react"
|
||||
import Link from "next/link"
|
||||
import { AnimatePresence, motion } from "motion/react"
|
||||
import { Button } from "@/components/ui/button"
|
||||
import { Portal } from "@/components/ui/portal"
|
||||
import { cn } from "@/lib/utils"
|
||||
import { setupNewSubscription } from "@/lib/actions/billing"
|
||||
import { SubmitButton } from "@/components/ui/submit-button"
|
||||
import { siteConfig } from "@/lib/home"
|
||||
import { isLocalMode } from "@/lib/config"
|
||||
import { createClient } from "@/lib/supabase/client"
|
||||
import { useEffect, useState } from "react"
|
||||
import { SUBSCRIPTION_PLANS } from "./plan-comparison"
|
||||
|
||||
interface PricingAlertProps {
|
||||
open: boolean
|
||||
onOpenChange: (open: boolean) => void
|
||||
closeable?: boolean
|
||||
accountId?: string | null | undefined
|
||||
}
|
||||
|
||||
export function PricingAlert({ open, onOpenChange, closeable = true, accountId }: PricingAlertProps) {
|
||||
const returnUrl = typeof window !== 'undefined' ? window.location.href : '';
|
||||
const [hasActiveSubscription, setHasActiveSubscription] = useState(false);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
|
||||
// Check if user has an active subscription
|
||||
useEffect(() => {
|
||||
async function checkSubscription() {
|
||||
if (!accountId) {
|
||||
setHasActiveSubscription(false);
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const supabase = createClient();
|
||||
const { data } = await supabase
|
||||
.schema('basejump')
|
||||
.from('billing_subscriptions')
|
||||
.select('price_id')
|
||||
.eq('account_id', accountId)
|
||||
.eq('status', 'active')
|
||||
.single();
|
||||
|
||||
// Check if the user has a paid subscription (not free tier)
|
||||
const isPaidSubscription = data?.price_id &&
|
||||
data.price_id !== SUBSCRIPTION_PLANS.FREE;
|
||||
|
||||
setHasActiveSubscription(isPaidSubscription);
|
||||
} catch (error) {
|
||||
console.error("Error checking subscription:", error);
|
||||
setHasActiveSubscription(false);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}
|
||||
|
||||
checkSubscription();
|
||||
}, [accountId]);
|
||||
|
||||
// Skip rendering in local development mode or if user has an active subscription
|
||||
if (isLocalMode() || !open || hasActiveSubscription || isLoading) return null;
|
||||
|
||||
// Filter plans to show only Pro and Enterprise
|
||||
const premiumPlans = siteConfig.cloudPricingItems.filter(plan =>
|
||||
plan.name === 'Pro' || plan.name === 'Enterprise'
|
||||
);
|
||||
|
||||
return (
|
||||
<Portal>
|
||||
<AnimatePresence>
|
||||
{open && (
|
||||
<>
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
exit={{ opacity: 0 }}
|
||||
transition={{ duration: 0.2 }}
|
||||
className="fixed inset-0 z-[9999] flex items-center justify-center overflow-y-auto py-8 px-4"
|
||||
>
|
||||
{/* Backdrop */}
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
exit={{ opacity: 0 }}
|
||||
transition={{ duration: 0.2 }}
|
||||
className="fixed inset-0 bg-black/60 backdrop-blur-sm"
|
||||
onClick={closeable ? () => onOpenChange(false) : undefined}
|
||||
aria-hidden="true"
|
||||
/>
|
||||
|
||||
{/* Modal */}
|
||||
<motion.div
|
||||
initial={{ opacity: 0, scale: 0.95, y: 20 }}
|
||||
animate={{ opacity: 1, scale: 1, y: 0 }}
|
||||
exit={{ opacity: 0, scale: 0.95, y: 20 }}
|
||||
transition={{ duration: 0.2, ease: [0.4, 0, 0.2, 1] }}
|
||||
className={cn(
|
||||
"relative bg-background rounded-xl shadow-2xl w-full max-w-3xl mx-3 border border-border"
|
||||
)}
|
||||
role="dialog"
|
||||
aria-modal="true"
|
||||
aria-labelledby="pricing-modal-title"
|
||||
>
|
||||
<div className="p-6">
|
||||
{/* Close button */}
|
||||
{closeable && (
|
||||
<button
|
||||
onClick={() => onOpenChange(false)}
|
||||
className="absolute top-4 right-4 text-muted-foreground hover:text-foreground transition-colors"
|
||||
aria-label="Close dialog"
|
||||
>
|
||||
<X className="h-5 w-5" />
|
||||
</button>
|
||||
)}
|
||||
|
||||
{/* Header */}
|
||||
<div className="text-center mb-8">
|
||||
<div className="inline-flex items-center justify-center p-2 bg-primary/10 rounded-full mb-3">
|
||||
<Zap className="h-5 w-5 text-primary" />
|
||||
</div>
|
||||
<h2 id="pricing-modal-title" className="text-2xl font-medium tracking-tight mb-2">
|
||||
Choose Your Suna Experience
|
||||
</h2>
|
||||
<p className="text-muted-foreground max-w-lg mx-auto">
|
||||
Due to overwhelming demand and AI costs, we're currently focusing on delivering
|
||||
our best experience to dedicated users. Select your preferred option below.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Plan comparison - 3 column layout */}
|
||||
<div className="grid md:grid-cols-3 gap-4 mb-6">
|
||||
{/* Self-Host Option */}
|
||||
<div className="rounded-xl bg-[#F3F4F6] dark:bg-[#F9FAFB]/[0.02] border border-border hover:border-muted-foreground/30 transition-all duration-300">
|
||||
<div className="flex flex-col gap-4 p-4">
|
||||
<p className="text-sm flex items-center">Open Source</p>
|
||||
<div className="flex items-baseline mt-2">
|
||||
<span className="text-2xl font-semibold">Self-host</span>
|
||||
</div>
|
||||
<p className="text-sm mt-2">Full control with your own infrastructure</p>
|
||||
<div className="inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-semibold bg-primary/10 border-primary/20 text-primary w-fit">
|
||||
∞ hours / month
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="px-4 pb-4">
|
||||
<div className="flex items-start gap-2 mb-3">
|
||||
<Check className="h-4 w-4 text-green-500 flex-shrink-0 mt-0.5" />
|
||||
<span className="text-xs text-muted-foreground">No usage limitations</span>
|
||||
</div>
|
||||
<Link
|
||||
href="https://github.com/kortix-ai/suna"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="h-10 w-full flex items-center justify-center gap-2 text-sm font-normal tracking-wide rounded-full px-4 cursor-pointer transition-all ease-out active:scale-95 bg-secondary/10 text-secondary shadow-[0px_1px_2px_0px_rgba(255,255,255,0.16)_inset,0px_3px_3px_-1.5px_rgba(16,24,40,0.24),0px_1px_1px_-0.5px_rgba(16,24,40,0.20)]"
|
||||
>
|
||||
<Github className="h-4 w-4" />
|
||||
<span>View on GitHub</span>
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Pro Plan */}
|
||||
<div className="rounded-xl md:shadow-[0px_61px_24px_-10px_rgba(0,0,0,0.01),0px_34px_20px_-8px_rgba(0,0,0,0.05),0px_15px_15px_-6px_rgba(0,0,0,0.09),0px_4px_8px_-2px_rgba(0,0,0,0.10),0px_0px_0px_1px_rgba(0,0,0,0.08)] bg-accent relative transform hover:scale-105 transition-all duration-300">
|
||||
<div className="absolute -top-3 -right-3">
|
||||
<span className="bg-gradient-to-b from-secondary/50 from-[1.92%] to-secondary to-[100%] text-white h-6 inline-flex w-fit items-center justify-center px-3 rounded-full text-xs font-medium shadow-[0px_6px_6px_-3px_rgba(0,0,0,0.08),0px_3px_3px_-1.5px_rgba(0,0,0,0.08),0px_1px_1px_-0.5px_rgba(0,0,0,0.08),0px_0px_0px_1px_rgba(255,255,255,0.12)_inset,0px_1px_0px_0px_rgba(255,255,255,0.12)_inset]">
|
||||
Most Popular
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex flex-col gap-4 p-4">
|
||||
<p className="text-sm flex items-center font-medium">Pro</p>
|
||||
<div className="flex items-baseline mt-2">
|
||||
<span className="text-2xl font-semibold">{premiumPlans[0]?.price || "$19"}</span>
|
||||
<span className="ml-2">/month</span>
|
||||
</div>
|
||||
<p className="text-sm mt-2">Supercharge your productivity with {premiumPlans[0]?.hours || "500 hours"} of Suna</p>
|
||||
<div className="inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-semibold bg-primary/10 border-primary/20 text-primary w-fit">
|
||||
{premiumPlans[0]?.hours || "500 hours"}/month
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="px-4 pb-4">
|
||||
<div className="flex items-start gap-2 mb-3">
|
||||
<Check className="h-4 w-4 text-green-500 flex-shrink-0 mt-0.5" />
|
||||
<span className="text-xs text-muted-foreground">Perfect for individuals and small teams</span>
|
||||
</div>
|
||||
<form>
|
||||
<input type="hidden" name="accountId" value={accountId || ''} />
|
||||
<input type="hidden" name="returnUrl" value={returnUrl} />
|
||||
<input type="hidden" name="planId" value={
|
||||
premiumPlans[0]?.stripePriceId || ''
|
||||
} />
|
||||
<SubmitButton
|
||||
pendingText="..."
|
||||
formAction={setupNewSubscription}
|
||||
className="h-10 w-full flex items-center justify-center text-sm font-medium tracking-wide rounded-full px-4 cursor-pointer transition-all ease-out active:scale-95 bg-primary text-primary-foreground shadow-[inset_0_1px_2px_rgba(255,255,255,0.25),0_3px_3px_-1.5px_rgba(16,24,40,0.06),0_1px_1px_rgba(16,24,40,0.08)]"
|
||||
>
|
||||
Get Started Now
|
||||
</SubmitButton>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Enterprise Plan */}
|
||||
<div className="rounded-xl bg-[#F3F4F6] dark:bg-[#F9FAFB]/[0.02] border border-border hover:border-muted-foreground/30 transition-all duration-300">
|
||||
<div className="flex flex-col gap-4 p-4">
|
||||
<p className="text-sm flex items-center font-medium">Enterprise</p>
|
||||
<div className="flex items-baseline mt-2">
|
||||
<span className="text-2xl font-semibold">{premiumPlans[1]?.price || "$99"}</span>
|
||||
<span className="ml-2">/month</span>
|
||||
</div>
|
||||
<p className="text-sm mt-2">Unlock boundless potential with {premiumPlans[1]?.hours || "2000 hours"} of Suna</p>
|
||||
<div className="inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-semibold bg-primary/10 border-primary/20 text-primary w-fit">
|
||||
{premiumPlans[1]?.hours || "2000 hours"}/month
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="px-4 pb-4">
|
||||
<div className="flex items-start gap-2 mb-3">
|
||||
<Check className="h-4 w-4 text-green-500 flex-shrink-0 mt-0.5" />
|
||||
<span className="text-xs text-muted-foreground">Ideal for larger organizations and power users</span>
|
||||
</div>
|
||||
<form>
|
||||
<input type="hidden" name="accountId" value={accountId || ''} />
|
||||
<input type="hidden" name="returnUrl" value={returnUrl} />
|
||||
<input type="hidden" name="planId" value={
|
||||
premiumPlans[1]?.stripePriceId || ''
|
||||
} />
|
||||
<SubmitButton
|
||||
pendingText="..."
|
||||
formAction={setupNewSubscription}
|
||||
className="h-10 w-full flex items-center justify-center text-sm font-normal tracking-wide rounded-full px-4 cursor-pointer transition-all ease-out active:scale-95 bg-gradient-to-b from-secondary/50 from-[1.92%] to-secondary to-[100%] text-white shadow-[0px_1px_2px_0px_rgba(255,255,255,0.16)_inset,0px_3px_3px_-1.5px_rgba(16,24,40,0.24),0px_1px_1px_-0.5px_rgba(16,24,40,0.20)]"
|
||||
>
|
||||
Upgrade to Enterprise
|
||||
</SubmitButton>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</motion.div>
|
||||
</motion.div>
|
||||
</>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
</Portal>
|
||||
)
|
||||
}
|
|
@ -1,18 +1,15 @@
|
|||
import { AlertCircle, X } from "lucide-react";
|
||||
'use client';
|
||||
|
||||
import { AlertTriangle } from "lucide-react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Portal } from "@/components/ui/portal";
|
||||
import { PlanComparison } from "./plan-comparison";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { motion, AnimatePresence } from "motion/react";
|
||||
import { isLocalMode } from "@/lib/config";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
interface BillingErrorAlertProps {
|
||||
message?: string;
|
||||
currentUsage?: number;
|
||||
limit?: number;
|
||||
accountId: string | null | undefined;
|
||||
onDismiss?: () => void;
|
||||
className?: string;
|
||||
accountId?: string | null;
|
||||
onDismiss: () => void;
|
||||
isOpen: boolean;
|
||||
}
|
||||
|
||||
|
@ -22,125 +19,42 @@ export function BillingErrorAlert({
|
|||
limit,
|
||||
accountId,
|
||||
onDismiss,
|
||||
className = "",
|
||||
isOpen
|
||||
}: BillingErrorAlertProps) {
|
||||
const returnUrl = typeof window !== 'undefined' ? window.location.href : '';
|
||||
|
||||
// Skip rendering in local development mode
|
||||
if (isLocalMode() || !isOpen) return null;
|
||||
const router = useRouter();
|
||||
|
||||
if (!isOpen) return null;
|
||||
|
||||
return (
|
||||
<Portal>
|
||||
<AnimatePresence>
|
||||
{isOpen && (
|
||||
<>
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
exit={{ opacity: 0 }}
|
||||
transition={{ duration: 0.2 }}
|
||||
className="fixed inset-0 z-[9999] flex items-center justify-center overflow-y-auto py-4"
|
||||
>
|
||||
{/* Backdrop */}
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
exit={{ opacity: 0 }}
|
||||
transition={{ duration: 0.2 }}
|
||||
className="fixed inset-0 bg-black/40 backdrop-blur-sm"
|
||||
<div className="fixed bottom-4 right-4 z-50">
|
||||
<div className="bg-destructive/10 border border-destructive/20 rounded-lg p-4 shadow-lg max-w-md">
|
||||
<div className="flex items-start gap-3">
|
||||
<div className="flex-shrink-0">
|
||||
<AlertTriangle className="h-5 w-5 text-destructive" />
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<h3 className="text-sm font-medium text-destructive mb-1">Usage Limit Reached</h3>
|
||||
<p className="text-sm text-muted-foreground mb-3">{message}</p>
|
||||
<div className="flex gap-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={onDismiss}
|
||||
aria-hidden="true"
|
||||
/>
|
||||
|
||||
{/* Modal */}
|
||||
<motion.div
|
||||
initial={{ opacity: 0, scale: 0.95, y: 20 }}
|
||||
animate={{ opacity: 1, scale: 1, y: 0 }}
|
||||
exit={{ opacity: 0, scale: 0.95, y: 20 }}
|
||||
transition={{ duration: 0.2, ease: [0.4, 0, 0.2, 1] }}
|
||||
className={cn(
|
||||
"relative bg-background rounded-lg shadow-xl w-full max-w-sm mx-3",
|
||||
className
|
||||
)}
|
||||
role="dialog"
|
||||
aria-modal="true"
|
||||
aria-labelledby="billing-modal-title"
|
||||
className="text-xs"
|
||||
>
|
||||
<div className="p-4">
|
||||
{/* Close button */}
|
||||
{onDismiss && (
|
||||
<button
|
||||
onClick={onDismiss}
|
||||
className="absolute top-2 right-2 text-muted-foreground hover:text-foreground transition-colors"
|
||||
aria-label="Close dialog"
|
||||
>
|
||||
<X className="h-4 w-4" />
|
||||
</button>
|
||||
)}
|
||||
|
||||
{/* Header */}
|
||||
<div className="text-center mb-4">
|
||||
<div className="inline-flex items-center justify-center p-1.5 bg-destructive/10 rounded-full mb-2">
|
||||
<AlertCircle className="h-4 w-4 text-destructive" />
|
||||
</div>
|
||||
<h2 id="billing-modal-title" className="text-lg font-medium tracking-tight mb-1">
|
||||
Usage Limit Reached
|
||||
</h2>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
{message || "You've reached your monthly usage limit."}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Usage Stats */}
|
||||
{currentUsage !== undefined && limit !== undefined && (
|
||||
<div className="mb-4 p-3 bg-muted/30 border border-border rounded-lg">
|
||||
<div className="flex justify-between items-center mb-2">
|
||||
<div>
|
||||
<p className="text-xs font-medium text-muted-foreground">Usage</p>
|
||||
<p className="text-base font-semibold">{(currentUsage * 60).toFixed(0)}m</p>
|
||||
</div>
|
||||
<div className="text-right">
|
||||
<p className="text-xs font-medium text-muted-foreground">Limit</p>
|
||||
<p className="text-base font-semibold">{(limit * 60).toFixed(0)}m</p>
|
||||
</div>
|
||||
</div>
|
||||
<div className="w-full h-1.5 bg-background rounded-full overflow-hidden">
|
||||
<motion.div
|
||||
initial={{ width: 0 }}
|
||||
animate={{ width: `${Math.min((currentUsage / limit) * 100, 100)}%` }}
|
||||
transition={{ duration: 0.5, ease: [0.4, 0, 0.2, 1] }}
|
||||
className="h-full bg-destructive rounded-full"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Plans Comparison */}
|
||||
<PlanComparison
|
||||
accountId={accountId}
|
||||
returnUrl={returnUrl}
|
||||
className="mb-3"
|
||||
isCompact={true}
|
||||
/>
|
||||
|
||||
{/* Dismiss Button */}
|
||||
{onDismiss && (
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="w-full text-muted-foreground hover:text-foreground text-xs h-7"
|
||||
onClick={onDismiss}
|
||||
>
|
||||
Continue with Current Plan
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</motion.div>
|
||||
</motion.div>
|
||||
</>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
</Portal>
|
||||
Dismiss
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
onClick={() => router.push(`/settings/billing?accountId=${accountId}`)}
|
||||
className="text-xs"
|
||||
>
|
||||
Upgrade Plan
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
|
@ -25,7 +25,7 @@ import {
|
|||
import { BillingErrorAlert } from '@/components/billing/usage-limit-alert';
|
||||
import { useBillingError } from "@/hooks/useBillingError";
|
||||
import { useAccounts } from "@/hooks/use-accounts";
|
||||
import { isLocalMode } from "@/lib/config";
|
||||
import { isLocalMode, config } from "@/lib/config";
|
||||
import { toast } from "sonner";
|
||||
|
||||
// Custom dialog overlay with blur effect
|
||||
|
@ -140,7 +140,7 @@ export function HeroSection() {
|
|||
currentUsage: error.detail.currentUsage as number | undefined,
|
||||
limit: error.detail.limit as number | undefined,
|
||||
subscription: error.detail.subscription || {
|
||||
price_id: "price_1RGJ9GG6l1KZGqIroxSqgphC", // Default Free
|
||||
price_id: config.SUBSCRIPTION_TIERS.FREE.priceId, // Default Free
|
||||
plan_name: "Free"
|
||||
}
|
||||
});
|
||||
|
|
|
@ -24,7 +24,7 @@ export function OpenSourceSection() {
|
|||
<div className="flex flex-col gap-6">
|
||||
<div className="flex items-center gap-2 text-primary font-medium">
|
||||
<Github className="h-5 w-5" />
|
||||
<span>Kortix/Suna</span>
|
||||
<span>kortix-ai/suna</span>
|
||||
</div>
|
||||
<div className="relative">
|
||||
<h3 className="text-2xl font-semibold tracking-tight">
|
||||
|
|
|
@ -1,20 +1,72 @@
|
|||
"use client";
|
||||
|
||||
import { SectionHeader } from "@/components/home/section-header";
|
||||
import type { PricingTier } from "@/lib/home";
|
||||
import { siteConfig } from "@/lib/home";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { motion } from "motion/react";
|
||||
import { useState } from "react";
|
||||
import { Github, GitFork, File, Terminal } from "lucide-react";
|
||||
import { useState, useEffect, useRef } from "react";
|
||||
import { CheckIcon } from "lucide-react";
|
||||
import Link from "next/link";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { getSubscription, createCheckoutSession, SubscriptionStatus, CreateCheckoutSessionResponse } from "@/lib/api";
|
||||
import { toast } from "sonner";
|
||||
import { isLocalMode } from "@/lib/config";
|
||||
|
||||
interface TabsProps {
|
||||
// Constants
|
||||
const DEFAULT_SELECTED_PLAN = "6 hours";
|
||||
export const SUBSCRIPTION_PLANS = {
|
||||
FREE: 'free',
|
||||
PRO: 'base',
|
||||
ENTERPRISE: 'extra',
|
||||
};
|
||||
|
||||
// Types
|
||||
type ButtonVariant = "default" | "secondary" | "ghost" | "outline" | "link" | null;
|
||||
|
||||
interface PricingTabsProps {
|
||||
activeTab: "cloud" | "self-hosted";
|
||||
setActiveTab: (tab: "cloud" | "self-hosted") => void;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
function PricingTabs({ activeTab, setActiveTab, className }: TabsProps) {
|
||||
interface PriceDisplayProps {
|
||||
price: string;
|
||||
isCompact?: boolean;
|
||||
}
|
||||
|
||||
interface CustomPriceDisplayProps {
|
||||
price: string;
|
||||
}
|
||||
|
||||
interface UpgradePlan {
|
||||
hours: string;
|
||||
price: string;
|
||||
stripePriceId: string;
|
||||
}
|
||||
|
||||
interface PricingTierProps {
|
||||
tier: PricingTier;
|
||||
isCompact?: boolean;
|
||||
currentSubscription: SubscriptionStatus | null;
|
||||
isLoading: Record<string, boolean>;
|
||||
isFetchingPlan: boolean;
|
||||
selectedPlan?: string;
|
||||
onPlanSelect?: (planId: string) => void;
|
||||
onSubscriptionUpdate?: () => void;
|
||||
isAuthenticated?: boolean;
|
||||
returnUrl: string;
|
||||
}
|
||||
|
||||
// Components
|
||||
function PricingTabs({ activeTab, setActiveTab, className }: PricingTabsProps) {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
|
@ -60,21 +112,460 @@ function PricingTabs({ activeTab, setActiveTab, className }: TabsProps) {
|
|||
);
|
||||
}
|
||||
|
||||
export function PricingSection() {
|
||||
const [deploymentType, setDeploymentType] = useState<"cloud" | "self-hosted">(
|
||||
"cloud",
|
||||
function PriceDisplay({ price, isCompact }: PriceDisplayProps) {
|
||||
return (
|
||||
<motion.span
|
||||
key={price}
|
||||
className={isCompact ? "text-xl font-semibold" : "text-4xl font-semibold"}
|
||||
initial={{
|
||||
opacity: 0,
|
||||
x: 10,
|
||||
filter: "blur(5px)",
|
||||
}}
|
||||
animate={{ opacity: 1, x: 0, filter: "blur(0px)" }}
|
||||
transition={{ duration: 0.25, ease: [0.4, 0, 0.2, 1] }}
|
||||
>
|
||||
{price}
|
||||
</motion.span>
|
||||
);
|
||||
}
|
||||
|
||||
function CustomPriceDisplay({ price }: CustomPriceDisplayProps) {
|
||||
return (
|
||||
<motion.span
|
||||
key={price}
|
||||
className="text-4xl font-semibold"
|
||||
initial={{
|
||||
opacity: 0,
|
||||
x: 10,
|
||||
filter: "blur(5px)",
|
||||
}}
|
||||
animate={{ opacity: 1, x: 0, filter: "blur(0px)" }}
|
||||
transition={{ duration: 0.25, ease: [0.4, 0, 0.2, 1] }}
|
||||
>
|
||||
{price}
|
||||
</motion.span>
|
||||
);
|
||||
}
|
||||
|
||||
function PricingTier({
|
||||
tier,
|
||||
isCompact = false,
|
||||
currentSubscription,
|
||||
isLoading,
|
||||
isFetchingPlan,
|
||||
selectedPlan,
|
||||
onPlanSelect,
|
||||
onSubscriptionUpdate,
|
||||
isAuthenticated = false,
|
||||
returnUrl
|
||||
}: PricingTierProps) {
|
||||
const [localSelectedPlan, setLocalSelectedPlan] = useState(selectedPlan || DEFAULT_SELECTED_PLAN);
|
||||
const hasInitialized = useRef(false);
|
||||
|
||||
// Auto-select the correct plan only on initial load
|
||||
useEffect(() => {
|
||||
if (!hasInitialized.current && tier.name === "Custom" && tier.upgradePlans && currentSubscription?.price_id) {
|
||||
const matchingPlan = tier.upgradePlans.find(plan => plan.stripePriceId === currentSubscription.price_id);
|
||||
if (matchingPlan) {
|
||||
setLocalSelectedPlan(matchingPlan.hours);
|
||||
}
|
||||
hasInitialized.current = true;
|
||||
}
|
||||
}, [currentSubscription, tier.name, tier.upgradePlans]);
|
||||
|
||||
// Only refetch when plan is selected
|
||||
const handlePlanSelect = (value: string) => {
|
||||
setLocalSelectedPlan(value);
|
||||
if (tier.name === "Custom" && onSubscriptionUpdate) {
|
||||
onSubscriptionUpdate();
|
||||
}
|
||||
};
|
||||
|
||||
const handleSubscribe = async (planStripePriceId: string) => {
|
||||
if (!isAuthenticated) {
|
||||
window.location.href = '/auth';
|
||||
return;
|
||||
}
|
||||
|
||||
if (isLoading[planStripePriceId]) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// For custom tier, get the selected plan's stripePriceId
|
||||
let finalPriceId = planStripePriceId;
|
||||
if (tier.name === "Custom" && tier.upgradePlans) {
|
||||
const selectedPlan = tier.upgradePlans.find(plan => plan.hours === localSelectedPlan);
|
||||
if (selectedPlan?.stripePriceId) {
|
||||
finalPriceId = selectedPlan.stripePriceId;
|
||||
}
|
||||
}
|
||||
|
||||
onPlanSelect?.(finalPriceId);
|
||||
|
||||
const response: CreateCheckoutSessionResponse = await createCheckoutSession({
|
||||
price_id: finalPriceId,
|
||||
success_url: returnUrl,
|
||||
cancel_url: returnUrl
|
||||
});
|
||||
|
||||
console.log('Subscription action response:', response);
|
||||
|
||||
switch (response.status) {
|
||||
case 'new':
|
||||
case 'checkout_created':
|
||||
if (response.url) {
|
||||
window.location.href = response.url;
|
||||
} else {
|
||||
console.error("Error: Received status 'checkout_created' but no checkout URL.");
|
||||
toast.error('Failed to initiate subscription. Please try again.');
|
||||
}
|
||||
break;
|
||||
case 'upgraded':
|
||||
case 'updated':
|
||||
const upgradeMessage = response.details?.is_upgrade
|
||||
? `Subscription upgraded from $${response.details.current_price} to $${response.details.new_price}`
|
||||
: 'Subscription updated successfully';
|
||||
toast.success(upgradeMessage);
|
||||
if (onSubscriptionUpdate) onSubscriptionUpdate();
|
||||
break;
|
||||
case 'downgrade_scheduled':
|
||||
case 'scheduled':
|
||||
const effectiveDate = response.effective_date
|
||||
? new Date(response.effective_date).toLocaleDateString()
|
||||
: 'the end of your billing period';
|
||||
|
||||
const downgradeMessage = response.details?.is_upgrade === false
|
||||
? `Subscription downgrade scheduled from $${response.details.current_price} to $${response.details.new_price}`
|
||||
: 'Subscription change scheduled';
|
||||
|
||||
toast.success(
|
||||
<div>
|
||||
<p>{downgradeMessage}</p>
|
||||
<p className="text-sm mt-1">
|
||||
Your plan will change on {effectiveDate}.
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
if (onSubscriptionUpdate) onSubscriptionUpdate();
|
||||
break;
|
||||
case 'no_change':
|
||||
toast.info(response.message || 'You are already on this plan.');
|
||||
break;
|
||||
default:
|
||||
console.warn('Received unexpected status from createCheckoutSession:', response.status);
|
||||
toast.error('An unexpected error occurred. Please try again.');
|
||||
}
|
||||
|
||||
} catch (error: any) {
|
||||
console.error('Error processing subscription:', error);
|
||||
const errorMessage = error?.response?.data?.detail || error?.message || 'Failed to process subscription. Please try again.';
|
||||
toast.error(errorMessage);
|
||||
}
|
||||
};
|
||||
|
||||
const getPriceValue = (tier: typeof siteConfig.cloudPricingItems[0], selectedHours?: string): string => {
|
||||
if (tier.upgradePlans && selectedHours) {
|
||||
const plan = tier.upgradePlans.find(plan => plan.hours === selectedHours);
|
||||
if (plan) {
|
||||
return plan.price;
|
||||
}
|
||||
}
|
||||
return tier.price;
|
||||
};
|
||||
|
||||
const getDisplayedHours = (tier: typeof siteConfig.cloudPricingItems[0]) => {
|
||||
if (tier.name === "Custom" && localSelectedPlan) {
|
||||
return localSelectedPlan;
|
||||
}
|
||||
return tier.hours;
|
||||
};
|
||||
|
||||
const getSelectedPlanPriceId = (tier: typeof siteConfig.cloudPricingItems[0]): string => {
|
||||
if (tier.name === "Custom" && tier.upgradePlans) {
|
||||
const selectedPlan = tier.upgradePlans.find(plan => plan.hours === localSelectedPlan);
|
||||
return selectedPlan?.stripePriceId || tier.stripePriceId;
|
||||
}
|
||||
return tier.stripePriceId;
|
||||
};
|
||||
|
||||
const getSelectedPlanPrice = (tier: typeof siteConfig.cloudPricingItems[0]): string => {
|
||||
if (tier.name === "Custom" && tier.upgradePlans) {
|
||||
const selectedPlan = tier.upgradePlans.find(plan => plan.hours === localSelectedPlan);
|
||||
return selectedPlan?.price || tier.price;
|
||||
}
|
||||
return tier.price;
|
||||
};
|
||||
|
||||
const tierPriceId = getSelectedPlanPriceId(tier);
|
||||
const isCurrentActivePlan = isAuthenticated && (
|
||||
// For custom tier, check if the selected plan matches the current subscription
|
||||
tier.name === "Custom"
|
||||
? tier.upgradePlans?.some(plan =>
|
||||
plan.hours === localSelectedPlan &&
|
||||
plan.stripePriceId === currentSubscription?.price_id
|
||||
)
|
||||
: currentSubscription?.price_id === tierPriceId
|
||||
);
|
||||
const isScheduled = isAuthenticated && currentSubscription?.has_schedule;
|
||||
const isScheduledTargetPlan = isScheduled && (
|
||||
// For custom tier, check if the selected plan matches the scheduled subscription
|
||||
tier.name === "Custom"
|
||||
? tier.upgradePlans?.some(plan =>
|
||||
plan.hours === localSelectedPlan &&
|
||||
plan.stripePriceId === currentSubscription?.scheduled_price_id
|
||||
)
|
||||
: currentSubscription?.scheduled_price_id === tierPriceId
|
||||
);
|
||||
const isPlanLoading = isLoading[tierPriceId];
|
||||
|
||||
let buttonText = isAuthenticated ? "Select Plan" : "Hire Suna";
|
||||
let buttonDisabled = isPlanLoading;
|
||||
let buttonVariant: ButtonVariant = null;
|
||||
let ringClass = "";
|
||||
let statusBadge = null;
|
||||
let buttonClassName = "";
|
||||
|
||||
if (isAuthenticated) {
|
||||
if (isCurrentActivePlan) {
|
||||
buttonText = "Current Plan";
|
||||
buttonDisabled = true;
|
||||
buttonVariant = "secondary";
|
||||
ringClass = isCompact ? "ring-1 ring-primary" : "ring-2 ring-primary";
|
||||
buttonClassName = "bg-primary/5 hover:bg-primary/10 text-primary";
|
||||
statusBadge = (
|
||||
<span className="bg-primary/10 text-primary text-[10px] font-medium px-1.5 py-0.5 rounded-full">
|
||||
Current
|
||||
</span>
|
||||
);
|
||||
} else if (isScheduledTargetPlan) {
|
||||
buttonText = "Scheduled";
|
||||
buttonDisabled = true;
|
||||
buttonVariant = "outline";
|
||||
ringClass = isCompact ? "ring-1 ring-yellow-500" : "ring-2 ring-yellow-500";
|
||||
buttonClassName = "bg-yellow-500/5 hover:bg-yellow-500/10 text-yellow-600 border-yellow-500/20";
|
||||
statusBadge = (
|
||||
<span className="bg-yellow-500/10 text-yellow-600 text-[10px] font-medium px-1.5 py-0.5 rounded-full">
|
||||
Scheduled
|
||||
</span>
|
||||
);
|
||||
} else if (isScheduled && currentSubscription?.price_id === tierPriceId) {
|
||||
buttonText = "Change Scheduled";
|
||||
buttonVariant = "secondary";
|
||||
ringClass = isCompact ? "ring-1 ring-primary" : "ring-2 ring-primary";
|
||||
buttonClassName = "bg-primary/5 hover:bg-primary/10 text-primary";
|
||||
statusBadge = (
|
||||
<span className="bg-yellow-500/10 text-yellow-600 text-[10px] font-medium px-1.5 py-0.5 rounded-full">
|
||||
Downgrade Pending
|
||||
</span>
|
||||
);
|
||||
} else {
|
||||
// For custom tier, find the current plan in upgradePlans
|
||||
const currentTier = tier.name === "Custom" && tier.upgradePlans
|
||||
? tier.upgradePlans.find(p => p.stripePriceId === currentSubscription?.price_id)
|
||||
: siteConfig.cloudPricingItems.find(p => p.stripePriceId === currentSubscription?.price_id);
|
||||
|
||||
// Find the highest active plan from upgradePlans
|
||||
const highestActivePlan = siteConfig.cloudPricingItems.reduce((highest, item) => {
|
||||
if (item.upgradePlans) {
|
||||
const activePlan = item.upgradePlans.find(p => p.stripePriceId === currentSubscription?.price_id);
|
||||
if (activePlan) {
|
||||
const activeAmount = parseFloat(activePlan.price.replace(/[^\d.]/g, '') || '0') * 100;
|
||||
const highestAmount = parseFloat(highest?.price?.replace(/[^\d.]/g, '') || '0') * 100;
|
||||
return activeAmount > highestAmount ? activePlan : highest;
|
||||
}
|
||||
}
|
||||
return highest;
|
||||
}, null as { price: string; hours: string; stripePriceId: string } | null);
|
||||
|
||||
const currentPriceString = currentSubscription ? (highestActivePlan?.price || currentTier?.price || '$0') : '$0';
|
||||
const selectedPriceString = getSelectedPlanPrice(tier);
|
||||
const currentAmount = currentPriceString === '$0' ? 0 : parseFloat(currentPriceString.replace(/[^\d.]/g, '') || '0') * 100;
|
||||
const targetAmount = selectedPriceString === '$0' ? 0 : parseFloat(selectedPriceString.replace(/[^\d.]/g, '') || '0') * 100;
|
||||
|
||||
if (currentAmount === 0 && targetAmount === 0 && currentSubscription?.status !== 'no_subscription') {
|
||||
buttonText = "Select Plan";
|
||||
buttonDisabled = true;
|
||||
buttonVariant = "secondary";
|
||||
buttonClassName = "bg-primary/5 hover:bg-primary/10 text-primary";
|
||||
} else {
|
||||
buttonText = targetAmount > currentAmount ? "Upgrade" : "Downgrade";
|
||||
buttonVariant = tier.buttonColor as ButtonVariant;
|
||||
buttonClassName = targetAmount > currentAmount
|
||||
? "bg-primary hover:bg-primary/90 text-primary-foreground"
|
||||
: "bg-primary/5 hover:bg-primary/10 text-primary";
|
||||
}
|
||||
}
|
||||
|
||||
if (isPlanLoading) {
|
||||
buttonText = "Loading...";
|
||||
buttonClassName = "opacity-70 cursor-not-allowed";
|
||||
}
|
||||
} else {
|
||||
// Non-authenticated state styling
|
||||
buttonVariant = tier.buttonColor as ButtonVariant;
|
||||
buttonClassName = tier.buttonColor === "default"
|
||||
? "bg-primary hover:bg-primary/90 text-white"
|
||||
: "bg-secondary hover:bg-secondary/90 text-white";
|
||||
}
|
||||
|
||||
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"rounded-xl flex flex-col relative h-fit min-h-[400px] min-[650px]:h-full min-[900px]:h-fit",
|
||||
tier.isPopular
|
||||
? "md:shadow-[0px_61px_24px_-10px_rgba(0,0,0,0.01),0px_34px_20px_-8px_rgba(0,0,0,0.05),0px_15px_15px_-6px_rgba(0,0,0,0.09),0px_4px_8px_-2px_rgba(0,0,0,0.10),0px_0px_0px_1px_rgba(0,0,0,0.08)] bg-accent"
|
||||
: "bg-[#F3F4F6] dark:bg-[#F9FAFB]/[0.02] border border-border",
|
||||
ringClass
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-col gap-4 p-4">
|
||||
<p className="text-sm flex items-center gap-2">
|
||||
{tier.name}
|
||||
{tier.isPopular && (
|
||||
<span className="bg-gradient-to-b from-secondary/50 from-[1.92%] to-secondary to-[100%] text-white h-6 inline-flex w-fit items-center justify-center px-2 rounded-full text-sm shadow-[0px_6px_6px_-3px_rgba(0,0,0,0.08),0px_3px_3px_-1.5px_rgba(0,0,0,0.08),0px_1px_1px_-0.5px_rgba(0,0,0,0.08),0px_0px_0px_1px_rgba(255,255,255,0.12)_inset,0px_1px_0px_0px_rgba(255,255,255,0.12)_inset]">
|
||||
Popular
|
||||
</span>
|
||||
)}
|
||||
{isAuthenticated && statusBadge}
|
||||
</p>
|
||||
<div className="flex items-baseline mt-2">
|
||||
{tier.name === "Custom" ? (
|
||||
<CustomPriceDisplay price={getPriceValue(tier, localSelectedPlan)} />
|
||||
) : (
|
||||
<PriceDisplay price={tier.price} />
|
||||
)}
|
||||
<span className="ml-2">
|
||||
{tier.price !== "$0" ? "/month" : ""}
|
||||
</span>
|
||||
</div>
|
||||
<p className="text-sm mt-2">{tier.description}</p>
|
||||
|
||||
{tier.name === "Custom" && tier.upgradePlans ? (
|
||||
<div className="w-full space-y-2">
|
||||
<p className="text-xs font-medium text-muted-foreground">Customize your monthly usage</p>
|
||||
<Select
|
||||
value={localSelectedPlan}
|
||||
onValueChange={handlePlanSelect}
|
||||
>
|
||||
<SelectTrigger className="w-full bg-white dark:bg-background">
|
||||
<SelectValue placeholder="Select a plan" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{tier.upgradePlans.map((plan) => (
|
||||
<SelectItem
|
||||
key={plan.hours}
|
||||
value={plan.hours}
|
||||
className={localSelectedPlan === plan.hours ? "font-medium bg-primary/5" : ""}
|
||||
>
|
||||
{plan.hours} - {plan.price}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<div className="inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-semibold bg-primary/10 border-primary/20 text-primary w-fit">
|
||||
{localSelectedPlan}/month
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-semibold bg-primary/10 border-primary/20 text-primary w-fit">
|
||||
{getDisplayedHours(tier)}/month
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="p-4 flex-grow">
|
||||
{tier.features && tier.features.length > 0 && (
|
||||
<ul className="space-y-3">
|
||||
{tier.features.map((feature) => (
|
||||
<li key={feature} className="flex items-center gap-2">
|
||||
<div className="size-5 rounded-full border border-primary/20 flex items-center justify-center">
|
||||
<CheckIcon className="size-3 text-primary" />
|
||||
</div>
|
||||
<span className="text-sm">{feature}</span>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="mt-auto p-4">
|
||||
<Button
|
||||
onClick={() => handleSubscribe(tierPriceId)}
|
||||
disabled={buttonDisabled}
|
||||
variant={buttonVariant || "default"}
|
||||
className={cn(
|
||||
"w-full font-medium transition-all duration-200",
|
||||
isCompact
|
||||
? "h-7 rounded-md text-xs"
|
||||
: "h-10 rounded-full text-sm",
|
||||
buttonClassName,
|
||||
isPlanLoading && "animate-pulse"
|
||||
)}
|
||||
>
|
||||
{buttonText}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface PricingSectionProps {
|
||||
returnUrl?: string;
|
||||
showTitleAndTabs?: boolean;
|
||||
}
|
||||
|
||||
export function PricingSection({
|
||||
returnUrl = typeof window !== 'undefined' ? window.location.href : '/',
|
||||
showTitleAndTabs = true
|
||||
}: PricingSectionProps) {
|
||||
const [deploymentType, setDeploymentType] = useState<"cloud" | "self-hosted">("cloud");
|
||||
const [currentSubscription, setCurrentSubscription] = useState<SubscriptionStatus | null>(null);
|
||||
const [isLoading, setIsLoading] = useState<Record<string, boolean>>({});
|
||||
const [isFetchingPlan, setIsFetchingPlan] = useState(true);
|
||||
const [isAuthenticated, setIsAuthenticated] = useState(false);
|
||||
|
||||
const fetchCurrentPlan = async () => {
|
||||
setIsFetchingPlan(true);
|
||||
try {
|
||||
const subscriptionData = await getSubscription();
|
||||
console.log("Fetched Subscription Status:", subscriptionData);
|
||||
setCurrentSubscription(subscriptionData);
|
||||
setIsAuthenticated(true);
|
||||
} catch (error) {
|
||||
console.error('Error fetching subscription:', error);
|
||||
setCurrentSubscription(null);
|
||||
setIsAuthenticated(false);
|
||||
} finally {
|
||||
setIsFetchingPlan(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handlePlanSelect = (planId: string) => {
|
||||
setIsLoading(prev => ({ ...prev, [planId]: true }));
|
||||
};
|
||||
|
||||
const handleSubscriptionUpdate = () => {
|
||||
fetchCurrentPlan();
|
||||
setTimeout(() => {
|
||||
setIsLoading({});
|
||||
}, 1000);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
fetchCurrentPlan();
|
||||
}, []);
|
||||
|
||||
// Handle tab change
|
||||
const handleTabChange = (tab: "cloud" | "self-hosted") => {
|
||||
if (tab === "self-hosted") {
|
||||
// Scroll to the open-source section when self-hosted tab is clicked
|
||||
const openSourceSection = document.getElementById("open-source");
|
||||
if (openSourceSection) {
|
||||
// Get the position of the section and scroll to a position slightly above it
|
||||
const rect = openSourceSection.getBoundingClientRect();
|
||||
const scrollTop = window.pageYOffset || document.documentElement.scrollTop;
|
||||
const offsetPosition = scrollTop + rect.top - 100; // 100px offset from the top
|
||||
const offsetPosition = scrollTop + rect.top - 100;
|
||||
|
||||
window.scrollTo({
|
||||
top: offsetPosition,
|
||||
|
@ -82,312 +573,64 @@ export function PricingSection() {
|
|||
});
|
||||
}
|
||||
} else {
|
||||
// Set the deployment type to cloud for cloud tab
|
||||
setDeploymentType(tab);
|
||||
}
|
||||
};
|
||||
|
||||
// Update price animation
|
||||
const PriceDisplay = ({
|
||||
tier,
|
||||
}: {
|
||||
tier: typeof siteConfig.cloudPricingItems[0];
|
||||
}) => {
|
||||
const price = tier.price;
|
||||
|
||||
if (isLocalMode()) {
|
||||
return (
|
||||
<motion.span
|
||||
key={price}
|
||||
className="text-4xl font-semibold"
|
||||
initial={{
|
||||
opacity: 0,
|
||||
x: 10,
|
||||
filter: "blur(5px)",
|
||||
}}
|
||||
animate={{ opacity: 1, x: 0, filter: "blur(0px)" }}
|
||||
transition={{ duration: 0.25, ease: [0.4, 0, 0.2, 1] }}
|
||||
>
|
||||
{price}
|
||||
</motion.span>
|
||||
);
|
||||
};
|
||||
|
||||
const SelfHostedContent = () => (
|
||||
<div className="rounded-xl bg-[#F3F4F6] dark:bg-[#F9FAFB]/[0.02] border border-border p-8 w-full max-w-6xl mx-auto">
|
||||
<div className="flex flex-col gap-6">
|
||||
<div className="inline-flex h-10 w-fit items-center justify-center gap-2 rounded-full bg-secondary/10 text-secondary px-4">
|
||||
<Github className="h-5 w-5" />
|
||||
<span className="text-sm font-medium">100% Open Source</span>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<h3 className="text-2xl font-semibold tracking-tight">
|
||||
Self-Hosted Version
|
||||
</h3>
|
||||
<p className="text-muted-foreground">
|
||||
Set up and run the platform on your own infrastructure with complete control over your data and deployment.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="grid grid-cols-1 md:grid-cols-3 gap-6 mt-6">
|
||||
<div className="rounded-xl border border-border bg-background p-6 flex flex-col gap-4">
|
||||
<div className="size-10 flex items-center justify-center rounded-full bg-primary/10 text-primary">
|
||||
<GitFork className="h-5 w-5" />
|
||||
</div>
|
||||
<h4 className="text-lg font-medium">Fork & Setup</h4>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Fork the repository and follow our step-by-step setup guide to deploy on your own infrastructure.
|
||||
</p>
|
||||
<Link
|
||||
href="https://github.com/Kortix-ai/Suna"
|
||||
target="_blank"
|
||||
className="text-sm text-primary flex items-center gap-1 mt-auto"
|
||||
>
|
||||
Fork on GitHub
|
||||
<svg width="12" height="12" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M7 17L17 7M17 7H8M17 7V16" stroke="currentColor" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round"/>
|
||||
</svg>
|
||||
</Link>
|
||||
</div>
|
||||
|
||||
<div className="rounded-xl border border-border bg-background p-6 flex flex-col gap-4">
|
||||
<div className="size-10 flex items-center justify-center rounded-full bg-primary/10 text-primary">
|
||||
<File className="h-5 w-5" />
|
||||
</div>
|
||||
<h4 className="text-lg font-medium">Documentation</h4>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Comprehensive documentation with detailed instructions for installation, configuration, and customization.
|
||||
</p>
|
||||
<Link
|
||||
href="#"
|
||||
className="text-sm text-primary flex items-center gap-1 mt-auto"
|
||||
>
|
||||
Read the docs
|
||||
<svg width="12" height="12" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M7 17L17 7M17 7H8M17 7V16" stroke="currentColor" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round"/>
|
||||
</svg>
|
||||
</Link>
|
||||
</div>
|
||||
|
||||
<div className="rounded-xl border border-border bg-background p-6 flex flex-col gap-4">
|
||||
<div className="size-10 flex items-center justify-center rounded-full bg-primary/10 text-primary">
|
||||
<Terminal className="h-5 w-5" />
|
||||
</div>
|
||||
<h4 className="text-lg font-medium">Custom APIs</h4>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Connect to your preferred language models and customize the platform to fit your specific requirements.
|
||||
</p>
|
||||
<Link
|
||||
href="#"
|
||||
className="text-sm text-primary flex items-center gap-1 mt-auto"
|
||||
>
|
||||
API Reference
|
||||
<svg width="12" height="12" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M7 17L17 7M17 7H8M17 7V16" stroke="currentColor" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round"/>
|
||||
</svg>
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="border-t border-border pt-6 mt-4">
|
||||
<h4 className="font-medium mb-4">Key benefits of self-hosting</h4>
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
|
||||
<ul className="space-y-3">
|
||||
{["Complete data privacy", "Full control over infrastructure", "Custom model integration", "Unlimited usage"].map((feature) => (
|
||||
<li key={feature} className="flex items-center gap-2">
|
||||
<div className="size-5 rounded-full border border-primary/20 flex items-center justify-center bg-muted-foreground/10">
|
||||
<div className="size-3 flex items-center justify-center">
|
||||
<svg width="8" height="7" viewBox="0 0 8 7" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M1.5 3.48828L3.375 5.36328L6.5 0.988281" stroke="currentColor" strokeWidth="1.5" strokeLinecap="round" strokeLinejoin="round"/>
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
<span className="text-sm">{feature}</span>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
<ul className="space-y-3">
|
||||
{["No usage fees", "Apache 2.0 license", "Community support", "Security customization"].map((feature) => (
|
||||
<li key={feature} className="flex items-center gap-2">
|
||||
<div className="size-5 rounded-full border border-primary/20 flex items-center justify-center bg-muted-foreground/10">
|
||||
<div className="size-3 flex items-center justify-center">
|
||||
<svg width="8" height="7" viewBox="0 0 8 7" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M1.5 3.48828L3.375 5.36328L6.5 0.988281" stroke="currentColor" strokeWidth="1.5" strokeLinecap="round" strokeLinejoin="round"/>
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
<span className="text-sm">{feature}</span>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col sm:flex-row gap-4 mt-4">
|
||||
<Link
|
||||
href="https://github.com/Kortix-ai/Suna"
|
||||
target="_blank"
|
||||
className="inline-flex h-11 items-center justify-center gap-2 text-sm font-medium tracking-wide rounded-full bg-primary text-white px-6 shadow-md hover:bg-primary/90 transition-all"
|
||||
>
|
||||
<Github className="h-4 w-4" />
|
||||
View on GitHub
|
||||
</Link>
|
||||
<Link
|
||||
href="#"
|
||||
className="inline-flex h-11 items-center justify-center gap-2 text-sm font-medium tracking-wide rounded-full bg-secondary/10 text-secondary px-6 border border-secondary/20 hover:bg-secondary/20 transition-all"
|
||||
>
|
||||
Read Documentation
|
||||
</Link>
|
||||
</div>
|
||||
<div className="p-4 bg-muted/30 border border-border rounded-lg text-center">
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Running in local development mode - billing features are disabled
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<section
|
||||
id="pricing"
|
||||
className="flex flex-col items-center justify-center gap-10 pb-20 w-full relative"
|
||||
>
|
||||
<SectionHeader>
|
||||
<h2 className="text-3xl md:text-4xl font-medium tracking-tighter text-center text-balance">
|
||||
General Intelligence available today
|
||||
</h2>
|
||||
<p className="text-muted-foreground text-center text-balance font-medium">
|
||||
You can self-host Suna or use our cloud for managed service.
|
||||
</p>
|
||||
</SectionHeader>
|
||||
<div className="relative w-full h-full">
|
||||
<div className="absolute -top-14 left-1/2 -translate-x-1/2">
|
||||
<PricingTabs
|
||||
activeTab={deploymentType}
|
||||
setActiveTab={handleTabChange}
|
||||
className="mx-auto"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{deploymentType === "cloud" && (
|
||||
<div className="grid min-[650px]:grid-cols-2 min-[900px]:grid-cols-3 gap-4 w-full max-w-6xl mx-auto px-6">
|
||||
{siteConfig.cloudPricingItems.map((tier) => (
|
||||
<div
|
||||
key={tier.name}
|
||||
className={cn(
|
||||
"rounded-xl grid grid-rows-[180px_auto_1fr] relative h-fit min-[650px]:h-full min-[900px]:h-fit",
|
||||
tier.isPopular
|
||||
? "md:shadow-[0px_61px_24px_-10px_rgba(0,0,0,0.01),0px_34px_20px_-8px_rgba(0,0,0,0.05),0px_15px_15px_-6px_rgba(0,0,0,0.09),0px_4px_8px_-2px_rgba(0,0,0,0.10),0px_0px_0px_1px_rgba(0,0,0,0.08)] bg-accent"
|
||||
: "bg-[#F3F4F6] dark:bg-[#F9FAFB]/[0.02] border border-border",
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-col gap-4 p-4">
|
||||
<p className="text-sm flex items-center">
|
||||
{tier.name}
|
||||
{tier.isPopular && (
|
||||
<span className="bg-gradient-to-b from-secondary/50 from-[1.92%] to-secondary to-[100%] text-white h-6 inline-flex w-fit items-center justify-center px-2 rounded-full text-sm ml-2 shadow-[0px_6px_6px_-3px_rgba(0,0,0,0.08),0px_3px_3px_-1.5px_rgba(0,0,0,0.08),0px_1px_1px_-0.5px_rgba(0,0,0,0.08),0px_0px_0px_1px_rgba(255,255,255,0.12)_inset,0px_1px_0px_0px_rgba(255,255,255,0.12)_inset]">
|
||||
Popular
|
||||
</span>
|
||||
)}
|
||||
</p>
|
||||
<div className="flex items-baseline mt-2">
|
||||
<PriceDisplay tier={tier} />
|
||||
<span className="ml-2">
|
||||
{tier.price !== "$0" ? "/month" : ""}
|
||||
</span>
|
||||
</div>
|
||||
<p className="text-sm mt-2">{tier.description}</p>
|
||||
<div className="inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-semibold bg-primary/10 border-primary/20 text-primary w-fit">
|
||||
{tier.hours}/month
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col gap-2 p-4">
|
||||
{tier.buttonText === "Hire Suna" ? (
|
||||
<Link
|
||||
href="/auth"
|
||||
className={`h-10 w-full flex items-center justify-center text-sm font-normal tracking-wide rounded-full px-4 cursor-pointer transition-all ease-out active:scale-95 ${
|
||||
tier.isPopular
|
||||
? `${tier.buttonColor} shadow-[inset_0_1px_2px_rgba(255,255,255,0.25),0_3px_3px_-1.5px_rgba(16,24,40,0.06),0_1px_1px_rgba(16,24,40,0.08)]`
|
||||
: `${tier.buttonColor} shadow-[0px_1px_2px_0px_rgba(255,255,255,0.16)_inset,0px_3px_3px_-1.5px_rgba(16,24,40,0.24),0px_1px_1px_-0.5px_rgba(16,24,40,0.20)]`
|
||||
}`}
|
||||
>
|
||||
{tier.buttonText}
|
||||
</Link>
|
||||
) : (
|
||||
<button
|
||||
className={`h-10 w-full flex items-center justify-center text-sm font-normal tracking-wide rounded-full px-4 cursor-pointer transition-all ease-out active:scale-95 ${
|
||||
tier.isPopular
|
||||
? `${tier.buttonColor} shadow-[inset_0_1px_2px_rgba(255,255,255,0.25),0_3px_3px_-1.5px_rgba(16,24,40,0.06),0_1px_1px_rgba(16,24,40,0.08)]`
|
||||
: `${tier.buttonColor} shadow-[0px_1px_2px_0px_rgba(255,255,255,0.16)_inset,0px_3px_3px_-1.5px_rgba(16,24,40,0.24),0px_1px_1px_-0.5px_rgba(16,24,40,0.20)]`
|
||||
}`}
|
||||
>
|
||||
{tier.buttonText}
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
{/* <hr className="border-border dark:border-white/20" /> */}
|
||||
<div className="p-4">
|
||||
{/*
|
||||
{tier.name !== "Free" && (
|
||||
<p className="text-sm mb-4">
|
||||
Everything in {tier.name === "Pro" ? "Free" : "Pro"} +
|
||||
</p>
|
||||
)}
|
||||
<ul className="space-y-3">
|
||||
{tier.features.filter(feature => !feature.startsWith('//'))
|
||||
.map((feature) => (
|
||||
<li key={feature} className="flex items-center gap-2">
|
||||
<div
|
||||
className={cn(
|
||||
"size-5 rounded-full border border-primary/20 flex items-center justify-center",
|
||||
tier.isPopular &&
|
||||
"bg-muted-foreground/40 border-border",
|
||||
)}
|
||||
>
|
||||
<div className="size-3 flex items-center justify-center">
|
||||
<svg
|
||||
width="8"
|
||||
height="7"
|
||||
viewBox="0 0 8 7"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className="block dark:hidden"
|
||||
>
|
||||
<path
|
||||
d="M1.5 3.48828L3.375 5.36328L6.5 0.988281"
|
||||
stroke="#101828"
|
||||
strokeWidth="1.5"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
|
||||
<svg
|
||||
width="8"
|
||||
height="7"
|
||||
viewBox="0 0 8 7"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className="hidden dark:block"
|
||||
>
|
||||
<path
|
||||
d="M1.5 3.48828L3.375 5.36328L6.5 0.988281"
|
||||
stroke="#FAFAFA"
|
||||
strokeWidth="1.5"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
<span className="text-sm">{feature}</span>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
*/}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
{showTitleAndTabs && (
|
||||
<>
|
||||
<SectionHeader>
|
||||
<h2 className="text-3xl md:text-4xl font-medium tracking-tighter text-center text-balance">
|
||||
Choose the right plan for your needs
|
||||
</h2>
|
||||
<p className="text-muted-foreground text-center text-balance font-medium">
|
||||
Start with our free plan or upgrade to a premium plan for more usage hours
|
||||
</p>
|
||||
</SectionHeader>
|
||||
<div className="relative w-full h-full">
|
||||
<div className="absolute -top-14 left-1/2 -translate-x-1/2">
|
||||
<PricingTabs
|
||||
activeTab={deploymentType}
|
||||
setActiveTab={handleTabChange}
|
||||
className="mx-auto"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
||||
{deploymentType === "cloud" && (
|
||||
<div className="grid min-[650px]:grid-cols-2 min-[900px]:grid-cols-3 gap-4 w-full max-w-6xl mx-auto px-6">
|
||||
{siteConfig.cloudPricingItems.map((tier) => (
|
||||
<PricingTier
|
||||
key={tier.name}
|
||||
tier={tier}
|
||||
currentSubscription={currentSubscription}
|
||||
isLoading={isLoading}
|
||||
isFetchingPlan={isFetchingPlan}
|
||||
onPlanSelect={handlePlanSelect}
|
||||
onSubscriptionUpdate={handleSubscriptionUpdate}
|
||||
isAuthenticated={isAuthenticated}
|
||||
returnUrl={returnUrl}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</section>
|
||||
);
|
||||
}
|
||||
}
|
|
@ -1,55 +0,0 @@
|
|||
"use server";
|
||||
|
||||
import { redirect } from "next/navigation";
|
||||
import { createClient } from "../supabase/server";
|
||||
import handleEdgeFunctionError from "../supabase/handle-edge-error";
|
||||
|
||||
export async function setupNewSubscription(prevState: any, formData: FormData) {
|
||||
const accountId = formData.get("accountId") as string;
|
||||
const returnUrl = formData.get("returnUrl") as string;
|
||||
const planId = formData.get("planId") as string;
|
||||
const supabaseClient = await createClient();
|
||||
|
||||
const { data, error } = await supabaseClient.functions.invoke('billing-functions', {
|
||||
body: {
|
||||
action: "get_new_subscription_url",
|
||||
args: {
|
||||
account_id: accountId,
|
||||
success_url: returnUrl,
|
||||
cancel_url: returnUrl,
|
||||
plan_id: planId
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (error) {
|
||||
return await handleEdgeFunctionError(error);
|
||||
}
|
||||
|
||||
redirect(data.url);
|
||||
};
|
||||
|
||||
export async function manageSubscription(prevState: any, formData: FormData) {
|
||||
const accountId = formData.get("accountId") as string;
|
||||
const returnUrl = formData.get("returnUrl") as string;
|
||||
const supabaseClient = await createClient();
|
||||
|
||||
const { data, error } = await supabaseClient.functions.invoke('billing-functions', {
|
||||
body: {
|
||||
action: "get_billing_portal_url",
|
||||
args: {
|
||||
account_id: accountId,
|
||||
return_url: returnUrl
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
console.log(data);
|
||||
|
||||
if (error) {
|
||||
console.error(error);
|
||||
return await handleEdgeFunctionError(error);
|
||||
}
|
||||
|
||||
redirect(data.url);
|
||||
};
|
|
@ -1177,3 +1177,202 @@ export const checkApiHealth = async (): Promise<HealthCheckResponse> => {
|
|||
}
|
||||
};
|
||||
|
||||
// Billing API Types
|
||||
export interface CreateCheckoutSessionRequest {
|
||||
price_id: string;
|
||||
success_url: string;
|
||||
cancel_url: string;
|
||||
}
|
||||
|
||||
export interface CreatePortalSessionRequest {
|
||||
return_url: string;
|
||||
}
|
||||
|
||||
export interface SubscriptionStatus {
|
||||
status: string; // Includes 'active', 'trialing', 'past_due', 'scheduled_downgrade', 'no_subscription'
|
||||
plan_name?: string;
|
||||
price_id?: string; // Added
|
||||
current_period_end?: string; // ISO Date string
|
||||
cancel_at_period_end: boolean;
|
||||
trial_end?: string; // ISO Date string
|
||||
minutes_limit?: number;
|
||||
current_usage?: number;
|
||||
// Fields for scheduled changes
|
||||
has_schedule: boolean;
|
||||
scheduled_plan_name?: string;
|
||||
scheduled_price_id?: string; // Added
|
||||
scheduled_change_date?: string; // ISO Date string - Deprecate? Check backend usage
|
||||
schedule_effective_date?: string; // ISO Date string - Added for consistency
|
||||
}
|
||||
|
||||
export interface BillingStatusResponse {
|
||||
can_run: boolean;
|
||||
message: string;
|
||||
subscription: {
|
||||
price_id: string;
|
||||
plan_name: string;
|
||||
minutes_limit?: number;
|
||||
};
|
||||
}
|
||||
|
||||
export interface CreateCheckoutSessionResponse {
|
||||
status: 'upgraded' | 'downgrade_scheduled' | 'checkout_created' | 'no_change' | 'new' | 'updated' | 'scheduled';
|
||||
subscription_id?: string;
|
||||
schedule_id?: string;
|
||||
session_id?: string;
|
||||
url?: string;
|
||||
effective_date?: string;
|
||||
message?: string;
|
||||
details?: {
|
||||
is_upgrade?: boolean;
|
||||
effective_date?: string;
|
||||
current_price?: number;
|
||||
new_price?: number;
|
||||
invoice?: {
|
||||
id: string;
|
||||
status: string;
|
||||
amount_due: number;
|
||||
amount_paid: number;
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
// Billing API Functions
|
||||
export const createCheckoutSession = async (request: CreateCheckoutSessionRequest): Promise<CreateCheckoutSessionResponse> => {
|
||||
try {
|
||||
const supabase = createClient();
|
||||
const { data: { session } } = await supabase.auth.getSession();
|
||||
|
||||
if (!session?.access_token) {
|
||||
throw new Error('No access token available');
|
||||
}
|
||||
|
||||
const response = await fetch(`${API_URL}/billing/create-checkout-session`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': `Bearer ${session.access_token}`,
|
||||
},
|
||||
body: JSON.stringify(request),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text().catch(() => 'No error details available');
|
||||
console.error(`Error creating checkout session: ${response.status} ${response.statusText}`, errorText);
|
||||
throw new Error(`Error creating checkout session: ${response.statusText} (${response.status})`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
console.log('Checkout session response:', data);
|
||||
|
||||
// Handle all possible statuses
|
||||
switch (data.status) {
|
||||
case 'upgraded':
|
||||
case 'updated':
|
||||
case 'downgrade_scheduled':
|
||||
case 'scheduled':
|
||||
case 'no_change':
|
||||
return data;
|
||||
case 'new':
|
||||
case 'checkout_created':
|
||||
if (!data.url) {
|
||||
throw new Error('No checkout URL provided');
|
||||
}
|
||||
return data;
|
||||
default:
|
||||
console.warn('Unexpected status from createCheckoutSession:', data.status);
|
||||
return data;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to create checkout session:', error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
export const createPortalSession = async (request: CreatePortalSessionRequest): Promise<{ url: string }> => {
|
||||
try {
|
||||
const supabase = createClient();
|
||||
const { data: { session } } = await supabase.auth.getSession();
|
||||
|
||||
if (!session?.access_token) {
|
||||
throw new Error('No access token available');
|
||||
}
|
||||
|
||||
const response = await fetch(`${API_URL}/billing/create-portal-session`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': `Bearer ${session.access_token}`,
|
||||
},
|
||||
body: JSON.stringify(request),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text().catch(() => 'No error details available');
|
||||
console.error(`Error creating portal session: ${response.status} ${response.statusText}`, errorText);
|
||||
throw new Error(`Error creating portal session: ${response.statusText} (${response.status})`);
|
||||
}
|
||||
|
||||
return response.json();
|
||||
} catch (error) {
|
||||
console.error('Failed to create portal session:', error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
export const getSubscription = async (): Promise<SubscriptionStatus> => {
|
||||
try {
|
||||
const supabase = createClient();
|
||||
const { data: { session } } = await supabase.auth.getSession();
|
||||
|
||||
if (!session?.access_token) {
|
||||
throw new Error('No access token available');
|
||||
}
|
||||
|
||||
const response = await fetch(`${API_URL}/billing/subscription`, {
|
||||
headers: {
|
||||
'Authorization': `Bearer ${session.access_token}`,
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text().catch(() => 'No error details available');
|
||||
console.error(`Error getting subscription: ${response.status} ${response.statusText}`, errorText);
|
||||
throw new Error(`Error getting subscription: ${response.statusText} (${response.status})`);
|
||||
}
|
||||
|
||||
return response.json();
|
||||
} catch (error) {
|
||||
console.error('Failed to get subscription:', error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
export const checkBillingStatus = async (): Promise<BillingStatusResponse> => {
|
||||
try {
|
||||
const supabase = createClient();
|
||||
const { data: { session } } = await supabase.auth.getSession();
|
||||
|
||||
if (!session?.access_token) {
|
||||
throw new Error('No access token available');
|
||||
}
|
||||
|
||||
const response = await fetch(`${API_URL}/billing/check-status`, {
|
||||
headers: {
|
||||
'Authorization': `Bearer ${session.access_token}`,
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text().catch(() => 'No error details available');
|
||||
console.error(`Error checking billing status: ${response.status} ${response.statusText}`, errorText);
|
||||
throw new Error(`Error checking billing status: ${response.statusText} (${response.status})`);
|
||||
}
|
||||
|
||||
return response.json();
|
||||
} catch (error) {
|
||||
console.error('Failed to check billing status:', error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -5,12 +5,103 @@ export enum EnvMode {
|
|||
PRODUCTION = 'production',
|
||||
}
|
||||
|
||||
// Subscription tier structure
|
||||
export interface SubscriptionTierData {
|
||||
priceId: string;
|
||||
name: string;
|
||||
}
|
||||
|
||||
// Subscription tiers structure
|
||||
export interface SubscriptionTiers {
|
||||
FREE: SubscriptionTierData;
|
||||
TIER_2_20: SubscriptionTierData;
|
||||
TIER_6_50: SubscriptionTierData;
|
||||
TIER_12_100: SubscriptionTierData;
|
||||
TIER_25_200: SubscriptionTierData;
|
||||
TIER_50_400: SubscriptionTierData;
|
||||
TIER_125_800: SubscriptionTierData;
|
||||
TIER_200_1000: SubscriptionTierData;
|
||||
}
|
||||
|
||||
// Configuration object
|
||||
interface Config {
|
||||
ENV_MODE: EnvMode;
|
||||
IS_LOCAL: boolean;
|
||||
SUBSCRIPTION_TIERS: SubscriptionTiers;
|
||||
}
|
||||
|
||||
// Production tier IDs
|
||||
const PROD_TIERS: SubscriptionTiers = {
|
||||
FREE: {
|
||||
priceId: 'price_1RILb4G6l1KZGqIrK4QLrx9i',
|
||||
name: 'Free',
|
||||
},
|
||||
TIER_2_20: {
|
||||
priceId: 'price_1RILb4G6l1KZGqIrhomjgDnO',
|
||||
name: '2h/$20',
|
||||
},
|
||||
TIER_6_50: {
|
||||
priceId: 'price_1RILb4G6l1KZGqIr5q0sybWn',
|
||||
name: '6h/$50',
|
||||
},
|
||||
TIER_12_100: {
|
||||
priceId: 'price_1RILb4G6l1KZGqIr5Y20ZLHm',
|
||||
name: '12h/$100',
|
||||
},
|
||||
TIER_25_200: {
|
||||
priceId: 'price_1RILb4G6l1KZGqIrGAD8rNjb',
|
||||
name: '25h/$200',
|
||||
},
|
||||
TIER_50_400: {
|
||||
priceId: 'price_1RILb4G6l1KZGqIruNBUMTF1',
|
||||
name: '50h/$400',
|
||||
},
|
||||
TIER_125_800: {
|
||||
priceId: 'price_1RILb3G6l1KZGqIrbJA766tN',
|
||||
name: '125h/$800',
|
||||
},
|
||||
TIER_200_1000: {
|
||||
priceId: 'price_1RILb3G6l1KZGqIrmauYPOiN',
|
||||
name: '200h/$1000',
|
||||
}
|
||||
} as const;
|
||||
|
||||
// Staging tier IDs
|
||||
const STAGING_TIERS: SubscriptionTiers = {
|
||||
FREE: {
|
||||
priceId: 'price_1RIGvuG6l1KZGqIrw14abxeL',
|
||||
name: 'Free',
|
||||
},
|
||||
TIER_2_20: {
|
||||
priceId: 'price_1RIGvuG6l1KZGqIrCRu0E4Gi',
|
||||
name: '2h/$20',
|
||||
},
|
||||
TIER_6_50: {
|
||||
priceId: 'price_1RIGvuG6l1KZGqIrvjlz5p5V',
|
||||
name: '6h/$50',
|
||||
},
|
||||
TIER_12_100: {
|
||||
priceId: 'price_1RIGvuG6l1KZGqIrT6UfgblC',
|
||||
name: '12h/$100',
|
||||
},
|
||||
TIER_25_200: {
|
||||
priceId: 'price_1RIGvuG6l1KZGqIrOVLKlOMj',
|
||||
name: '25h/$200',
|
||||
},
|
||||
TIER_50_400: {
|
||||
priceId: 'price_1RIKNgG6l1KZGqIrvsat5PW7',
|
||||
name: '50h/$400',
|
||||
},
|
||||
TIER_125_800: {
|
||||
priceId: 'price_1RIKNrG6l1KZGqIrjKT0yGvI',
|
||||
name: '125h/$800',
|
||||
},
|
||||
TIER_200_1000: {
|
||||
priceId: 'price_1RIKQ2G6l1KZGqIrum9n8SI7',
|
||||
name: '200h/$1000',
|
||||
}
|
||||
} as const;
|
||||
|
||||
// Determine the environment mode from environment variables
|
||||
const getEnvironmentMode = (): EnvMode => {
|
||||
// Get the environment mode from the environment variable, if set
|
||||
|
@ -47,9 +138,13 @@ const currentEnvMode = getEnvironmentMode();
|
|||
export const config: Config = {
|
||||
ENV_MODE: currentEnvMode,
|
||||
IS_LOCAL: currentEnvMode === EnvMode.LOCAL,
|
||||
SUBSCRIPTION_TIERS: currentEnvMode === EnvMode.STAGING ? STAGING_TIERS : PROD_TIERS,
|
||||
};
|
||||
|
||||
// Helper function to check if we're in local mode (for component conditionals)
|
||||
export const isLocalMode = (): boolean => {
|
||||
return config.IS_LOCAL;
|
||||
};
|
||||
};
|
||||
|
||||
// Export subscription tier type for typing elsewhere
|
||||
export type SubscriptionTier = keyof typeof PROD_TIERS;
|
|
@ -6,6 +6,7 @@ import { FlickeringGrid } from "@/components/home/ui/flickering-grid";
|
|||
import { Globe } from "@/components/home/ui/globe";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { motion } from "motion/react";
|
||||
import { config } from '@/lib/config';
|
||||
|
||||
export const Highlight = ({
|
||||
children,
|
||||
|
@ -28,6 +29,25 @@ export const Highlight = ({
|
|||
|
||||
export const BLUR_FADE_DELAY = 0.15;
|
||||
|
||||
interface UpgradePlan {
|
||||
hours: string;
|
||||
price: string;
|
||||
stripePriceId: string;
|
||||
}
|
||||
|
||||
export interface PricingTier {
|
||||
name: string;
|
||||
price: string;
|
||||
description: string;
|
||||
buttonText: string;
|
||||
buttonColor: string;
|
||||
isPopular: boolean;
|
||||
hours: string;
|
||||
features: string[];
|
||||
stripePriceId: string;
|
||||
upgradePlans: UpgradePlan[];
|
||||
}
|
||||
|
||||
export const siteConfig = {
|
||||
name: "Kortix Suna",
|
||||
description: "The Generalist AI Agent that can act on your behalf.",
|
||||
|
@ -79,54 +99,53 @@ export const siteConfig = {
|
|||
{
|
||||
name: "Free",
|
||||
price: "$0",
|
||||
description: "For individual use and exploration",
|
||||
description: "Get started with",
|
||||
buttonText: "Hire Suna",
|
||||
buttonColor: "bg-secondary text-white",
|
||||
isPopular: false,
|
||||
hours: "10 min",
|
||||
features: [
|
||||
"10 minutes",
|
||||
// "Community support",
|
||||
// "Single user",
|
||||
// "Standard response time",
|
||||
"Public Projects",
|
||||
],
|
||||
stripePriceId: 'price_1RGJ9GG6l1KZGqIroxSqgphC',
|
||||
stripePriceId: config.SUBSCRIPTION_TIERS.FREE.priceId,
|
||||
upgradePlans: [],
|
||||
},
|
||||
{
|
||||
name: "Pro",
|
||||
price: "$29",
|
||||
description: "For professionals and small teams",
|
||||
price: "$20",
|
||||
description: "Everything in Free, plus:",
|
||||
buttonText: "Hire Suna",
|
||||
buttonColor: "bg-primary text-white dark:text-black",
|
||||
isPopular: true,
|
||||
hours: "4 hours",
|
||||
hours: "2 hours",
|
||||
features: [
|
||||
"4 hours usage per month",
|
||||
// "Priority support",
|
||||
// "Advanced features",
|
||||
// "5 team members",
|
||||
// "Custom integrations",
|
||||
"2 hours",
|
||||
"Private projects",
|
||||
"Team functionality (coming soon)",
|
||||
],
|
||||
stripePriceId: 'price_1RGJ9LG6l1KZGqIrd9pwzeNW',
|
||||
stripePriceId: config.SUBSCRIPTION_TIERS.TIER_2_20.priceId,
|
||||
upgradePlans: [],
|
||||
},
|
||||
{
|
||||
name: "Enterprise",
|
||||
price: "$199",
|
||||
description: "For organizations with complex needs",
|
||||
name: "Custom",
|
||||
price: "$50",
|
||||
description: "Everything in Pro, plus:",
|
||||
buttonText: "Hire Suna",
|
||||
buttonColor: "bg-secondary text-white",
|
||||
isPopular: false,
|
||||
hours: "40 hours",
|
||||
hours: "6 hours",
|
||||
features: [
|
||||
"40 hours usage per month",
|
||||
// "Dedicated support",
|
||||
// "SSO & advanced security",
|
||||
// "Unlimited team members",
|
||||
// "Service level agreement",
|
||||
// "Custom AI model training",
|
||||
"Unlimited seats",
|
||||
],
|
||||
showContactSales: true,
|
||||
stripePriceId: 'price_1RGJ9JG6l1KZGqIrVUU4ZRv6',
|
||||
upgradePlans: [
|
||||
{ hours: "6 hours", price: "$50", stripePriceId: config.SUBSCRIPTION_TIERS.TIER_6_50.priceId },
|
||||
{ hours: "12 hours", price: "$100", stripePriceId: config.SUBSCRIPTION_TIERS.TIER_12_100.priceId },
|
||||
{ hours: "25 hours", price: "$200", stripePriceId: config.SUBSCRIPTION_TIERS.TIER_25_200.priceId },
|
||||
{ hours: "50 hours", price: "$400", stripePriceId: config.SUBSCRIPTION_TIERS.TIER_50_400.priceId },
|
||||
{ hours: "125 hours", price: "$800", stripePriceId: config.SUBSCRIPTION_TIERS.TIER_125_800.priceId },
|
||||
{ hours: "200 hours", price: "$1000", stripePriceId: config.SUBSCRIPTION_TIERS.TIER_200_1000.priceId },
|
||||
],
|
||||
stripePriceId: config.SUBSCRIPTION_TIERS.TIER_6_50.priceId,
|
||||
},
|
||||
],
|
||||
companyShowcase: {
|
||||
|
|
|
@ -13,8 +13,8 @@ export const createClient = async () => {
|
|||
supabaseUrl = `http://${supabaseUrl}`;
|
||||
}
|
||||
|
||||
console.log('[SERVER] Supabase URL:', supabaseUrl);
|
||||
console.log('[SERVER] Supabase Anon Key:', supabaseAnonKey);
|
||||
// console.log('[SERVER] Supabase URL:', supabaseUrl);
|
||||
// console.log('[SERVER] Supabase Anon Key:', supabaseAnonKey);
|
||||
|
||||
return createServerClient(
|
||||
supabaseUrl,
|
||||
|
|
Loading…
Reference in New Issue