Compare commits

...

13 Commits

Author SHA1 Message Date
marko-kraemer 6623e87ea9 fix login redirect 2025-04-27 04:48:09 +01:00
marko-kraemer b3666e8aad update prod configs 2025-04-27 04:15:55 +01:00
marko-kraemer a14c2a1a2c update prod configs 2025-04-27 04:15:48 +01:00
marko-kraemer 5d28b65111 fix dependency 2025-04-27 04:10:38 +01:00
Marko Kraemer 7628ced002
Merge pull request #155 from kortix-ai/fix/billing
v1 Fix/billing
2025-04-26 19:49:35 -07:00
marko-kraemer b7b7eeb705 serious wip 2025-04-27 03:44:58 +01:00
marko-kraemer a7d38c0096 serious wip 2025-04-27 03:20:49 +01:00
marko-kraemer 23574e37cf python billing 2025-04-27 00:47:31 +01:00
marko-kraemer 09c4099ca5 temp wip, downgrade & upgrade w. credit not implemented 2025-04-26 20:54:41 +01:00
marko-kraemer 28da425ce8 wip 2025-04-26 18:56:52 +01:00
marko-kraemer 865b2f3633 wip 2025-04-26 18:51:25 +01:00
marko-kraemer b3f1398c3d wip 2025-04-26 16:55:57 +01:00
marko-kraemer d6706ead43 wip 2025-04-26 16:37:03 +01:00
44 changed files with 2083 additions and 2853 deletions

1
.gitignore vendored
View File

@ -162,7 +162,6 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# AgentPress
/threads
state.json
/workspace/

1
backend/.gitignore vendored
View File

@ -163,7 +163,6 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# AgentPress
/threads
state.json
/workspace/

View File

@ -15,9 +15,9 @@ from agentpress.thread_manager import ThreadManager
from services.supabase import DBConnection
from services import redis
from agent.run import run_agent
from utils.auth_utils import get_current_user_id, get_user_id_from_stream_auth, verify_thread_access
from utils.auth_utils import get_current_user_id_from_jwt, get_user_id_from_stream_auth, verify_thread_access
from utils.logger import logger
from utils.billing import check_billing_status, get_account_id_from_thread
from services.billing import check_billing_status
from sandbox.sandbox import create_sandbox, get_or_start_sandbox
from services.llm import make_llm_api_call
@ -348,7 +348,7 @@ async def get_or_create_project_sandbox(client, project_id: str):
async def start_agent(
thread_id: str,
body: AgentStartRequest = Body(...),
user_id: str = Depends(get_current_user_id)
user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Start an agent for a specific thread in the background."""
global instance_id # Ensure instance_id is accessible
@ -412,7 +412,7 @@ async def start_agent(
return {"agent_run_id": agent_run_id, "status": "running"}
@router.post("/agent-run/{agent_run_id}/stop")
async def stop_agent(agent_run_id: str, user_id: str = Depends(get_current_user_id)):
async def stop_agent(agent_run_id: str, user_id: str = Depends(get_current_user_id_from_jwt)):
"""Stop a running agent."""
logger.info(f"Received request to stop agent run: {agent_run_id}")
client = await db.client
@ -421,7 +421,7 @@ async def stop_agent(agent_run_id: str, user_id: str = Depends(get_current_user_
return {"status": "stopped"}
@router.get("/thread/{thread_id}/agent-runs")
async def get_agent_runs(thread_id: str, user_id: str = Depends(get_current_user_id)):
async def get_agent_runs(thread_id: str, user_id: str = Depends(get_current_user_id_from_jwt)):
"""Get all agent runs for a thread."""
logger.info(f"Fetching agent runs for thread: {thread_id}")
client = await db.client
@ -431,7 +431,7 @@ async def get_agent_runs(thread_id: str, user_id: str = Depends(get_current_user
return {"agent_runs": agent_runs.data}
@router.get("/agent-run/{agent_run_id}")
async def get_agent_run(agent_run_id: str, user_id: str = Depends(get_current_user_id)):
async def get_agent_run(agent_run_id: str, user_id: str = Depends(get_current_user_id_from_jwt)):
"""Get agent run status and responses."""
logger.info(f"Fetching agent run details: {agent_run_id}")
client = await db.client
@ -859,7 +859,7 @@ async def initiate_agent_with_files(
stream: Optional[bool] = Form(True),
enable_context_manager: Optional[bool] = Form(False),
files: List[UploadFile] = File(default=[]),
user_id: str = Depends(get_current_user_id)
user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Initiate a new agent session with optional file attachments."""
global instance_id # Ensure instance_id is accessible

View File

@ -20,7 +20,8 @@ from agent.tools.sb_browser_tool import SandboxBrowserTool
from agent.tools.data_providers_tool import DataProvidersTool
from agent.prompt import get_system_prompt
from utils import logger
from utils.billing import check_billing_status, get_account_id_from_thread
from utils.auth_utils import get_account_id_from_thread
from services.billing import check_billing_status
load_dotenv()

View File

@ -16,6 +16,7 @@ from collections import OrderedDict
# Import the agent API module
from agent import api as agent_api
from sandbox import api as sandbox_api
from services import billing as billing_api
# Load environment variables (these will be available through config)
load_dotenv()
@ -132,6 +133,9 @@ app.include_router(agent_api.router, prefix="/api")
# Include the sandbox router with a prefix
app.include_router(sandbox_api.router, prefix="/api")
# Include the billing router with a prefix
app.include_router(billing_api.router, prefix="/api")
@app.get("/api/health")
async def health_check():
"""Health check endpoint to verify API is working."""
@ -152,5 +156,6 @@ if __name__ == "__main__":
"api:app",
host="0.0.0.0",
port=8000,
workers=workers
workers=workers,
reload=True
)

17
backend/poetry.lock generated
View File

@ -2873,6 +2873,21 @@ docs = ["myst-parser[linkify]", "sphinx", "sphinx-rtd-theme"]
release = ["twine"]
test = ["pylint", "pytest", "pytest-black", "pytest-cov", "pytest-pylint"]
[[package]]
name = "stripe"
version = "12.0.1"
description = "Python bindings for the Stripe API"
optional = false
python-versions = ">=3.6"
files = [
{file = "stripe-12.0.1-py2.py3-none-any.whl", hash = "sha256:b10b19dbd0622868b98a7c6e879ebde704be96ad75c780944bca4069bb427988"},
{file = "stripe-12.0.1.tar.gz", hash = "sha256:3fc7cc190946d8ebcc5b637e7e04f387d61b9c5156a89619a3ba90704ac09d4a"},
]
[package.dependencies]
requests = {version = ">=2.20", markers = "python_version >= \"3.0\""}
typing_extensions = {version = ">=4.5.0", markers = "python_version >= \"3.7\""}
[[package]]
name = "supabase"
version = "2.15.0"
@ -3622,4 +3637,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"]
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "622a06feff14fc27c612f15e50be3375531175462c46fa57c3bcf33851e2a9c3"
content-hash = "160a62f76af02d841f0f1b60e2d96c5c8e91310182d728dfa82bb792fa098e95"

View File

@ -20,7 +20,7 @@ classifiers = [
python = "^3.11"
streamlit-quill = "0.0.3"
python-dotenv = "1.0.1"
litellm = "^1.44.0"
litellm = "1.66.1"
click = "8.1.7"
questionary = "2.0.1"
requests = "^2.31.0"
@ -50,6 +50,7 @@ nest-asyncio = "^1.6.0"
vncdotool = "^1.2.0"
tavily-python = "^0.5.4"
pytesseract = "^0.3.13"
stripe = "^12.0.1"
[tool.poetry.scripts]
agentpress = "agentpress.cli:main"
@ -57,7 +58,6 @@ agentpress = "agentpress.cli:main"
[[tool.poetry.packages]]
include = "agentpress"
[tool.poetry.group.dev.dependencies]
daytona-sdk = "^0.14.0"

View File

@ -1,6 +1,6 @@
streamlit-quill==0.0.3
python-dotenv==1.0.1
litellm>=1.66.2
litellm==1.66.2
click==8.1.7
questionary==2.0.1
requests>=2.31.0
@ -30,4 +30,5 @@ nest-asyncio>=1.6.0
vncdotool>=1.2.0
pydantic
tavily-python>=0.5.4
pytesseract==0.3.13
pytesseract==0.3.13
stripe>=7.0.0

View File

@ -6,7 +6,7 @@ from fastapi.responses import Response, JSONResponse
from pydantic import BaseModel
from utils.logger import logger
from utils.auth_utils import get_current_user_id, get_user_id_from_stream_auth, get_optional_user_id
from utils.auth_utils import get_current_user_id_from_jwt, get_user_id_from_stream_auth, get_optional_user_id
from sandbox.sandbox import get_or_start_sandbox
from services.supabase import DBConnection
from agent.api import get_or_create_project_sandbox

748
backend/services/billing.py Normal file
View File

@ -0,0 +1,748 @@
"""
Stripe Billing API implementation for Suna on top of Basejump. ONLY HAS SUPPOT FOR USER ACCOUNTS no team accounts. As we are using the user_id as account_id as is the case with personal accounts. In personal accounts, the account_id equals the user_id. In team accounts, the account_id is unique.
stripe listen --forward-to localhost:8000/api/billing/webhook
"""
from fastapi import APIRouter, HTTPException, Depends, Request
from typing import Optional, Dict, Any, List, Tuple
import stripe
from datetime import datetime, timezone
from utils.logger import logger
from utils.config import config, EnvMode
from services.supabase import DBConnection
from utils.auth_utils import get_current_user_id_from_jwt
from pydantic import BaseModel, Field
# Initialize Stripe
stripe.api_key = config.STRIPE_SECRET_KEY
# Initialize router
router = APIRouter(prefix="/billing", tags=["billing"])
SUBSCRIPTION_TIERS = {
config.STRIPE_FREE_TIER_ID: {'name': 'free', 'minutes': 10},
config.STRIPE_TIER_2_20_ID: {'name': 'tier_2_20', 'minutes': 120}, # 2 hours
config.STRIPE_TIER_6_50_ID: {'name': 'tier_6_50', 'minutes': 360}, # 6 hours
config.STRIPE_TIER_12_100_ID: {'name': 'tier_12_100', 'minutes': 720}, # 12 hours
config.STRIPE_TIER_25_200_ID: {'name': 'tier_25_200', 'minutes': 1500}, # 25 hours
config.STRIPE_TIER_50_400_ID: {'name': 'tier_50_400', 'minutes': 3000}, # 50 hours
config.STRIPE_TIER_125_800_ID: {'name': 'tier_125_800', 'minutes': 7500}, # 125 hours
config.STRIPE_TIER_200_1000_ID: {'name': 'tier_200_1000', 'minutes': 12000}, # 200 hours
}
# Pydantic models for request/response validation
class CreateCheckoutSessionRequest(BaseModel):
price_id: str
success_url: str
cancel_url: str
class CreatePortalSessionRequest(BaseModel):
return_url: str
class SubscriptionStatus(BaseModel):
status: str # e.g., 'active', 'trialing', 'past_due', 'scheduled_downgrade', 'no_subscription'
plan_name: Optional[str] = None
price_id: Optional[str] = None # Added price ID
current_period_end: Optional[datetime] = None
cancel_at_period_end: bool = False
trial_end: Optional[datetime] = None
minutes_limit: Optional[int] = None
current_usage: Optional[float] = None
# Fields for scheduled changes
has_schedule: bool = False
scheduled_plan_name: Optional[str] = None
scheduled_price_id: Optional[str] = None # Added scheduled price ID
scheduled_change_date: Optional[datetime] = None
# Helper functions
async def get_stripe_customer_id(client, user_id: str) -> Optional[str]:
"""Get the Stripe customer ID for a user."""
result = await client.schema('basejump').from_('billing_customers') \
.select('id') \
.eq('account_id', user_id) \
.execute()
if result.data and len(result.data) > 0:
return result.data[0]['id']
return None
async def create_stripe_customer(client, user_id: str, email: str) -> str:
"""Create a new Stripe customer for a user."""
# Create customer in Stripe
customer = stripe.Customer.create(
email=email,
metadata={"user_id": user_id}
)
# Store customer ID in Supabase
await client.schema('basejump').from_('billing_customers').insert({
'id': customer.id,
'account_id': user_id,
'email': email,
'provider': 'stripe'
}).execute()
return customer.id
async def get_user_subscription(user_id: str) -> Optional[Dict]:
"""Get the current subscription for a user from Stripe."""
try:
# Get customer ID
db = DBConnection()
client = await db.client
customer_id = await get_stripe_customer_id(client, user_id)
if not customer_id:
return None
# Get all active subscriptions for the customer
subscriptions = stripe.Subscription.list(
customer=customer_id,
status='active'
)
# print("Found subscriptions:", subscriptions)
# Check if we have any subscriptions
if not subscriptions or not subscriptions.get('data'):
return None
# Filter subscriptions to only include our product's subscriptions
our_subscriptions = []
for sub in subscriptions['data']:
# Get the first subscription item
if sub.get('items') and sub['items'].get('data') and len(sub['items']['data']) > 0:
item = sub['items']['data'][0]
if item.get('price') and item['price'].get('id') in [
config.STRIPE_FREE_TIER_ID,
config.STRIPE_TIER_2_20_ID,
config.STRIPE_TIER_6_50_ID,
config.STRIPE_TIER_12_100_ID,
config.STRIPE_TIER_25_200_ID,
config.STRIPE_TIER_50_400_ID,
config.STRIPE_TIER_125_800_ID,
config.STRIPE_TIER_200_1000_ID
]:
our_subscriptions.append(sub)
if not our_subscriptions:
return None
# If there are multiple active subscriptions, we need to handle this
if len(our_subscriptions) > 1:
logger.warning(f"User {user_id} has multiple active subscriptions: {[sub['id'] for sub in our_subscriptions]}")
# Get the most recent subscription
most_recent = max(our_subscriptions, key=lambda x: x['created'])
# Cancel all other subscriptions
for sub in our_subscriptions:
if sub['id'] != most_recent['id']:
try:
stripe.Subscription.modify(
sub['id'],
cancel_at_period_end=True
)
logger.info(f"Cancelled subscription {sub['id']} for user {user_id}")
except Exception as e:
logger.error(f"Error cancelling subscription {sub['id']}: {str(e)}")
return most_recent
return our_subscriptions[0]
except Exception as e:
logger.error(f"Error getting subscription from Stripe: {str(e)}")
return None
async def calculate_monthly_usage(client, user_id: str) -> float:
"""Calculate total agent run minutes for the current month for a user."""
# Get start of current month in UTC
now = datetime.now(timezone.utc)
start_of_month = datetime(now.year, now.month, 1, tzinfo=timezone.utc)
# First get all threads for this user
threads_result = await client.table('threads') \
.select('thread_id') \
.eq('account_id', user_id) \
.execute()
if not threads_result.data:
return 0.0
thread_ids = [t['thread_id'] for t in threads_result.data]
# Then get all agent runs for these threads in current month
runs_result = await client.table('agent_runs') \
.select('started_at, completed_at') \
.in_('thread_id', thread_ids) \
.gte('started_at', start_of_month.isoformat()) \
.execute()
if not runs_result.data:
return 0.0
# Calculate total minutes
total_seconds = 0
now_ts = now.timestamp()
for run in runs_result.data:
start_time = datetime.fromisoformat(run['started_at'].replace('Z', '+00:00')).timestamp()
if run['completed_at']:
end_time = datetime.fromisoformat(run['completed_at'].replace('Z', '+00:00')).timestamp()
else:
# For running jobs, use current time
end_time = now_ts
total_seconds += (end_time - start_time)
return total_seconds / 60 # Convert to minutes
async def check_billing_status(client, user_id: str) -> Tuple[bool, str, Optional[Dict]]:
"""
Check if a user can run agents based on their subscription and usage.
Returns:
Tuple[bool, str, Optional[Dict]]: (can_run, message, subscription_info)
"""
if config.ENV_MODE == EnvMode.LOCAL:
logger.info("Running in local development mode - billing checks are disabled")
return True, "Local development mode - billing disabled", {
"price_id": "local_dev",
"plan_name": "Local Development",
"minutes_limit": "no limit"
}
# Get current subscription
subscription = await get_user_subscription(user_id)
# print("Current subscription:", subscription)
# If no subscription, they can use free tier
if not subscription:
subscription = {
'price_id': config.STRIPE_FREE_TIER_ID, # Free tier
'plan_name': 'free'
}
# Get tier info - default to free tier if not found
tier_info = SUBSCRIPTION_TIERS.get(subscription.get('price_id', config.STRIPE_FREE_TIER_ID))
if not tier_info:
logger.warning(f"Unknown subscription tier: {subscription.get('price_id')}, defaulting to free tier")
tier_info = SUBSCRIPTION_TIERS[config.STRIPE_FREE_TIER_ID]
# Calculate current month's usage
current_usage = await calculate_monthly_usage(client, user_id)
# Check if within limits
if current_usage >= tier_info['minutes']:
return False, f"Monthly limit of {tier_info['minutes']} minutes reached. Please upgrade your plan or wait until next month.", subscription
return True, "OK", subscription
# API endpoints
@router.post("/create-checkout-session")
async def create_checkout_session(
request: CreateCheckoutSessionRequest,
current_user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Create a Stripe Checkout session or modify an existing subscription."""
try:
# Get Supabase client
db = DBConnection()
client = await db.client
# Get user email from auth.users
user_result = await client.auth.admin.get_user_by_id(current_user_id)
if not user_result: raise HTTPException(status_code=404, detail="User not found")
email = user_result.user.email
# Get or create Stripe customer
customer_id = await get_stripe_customer_id(client, current_user_id)
if not customer_id: customer_id = await create_stripe_customer(client, current_user_id, email)
# Get the target price and product ID
try:
price = stripe.Price.retrieve(request.price_id, expand=['product'])
product_id = price['product']['id']
except stripe.error.InvalidRequestError:
raise HTTPException(status_code=400, detail=f"Invalid price ID: {request.price_id}")
# Verify the price belongs to our product
if product_id != config.STRIPE_PRODUCT_ID:
raise HTTPException(status_code=400, detail="Price ID does not belong to the correct product.")
# Check for existing subscription for our product
existing_subscription = await get_user_subscription(current_user_id)
# print("Existing subscription for product:", existing_subscription)
if existing_subscription:
# --- Handle Subscription Change (Upgrade or Downgrade) ---
try:
subscription_id = existing_subscription['id']
subscription_item = existing_subscription['items']['data'][0]
current_price_id = subscription_item['price']['id']
# Skip if already on this plan
if current_price_id == request.price_id:
return {
"subscription_id": subscription_id,
"status": "no_change",
"message": "Already subscribed to this plan.",
"details": {
"is_upgrade": None,
"effective_date": None,
"current_price": round(price['unit_amount'] / 100, 2) if price.get('unit_amount') else 0,
"new_price": round(price['unit_amount'] / 100, 2) if price.get('unit_amount') else 0,
}
}
# Get current and new price details
current_price = stripe.Price.retrieve(current_price_id)
new_price = price # Already retrieved
is_upgrade = new_price['unit_amount'] > current_price['unit_amount']
if is_upgrade:
# --- Handle Upgrade --- Immediate modification
updated_subscription = stripe.Subscription.modify(
subscription_id,
items=[{
'id': subscription_item['id'],
'price': request.price_id,
}],
proration_behavior='always_invoice', # Prorate and charge immediately
billing_cycle_anchor='now' # Reset billing cycle
)
latest_invoice = None
if updated_subscription.get('latest_invoice'):
latest_invoice = stripe.Invoice.retrieve(updated_subscription['latest_invoice'])
return {
"subscription_id": updated_subscription['id'],
"status": "updated",
"message": "Subscription upgraded successfully",
"details": {
"is_upgrade": True,
"effective_date": "immediate",
"current_price": round(current_price['unit_amount'] / 100, 2) if current_price.get('unit_amount') else 0,
"new_price": round(new_price['unit_amount'] / 100, 2) if new_price.get('unit_amount') else 0,
"invoice": {
"id": latest_invoice['id'] if latest_invoice else None,
"status": latest_invoice['status'] if latest_invoice else None,
"amount_due": round(latest_invoice['amount_due'] / 100, 2) if latest_invoice else 0,
"amount_paid": round(latest_invoice['amount_paid'] / 100, 2) if latest_invoice else 0
} if latest_invoice else None
}
}
else:
# --- Handle Downgrade --- Use Subscription Schedule
try:
current_period_end_ts = subscription_item['current_period_end']
# Retrieve the subscription again to get the schedule ID if it exists
# This ensures we have the latest state before creating/modifying schedule
sub_with_schedule = stripe.Subscription.retrieve(subscription_id)
schedule_id = sub_with_schedule.get('schedule')
# Get the current phase configuration from the schedule or subscription
if schedule_id:
schedule = stripe.SubscriptionSchedule.retrieve(schedule_id)
# Find the current phase in the schedule
# This logic assumes simple schedules; might need refinement for complex ones
current_phase = None
for phase in reversed(schedule['phases']):
if phase['start_date'] <= datetime.now(timezone.utc).timestamp():
current_phase = phase
break
if not current_phase: # Fallback if logic fails
current_phase = schedule['phases'][-1]
else:
# If no schedule, the current subscription state defines the current phase
current_phase = {
'items': existing_subscription['items']['data'], # Use original items data
'start_date': existing_subscription['current_period_start'], # Use sub start if no schedule
# Add other relevant fields if needed for create/modify
}
# Prepare the current phase data for the update/create
# Ensure items is formatted correctly for the API
current_phase_items_for_api = []
for item in current_phase.get('items', []):
price_data = item.get('price')
quantity = item.get('quantity')
price_id = None
# Safely extract price ID whether it's an object or just the ID string
if isinstance(price_data, dict):
price_id = price_data.get('id')
elif isinstance(price_data, str):
price_id = price_data
if price_id and quantity is not None:
current_phase_items_for_api.append({'price': price_id, 'quantity': quantity})
else:
logger.warning(f"Skipping item in current phase due to missing price ID or quantity: {item}")
if not current_phase_items_for_api:
raise ValueError("Could not determine valid items for the current phase.")
current_phase_update_data = {
'items': current_phase_items_for_api,
'start_date': current_phase['start_date'], # Preserve original start date
'end_date': current_period_end_ts, # End this phase at period end
'proration_behavior': 'none'
# Include other necessary fields from current_phase if modifying?
# e.g., 'billing_cycle_anchor', 'collection_method'? Usually inherited.
}
# Define the new (downgrade) phase
new_downgrade_phase_data = {
'items': [{'price': request.price_id, 'quantity': 1}],
'start_date': current_period_end_ts, # Start immediately after current phase ends
'proration_behavior': 'none'
# iterations defaults to 1, meaning it runs for one billing cycle
# then schedule ends based on end_behavior
}
# Update or Create Schedule
if schedule_id:
# Update existing schedule, replacing all future phases
# print(f"Updating existing schedule {schedule_id}")
logger.info(f"Updating existing schedule {schedule_id} for subscription {subscription_id}")
logger.debug(f"Current phase data: {current_phase_update_data}")
logger.debug(f"New phase data: {new_downgrade_phase_data}")
updated_schedule = stripe.SubscriptionSchedule.modify(
schedule_id,
phases=[current_phase_update_data, new_downgrade_phase_data],
end_behavior='release'
)
logger.info(f"Successfully updated schedule {updated_schedule['id']}")
else:
# Create a new schedule using the defined phases
print(f"Creating new schedule for subscription {subscription_id}")
logger.info(f"Creating new schedule for subscription {subscription_id}")
# Deep debug logging - write subscription details to help diagnose issues
logger.debug(f"Subscription details: {subscription_id}, current_period_end_ts: {current_period_end_ts}")
logger.debug(f"Current price: {current_price_id}, New price: {request.price_id}")
try:
updated_schedule = stripe.SubscriptionSchedule.create(
from_subscription=subscription_id,
phases=[
{
'start_date': current_phase['start_date'],
'end_date': current_period_end_ts,
'proration_behavior': 'none',
'items': [
{
'price': current_price_id,
'quantity': 1
}
]
},
{
'start_date': current_period_end_ts,
'proration_behavior': 'none',
'items': [
{
'price': request.price_id,
'quantity': 1
}
]
}
],
end_behavior='release'
)
# Don't try to link the schedule - that's handled by from_subscription
logger.info(f"Created new schedule {updated_schedule['id']} from subscription {subscription_id}")
# print(f"Created new schedule {updated_schedule['id']} from subscription {subscription_id}")
# Verify the schedule was created correctly
fetched_schedule = stripe.SubscriptionSchedule.retrieve(updated_schedule['id'])
logger.info(f"Schedule verification - Status: {fetched_schedule.get('status')}, Phase Count: {len(fetched_schedule.get('phases', []))}")
logger.debug(f"Schedule details: {fetched_schedule}")
except Exception as schedule_error:
logger.exception(f"Failed to create schedule: {str(schedule_error)}")
raise schedule_error # Re-raise to be caught by the outer try-except
return {
"subscription_id": subscription_id,
"schedule_id": updated_schedule['id'],
"status": "scheduled",
"message": "Subscription downgrade scheduled",
"details": {
"is_upgrade": False,
"effective_date": "end_of_period",
"current_price": round(current_price['unit_amount'] / 100, 2) if current_price.get('unit_amount') else 0,
"new_price": round(new_price['unit_amount'] / 100, 2) if new_price.get('unit_amount') else 0,
"effective_at": datetime.fromtimestamp(current_period_end_ts, tz=timezone.utc).isoformat()
}
}
except Exception as e:
logger.exception(f"Error handling subscription schedule for sub {subscription_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error handling subscription schedule: {str(e)}")
except Exception as e:
logger.exception(f"Error updating subscription {existing_subscription.get('id') if existing_subscription else 'N/A'}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error updating subscription: {str(e)}")
else:
# --- Create New Subscription via Checkout Session ---
session = stripe.checkout.Session.create(
customer=customer_id,
payment_method_types=['card'],
line_items=[{'price': request.price_id, 'quantity': 1}],
mode='subscription',
success_url=request.success_url,
cancel_url=request.cancel_url,
metadata={
'user_id': current_user_id,
'product_id': product_id
}
)
return {"session_id": session['id'], "url": session['url'], "status": "new"}
except Exception as e:
logger.exception(f"Error creating checkout session: {str(e)}")
# Check if it's a Stripe error with more details
if hasattr(e, 'json_body') and e.json_body and 'error' in e.json_body:
error_detail = e.json_body['error'].get('message', str(e))
else:
error_detail = str(e)
raise HTTPException(status_code=500, detail=f"Error creating checkout session: {error_detail}")
@router.post("/create-portal-session")
async def create_portal_session(
request: CreatePortalSessionRequest,
current_user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Create a Stripe Customer Portal session for subscription management."""
try:
# Get Supabase client
db = DBConnection()
client = await db.client
# Get customer ID
customer_id = await get_stripe_customer_id(client, current_user_id)
if not customer_id:
raise HTTPException(status_code=404, detail="No billing customer found")
# Ensure the portal configuration has subscription_update enabled
try:
# First, check if we have a configuration that already enables subscription update
configurations = stripe.billing_portal.Configuration.list(limit=100)
active_config = None
# Look for a configuration with subscription_update enabled
for config in configurations.get('data', []):
features = config.get('features', {})
subscription_update = features.get('subscription_update', {})
if subscription_update.get('enabled', False):
active_config = config
logger.info(f"Found existing portal configuration with subscription_update enabled: {config['id']}")
break
# If no config with subscription_update found, create one or update the active one
if not active_config:
# Find the active configuration or create a new one
if configurations.get('data', []):
default_config = configurations['data'][0]
logger.info(f"Updating default portal configuration: {default_config['id']} to enable subscription_update")
active_config = stripe.billing_portal.Configuration.update(
default_config['id'],
features={
'subscription_update': {
'enabled': True,
'proration_behavior': 'create_prorations',
'default_allowed_updates': ['price']
},
# Preserve other features that may already be enabled
'customer_update': default_config.get('features', {}).get('customer_update', {'enabled': True, 'allowed_updates': ['email', 'address']}),
'invoice_history': {'enabled': True},
'payment_method_update': {'enabled': True}
}
)
else:
# Create a new configuration with subscription_update enabled
logger.info("Creating new portal configuration with subscription_update enabled")
active_config = stripe.billing_portal.Configuration.create(
business_profile={
'headline': 'Subscription Management',
'privacy_policy_url': config.FRONTEND_URL + '/privacy',
'terms_of_service_url': config.FRONTEND_URL + '/terms'
},
features={
'subscription_update': {
'enabled': True,
'proration_behavior': 'create_prorations',
'default_allowed_updates': ['price']
},
'customer_update': {
'enabled': True,
'allowed_updates': ['email', 'address']
},
'invoice_history': {'enabled': True},
'payment_method_update': {'enabled': True}
}
)
# Log the active configuration for debugging
logger.info(f"Using portal configuration: {active_config['id']} with subscription_update: {active_config.get('features', {}).get('subscription_update', {}).get('enabled', False)}")
except Exception as config_error:
logger.warning(f"Error configuring portal: {config_error}. Continuing with default configuration.")
# Create portal session using the proper configuration if available
portal_params = {
"customer": customer_id,
"return_url": request.return_url
}
# Add configuration_id if we found or created one with subscription_update enabled
if active_config:
portal_params["configuration"] = active_config['id']
# Create the session
session = stripe.billing_portal.Session.create(**portal_params)
return {"url": session.url}
except Exception as e:
logger.error(f"Error creating portal session: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/subscription")
async def get_subscription(
current_user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Get the current subscription status for the current user, including scheduled changes."""
try:
# Get subscription from Stripe (this helper already handles filtering/cleanup)
subscription = await get_user_subscription(current_user_id)
# print("Subscription data for status:", subscription)
if not subscription:
# Default to free tier status if no active subscription for our product
free_tier_id = config.STRIPE_FREE_TIER_ID
free_tier_info = SUBSCRIPTION_TIERS.get(free_tier_id)
return SubscriptionStatus(
status="no_subscription",
plan_name=free_tier_info.get('name', 'free') if free_tier_info else 'free',
price_id=free_tier_id,
minutes_limit=free_tier_info.get('minutes') if free_tier_info else 0
)
# Extract current plan details
current_item = subscription['items']['data'][0]
current_price_id = current_item['price']['id']
current_tier_info = SUBSCRIPTION_TIERS.get(current_price_id)
if not current_tier_info:
# Fallback if somehow subscribed to an unknown price within our product
logger.warning(f"User {current_user_id} subscribed to unknown price {current_price_id}. Defaulting info.")
current_tier_info = {'name': 'unknown', 'minutes': 0}
# Calculate current usage
db = DBConnection()
client = await db.client
current_usage = await calculate_monthly_usage(client, current_user_id)
status_response = SubscriptionStatus(
status=subscription['status'], # 'active', 'trialing', etc.
plan_name=subscription['plan'].get('nickname') or current_tier_info['name'],
price_id=current_price_id,
current_period_end=datetime.fromtimestamp(current_item['current_period_end'], tz=timezone.utc),
cancel_at_period_end=subscription['cancel_at_period_end'],
trial_end=datetime.fromtimestamp(subscription['trial_end'], tz=timezone.utc) if subscription.get('trial_end') else None,
minutes_limit=current_tier_info['minutes'],
current_usage=round(current_usage, 2),
has_schedule=False # Default
)
# Check for an attached schedule (indicates pending downgrade)
schedule_id = subscription.get('schedule')
if schedule_id:
try:
schedule = stripe.SubscriptionSchedule.retrieve(schedule_id)
# Find the *next* phase after the current one
next_phase = None
current_phase_end = current_item['current_period_end']
for phase in schedule.get('phases', []):
# Check if this phase starts exactly when the current one ends
if phase.get('start_date') == current_phase_end:
next_phase = phase
break # Found the immediate next phase
if next_phase:
scheduled_item = next_phase['items'][0] # Assuming single item
scheduled_price_id = scheduled_item['price'] # Price ID might be string here
scheduled_tier_info = SUBSCRIPTION_TIERS.get(scheduled_price_id)
status_response.has_schedule = True
status_response.status = 'scheduled_downgrade' # Override status
status_response.scheduled_plan_name = scheduled_tier_info.get('name', 'unknown') if scheduled_tier_info else 'unknown'
status_response.scheduled_price_id = scheduled_price_id
status_response.scheduled_change_date = datetime.fromtimestamp(next_phase['start_date'], tz=timezone.utc)
except Exception as schedule_error:
logger.error(f"Error retrieving or parsing schedule {schedule_id} for sub {subscription['id']}: {schedule_error}")
# Proceed without schedule info if retrieval fails
return status_response
except Exception as e:
logger.exception(f"Error getting subscription status for user {current_user_id}: {str(e)}") # Use logger.exception
raise HTTPException(status_code=500, detail="Error retrieving subscription status.")
@router.get("/check-status")
async def check_status(
current_user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Check if the user can run agents based on their subscription and usage."""
try:
# Get Supabase client
db = DBConnection()
client = await db.client
can_run, message, subscription = await check_billing_status(client, current_user_id)
return {
"can_run": can_run,
"message": message,
"subscription": subscription
}
except Exception as e:
logger.error(f"Error checking billing status: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/webhook")
async def stripe_webhook(request: Request):
"""Handle Stripe webhook events."""
try:
# Get the webhook secret from config
webhook_secret = config.STRIPE_WEBHOOK_SECRET
# Get the webhook payload
payload = await request.body()
sig_header = request.headers.get('stripe-signature')
# Verify webhook signature
try:
event = stripe.Webhook.construct_event(
payload, sig_header, webhook_secret
)
except ValueError as e:
raise HTTPException(status_code=400, detail="Invalid payload")
except stripe.error.SignatureVerificationError as e:
raise HTTPException(status_code=400, detail="Invalid signature")
# Handle the event
if event.type in ['customer.subscription.created', 'customer.subscription.updated', 'customer.subscription.deleted']:
# We don't need to do anything here as we'll query Stripe directly
pass
return {"status": "success"}
except Exception as e:
logger.error(f"Error processing webhook: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))

View File

@ -1,5 +0,0 @@
{
"imports": {
"@supabase/supabase-js": "https://esm.sh/@supabase/supabase-js"
}
}

View File

@ -1,284 +0,0 @@
import {serve} from "https://deno.land/std@0.168.0/http/server.ts";
import {stripeFunctionHandler} from "https://deno.land/x/basejump@v2.0.3/billing-functions/mod.ts";
import { requireAuthorizedBillingUser } from "https://deno.land/x/basejump@v2.0.3/billing-functions/src/require-authorized-billing-user.ts";
import getBillingStatus from "https://deno.land/x/basejump@v2.0.3/billing-functions/src/wrappers/get-billing-status.ts";
import createSupabaseServiceClient from "https://deno.land/x/basejump@v2.0.3/billing-functions/lib/create-supabase-service-client.ts";
import validateUrl from "https://deno.land/x/basejump@v2.0.3/billing-functions/lib/validate-url.ts";
import Stripe from "https://esm.sh/stripe@11.1.0?target=deno";
console.log("Starting billing functions...");
const defaultAllowedHost = Deno.env.get("ALLOWED_HOST") || "http://localhost:3000";
const allowedHosts = [defaultAllowedHost, "https://www.suna.so", "https://suna.so", "https://staging.suna.so"];
console.log("Default allowed host:", defaultAllowedHost);
export const corsHeaders = {
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers":
"authorization, x-client-info, apikey, content-type",
};
console.log("Initializing Stripe client...");
const stripeClient = new Stripe(Deno.env.get("STRIPE_API_KEY") as string, {
// This is needed to use the Fetch API rather than relying on the Node http
// package.
apiVersion: "2022-11-15",
httpClient: Stripe.createFetchHttpClient(),
});
console.log("Stripe client initialized");
console.log("Setting up stripe handler...");
const stripeHandler = stripeFunctionHandler({
stripeClient,
defaultPlanId: Deno.env.get("STRIPE_DEFAULT_PLAN_ID") as string,
defaultTrialDays: Deno.env.get("STRIPE_DEFAULT_TRIAL_DAYS") ? Number(Deno.env.get("STRIPE_DEFAULT_TRIAL_DAYS")) : undefined
});
console.log("Stripe handler configured");
serve(async (req) => {
console.log("Received request:", req.method, req.url);
if (req.method === "OPTIONS") {
console.log("Handling OPTIONS request");
return new Response("ok", {headers: corsHeaders});
}
try {
const body = await req.json();
console.log("Request body:", body);
if (!body.args?.account_id) {
console.log("Missing account_id in request");
return new Response(
JSON.stringify({ error: "Account id is required" }),
{
status: 400,
headers: {
...corsHeaders,
"Content-Type": "application/json"
}
}
);
}
switch (body.action) {
case "get_plans":
console.log("Getting plans");
try {
const plans = await stripeHandler.getPlans(body.args);
console.log("Plans retrieved:", plans.length);
return new Response(
JSON.stringify(plans),
{
headers: {
...corsHeaders,
"Content-Type": "application/json",
},
}
);
} catch (e) {
console.log("Error getting plans:", e);
return new Response(
JSON.stringify({ error: "Failed to get plans" }),
{
status: 500,
headers: {
...corsHeaders,
"Content-Type": "application/json"
}
}
);
}
case "get_billing_portal_url":
console.log("Getting billing portal URL for account:", body.args.account_id);
if (!validateUrl(body.args.return_url, allowedHosts)) {
console.log("Invalid return URL:", body.args.return_url);
return new Response(
JSON.stringify({ error: "Return url is not allowed" }),
{
status: 400,
headers: {
...corsHeaders,
"Content-Type": "application/json"
}
}
);
}
return await requireAuthorizedBillingUser(req, {
accountId: body.args.account_id,
authorizedRoles: ["owner"],
async onBillableAndAuthorized(roleInfo) {
console.log("User authorized for billing portal, role info:", roleInfo);
try {
const response = await stripeHandler.getBillingPortalUrl({
accountId: roleInfo.account_id,
subscriptionId: roleInfo.billing_subscription_id,
customerId: roleInfo.billing_customer_id,
returnUrl: body.args.return_url,
});
console.log("Billing portal URL generated");
return new Response(
JSON.stringify({
billing_enabled: roleInfo.billing_enabled,
...response,
}),
{
headers: {
...corsHeaders,
"Content-Type": "application/json",
},
}
);
} catch (e) {
console.log("Error getting billing portal URL:", e);
return new Response(
JSON.stringify({ error: "Failed to generate billing portal URL" }),
{
status: 500,
headers: {
...corsHeaders,
"Content-Type": "application/json"
}
}
);
}
},
});
case "get_new_subscription_url":
console.log("Getting new subscription URL for account:", body.args.account_id);
if (!validateUrl(body.args.success_url, allowedHosts) || !validateUrl(body.args.cancel_url, allowedHosts)) {
console.log("Invalid success or cancel URL:", body.args.success_url, body.args.cancel_url);
return new Response(
JSON.stringify({ error: "Success or cancel url is not allowed" }),
{
status: 400,
headers: {
...corsHeaders,
"Content-Type": "application/json"
}
}
);
}
return await requireAuthorizedBillingUser(req, {
accountId: body.args.account_id,
authorizedRoles: ["owner"],
async onBillableAndAuthorized(roleInfo) {
console.log("User authorized for new subscription, role info:", roleInfo);
try {
const response = await stripeHandler.getNewSubscriptionUrl({
accountId: roleInfo.account_id,
planId: body.args.plan_id,
successUrl: body.args.success_url,
cancelUrl: body.args.cancel_url,
billingEmail: roleInfo.billing_email,
customerId: roleInfo.billing_customer_id,
});
console.log("New subscription URL generated");
return new Response(
JSON.stringify({
billing_enabled: roleInfo.billing_enabled,
...response,
}),
{
headers: {
...corsHeaders,
"Content-Type": "application/json",
},
}
);
} catch (e) {
console.log("Error getting new subscription URL:", e);
return new Response(
JSON.stringify({ error: "Failed to generate new subscription URL" }),
{
status: 500,
headers: {
...corsHeaders,
"Content-Type": "application/json"
}
}
);
}
},
});
case "get_billing_status":
console.log("Getting billing status for account:", body.args.account_id);
return await requireAuthorizedBillingUser(req, {
accountId: body.args.account_id,
authorizedRoles: ["owner"],
async onBillableAndAuthorized(roleInfo) {
console.log("User authorized, role info:", roleInfo);
const supabaseClient = createSupabaseServiceClient();
console.log("Getting billing status...");
try {
const response = await getBillingStatus(
supabaseClient,
roleInfo,
stripeHandler
);
console.log("Billing status response:", response);
return new Response(
JSON.stringify({
...response,
status: response.status || "not_setup",
billing_enabled: roleInfo.billing_enabled,
}),
{
headers: {
...corsHeaders,
"Content-Type": "application/json",
},
}
);
} catch (e) {
console.log("Error getting billing status:", e);
return new Response(
JSON.stringify({ error: "Internal server error" }),
{
status: 500,
headers: {
...corsHeaders,
"Content-Type": "application/json"
}
}
);
}
},
});
default:
console.log("Invalid action requested:", body.action);
return new Response(
JSON.stringify({ error: "Invalid action" }),
{
status: 400,
headers: {
...corsHeaders,
"Content-Type": "application/json"
}
}
);
}
} catch (e) {
console.log("Error processing request:", e);
return new Response(
JSON.stringify({ error: "Internal server error" }),
{
status: 500,
headers: {
...corsHeaders,
"Content-Type": "application/json"
}
}
);
}
});

View File

@ -1,5 +0,0 @@
{
"imports": {
"@supabase/supabase-js": "https://esm.sh/@supabase/supabase-js"
}
}

View File

@ -1,24 +0,0 @@
import {serve} from "https://deno.land/std@0.168.0/http/server.ts";
import {billingWebhooksWrapper, stripeWebhookHandler} from "https://deno.land/x/basejump@v2.0.3/billing-functions/mod.ts";
import Stripe from "https://esm.sh/stripe@11.1.0?target=deno";
const stripeClient = new Stripe(Deno.env.get("STRIPE_API_KEY") as string, {
// This is needed to use the Fetch API rather than relying on the Node http
// package.
apiVersion: "2022-11-15",
httpClient: Stripe.createFetchHttpClient(),
});
const stripeResponse = stripeWebhookHandler({
stripeClient,
stripeWebhookSigningSecret: Deno.env.get("STRIPE_WEBHOOK_SIGNING_SECRET") as string,
});
const webhookEndpoint = billingWebhooksWrapper(stripeResponse);
serve(async (req) => {
const response = await webhookEndpoint(req);
return response;
});

View File

@ -1,133 +0,0 @@
"""
Tests for direct tool execution in AgentPress.
This module tests the performance difference between sequential and parallel
tool execution strategies by directly calling the execution methods without thread overhead.
"""
import os
import asyncio
import sys
from dotenv import load_dotenv
from agentpress.thread_manager import ThreadManager
from agentpress.response_processor import ProcessorConfig
from agent.tools.wait_tool import WaitTool
from agentpress.tool import ToolResult
# Load environment variables
load_dotenv()
async def test_direct_execution():
"""Directly test sequential vs parallel execution without thread overhead."""
print("\n" + "="*80)
print("🧪 TESTING DIRECT TOOL EXECUTION: PARALLEL VS SEQUENTIAL")
print("="*80 + "\n")
# Initialize ThreadManager and register tools
thread_manager = ThreadManager()
thread_manager.add_tool(WaitTool)
# Create wait tool calls
wait_tool_calls = [
{"name": "wait", "arguments": {"seconds": 2, "message": "Wait tool 1"}},
{"name": "wait", "arguments": {"seconds": 2, "message": "Wait tool 2"}},
{"name": "wait", "arguments": {"seconds": 2, "message": "Wait tool 3"}}
]
# Expected values for validation
expected_tool_count = len(wait_tool_calls)
# Test sequential execution
print("🔄 Testing Sequential Execution")
print("-"*60)
sequential_start = asyncio.get_event_loop().time()
sequential_results = await thread_manager.response_processor._execute_tools(
wait_tool_calls,
execution_strategy="sequential"
)
sequential_end = asyncio.get_event_loop().time()
sequential_time = sequential_end - sequential_start
print(f"Sequential execution completed in {sequential_time:.2f} seconds")
# Validate sequential results - results are a list of tuples (tool_call, tool_result)
assert len(sequential_results) == expected_tool_count, f"❌ Expected {expected_tool_count} tool results, got {len(sequential_results)} in sequential execution"
assert all(isinstance(result_tuple, tuple) and len(result_tuple) == 2 for result_tuple in sequential_results), "❌ Not all sequential results are tuples of (tool_call, result)"
assert all(isinstance(result_tuple[1], ToolResult) for result_tuple in sequential_results), "❌ Not all sequential result values are ToolResult instances"
assert all(result_tuple[1].success for result_tuple in sequential_results), "❌ Not all sequential tool executions were successful"
print("✅ PASS: Sequential execution completed all tool calls successfully")
print()
# Test parallel execution
print("⚡ Testing Parallel Execution")
print("-"*60)
parallel_start = asyncio.get_event_loop().time()
parallel_results = await thread_manager.response_processor._execute_tools(
wait_tool_calls,
execution_strategy="parallel"
)
parallel_end = asyncio.get_event_loop().time()
parallel_time = parallel_end - parallel_start
print(f"Parallel execution completed in {parallel_time:.2f} seconds")
# Validate parallel results - results are a list of tuples (tool_call, tool_result)
assert len(parallel_results) == expected_tool_count, f"❌ Expected {expected_tool_count} tool results, got {len(parallel_results)} in parallel execution"
assert all(isinstance(result_tuple, tuple) and len(result_tuple) == 2 for result_tuple in parallel_results), "❌ Not all parallel results are tuples of (tool_call, result)"
assert all(isinstance(result_tuple[1], ToolResult) for result_tuple in parallel_results), "❌ Not all parallel result values are ToolResult instances"
assert all(result_tuple[1].success for result_tuple in parallel_results), "❌ Not all parallel tool executions were successful"
print("✅ PASS: Parallel execution completed all tool calls successfully")
print()
# Report results
print("\n" + "="*80)
print(f"🧮 RESULTS SUMMARY")
print("="*80)
print(f"Sequential: {sequential_time:.2f} seconds")
print(f"Parallel: {parallel_time:.2f} seconds")
# Calculate and validate speedup
speedup = sequential_time / parallel_time if parallel_time > 0 else 0
print(f"Speedup: {speedup:.2f}x faster")
# Validate speedup is significant (at least 1.5x faster)
min_expected_speedup = 1.5
assert speedup >= min_expected_speedup, f"❌ Expected parallel execution to be at least {min_expected_speedup}x faster than sequential, but got {speedup:.2f}x"
print(f"✅ PASS: Parallel execution is {speedup:.2f}x faster than sequential as expected")
# Ideal speedup should be close to the number of tools (3x)
# But allow for some overhead (at least 1.5x)
theoretical_max_speedup = len(wait_tool_calls)
print(f"Note: Theoretical max speedup: {theoretical_max_speedup:.1f}x")
print("\n" + "="*80)
print("✅ ALL TESTS PASSED")
print("="*80)
# Return results for potential further analysis
return {
"sequential": {
"time": sequential_time,
"results": sequential_results
},
"parallel": {
"time": parallel_time,
"results": parallel_results
},
"speedup": speedup
}
if __name__ == "__main__":
try:
asyncio.run(test_direct_execution())
print("\n✅ Test completed successfully")
sys.exit(0)
except AssertionError as e:
print(f"\n\n❌ Test failed: {str(e)}")
sys.exit(1)
except KeyboardInterrupt:
print("\n\n❌ Test interrupted by user")
sys.exit(1)
except Exception as e:
print(f"\n\n❌ Error during test: {str(e)}")
sys.exit(1)

View File

@ -1,193 +0,0 @@
"""
Raw streaming test to analyze tool call streaming behavior.
This script specifically tests how raw streaming chunks are delivered from the Anthropic API
with tool calls containing large JSON payloads.
"""
import asyncio
import json
import sys
import os
from typing import Dict, Any
from anthropic import AsyncAnthropic
from utils.logger import logger
# Example tool schema for Anthropic format
CREATE_FILE_TOOL = {
"name": "create_file",
"description": "Create a new file with the provided contents at a given path in the workspace",
"input_schema": {
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "Path to the file to be created"
},
"file_contents": {
"type": "string",
"description": "The content to write to the file"
}
},
"required": ["file_path", "file_contents"]
}
}
async def test_raw_streaming():
"""Test tool calling with streaming to observe raw chunk behavior using Anthropic SDK directly."""
# Setup conversation with a prompt likely to generate large file payloads
messages = [
{"role": "user", "content": "Create a CSS file with a comprehensive set of styles for a modern responsive website."}
]
print("\n=== Testing Raw Streaming Tool Call Behavior ===\n")
try:
# Get API key from environment
api_key = os.environ.get("ANTHROPIC_API_KEY")
if not api_key:
logger.error("ANTHROPIC_API_KEY environment variable not set")
return
# Initialize Anthropic client
client = AsyncAnthropic(api_key=api_key)
# Make API call with tool in streaming mode
print("Sending streaming request...")
stream = await client.messages.create(
model="claude-3-5-sonnet-latest",
max_tokens=4096,
temperature=0.0,
system="You are a helpful assistant with access to file management tools.",
messages=messages,
tools=[CREATE_FILE_TOOL],
tool_choice={"type": "tool", "name": "create_file"},
stream=True
)
# Process streaming response
print("\nResponse stream started. Processing raw chunks:\n")
# Stream statistics
chunk_count = 0
tool_call_chunks = 0
accumulated_tool_input = ""
current_tool_name = None
accumulated_content = ""
# Process each chunk with ZERO buffering
print("\n--- BEGINNING STREAM OUTPUT ---\n", flush=True)
sys.stdout.flush()
# Process each event in the stream
async for event in stream:
chunk_count += 1
# Immediate debug output for every chunk
print(f"\n[CHUNK {chunk_count}] Type: {event.type}", end="", flush=True)
sys.stdout.flush()
# Process based on event type
if event.type == "message_start":
print(f" Message ID: {event.message.id}", end="", flush=True)
elif event.type == "content_block_start":
print(f" Content block start: {event.content_block.type}", end="", flush=True)
elif event.type == "content_block_delta":
if hasattr(event.delta, "text") and event.delta.text:
text = event.delta.text
accumulated_content += text
print(f" Content: {repr(text)}", end="", flush=True)
elif event.type == "tool_use":
current_tool_name = event.tool_use.name
print(f" Tool use: {current_tool_name}", end="", flush=True)
# If input is available immediately
if hasattr(event.tool_use, "input") and event.tool_use.input:
tool_call_chunks += 1
input_json = json.dumps(event.tool_use.input)
input_len = len(input_json)
print(f" Input[{input_len}]: {input_json[:50]}...", end="", flush=True)
accumulated_tool_input = input_json
elif event.type == "tool_use_delta":
if hasattr(event.delta, "input") and event.delta.input:
tool_call_chunks += 1
# For streaming tool inputs, we get partial updates
# The delta.input is a dictionary with partial updates to specific fields
input_json = json.dumps(event.delta.input)
input_len = len(input_json)
print(f" Input delta[{input_len}]: {input_json[:50]}...", end="", flush=True)
# Try to merge the deltas
try:
if accumulated_tool_input:
# Parse existing accumulated JSON
existing_input = json.loads(accumulated_tool_input)
# Update with new delta
existing_input.update(event.delta.input)
accumulated_tool_input = json.dumps(existing_input)
else:
accumulated_tool_input = input_json
except json.JSONDecodeError:
# If we can't parse JSON yet, just append the raw delta
accumulated_tool_input += input_json
elif event.type == "message_delta":
if hasattr(event.delta, "stop_reason") and event.delta.stop_reason:
print(f"\n--- FINISH REASON: {event.delta.stop_reason} ---", flush=True)
elif event.type == "message_stop":
# Access stop_reason directly from the event
if hasattr(event, "stop_reason"):
print(f"\n--- MESSAGE STOP: {event.stop_reason} ---", flush=True)
else:
print("\n--- MESSAGE STOP ---", flush=True)
# Force flush after every chunk
sys.stdout.flush()
print("\n\n--- END STREAM OUTPUT ---\n", flush=True)
sys.stdout.flush()
# Summary after all chunks processed
print("\n=== Streaming Summary ===")
print(f"Total chunks: {chunk_count}")
print(f"Tool call chunks: {tool_call_chunks}")
if current_tool_name:
print(f"\nTool name: {current_tool_name}")
if accumulated_content:
print(f"\nAccumulated content:")
print(accumulated_content)
# Try to parse accumulated arguments as JSON
try:
if accumulated_tool_input:
print(f"\nTotal accumulated tool input length: {len(accumulated_tool_input)}")
input_obj = json.loads(accumulated_tool_input)
print(f"\nSuccessfully parsed accumulated tool input as JSON")
if 'file_path' in input_obj:
print(f"file_path: {input_obj['file_path']}")
if 'file_contents' in input_obj:
contents = input_obj['file_contents']
print(f"file_contents length: {len(contents)}")
print(f"file_contents preview: {contents[:100]}...")
except json.JSONDecodeError as e:
print(f"\nError parsing accumulated tool input: {e}")
print(f"Tool input start: {accumulated_tool_input[:100]}...")
print(f"Tool input end: {accumulated_tool_input[-100:]}")
except Exception as e:
logger.error(f"Error in streaming test: {str(e)}", exc_info=True)
async def main():
"""Run the raw streaming test."""
await test_raw_streaming()
if __name__ == "__main__":
asyncio.run(main())

View File

@ -1,236 +0,0 @@
"""
Simple test script for LLM API with tool calling functionality.
This script tests basic tool calling with both streaming and non-streaming to verify functionality.
"""
import asyncio
import json
from typing import Dict, Any
from services.llm import make_llm_api_call
from utils.logger import logger
# Example tool schema from files_tool.py
CREATE_FILE_SCHEMA = {
"type": "function",
"function": {
"name": "create_file",
"description": "Create a new file with the provided contents at a given path in the workspace",
"parameters": {
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "Path to the file to be created"
},
"file_contents": {
"type": "string",
"description": "The content to write to the file"
}
},
"required": ["file_path", "file_contents"]
}
}
}
async def test_simple_tool_call():
"""Test a simple non-streaming tool call to verify functionality."""
# Setup conversation
messages = [
{"role": "system", "content": "You are a helpful assistant with access to file management tools."},
{"role": "user", "content": "Create an HTML file named hello.html with a simple Hello World message."}
]
print("\n=== Testing non-streaming tool call ===\n")
try:
# Make API call with tool
response = await make_llm_api_call(
messages=messages,
model_name="gpt-4o",
temperature=0.0,
tools=[CREATE_FILE_SCHEMA],
tool_choice={"type": "function", "function": {"name": "create_file"}}
)
# Print basic response info
print(f"Response model: {response.model}")
print(f"Response type: {type(response)}")
# Check if the response has tool calls
assistant_message = response.choices[0].message
print(f"\nAssistant message content: {assistant_message.content}")
if hasattr(assistant_message, 'tool_calls') and assistant_message.tool_calls:
print("\nTool calls detected:")
for i, tool_call in enumerate(assistant_message.tool_calls):
print(f"\nTool call {i+1}:")
print(f" ID: {tool_call.id}")
print(f" Type: {tool_call.type}")
print(f" Function: {tool_call.function.name}")
print(f" Arguments:")
try:
args = json.loads(tool_call.function.arguments)
print(json.dumps(args, indent=4))
# Access and print specific arguments
if tool_call.function.name == "create_file":
print(f"\nFile path: {args.get('file_path')}")
print(f"File contents length: {len(args.get('file_contents', ''))}")
print(f"File contents preview: {args.get('file_contents', '')[:100]}...")
except Exception as e:
print(f"Error parsing arguments: {e}")
else:
print("\nNo tool calls found in the response.")
print(f"Full response: {response}")
except Exception as e:
logger.error(f"Error in test: {str(e)}", exc_info=True)
async def test_streaming_tool_call():
"""Test tool calling with streaming to observe behavior."""
# Setup conversation
messages = [
{"role": "system", "content": "You are a helpful assistant with access to file management tools. YOU ALWAYS USE MULTIPLE TOOL FUNCTION CALLS AT ONCE. YOU NEVER USE ONE TOOL FUNCTION CALL AT A TIME."},
{"role": "user", "content": "Create 10 random files with different extensions and content."}
]
print("\n=== Testing streaming tool call ===\n")
try:
# Make API call with tool in streaming mode
print("Sending streaming request...")
stream_response = await make_llm_api_call(
messages=messages,
model_name="anthropic/claude-3-5-sonnet-latest",
temperature=0.0,
tools=[CREATE_FILE_SCHEMA],
tool_choice="auto",
stream=True
)
# Process streaming response
print("\nResponse stream started. Processing chunks:\n")
# Stream statistics
chunk_count = 0
content_chunks = 0
tool_call_chunks = 0
accumulated_content = ""
# Storage for accumulated tool calls
tool_calls = []
last_chunk = None # Variable to store the last chunk
# Process each chunk
async for chunk in stream_response:
chunk_count += 1
last_chunk = chunk # Keep track of the last chunk
# Print chunk number and type
print(f"\n--- Chunk {chunk_count} ---")
print(f"Chunk type: {type(chunk)}")
if not hasattr(chunk, 'choices') or not chunk.choices:
print("No choices in chunk")
continue
delta = chunk.choices[0].delta
# Process content if present
if hasattr(delta, 'content') and delta.content is not None:
content_chunks += 1
accumulated_content += delta.content
print(f"Content: {delta.content}")
# Look for tool calls
if hasattr(delta, 'tool_calls') and delta.tool_calls:
tool_call_chunks += 1
print("Tool call detected in chunk!")
for tool_call in delta.tool_calls:
print(f"Tool call: {tool_call.model_dump()}")
# Track tool call parts
tool_call_index = tool_call.index if hasattr(tool_call, 'index') else 0
# Initialize tool call if new
while len(tool_calls) <= tool_call_index:
tool_calls.append({
"id": "",
"type": "function",
"function": {"name": "", "arguments": ""}
})
# Update tool call ID if present
if hasattr(tool_call, 'id') and tool_call.id:
tool_calls[tool_call_index]["id"] = tool_call.id
# Update function name if present
if hasattr(tool_call, 'function'):
if hasattr(tool_call.function, 'name') and tool_call.function.name:
tool_calls[tool_call_index]["function"]["name"] = tool_call.function.name
# Update function arguments if present
if hasattr(tool_call.function, 'arguments') and tool_call.function.arguments:
tool_calls[tool_call_index]["function"]["arguments"] += tool_call.function.arguments
# Summary after all chunks processed
print("\n=== Streaming Summary ===")
print(f"Total chunks: {chunk_count}")
print(f"Content chunks: {content_chunks}")
print(f"Tool call chunks: {tool_call_chunks}")
if accumulated_content:
print(f"\nAccumulated content: {accumulated_content}")
if tool_calls:
print("\nAccumulated tool calls:")
for i, tool_call in enumerate(tool_calls):
print(f"\nTool call {i+1}:")
print(f" ID: {tool_call['id']}")
print(f" Type: {tool_call['type']}")
print(f" Function: {tool_call['function']['name']}")
print(f" Arguments: {tool_call['function']['arguments']}")
# Try to parse arguments
try:
args = json.loads(tool_call['function']['arguments'])
print("\nParsed arguments:")
print(json.dumps(args, indent=4))
except Exception as e:
print(f"Error parsing arguments: {str(e)}")
else:
print("\nNo tool calls accumulated from streaming response.")
# --- Added logging for last chunk and finish reason ---
finish_reason = None
if last_chunk:
try:
if hasattr(last_chunk, 'choices') and last_chunk.choices:
finish_reason = last_chunk.choices[0].finish_reason
last_chunk_data = last_chunk.model_dump() if hasattr(last_chunk, 'model_dump') else vars(last_chunk)
print("\n--- Last Chunk Received ---")
print(f"Finish Reason: {finish_reason}")
print(f"Raw Last Chunk Data: {json.dumps(last_chunk_data, indent=2)}")
except Exception as log_ex:
print("\n--- Error logging last chunk ---")
print(f"Error: {log_ex}")
print(f"Last Chunk (repr): {repr(last_chunk)}")
else:
print("\n--- No last chunk recorded ---")
# --- End added logging ---
except Exception as e:
logger.error(f"Error in streaming test: {str(e)}", exc_info=True)
async def main():
"""Run both tests for comparison."""
# await test_simple_tool_call()
await test_streaming_tool_call()
if __name__ == "__main__":
asyncio.run(main())

View File

@ -1,237 +0,0 @@
"""
Tests for tool execution strategies in AgentPress.
This module tests both sequential and parallel execution strategies using the WaitTool
in a realistic thread with XML tool calls.
"""
import os
import asyncio
import sys
from unittest.mock import AsyncMock, patch
from dotenv import load_dotenv
from agentpress.thread_manager import ThreadManager
from agentpress.response_processor import ProcessorConfig
from agent.tools.wait_tool import WaitTool
# Load environment variables
load_dotenv()
TOOL_XML_SEQUENTIAL = """
Here are some examples of using the wait tool:
<wait seconds="2">This is sequential wait 1</wait>
<wait seconds="2">This is sequential wait 2</wait>
<wait seconds="2">This is sequential wait 3</wait>
Now wait sequence:
<wait-sequence count="3" seconds="1" label="Sequential Test" />
"""
TOOL_XML_PARALLEL = """
Here are some examples of using the wait tool:
<wait seconds="2">This is parallel wait 1</wait>
<wait seconds="2">This is parallel wait 2</wait>
<wait seconds="2">This is parallel wait 3</wait>
Now wait sequence:
<wait-sequence count="3" seconds="1" label="Parallel Test" />
"""
# Create a simple mock function that logs instead of accessing the database
async def mock_add_message(thread_id, message):
print(f"MOCK: Adding message to thread {thread_id}")
print(f"MOCK: Message role: {message.get('role')}")
print(f"MOCK: Content length: {len(message.get('content', ''))}")
return {"id": "mock-message-id", "thread_id": thread_id}
async def test_execution_strategies():
"""Test both sequential and parallel execution strategies in a thread."""
print("\n" + "="*80)
print("🧪 TESTING TOOL EXECUTION STRATEGIES")
print("="*80 + "\n")
# Initialize ThreadManager and register tools
thread_manager = ThreadManager()
thread_manager.add_tool(WaitTool)
# Mock both ThreadManager's and ResponseProcessor's add_message method
thread_manager.add_message = AsyncMock(side_effect=mock_add_message)
# This is crucial - the ResponseProcessor receives add_message as a callback
thread_manager.response_processor.add_message = AsyncMock(side_effect=mock_add_message)
# Create a test thread - we'll use a dummy ID since we're mocking the database
thread_id = "test-thread-id"
print(f"🧵 Using test thread: {thread_id}\n")
# Set up the get_llm_messages mock
original_get_llm_messages = thread_manager.get_llm_messages
thread_manager.get_llm_messages = AsyncMock()
# Test both strategies
test_cases = [
{"name": "Sequential", "strategy": "sequential", "content": TOOL_XML_SEQUENTIAL},
{"name": "Parallel", "strategy": "parallel", "content": TOOL_XML_PARALLEL}
]
# Expected values for validation - this varies based on XML parsing
# For reliable testing, we look at <wait> tags which we know are being parsed
expected_wait_count = 3 # 3 wait tags per test
test_results = {}
for test in test_cases:
print("\n" + "-"*60)
print(f"🔍 Testing {test['name']} Execution Strategy")
print("-"*60 + "\n")
# Setup mock for get_llm_messages to return our test content
thread_manager.get_llm_messages.return_value = [
{
"role": "system",
"content": "You are a testing assistant that will execute wait commands."
},
{
"role": "assistant",
"content": test["content"]
}
]
# Simulate adding message (mocked)
print(f"MOCK: Adding test message with {test['name']} execution strategy content")
await thread_manager.add_message(
thread_id=thread_id,
type="assistant",
content={
"role": "assistant",
"content": test["content"]
},
is_llm_message=True
)
start_time = asyncio.get_event_loop().time()
print(f"⏱️ Starting execution with {test['strategy']} strategy at {start_time:.2f}s")
# Process the response with appropriate strategy
config = ProcessorConfig(
xml_tool_calling=True,
native_tool_calling=False,
execute_tools=True,
execute_on_stream=False,
tool_execution_strategy=test["strategy"]
)
# Get the last message to process (mocked)
messages = await thread_manager.get_llm_messages(thread_id)
last_message = messages[-1]
# Create a simple non-streaming response object
class MockResponse:
def __init__(self, content):
self.choices = [type('obj', (object,), {
'message': type('obj', (object,), {
'content': content
})
})]
mock_response = MockResponse(last_message["content"])
# Process using the response processor
tool_execution_count = 0
wait_tool_count = 0
tool_results = []
async for chunk in thread_manager.response_processor.process_non_streaming_response(
llm_response=mock_response,
thread_id=thread_id,
config=config
):
if chunk.get('type') == 'tool_result':
tool_name = chunk.get('name', '')
tool_execution_count += 1
if tool_name == 'wait':
wait_tool_count += 1
elapsed = asyncio.get_event_loop().time() - start_time
print(f"⏱️ [{elapsed:.2f}s] Tool result: {chunk['name']}")
print(f" {chunk['result']}")
print()
tool_results.append(chunk)
end_time = asyncio.get_event_loop().time()
elapsed = end_time - start_time
print(f"\n⏱️ {test['name']} execution completed in {elapsed:.2f} seconds")
print(f"🔢 Total tool executions: {tool_execution_count}")
print(f"🔢 Wait tool executions: {wait_tool_count}")
# Store results for validation
test_results[test['name']] = {
'execution_time': elapsed,
'tool_count': tool_execution_count,
'wait_count': wait_tool_count,
'tool_results': tool_results
}
# Assert correct number of wait tools executions (this is more reliable than total count)
assert wait_tool_count == expected_wait_count, f"❌ Expected {expected_wait_count} wait tool executions, got {wait_tool_count} in {test['name']} strategy"
print(f"✅ PASS: {test['name']} executed {wait_tool_count} wait tools as expected")
# Restore original get_llm_messages method
thread_manager.get_llm_messages = original_get_llm_messages
# Additional assertions for both test cases
assert 'Sequential' in test_results, "❌ Sequential test not completed"
assert 'Parallel' in test_results, "❌ Parallel test not completed"
# Validate parallel is faster than sequential for multiple wait tools
sequential_time = test_results['Sequential']['execution_time']
parallel_time = test_results['Parallel']['execution_time']
speedup = sequential_time / parallel_time if parallel_time > 0 else 0
# Parallel should be faster than sequential (at least 1.5x speedup expected)
print(f"\n⏱️ Execution time comparison:")
print(f" Sequential: {sequential_time:.2f}s")
print(f" Parallel: {parallel_time:.2f}s")
print(f" Speedup: {speedup:.2f}x")
min_expected_speedup = 1.5
assert speedup >= min_expected_speedup, f"❌ Expected parallel execution to be at least {min_expected_speedup}x faster than sequential, but got {speedup:.2f}x"
print(f"✅ PASS: Parallel execution is {speedup:.2f}x faster than sequential")
# Check if all results have a status field
all_have_status = all(
'status' in result
for test_data in test_results.values()
for result in test_data['tool_results']
)
# If results have a status field, check if they're all successful
if all_have_status:
all_successful = all(
result.get('status') == 'success'
for test_data in test_results.values()
for result in test_data['tool_results']
)
assert all_successful, "❌ Not all tool executions were successful"
print("✅ PASS: All tool executions completed successfully")
print("\n" + "="*80)
print("✅ ALL TESTS PASSED")
print("="*80 + "\n")
return test_results
if __name__ == "__main__":
try:
asyncio.run(test_execution_strategies())
print("\n✅ Test completed successfully")
sys.exit(0)
except AssertionError as e:
print(f"\n\n❌ Test failed: {str(e)}")
sys.exit(1)
except KeyboardInterrupt:
print("\n\n❌ Test interrupted by user")
sys.exit(1)
except Exception as e:
print(f"\n\n❌ Error during test: {str(e)}")
sys.exit(1)

View File

@ -1,282 +0,0 @@
"""
Tests for XML tool execution in streaming and non-streaming modes.
This module tests XML tool execution with execute_on_stream set to TRUE and FALSE,
to ensure both modes work correctly with the WaitTool.
"""
import os
import asyncio
import sys
from unittest.mock import AsyncMock, patch
from dotenv import load_dotenv
from agentpress.thread_manager import ThreadManager
from agentpress.response_processor import ProcessorConfig
from agent.tools.wait_tool import WaitTool
# Load environment variables
load_dotenv()
# XML content with wait tool calls
XML_CONTENT = """
Here are some examples of using the wait tool:
<wait seconds="1">This is wait 1</wait>
<wait seconds="1">This is wait 2</wait>
<wait seconds="1">This is wait 3</wait>
Now wait sequence:
<wait-sequence count="2" seconds="1" label="Test" />
"""
class MockStreamingResponse:
"""Mock streaming response from an LLM."""
def __init__(self, content):
self.content = content
self.chunk_size = 20 # Small chunks to simulate streaming
async def __aiter__(self):
# Split content into chunks to simulate streaming
for i in range(0, len(self.content), self.chunk_size):
chunk = self.content[i:i+self.chunk_size]
yield type('obj', (object,), {
'choices': [type('obj', (object,), {
'delta': type('obj', (object,), {
'content': chunk
})
})]
})
# Simulate some network delay
await asyncio.sleep(0.1)
class MockNonStreamingResponse:
"""Mock non-streaming response from an LLM."""
def __init__(self, content):
self.choices = [type('obj', (object,), {
'message': type('obj', (object,), {
'content': content
})
})]
# Create a simple mock function that logs instead of accessing the database
async def mock_add_message(thread_id, message):
print(f"MOCK: Adding message to thread {thread_id}")
print(f"MOCK: Message role: {message.get('role')}")
print(f"MOCK: Content length: {len(message.get('content', ''))}")
return {"id": "mock-message-id", "thread_id": thread_id}
async def test_xml_streaming_execution():
"""Test XML tool execution in both streaming and non-streaming modes."""
print("\n" + "="*80)
print("🧪 TESTING XML TOOL EXECUTION: STREAMING VS NON-STREAMING")
print("="*80 + "\n")
# Initialize ThreadManager and register tools
thread_manager = ThreadManager()
thread_manager.add_tool(WaitTool)
# Mock both ThreadManager's and ResponseProcessor's add_message method
thread_manager.add_message = AsyncMock(side_effect=mock_add_message)
thread_manager.response_processor.add_message = AsyncMock(side_effect=mock_add_message)
# Set up the get_llm_messages mock
original_get_llm_messages = thread_manager.get_llm_messages
thread_manager.get_llm_messages = AsyncMock()
# Test cases for streaming and non-streaming
test_cases = [
{"name": "Non-Streaming", "execute_on_stream": False},
{"name": "Streaming", "execute_on_stream": True}
]
# Expected values for validation - focus specifically on wait tools
expected_wait_count = 3 # 3 wait tags in the XML content
test_results = {}
for test in test_cases:
# Create a test thread ID - we're mocking so no actual creation
thread_id = f"test-thread-{test['name'].lower()}"
print("\n" + "-"*60)
print(f"🔍 Testing XML Tool Execution - {test['name']} Mode")
print("-"*60 + "\n")
# Setup mock for get_llm_messages to return test content
thread_manager.get_llm_messages.return_value = [
{
"role": "system",
"content": "You are a testing assistant that will execute wait commands."
},
{
"role": "assistant",
"content": XML_CONTENT
}
]
# Simulate adding system message (mocked)
print(f"MOCK: Adding system message to thread {thread_id}")
await thread_manager.add_message(
thread_id=thread_id,
type="system",
content={
"role": "system",
"content": "You are a testing assistant that will execute wait commands."
},
is_llm_message=False
)
# Simulate adding message with XML content (mocked)
print(f"MOCK: Adding message with XML content to thread {thread_id}")
await thread_manager.add_message(
thread_id=thread_id,
type="assistant",
content={
"role": "assistant",
"content": XML_CONTENT
},
is_llm_message=True
)
print(f"🧵 Using test thread: {thread_id}")
print(f"⚙️ execute_on_stream: {test['execute_on_stream']}")
# Prepare the response processor config
config = ProcessorConfig(
xml_tool_calling=True,
native_tool_calling=False,
execute_tools=True,
execute_on_stream=test['execute_on_stream'],
tool_execution_strategy="sequential"
)
# Get the last message to process (using mock)
messages = await thread_manager.get_llm_messages(thread_id)
last_message = messages[-1]
# Process response based on mode
start_time = asyncio.get_event_loop().time()
print(f"⏱️ Starting execution at {start_time:.2f}s")
tool_execution_count = 0
wait_tool_count = 0
tool_results = []
if test['execute_on_stream']:
# Create streaming response
streaming_response = MockStreamingResponse(last_message["content"])
# Process streaming response
async for chunk in thread_manager.response_processor.process_streaming_response(
llm_response=streaming_response,
thread_id=thread_id,
config=config
):
if chunk.get('type') == 'tool_result':
elapsed = asyncio.get_event_loop().time() - start_time
tool_name = chunk.get('name', '')
tool_execution_count += 1
if tool_name == 'wait':
wait_tool_count += 1
print(f"⏱️ [{elapsed:.2f}s] Tool result: {chunk['name']}")
print(f" {chunk['result']}")
print()
tool_results.append(chunk)
else:
# Create non-streaming response
non_streaming_response = MockNonStreamingResponse(last_message["content"])
# Process non-streaming response
async for chunk in thread_manager.response_processor.process_non_streaming_response(
llm_response=non_streaming_response,
thread_id=thread_id,
config=config
):
if chunk.get('type') == 'tool_result':
elapsed = asyncio.get_event_loop().time() - start_time
tool_name = chunk.get('name', '')
tool_execution_count += 1
if tool_name == 'wait':
wait_tool_count += 1
print(f"⏱️ [{elapsed:.2f}s] Tool result: {chunk['name']}")
print(f" {chunk['result']}")
print()
tool_results.append(chunk)
end_time = asyncio.get_event_loop().time()
elapsed = end_time - start_time
print(f"\n⏱️ {test['name']} execution completed in {elapsed:.2f} seconds")
print(f"🔢 Total tool executions: {tool_execution_count}")
print(f"🔢 Wait tool executions: {wait_tool_count}")
# Store results for validation
test_results[test['name']] = {
'execution_time': elapsed,
'tool_count': tool_execution_count,
'wait_count': wait_tool_count,
'tool_results': tool_results
}
# Assert correct number of wait tool executions
assert wait_tool_count == expected_wait_count, f"❌ Expected {expected_wait_count} wait tool executions, got {wait_tool_count} in {test['name']} mode"
print(f"✅ PASS: {test['name']} executed {wait_tool_count} wait tools as expected")
# Restore original get_llm_messages method
thread_manager.get_llm_messages = original_get_llm_messages
# Additional assertions for both test cases
assert 'Non-Streaming' in test_results, "❌ Non-streaming test not completed"
assert 'Streaming' in test_results, "❌ Streaming test not completed"
# Validate streaming has different timing characteristics than non-streaming
non_streaming_time = test_results['Non-Streaming']['execution_time']
streaming_time = test_results['Streaming']['execution_time']
# Streaming should have different timing due to the nature of execution
# We don't assert strict timing as it can vary, but we validate the tests ran successfully
print(f"\n⏱️ Execution time comparison:")
print(f" Non-Streaming: {non_streaming_time:.2f}s")
print(f" Streaming: {streaming_time:.2f}s")
print(f" Time difference: {abs(non_streaming_time - streaming_time):.2f}s")
# Check if all results have a status field
all_have_status = all(
'status' in result
for test_data in test_results.values()
for result in test_data['tool_results']
)
# If results have a status field, check if they're all successful
if all_have_status:
all_successful = all(
result.get('status') == 'success'
for test_data in test_results.values()
for result in test_data['tool_results']
)
assert all_successful, "❌ Not all tool executions were successful"
print("✅ PASS: All tool executions completed successfully")
print("\n" + "="*80)
print("✅ ALL TESTS PASSED")
print("="*80 + "\n")
return test_results
if __name__ == "__main__":
try:
asyncio.run(test_xml_streaming_execution())
print("\n✅ Test completed successfully")
sys.exit(0)
except AssertionError as e:
print(f"\n\n❌ Test failed: {str(e)}")
sys.exit(1)
except KeyboardInterrupt:
print("\n\n❌ Test interrupted by user")
sys.exit(1)
except Exception as e:
print(f"\n\n❌ Error during test: {str(e)}")
sys.exit(1)

View File

@ -5,7 +5,7 @@ from jwt.exceptions import PyJWTError
from utils.logger import logger
# This function extracts the user ID from Supabase JWT
async def get_current_user_id(request: Request) -> str:
async def get_current_user_id_from_jwt(request: Request) -> str:
"""
Extract and verify the user ID from the JWT in the Authorization header.
@ -56,6 +56,45 @@ async def get_current_user_id(request: Request) -> str:
headers={"WWW-Authenticate": "Bearer"}
)
async def get_account_id_from_thread(client, thread_id: str) -> str:
"""
Extract and verify the account ID from the thread.
Args:
client: The Supabase client
thread_id: The ID of the thread
Returns:
str: The account ID associated with the thread
Raises:
HTTPException: If the thread is not found or if there's an error
"""
try:
response = await client.table('threads').select('account_id').eq('thread_id', thread_id).execute()
if not response.data or len(response.data) == 0:
raise HTTPException(
status_code=404,
detail="Thread not found"
)
account_id = response.data[0].get('account_id')
if not account_id:
raise HTTPException(
status_code=500,
detail="Thread has no associated account"
)
return account_id
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error retrieving thread information: {str(e)}"
)
async def get_user_id_from_stream_auth(
request: Request,
token: Optional[str] = None
@ -105,6 +144,7 @@ async def get_user_id_from_stream_auth(
detail="No valid authentication credentials found",
headers={"WWW-Authenticate": "Bearer"}
)
async def verify_thread_access(client, thread_id: str, user_id: str):
"""
Verify that a user has access to a specific thread based on account membership.

View File

@ -1,125 +0,0 @@
from datetime import datetime, timezone
from typing import Dict, Optional, Tuple
from utils.logger import logger
from utils.config import config, EnvMode
# Define subscription tiers and their monthly limits (in minutes)
SUBSCRIPTION_TIERS = {
'price_1RGJ9GG6l1KZGqIroxSqgphC': {'name': 'free', 'minutes': 8},
'price_1RGJ9LG6l1KZGqIrd9pwzeNW': {'name': 'base', 'minutes': 300},
'price_1RGJ9JG6l1KZGqIrVUU4ZRv6': {'name': 'extra', 'minutes': 2400}
}
async def get_account_subscription(client, account_id: str) -> Optional[Dict]:
"""Get the current subscription for an account."""
result = await client.schema('basejump').from_('billing_subscriptions') \
.select('*') \
.eq('account_id', account_id) \
.eq('status', 'active') \
.order('created', desc=True) \
.limit(1) \
.execute()
if result.data and len(result.data) > 0:
return result.data[0]
return None
async def calculate_monthly_usage(client, account_id: str) -> float:
"""Calculate total agent run minutes for the current month for an account."""
# Get start of current month in UTC
now = datetime.now(timezone.utc)
start_of_month = datetime(now.year, now.month, 1, tzinfo=timezone.utc)
# First get all threads for this account
threads_result = await client.table('threads') \
.select('thread_id') \
.eq('account_id', account_id) \
.execute()
if not threads_result.data:
return 0.0
thread_ids = [t['thread_id'] for t in threads_result.data]
# Then get all agent runs for these threads in current month
runs_result = await client.table('agent_runs') \
.select('started_at, completed_at') \
.in_('thread_id', thread_ids) \
.gte('started_at', start_of_month.isoformat()) \
.execute()
if not runs_result.data:
return 0.0
# Calculate total minutes
total_seconds = 0
now_ts = now.timestamp()
for run in runs_result.data:
start_time = datetime.fromisoformat(run['started_at'].replace('Z', '+00:00')).timestamp()
if run['completed_at']:
end_time = datetime.fromisoformat(run['completed_at'].replace('Z', '+00:00')).timestamp()
else:
# For running jobs, use current time
end_time = now_ts
total_seconds += (end_time - start_time)
return total_seconds / 60 # Convert to minutes
async def check_billing_status(client, account_id: str) -> Tuple[bool, str, Optional[Dict]]:
"""
Check if an account can run agents based on their subscription and usage.
Returns:
Tuple[bool, str, Optional[Dict]]: (can_run, message, subscription_info)
"""
if config.ENV_MODE == EnvMode.LOCAL:
logger.info("Running in local development mode - billing checks are disabled")
return True, "Local development mode - billing disabled", {
"price_id": "local_dev",
"plan_name": "Local Development",
"minutes_limit": "no limit"
}
# For staging/production, check subscription status
# Get current subscription
subscription = await get_account_subscription(client, account_id)
# If no subscription, they can use free tier
if not subscription:
subscription = {
'price_id': 'price_1RGJ9GG6l1KZGqIroxSqgphC', # Free tier
'plan_name': 'free'
}
# if not subscription or subscription.get('price_id') is None or subscription.get('price_id') == 'price_1RGJ9GG6l1KZGqIroxSqgphC':
# return False, "You are not subscribed to any plan. Please upgrade your plan to continue.", subscription
# Get tier info
tier_info = SUBSCRIPTION_TIERS.get(subscription['price_id'])
if not tier_info:
return False, "Invalid subscription tier", subscription
# Calculate current month's usage
current_usage = await calculate_monthly_usage(client, account_id)
# Check if within limits
if current_usage >= tier_info['minutes']:
return False, f"Monthly limit of {tier_info['minutes']} minutes reached. Please upgrade your plan or wait until next month.", subscription
return True, "OK", subscription
# Helper function to get account ID from thread
async def get_account_id_from_thread(client, thread_id: str) -> Optional[str]:
"""Get the account ID associated with a thread."""
result = await client.table('threads') \
.select('account_id') \
.eq('thread_id', thread_id) \
.limit(1) \
.execute()
if result.data and len(result.data) > 0:
return result.data[0]['account_id']
return None

View File

@ -39,6 +39,75 @@ class Configuration:
# Environment mode
ENV_MODE: EnvMode = EnvMode.LOCAL
# Subscription tier IDs - Production
STRIPE_FREE_TIER_ID_PROD: str = 'price_1RILb4G6l1KZGqIrK4QLrx9i'
STRIPE_TIER_2_20_ID_PROD: str = 'price_1RILb4G6l1KZGqIrhomjgDnO'
STRIPE_TIER_6_50_ID_PROD: str = 'price_1RILb4G6l1KZGqIr5q0sybWn'
STRIPE_TIER_12_100_ID_PROD: str = 'price_1RILb4G6l1KZGqIr5Y20ZLHm'
STRIPE_TIER_25_200_ID_PROD: str = 'price_1RILb4G6l1KZGqIrGAD8rNjb'
STRIPE_TIER_50_400_ID_PROD: str = 'price_1RILb4G6l1KZGqIruNBUMTF1'
STRIPE_TIER_125_800_ID_PROD: str = 'price_1RILb3G6l1KZGqIrbJA766tN'
STRIPE_TIER_200_1000_ID_PROD: str = 'price_1RILb3G6l1KZGqIrmauYPOiN'
# Subscription tier IDs - Staging
STRIPE_FREE_TIER_ID_STAGING: str = 'price_1RIGvuG6l1KZGqIrw14abxeL'
STRIPE_TIER_2_20_ID_STAGING: str = 'price_1RIGvuG6l1KZGqIrCRu0E4Gi'
STRIPE_TIER_6_50_ID_STAGING: str = 'price_1RIGvuG6l1KZGqIrvjlz5p5V'
STRIPE_TIER_12_100_ID_STAGING: str = 'price_1RIGvuG6l1KZGqIrT6UfgblC'
STRIPE_TIER_25_200_ID_STAGING: str = 'price_1RIGvuG6l1KZGqIrOVLKlOMj'
STRIPE_TIER_50_400_ID_STAGING: str = 'price_1RIKNgG6l1KZGqIrvsat5PW7'
STRIPE_TIER_125_800_ID_STAGING: str = 'price_1RIKNrG6l1KZGqIrjKT0yGvI'
STRIPE_TIER_200_1000_ID_STAGING: str = 'price_1RIKQ2G6l1KZGqIrum9n8SI7'
# Computed subscription tier IDs based on environment
@property
def STRIPE_FREE_TIER_ID(self) -> str:
if self.ENV_MODE == EnvMode.STAGING:
return self.STRIPE_FREE_TIER_ID_STAGING
return self.STRIPE_FREE_TIER_ID_PROD
@property
def STRIPE_TIER_2_20_ID(self) -> str:
if self.ENV_MODE == EnvMode.STAGING:
return self.STRIPE_TIER_2_20_ID_STAGING
return self.STRIPE_TIER_2_20_ID_PROD
@property
def STRIPE_TIER_6_50_ID(self) -> str:
if self.ENV_MODE == EnvMode.STAGING:
return self.STRIPE_TIER_6_50_ID_STAGING
return self.STRIPE_TIER_6_50_ID_PROD
@property
def STRIPE_TIER_12_100_ID(self) -> str:
if self.ENV_MODE == EnvMode.STAGING:
return self.STRIPE_TIER_12_100_ID_STAGING
return self.STRIPE_TIER_12_100_ID_PROD
@property
def STRIPE_TIER_25_200_ID(self) -> str:
if self.ENV_MODE == EnvMode.STAGING:
return self.STRIPE_TIER_25_200_ID_STAGING
return self.STRIPE_TIER_25_200_ID_PROD
@property
def STRIPE_TIER_50_400_ID(self) -> str:
if self.ENV_MODE == EnvMode.STAGING:
return self.STRIPE_TIER_50_400_ID_STAGING
return self.STRIPE_TIER_50_400_ID_PROD
@property
def STRIPE_TIER_125_800_ID(self) -> str:
if self.ENV_MODE == EnvMode.STAGING:
return self.STRIPE_TIER_125_800_ID_STAGING
return self.STRIPE_TIER_125_800_ID_PROD
@property
def STRIPE_TIER_200_1000_ID(self) -> str:
if self.ENV_MODE == EnvMode.STAGING:
return self.STRIPE_TIER_200_1000_ID_STAGING
return self.STRIPE_TIER_200_1000_ID_PROD
# LLM API keys
ANTHROPIC_API_KEY: str = None
OPENAI_API_KEY: Optional[str] = None
@ -80,9 +149,19 @@ class Configuration:
# Stripe configuration
STRIPE_SECRET_KEY: Optional[str] = None
STRIPE_WEBHOOK_SECRET: Optional[str] = None
STRIPE_DEFAULT_PLAN_ID: Optional[str] = None
STRIPE_DEFAULT_TRIAL_DAYS: int = 14
# Stripe Product IDs
STRIPE_PRODUCT_ID_PROD: str = 'prod_SCl7AQ2C8kK1CD' # Production product ID
STRIPE_PRODUCT_ID_STAGING: str = 'prod_SCgIj3G7yPOAWY' # Staging product ID
@property
def STRIPE_PRODUCT_ID(self) -> str:
if self.ENV_MODE == EnvMode.STAGING:
return self.STRIPE_PRODUCT_ID_STAGING
return self.STRIPE_PRODUCT_ID_PROD
def __init__(self):
"""Initialize configuration by loading from environment variables."""

View File

@ -1,5 +1,3 @@
version: '3.8'
services:
redis:
image: redis:7-alpine

View File

@ -1,5 +1,5 @@
import {createClient} from "@/lib/supabase/server";
import AccountBillingStatus from "@/components/basejump/account-billing-status";
import AccountBillingStatus from "@/components/billing/account-billing-status";
const returnUrl = process.env.NEXT_PUBLIC_URL as string;

View File

@ -2,7 +2,7 @@
import React from 'react';
import {createClient} from "@/lib/supabase/server";
import AccountBillingStatus from "@/components/basejump/account-billing-status";
import AccountBillingStatus from "@/components/billing/account-billing-status";
import { Alert, AlertTitle, AlertDescription } from "@/components/ui/alert";
const returnUrl = process.env.NEXT_PUBLIC_URL as string;

View File

@ -7,7 +7,7 @@ import { Button } from '@/components/ui/button';
import {
ArrowDown, CheckCircle, CircleDashed, AlertTriangle, Info, File, ChevronRight
} from 'lucide-react';
import { addUserMessage, getMessages, startAgent, stopAgent, getAgentRuns, getProject, getThread, updateProject, Project, Message as BaseApiMessageType, BillingError } from '@/lib/api';
import { addUserMessage, getMessages, startAgent, stopAgent, getAgentRuns, getProject, getThread, updateProject, Project, Message as BaseApiMessageType, BillingError, checkBillingStatus } from '@/lib/api';
import { toast } from 'sonner';
import { Skeleton } from "@/components/ui/skeleton";
import { ChatInput } from '@/components/thread/chat-input';
@ -20,10 +20,9 @@ import { Markdown } from '@/components/ui/markdown';
import { cn } from "@/lib/utils";
import { useIsMobile } from "@/hooks/use-mobile";
import { BillingErrorAlert } from '@/components/billing/usage-limit-alert';
import { SUBSCRIPTION_PLANS } from '@/components/billing/plan-comparison';
import { createClient } from '@/lib/supabase/client';
import { isLocalMode } from "@/lib/config";
import { UnifiedMessage, ParsedContent, ParsedMetadata, ThreadParams } from '@/components/thread/types';
import { getToolIcon, extractPrimaryParam, safeJsonParse } from '@/components/thread/utils';
@ -1040,122 +1039,62 @@ export default function ThreadPage({ params }: { params: Promise<ThreadParams> }
}
}, [agentStatus, threadId, isLoading, streamHookStatus]);
// Check billing status when agent completes
const checkBillingStatus = useCallback(async () => {
// Update the checkBillingStatus function
const checkBillingLimits = useCallback(async () => {
// Skip billing checks in local development mode
if (isLocalMode()) {
console.log("Running in local development mode - billing checks are disabled");
return false;
}
if (!project?.account_id) return;
const supabase = createClient();
try {
// Check subscription status
const { data: subscriptionData } = await supabase
.schema('basejump')
.from('billing_subscriptions')
.select('price_id')
.eq('account_id', project.account_id)
.eq('status', 'active')
.single();
const result = await checkBillingStatus();
const currentPlanId = subscriptionData?.price_id || SUBSCRIPTION_PLANS.FREE;
// Only check usage limits for free tier users
if (currentPlanId === SUBSCRIPTION_PLANS.FREE) {
// Calculate usage
const startOfMonth = new Date();
startOfMonth.setDate(1);
startOfMonth.setHours(0, 0, 0, 0);
// Get threads for this account
const { data: threadsData } = await supabase
.from('threads')
.select('thread_id')
.eq('account_id', project.account_id);
const threadIds = threadsData?.map(t => t.thread_id) || [];
// Get agent runs for those threads
const { data: agentRunData } = await supabase
.from('agent_runs')
.select('started_at, completed_at')
.in('thread_id', threadIds)
.gte('started_at', startOfMonth.toISOString());
let totalSeconds = 0;
if (agentRunData) {
totalSeconds = agentRunData.reduce((acc, run) => {
const start = new Date(run.started_at);
const end = run.completed_at ? new Date(run.completed_at) : new Date();
const seconds = (end.getTime() - start.getTime()) / 1000;
return acc + seconds;
}, 0);
}
// Convert to hours for display
const hours = totalSeconds / 3600;
const minutesUsed = totalSeconds / 60;
// The free plan has a 10 minute limit as defined in backend/utils/billing.py
const FREE_PLAN_LIMIT_MINUTES = 10;
const FREE_PLAN_LIMIT_HOURS = FREE_PLAN_LIMIT_MINUTES / 60;
// Show alert if over limit
if (minutesUsed > FREE_PLAN_LIMIT_MINUTES) {
console.log("Usage limit exceeded:", {
minutesUsed,
hoursUsed: hours,
limit: FREE_PLAN_LIMIT_MINUTES
});
setBillingData({
currentUsage: Number(hours.toFixed(2)),
limit: FREE_PLAN_LIMIT_HOURS,
message: `You've used ${Math.floor(minutesUsed)} minutes on the Free plan. The limit is ${FREE_PLAN_LIMIT_MINUTES} minutes per month.`,
accountId: project.account_id || null
});
setShowBillingAlert(true);
return true; // Return true if over limit
}
if (!result.can_run) {
setBillingData({
currentUsage: result.subscription?.minutes_limit || 0,
limit: result.subscription?.minutes_limit || 0,
message: result.message || 'Usage limit reached',
accountId: project?.account_id || null
});
setShowBillingAlert(true);
return true;
}
return false; // Return false if not over limit
return false;
} catch (err) {
console.error('Error checking billing status:', err);
return false;
}
}, [project?.account_id]);
// Update useEffect to check billing when agent completes
// Update useEffect to use the renamed function
useEffect(() => {
const previousStatus = previousAgentStatus.current;
// Check if agent just completed (status changed from running to idle)
if (previousStatus === 'running' && agentStatus === 'idle') {
checkBillingStatus();
checkBillingLimits();
}
// Store current status for next comparison
previousAgentStatus.current = agentStatus;
}, [agentStatus, checkBillingStatus]);
}, [agentStatus, checkBillingLimits]);
// Add new useEffect to check billing limits when page first loads or project changes
// Update other useEffect to use the renamed function
useEffect(() => {
if (project?.account_id && initialLoadCompleted.current) {
console.log("Checking billing status on page load");
checkBillingStatus();
checkBillingLimits();
}
}, [project?.account_id, checkBillingStatus, initialLoadCompleted]);
// Also check after messages are loaded to ensure we have the complete state
}, [project?.account_id, checkBillingLimits, initialLoadCompleted]);
// Update the last useEffect to use the renamed function
useEffect(() => {
if (messagesLoadedRef.current && project?.account_id && !isLoading) {
console.log("Checking billing status after messages loaded");
checkBillingStatus();
checkBillingLimits();
}
}, [messagesLoadedRef.current, checkBillingStatus, project?.account_id, isLoading]);
}, [messagesLoadedRef.current, checkBillingLimits, project?.account_id, isLoading]);
if (isLoading && !initialLoadCompleted.current) {
return (

View File

@ -14,7 +14,7 @@ import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip
import { useBillingError } from "@/hooks/useBillingError";
import { BillingErrorAlert } from "@/components/billing/usage-limit-alert";
import { useAccounts } from "@/hooks/use-accounts";
import { isLocalMode } from "@/lib/config";
import { isLocalMode, config } from "@/lib/config";
import { toast } from "sonner";
// Constant for localStorage key to ensure consistency
@ -103,7 +103,7 @@ function DashboardContent() {
limit: error.detail.limit as number | undefined,
// Include subscription details if available in the error, otherwise provide defaults
subscription: error.detail.subscription || {
price_id: "price_1RGJ9GG6l1KZGqIroxSqgphC", // Default to Free tier
price_id: config.SUBSCRIPTION_TIERS.FREE.priceId, // Default to Free tier
plan_name: "Free"
}
});

View File

@ -6,7 +6,7 @@ import {
SidebarInset,
SidebarProvider,
} from "@/components/ui/sidebar"
import { PricingAlert } from "@/components/billing/pricing-alert"
// import { PricingAlert } from "@/components/billing/pricing-alert"
import { MaintenanceAlert } from "@/components/maintenance-alert"
import { useAccounts } from "@/hooks/use-accounts"
import { useAuth } from "@/components/AuthProvider"
@ -22,7 +22,7 @@ interface DashboardLayoutProps {
export default function DashboardLayout({
children,
}: DashboardLayoutProps) {
const [showPricingAlert, setShowPricingAlert] = useState(false)
// const [showPricingAlert, setShowPricingAlert] = useState(false)
const [showMaintenanceAlert, setShowMaintenanceAlert] = useState(false)
const [isApiHealthy, setIsApiHealthy] = useState(true)
const [isCheckingHealth, setIsCheckingHealth] = useState(true)
@ -32,7 +32,7 @@ export default function DashboardLayout({
const router = useRouter()
useEffect(() => {
setShowPricingAlert(false)
// setShowPricingAlert(false)
setShowMaintenanceAlert(false)
}, [])
@ -91,12 +91,12 @@ export default function DashboardLayout({
</div>
</SidebarInset>
<PricingAlert
{/* <PricingAlert
open={showPricingAlert}
onOpenChange={setShowPricingAlert}
closeable={false}
accountId={personalAccount?.account_id}
/>
/> */}
<MaintenanceAlert
open={showMaintenanceAlert}

View File

@ -27,7 +27,8 @@ export async function signIn(prevState: any, formData: FormData) {
return { message: error.message || "Could not authenticate user" };
}
return redirect(returnUrl || "/dashboard");
// Use client-side navigation instead of server-side redirect
return { success: true, redirectTo: returnUrl || "/dashboard" };
}
export async function signUp(prevState: any, formData: FormData) {
@ -73,7 +74,8 @@ export async function signUp(prevState: any, formData: FormData) {
return { message: "Account created! Check your email to confirm your registration." };
}
return redirect(returnUrl || "/dashboard");
// Use client-side navigation instead of server-side redirect
return { success: true, redirectTo: returnUrl || "/dashboard" };
}
export async function forgotPassword(prevState: any, formData: FormData) {

View File

@ -101,8 +101,19 @@ function LoginContent() {
const handleSignIn = async (prevState: any, formData: FormData) => {
if (returnUrl) {
formData.append("returnUrl", returnUrl);
} else {
formData.append("returnUrl", "/dashboard");
}
return signIn(prevState, formData);
const result = await signIn(prevState, formData);
// Check for success and redirectTo properties
if (result && typeof result === 'object' && 'success' in result && result.success && 'redirectTo' in result) {
// Use window.location for hard navigation to avoid stale state
window.location.href = result.redirectTo as string;
return null; // Return null to prevent normal form action completion
}
return result;
};
const handleSignUp = async (prevState: any, formData: FormData) => {
@ -119,6 +130,13 @@ function LoginContent() {
const result = await signUp(prevState, formData);
// Check for success and redirectTo properties (direct login case)
if (result && typeof result === 'object' && 'success' in result && result.success && 'redirectTo' in result) {
// Use window.location for hard navigation to avoid stale state
window.location.href = result.redirectTo as string;
return null; // Return null to prevent normal form action completion
}
// Check if registration was successful but needs email verification
if (result && typeof result === 'object' && 'message' in result) {
const resultMessage = result.message as string;
@ -166,9 +184,10 @@ function LoginContent() {
const resetRegistrationSuccess = () => {
setRegistrationSuccess(false);
// Remove message from URL
// Remove message from URL and set mode to signin
const params = new URLSearchParams(window.location.search);
params.delete('message');
params.set('mode', 'signin');
const newUrl =
window.location.pathname +

View File

@ -68,6 +68,9 @@ export default function GoogleSignIn({ returnUrl }: GoogleSignInProps) {
try {
setIsLoading(true);
const supabase = createClient();
console.log('Starting Google sign in process');
const { error } = await supabase.auth.signInWithIdToken({
provider: 'google',
token: response.credential,
@ -75,10 +78,13 @@ export default function GoogleSignIn({ returnUrl }: GoogleSignInProps) {
if (error) throw error;
// Add a small delay before redirecting to ensure localStorage is properly saved
console.log('Google sign in successful, preparing redirect to:', returnUrl || "/dashboard");
// Add a longer delay before redirecting to ensure localStorage is properly saved
setTimeout(() => {
console.log('Executing redirect now to:', returnUrl || "/dashboard");
window.location.href = returnUrl || "/dashboard";
}, 100);
}, 500); // Increased from 100ms to 500ms
} catch (error) {
console.error('Error signing in with Google:', error);
setIsLoading(false);

View File

@ -1,178 +0,0 @@
import { createClient } from "@/lib/supabase/server";
import { SubmitButton } from "../ui/submit-button";
import { manageSubscription } from "@/lib/actions/billing";
import { PlanComparison, SUBSCRIPTION_PLANS } from "../billing/plan-comparison";
import { isLocalMode } from "@/lib/config";
type Props = {
accountId: string;
returnUrl: string;
}
export default async function AccountBillingStatus({ accountId, returnUrl }: Props) {
// In local development mode, show a simplified component
if (isLocalMode()) {
return (
<div className="rounded-xl border shadow-sm bg-card p-6">
<h2 className="text-xl font-semibold mb-4">Billing Status</h2>
<div className="p-4 mb-4 bg-muted/30 border border-border rounded-lg text-center">
<p className="text-sm text-muted-foreground">
Running in local development mode - billing features are disabled
</p>
<p className="text-xs text-muted-foreground mt-2">
Agent usage limits are not enforced in this environment
</p>
</div>
</div>
);
}
const supabaseClient = await createClient();
// Get account subscription and usage data
const { data: subscriptionData } = await supabaseClient
.schema('basejump')
.from('billing_subscriptions')
.select('*')
.eq('account_id', accountId)
.eq('status', 'active')
.limit(1)
.order('created_at', { ascending: false })
.single();
// Get agent runs for this account
// Get the account's threads
const { data: threads } = await supabaseClient
.from('threads')
.select('thread_id')
.eq('account_id', accountId);
const threadIds = threads?.map(t => t.thread_id) || [];
// Get current month usage
const now = new Date();
const startOfMonth = new Date(now.getFullYear(), now.getMonth(), 1);
const isoStartOfMonth = startOfMonth.toISOString();
let totalAgentTime = 0;
let usageDisplay = "No usage this month";
if (threadIds.length > 0) {
const { data: agentRuns } = await supabaseClient
.from('agent_runs')
.select('started_at, completed_at')
.in('thread_id', threadIds)
.gte('started_at', isoStartOfMonth);
if (agentRuns && agentRuns.length > 0) {
const nowTimestamp = now.getTime();
totalAgentTime = agentRuns.reduce((total, run) => {
const startTime = new Date(run.started_at).getTime();
const endTime = run.completed_at
? new Date(run.completed_at).getTime()
: nowTimestamp;
return total + (endTime - startTime) / 1000; // In seconds
}, 0);
// Convert to minutes
const totalMinutes = Math.round(totalAgentTime / 60);
usageDisplay = `${totalMinutes} minutes`;
}
}
const isPlan = (planId?: string) => {
return subscriptionData?.price_id === planId;
};
const planName = isPlan(SUBSCRIPTION_PLANS.FREE)
? "Free"
: isPlan(SUBSCRIPTION_PLANS.PRO)
? "Pro"
: isPlan(SUBSCRIPTION_PLANS.ENTERPRISE)
? "Enterprise"
: "Unknown";
return (
<div className="rounded-xl border shadow-sm bg-card p-6">
<h2 className="text-xl font-semibold mb-4">Billing Status</h2>
{subscriptionData ? (
<>
<div className="mb-6">
<div className="rounded-lg border bg-background p-4 grid grid-cols-1 md:grid-cols-2 gap-4">
<div>
<div className="flex justify-between items-center">
<span className="text-sm font-medium text-foreground/90">Current Plan</span>
<span className="text-sm font-medium text-card-title">{planName}</span>
</div>
</div>
<div className="flex justify-between items-center">
<span className="text-sm font-medium text-foreground/90">Agent Usage This Month</span>
<span className="text-sm font-medium text-card-title">{usageDisplay}</span>
</div>
</div>
</div>
{/* Plans Comparison */}
<PlanComparison
accountId={accountId}
returnUrl={returnUrl}
className="mb-6"
/>
{/* Manage Subscription Button */}
<form>
<input type="hidden" name="accountId" value={accountId} />
<input type="hidden" name="returnUrl" value={returnUrl} />
<SubmitButton
pendingText="Loading..."
formAction={manageSubscription}
className="w-full bg-primary text-white hover:bg-primary/90 shadow-md hover:shadow-lg transition-all"
>
Manage Subscription
</SubmitButton>
</form>
</>
) : (
<>
<div className="mb-6">
<div className="rounded-lg border bg-background p-4 gap-4">
<div className="flex justify-between items-center">
<span className="text-sm font-medium text-foreground/90">Current Plan</span>
<span className="text-sm font-medium text-card-title">Free</span>
</div>
<div className="flex justify-between items-center">
<span className="text-sm font-medium text-foreground/90">Agent Usage This Month</span>
<span className="text-sm font-medium text-card-title">{usageDisplay}</span>
</div>
</div>
</div>
{/* Plans Comparison */}
<PlanComparison
accountId={accountId}
returnUrl={returnUrl}
className="mb-6"
/>
{/* Manage Subscription Button */}
<form>
<input type="hidden" name="accountId" value={accountId} />
<input type="hidden" name="returnUrl" value={returnUrl} />
<SubmitButton
pendingText="Loading..."
formAction={manageSubscription}
className="w-full bg-primary text-white hover:bg-primary/90 shadow-md hover:shadow-lg transition-all"
>
Manage Subscription
</SubmitButton>
</form>
</>
)}
</div>
)
}

View File

@ -0,0 +1,180 @@
'use client';
import { useEffect, useState } from 'react';
import { Button } from "@/components/ui/button";
import { PricingSection } from "@/components/home/sections/pricing-section";
import { isLocalMode } from "@/lib/config";
import { getSubscription, createPortalSession, SubscriptionStatus } from "@/lib/api";
import { useAuth } from "@/components/AuthProvider";
import { Skeleton } from "@/components/ui/skeleton";
type Props = {
accountId: string;
returnUrl: string;
}
export default function AccountBillingStatus({ accountId, returnUrl }: Props) {
const { session, isLoading: authLoading } = useAuth();
const [subscriptionData, setSubscriptionData] = useState<SubscriptionStatus | null>(null);
const [isLoading, setIsLoading] = useState(true);
const [error, setError] = useState<string | null>(null);
const [isManaging, setIsManaging] = useState(false);
useEffect(() => {
async function fetchSubscription() {
if (authLoading || !session) return;
try {
const data = await getSubscription();
setSubscriptionData(data);
setError(null);
} catch (err) {
console.error('Failed to get subscription:', err);
setError(err instanceof Error ? err.message : 'Failed to load subscription data');
} finally {
setIsLoading(false);
}
}
fetchSubscription();
}, [session, authLoading]);
const handleManageSubscription = async () => {
try {
setIsManaging(true);
const { url } = await createPortalSession({ return_url: returnUrl });
window.location.href = url;
} catch (err) {
console.error('Failed to create portal session:', err);
setError(err instanceof Error ? err.message : 'Failed to create portal session');
} finally {
setIsManaging(false);
}
};
// In local development mode, show a simplified component
if (isLocalMode()) {
return (
<div className="rounded-xl border shadow-sm bg-card p-6">
<h2 className="text-xl font-semibold mb-4">Billing Status</h2>
<div className="p-4 mb-4 bg-muted/30 border border-border rounded-lg text-center">
<p className="text-sm text-muted-foreground">
Running in local development mode - billing features are disabled
</p>
<p className="text-xs text-muted-foreground mt-2">
Agent usage limits are not enforced in this environment
</p>
</div>
</div>
);
}
// Show loading state
if (isLoading || authLoading) {
return (
<div className="rounded-xl border shadow-sm bg-card p-6">
<h2 className="text-xl font-semibold mb-4">Billing Status</h2>
<div className="space-y-4">
<Skeleton className="h-20 w-full" />
<Skeleton className="h-40 w-full" />
<Skeleton className="h-10 w-full" />
</div>
</div>
);
}
// Show error state
if (error) {
return (
<div className="rounded-xl border shadow-sm bg-card p-6">
<h2 className="text-xl font-semibold mb-4">Billing Status</h2>
<div className="p-4 mb-4 bg-destructive/10 border border-destructive/20 rounded-lg text-center">
<p className="text-sm text-destructive">
Error loading billing status: {error}
</p>
</div>
</div>
);
}
const isPlan = (planId?: string) => {
return subscriptionData?.plan_name === planId;
};
const planName = isPlan('free')
? "Free"
: isPlan('base')
? "Pro"
: isPlan('extra')
? "Enterprise"
: "Unknown";
return (
<div className="rounded-xl border shadow-sm bg-card p-6">
<h2 className="text-xl font-semibold mb-4">Billing Status</h2>
{subscriptionData ? (
<>
<div className="mb-6">
<div className="rounded-lg border bg-background p-4 grid grid-cols-1 md:grid-cols-2 gap-4">
<div className="flex justify-between items-center">
<span className="text-sm font-medium text-foreground/90">Agent Usage This Month</span>
<span className="text-sm font-medium text-card-title">
{subscriptionData.current_usage?.toFixed(2) || '0'} / {subscriptionData.minutes_limit || '0'} minutes
</span>
</div>
</div>
</div>
{/* Plans Comparison */}
<PricingSection
returnUrl={returnUrl}
showTitleAndTabs={false}
/>
{/* Manage Subscription Button */}
<Button
onClick={handleManageSubscription}
disabled={isManaging}
className="w-full bg-primary text-white hover:bg-primary/90 shadow-md hover:shadow-lg transition-all"
>
{isManaging ? "Loading..." : "Manage Subscription"}
</Button>
</>
) : (
<>
<div className="mb-6">
<div className="rounded-lg border bg-background p-4 gap-4">
<div className="flex justify-between items-center">
<span className="text-sm font-medium text-foreground/90">Current Plan</span>
<span className="text-sm font-medium text-card-title">Free</span>
</div>
<div className="flex justify-between items-center">
<span className="text-sm font-medium text-foreground/90">Agent Usage This Month</span>
<span className="text-sm font-medium text-card-title">
{subscriptionData?.current_usage?.toFixed(2) || '0'} / {subscriptionData?.minutes_limit || '0'} minutes
</span>
</div>
</div>
</div>
{/* Plans Comparison */}
<PricingSection
returnUrl={returnUrl}
showTitleAndTabs={false}
/>
{/* Manage Subscription Button */}
<Button
onClick={handleManageSubscription}
disabled={isManaging}
className="w-full bg-primary text-white hover:bg-primary/90 shadow-md hover:shadow-lg transition-all"
>
{isManaging ? "Loading..." : "Manage Subscription"}
</Button>
</>
)}
</div>
);
}

View File

@ -1,263 +0,0 @@
'use client';
import { createClient } from "@/lib/supabase/client";
import { useEffect, useState } from "react";
import { cn } from "@/lib/utils";
import { motion } from "motion/react";
import { setupNewSubscription } from "@/lib/actions/billing";
import { SubmitButton } from "@/components/ui/submit-button";
import { Button } from "@/components/ui/button";
import { siteConfig } from "@/lib/home";
import { isLocalMode } from "@/lib/config";
// Create SUBSCRIPTION_PLANS using stripePriceId from siteConfig
export const SUBSCRIPTION_PLANS = {
FREE: siteConfig.cloudPricingItems.find(item => item.name === 'Free')?.stripePriceId || '',
PRO: siteConfig.cloudPricingItems.find(item => item.name === 'Pro')?.stripePriceId || '',
ENTERPRISE: siteConfig.cloudPricingItems.find(item => item.name === 'Enterprise')?.stripePriceId || '',
};
// Price display animation component
const PriceDisplay = ({ tier, isCompact }: { tier: typeof siteConfig.cloudPricingItems[number]; isCompact?: boolean }) => {
return (
<motion.span
key={tier.price}
className={isCompact ? "text-xl font-semibold" : "text-3xl font-semibold"}
initial={{
opacity: 0,
x: 10,
filter: "blur(5px)",
}}
animate={{ opacity: 1, x: 0, filter: "blur(0px)" }}
transition={{ duration: 0.25, ease: [0.4, 0, 0.2, 1] }}
>
{tier.price}
</motion.span>
);
};
interface PlanComparisonProps {
accountId?: string | null;
returnUrl?: string;
isManaged?: boolean;
onPlanSelect?: (planId: string) => void;
className?: string;
isCompact?: boolean; // When true, uses vertical stacked layout for modals
}
export function PlanComparison({
accountId,
returnUrl = typeof window !== 'undefined' ? window.location.href : '',
isManaged = true,
onPlanSelect,
className = "",
isCompact = false
}: PlanComparisonProps) {
const [currentPlanId, setCurrentPlanId] = useState<string | undefined>();
useEffect(() => {
async function fetchCurrentPlan() {
if (accountId) {
const supabase = createClient();
const { data } = await supabase
.schema('basejump')
.from('billing_subscriptions')
.select('price_id')
.eq('account_id', accountId)
.eq('status', 'active')
.single();
setCurrentPlanId(data?.price_id || SUBSCRIPTION_PLANS.FREE);
} else {
setCurrentPlanId(SUBSCRIPTION_PLANS.FREE);
}
}
fetchCurrentPlan();
}, [accountId]);
// For local development mode, show a message instead
if (isLocalMode()) {
return (
<div className={cn("p-4 bg-muted/30 border border-border rounded-lg text-center", className)}>
<p className="text-sm text-muted-foreground">
Running in local development mode - billing features are disabled
</p>
</div>
);
}
return (
<div
className={cn(
"grid gap-3 w-full mx-auto",
isCompact
? "grid-cols-1 max-w-md"
: "grid-cols-1 md:grid-cols-3 max-w-6xl",
className
)}
>
{siteConfig.cloudPricingItems.map((tier) => {
const isCurrentPlan = currentPlanId === SUBSCRIPTION_PLANS[tier.name.toUpperCase() as keyof typeof SUBSCRIPTION_PLANS];
return (
<div
key={tier.name}
className={cn(
"rounded-lg bg-background border border-border",
isCompact ? "p-3 text-sm" : "p-5",
isCurrentPlan && (isCompact ? "ring-1 ring-primary" : "ring-2 ring-primary")
)}
>
{isCompact ? (
// Compact layout for modal
<>
<div className="flex justify-between mb-2">
<div>
<div className="flex items-center gap-1">
<h3 className="font-medium">{tier.name}</h3>
{tier.isPopular && (
<span className="bg-primary/10 text-primary text-[10px] font-medium px-1.5 py-0.5 rounded-full">
Popular
</span>
)}
{isCurrentPlan && (
<span className="bg-secondary/10 text-secondary text-[10px] font-medium px-1.5 py-0.5 rounded-full">
Current
</span>
)}
</div>
<div className="text-xs text-muted-foreground mt-0.5">{tier.description}</div>
</div>
<div className="text-right">
<div className="flex items-baseline">
<PriceDisplay tier={tier} isCompact={true} />
<span className="text-xs text-muted-foreground ml-1">
{tier.price !== "$0" ? "/mo" : ""}
</span>
</div>
<div className="text-[10px] text-muted-foreground mt-0.5">
{tier.hours}/month
</div>
</div>
</div>
<div className="mb-2.5">
<div className="text-[10px] text-muted-foreground leading-tight max-h-[40px] overflow-y-auto pr-1">
{tier.features.map((feature, index) => (
<span key={index} className="whitespace-normal">
{index > 0 && ' • '}
{feature}
</span>
))}
</div>
</div>
</>
) : (
// Standard layout for normal view
<>
<div className="flex items-center justify-between mb-4">
<h3 className="text-lg font-medium">{tier.name}</h3>
<div className="flex gap-1">
{tier.isPopular && (
<span className="bg-primary/10 text-primary text-xs font-medium px-2 py-0.5 rounded-full">
Popular
</span>
)}
{isCurrentPlan && (
<span className="bg-secondary/10 text-secondary text-xs font-medium px-2 py-0.5 rounded-full">
Current
</span>
)}
</div>
</div>
<div className="flex items-baseline mb-1">
<PriceDisplay tier={tier} />
<span className="text-muted-foreground ml-2">
{tier.price !== "$0" ? "/month" : ""}
</span>
</div>
<div className="inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-medium bg-secondary/10 text-secondary mb-4">
{tier.hours}/month
</div>
<p className="text-muted-foreground mb-6">{tier.description}</p>
<div className="mb-6">
<div className="text-sm text-muted-foreground space-y-2">
{tier.features.map((feature, index) => (
<div key={index} className="flex items-start gap-2">
<div className="size-5 rounded-full bg-primary/10 flex items-center justify-center mt-0.5">
<svg
width="12"
height="12"
viewBox="0 0 12 12"
fill="none"
xmlns="http://www.w3.org/2000/svg"
className="text-primary"
>
<path
d="M2.5 6L5 8.5L9.5 4"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
</div>
<span>{feature}</span>
</div>
))}
</div>
</div>
</>
)}
<form>
<input type="hidden" name="accountId" value={accountId} />
<input type="hidden" name="returnUrl" value={returnUrl} />
<input type="hidden" name="planId" value={SUBSCRIPTION_PLANS[tier.name.toUpperCase() as keyof typeof SUBSCRIPTION_PLANS]} />
{isManaged ? (
<SubmitButton
pendingText="..."
formAction={setupNewSubscription}
// disabled={isCurrentPlan}
className={cn(
"w-full font-medium transition-colors",
isCompact
? "h-7 rounded-md text-xs"
: "h-10 rounded-full text-sm",
isCurrentPlan
? "bg-muted text-muted-foreground hover:bg-muted"
: tier.buttonColor
)}
>
{isCurrentPlan ? "Current Plan" : (tier.name === "Free" ? tier.buttonText : "Upgrade")}
</SubmitButton>
) : (
<Button
className={cn(
"w-full font-medium transition-colors",
isCompact
? "h-7 rounded-md text-xs"
: "h-10 rounded-full text-sm",
isCurrentPlan
? "bg-muted text-muted-foreground hover:bg-muted"
: tier.buttonColor
)}
disabled={isCurrentPlan}
onClick={() => onPlanSelect?.(SUBSCRIPTION_PLANS[tier.name.toUpperCase() as keyof typeof SUBSCRIPTION_PLANS])}
>
{isCurrentPlan ? "Current Plan" : (tier.name === "Free" ? tier.buttonText : "Upgrade")}
</Button>
)}
</form>
</div>
);
})}
</div>
);
}

View File

@ -1,251 +0,0 @@
"use client"
import { X, Zap, Github, Check } from "lucide-react"
import Link from "next/link"
import { AnimatePresence, motion } from "motion/react"
import { Button } from "@/components/ui/button"
import { Portal } from "@/components/ui/portal"
import { cn } from "@/lib/utils"
import { setupNewSubscription } from "@/lib/actions/billing"
import { SubmitButton } from "@/components/ui/submit-button"
import { siteConfig } from "@/lib/home"
import { isLocalMode } from "@/lib/config"
import { createClient } from "@/lib/supabase/client"
import { useEffect, useState } from "react"
import { SUBSCRIPTION_PLANS } from "./plan-comparison"
interface PricingAlertProps {
open: boolean
onOpenChange: (open: boolean) => void
closeable?: boolean
accountId?: string | null | undefined
}
export function PricingAlert({ open, onOpenChange, closeable = true, accountId }: PricingAlertProps) {
const returnUrl = typeof window !== 'undefined' ? window.location.href : '';
const [hasActiveSubscription, setHasActiveSubscription] = useState(false);
const [isLoading, setIsLoading] = useState(true);
// Check if user has an active subscription
useEffect(() => {
async function checkSubscription() {
if (!accountId) {
setHasActiveSubscription(false);
setIsLoading(false);
return;
}
try {
const supabase = createClient();
const { data } = await supabase
.schema('basejump')
.from('billing_subscriptions')
.select('price_id')
.eq('account_id', accountId)
.eq('status', 'active')
.single();
// Check if the user has a paid subscription (not free tier)
const isPaidSubscription = data?.price_id &&
data.price_id !== SUBSCRIPTION_PLANS.FREE;
setHasActiveSubscription(isPaidSubscription);
} catch (error) {
console.error("Error checking subscription:", error);
setHasActiveSubscription(false);
} finally {
setIsLoading(false);
}
}
checkSubscription();
}, [accountId]);
// Skip rendering in local development mode or if user has an active subscription
if (isLocalMode() || !open || hasActiveSubscription || isLoading) return null;
// Filter plans to show only Pro and Enterprise
const premiumPlans = siteConfig.cloudPricingItems.filter(plan =>
plan.name === 'Pro' || plan.name === 'Enterprise'
);
return (
<Portal>
<AnimatePresence>
{open && (
<>
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
exit={{ opacity: 0 }}
transition={{ duration: 0.2 }}
className="fixed inset-0 z-[9999] flex items-center justify-center overflow-y-auto py-8 px-4"
>
{/* Backdrop */}
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
exit={{ opacity: 0 }}
transition={{ duration: 0.2 }}
className="fixed inset-0 bg-black/60 backdrop-blur-sm"
onClick={closeable ? () => onOpenChange(false) : undefined}
aria-hidden="true"
/>
{/* Modal */}
<motion.div
initial={{ opacity: 0, scale: 0.95, y: 20 }}
animate={{ opacity: 1, scale: 1, y: 0 }}
exit={{ opacity: 0, scale: 0.95, y: 20 }}
transition={{ duration: 0.2, ease: [0.4, 0, 0.2, 1] }}
className={cn(
"relative bg-background rounded-xl shadow-2xl w-full max-w-3xl mx-3 border border-border"
)}
role="dialog"
aria-modal="true"
aria-labelledby="pricing-modal-title"
>
<div className="p-6">
{/* Close button */}
{closeable && (
<button
onClick={() => onOpenChange(false)}
className="absolute top-4 right-4 text-muted-foreground hover:text-foreground transition-colors"
aria-label="Close dialog"
>
<X className="h-5 w-5" />
</button>
)}
{/* Header */}
<div className="text-center mb-8">
<div className="inline-flex items-center justify-center p-2 bg-primary/10 rounded-full mb-3">
<Zap className="h-5 w-5 text-primary" />
</div>
<h2 id="pricing-modal-title" className="text-2xl font-medium tracking-tight mb-2">
Choose Your Suna Experience
</h2>
<p className="text-muted-foreground max-w-lg mx-auto">
Due to overwhelming demand and AI costs, we're currently focusing on delivering
our best experience to dedicated users. Select your preferred option below.
</p>
</div>
{/* Plan comparison - 3 column layout */}
<div className="grid md:grid-cols-3 gap-4 mb-6">
{/* Self-Host Option */}
<div className="rounded-xl bg-[#F3F4F6] dark:bg-[#F9FAFB]/[0.02] border border-border hover:border-muted-foreground/30 transition-all duration-300">
<div className="flex flex-col gap-4 p-4">
<p className="text-sm flex items-center">Open Source</p>
<div className="flex items-baseline mt-2">
<span className="text-2xl font-semibold">Self-host</span>
</div>
<p className="text-sm mt-2">Full control with your own infrastructure</p>
<div className="inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-semibold bg-primary/10 border-primary/20 text-primary w-fit">
hours / month
</div>
</div>
<div className="px-4 pb-4">
<div className="flex items-start gap-2 mb-3">
<Check className="h-4 w-4 text-green-500 flex-shrink-0 mt-0.5" />
<span className="text-xs text-muted-foreground">No usage limitations</span>
</div>
<Link
href="https://github.com/kortix-ai/suna"
target="_blank"
rel="noopener noreferrer"
className="h-10 w-full flex items-center justify-center gap-2 text-sm font-normal tracking-wide rounded-full px-4 cursor-pointer transition-all ease-out active:scale-95 bg-secondary/10 text-secondary shadow-[0px_1px_2px_0px_rgba(255,255,255,0.16)_inset,0px_3px_3px_-1.5px_rgba(16,24,40,0.24),0px_1px_1px_-0.5px_rgba(16,24,40,0.20)]"
>
<Github className="h-4 w-4" />
<span>View on GitHub</span>
</Link>
</div>
</div>
{/* Pro Plan */}
<div className="rounded-xl md:shadow-[0px_61px_24px_-10px_rgba(0,0,0,0.01),0px_34px_20px_-8px_rgba(0,0,0,0.05),0px_15px_15px_-6px_rgba(0,0,0,0.09),0px_4px_8px_-2px_rgba(0,0,0,0.10),0px_0px_0px_1px_rgba(0,0,0,0.08)] bg-accent relative transform hover:scale-105 transition-all duration-300">
<div className="absolute -top-3 -right-3">
<span className="bg-gradient-to-b from-secondary/50 from-[1.92%] to-secondary to-[100%] text-white h-6 inline-flex w-fit items-center justify-center px-3 rounded-full text-xs font-medium shadow-[0px_6px_6px_-3px_rgba(0,0,0,0.08),0px_3px_3px_-1.5px_rgba(0,0,0,0.08),0px_1px_1px_-0.5px_rgba(0,0,0,0.08),0px_0px_0px_1px_rgba(255,255,255,0.12)_inset,0px_1px_0px_0px_rgba(255,255,255,0.12)_inset]">
Most Popular
</span>
</div>
<div className="flex flex-col gap-4 p-4">
<p className="text-sm flex items-center font-medium">Pro</p>
<div className="flex items-baseline mt-2">
<span className="text-2xl font-semibold">{premiumPlans[0]?.price || "$19"}</span>
<span className="ml-2">/month</span>
</div>
<p className="text-sm mt-2">Supercharge your productivity with {premiumPlans[0]?.hours || "500 hours"} of Suna</p>
<div className="inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-semibold bg-primary/10 border-primary/20 text-primary w-fit">
{premiumPlans[0]?.hours || "500 hours"}/month
</div>
</div>
<div className="px-4 pb-4">
<div className="flex items-start gap-2 mb-3">
<Check className="h-4 w-4 text-green-500 flex-shrink-0 mt-0.5" />
<span className="text-xs text-muted-foreground">Perfect for individuals and small teams</span>
</div>
<form>
<input type="hidden" name="accountId" value={accountId || ''} />
<input type="hidden" name="returnUrl" value={returnUrl} />
<input type="hidden" name="planId" value={
premiumPlans[0]?.stripePriceId || ''
} />
<SubmitButton
pendingText="..."
formAction={setupNewSubscription}
className="h-10 w-full flex items-center justify-center text-sm font-medium tracking-wide rounded-full px-4 cursor-pointer transition-all ease-out active:scale-95 bg-primary text-primary-foreground shadow-[inset_0_1px_2px_rgba(255,255,255,0.25),0_3px_3px_-1.5px_rgba(16,24,40,0.06),0_1px_1px_rgba(16,24,40,0.08)]"
>
Get Started Now
</SubmitButton>
</form>
</div>
</div>
{/* Enterprise Plan */}
<div className="rounded-xl bg-[#F3F4F6] dark:bg-[#F9FAFB]/[0.02] border border-border hover:border-muted-foreground/30 transition-all duration-300">
<div className="flex flex-col gap-4 p-4">
<p className="text-sm flex items-center font-medium">Enterprise</p>
<div className="flex items-baseline mt-2">
<span className="text-2xl font-semibold">{premiumPlans[1]?.price || "$99"}</span>
<span className="ml-2">/month</span>
</div>
<p className="text-sm mt-2">Unlock boundless potential with {premiumPlans[1]?.hours || "2000 hours"} of Suna</p>
<div className="inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-semibold bg-primary/10 border-primary/20 text-primary w-fit">
{premiumPlans[1]?.hours || "2000 hours"}/month
</div>
</div>
<div className="px-4 pb-4">
<div className="flex items-start gap-2 mb-3">
<Check className="h-4 w-4 text-green-500 flex-shrink-0 mt-0.5" />
<span className="text-xs text-muted-foreground">Ideal for larger organizations and power users</span>
</div>
<form>
<input type="hidden" name="accountId" value={accountId || ''} />
<input type="hidden" name="returnUrl" value={returnUrl} />
<input type="hidden" name="planId" value={
premiumPlans[1]?.stripePriceId || ''
} />
<SubmitButton
pendingText="..."
formAction={setupNewSubscription}
className="h-10 w-full flex items-center justify-center text-sm font-normal tracking-wide rounded-full px-4 cursor-pointer transition-all ease-out active:scale-95 bg-gradient-to-b from-secondary/50 from-[1.92%] to-secondary to-[100%] text-white shadow-[0px_1px_2px_0px_rgba(255,255,255,0.16)_inset,0px_3px_3px_-1.5px_rgba(16,24,40,0.24),0px_1px_1px_-0.5px_rgba(16,24,40,0.20)]"
>
Upgrade to Enterprise
</SubmitButton>
</form>
</div>
</div>
</div>
</div>
</motion.div>
</motion.div>
</>
)}
</AnimatePresence>
</Portal>
)
}

View File

@ -1,18 +1,15 @@
import { AlertCircle, X } from "lucide-react";
'use client';
import { AlertTriangle } from "lucide-react";
import { Button } from "@/components/ui/button";
import { Portal } from "@/components/ui/portal";
import { PlanComparison } from "./plan-comparison";
import { cn } from "@/lib/utils";
import { motion, AnimatePresence } from "motion/react";
import { isLocalMode } from "@/lib/config";
import { useRouter } from "next/navigation";
interface BillingErrorAlertProps {
message?: string;
currentUsage?: number;
limit?: number;
accountId: string | null | undefined;
onDismiss?: () => void;
className?: string;
accountId?: string | null;
onDismiss: () => void;
isOpen: boolean;
}
@ -22,125 +19,42 @@ export function BillingErrorAlert({
limit,
accountId,
onDismiss,
className = "",
isOpen
}: BillingErrorAlertProps) {
const returnUrl = typeof window !== 'undefined' ? window.location.href : '';
// Skip rendering in local development mode
if (isLocalMode() || !isOpen) return null;
const router = useRouter();
if (!isOpen) return null;
return (
<Portal>
<AnimatePresence>
{isOpen && (
<>
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
exit={{ opacity: 0 }}
transition={{ duration: 0.2 }}
className="fixed inset-0 z-[9999] flex items-center justify-center overflow-y-auto py-4"
>
{/* Backdrop */}
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
exit={{ opacity: 0 }}
transition={{ duration: 0.2 }}
className="fixed inset-0 bg-black/40 backdrop-blur-sm"
<div className="fixed bottom-4 right-4 z-50">
<div className="bg-destructive/10 border border-destructive/20 rounded-lg p-4 shadow-lg max-w-md">
<div className="flex items-start gap-3">
<div className="flex-shrink-0">
<AlertTriangle className="h-5 w-5 text-destructive" />
</div>
<div className="flex-1">
<h3 className="text-sm font-medium text-destructive mb-1">Usage Limit Reached</h3>
<p className="text-sm text-muted-foreground mb-3">{message}</p>
<div className="flex gap-2">
<Button
variant="outline"
size="sm"
onClick={onDismiss}
aria-hidden="true"
/>
{/* Modal */}
<motion.div
initial={{ opacity: 0, scale: 0.95, y: 20 }}
animate={{ opacity: 1, scale: 1, y: 0 }}
exit={{ opacity: 0, scale: 0.95, y: 20 }}
transition={{ duration: 0.2, ease: [0.4, 0, 0.2, 1] }}
className={cn(
"relative bg-background rounded-lg shadow-xl w-full max-w-sm mx-3",
className
)}
role="dialog"
aria-modal="true"
aria-labelledby="billing-modal-title"
className="text-xs"
>
<div className="p-4">
{/* Close button */}
{onDismiss && (
<button
onClick={onDismiss}
className="absolute top-2 right-2 text-muted-foreground hover:text-foreground transition-colors"
aria-label="Close dialog"
>
<X className="h-4 w-4" />
</button>
)}
{/* Header */}
<div className="text-center mb-4">
<div className="inline-flex items-center justify-center p-1.5 bg-destructive/10 rounded-full mb-2">
<AlertCircle className="h-4 w-4 text-destructive" />
</div>
<h2 id="billing-modal-title" className="text-lg font-medium tracking-tight mb-1">
Usage Limit Reached
</h2>
<p className="text-xs text-muted-foreground">
{message || "You've reached your monthly usage limit."}
</p>
</div>
{/* Usage Stats */}
{currentUsage !== undefined && limit !== undefined && (
<div className="mb-4 p-3 bg-muted/30 border border-border rounded-lg">
<div className="flex justify-between items-center mb-2">
<div>
<p className="text-xs font-medium text-muted-foreground">Usage</p>
<p className="text-base font-semibold">{(currentUsage * 60).toFixed(0)}m</p>
</div>
<div className="text-right">
<p className="text-xs font-medium text-muted-foreground">Limit</p>
<p className="text-base font-semibold">{(limit * 60).toFixed(0)}m</p>
</div>
</div>
<div className="w-full h-1.5 bg-background rounded-full overflow-hidden">
<motion.div
initial={{ width: 0 }}
animate={{ width: `${Math.min((currentUsage / limit) * 100, 100)}%` }}
transition={{ duration: 0.5, ease: [0.4, 0, 0.2, 1] }}
className="h-full bg-destructive rounded-full"
/>
</div>
</div>
)}
{/* Plans Comparison */}
<PlanComparison
accountId={accountId}
returnUrl={returnUrl}
className="mb-3"
isCompact={true}
/>
{/* Dismiss Button */}
{onDismiss && (
<Button
variant="ghost"
size="sm"
className="w-full text-muted-foreground hover:text-foreground text-xs h-7"
onClick={onDismiss}
>
Continue with Current Plan
</Button>
)}
</div>
</motion.div>
</motion.div>
</>
)}
</AnimatePresence>
</Portal>
Dismiss
</Button>
<Button
size="sm"
onClick={() => router.push(`/settings/billing?accountId=${accountId}`)}
className="text-xs"
>
Upgrade Plan
</Button>
</div>
</div>
</div>
</div>
</div>
);
}

View File

@ -25,7 +25,7 @@ import {
import { BillingErrorAlert } from '@/components/billing/usage-limit-alert';
import { useBillingError } from "@/hooks/useBillingError";
import { useAccounts } from "@/hooks/use-accounts";
import { isLocalMode } from "@/lib/config";
import { isLocalMode, config } from "@/lib/config";
import { toast } from "sonner";
// Custom dialog overlay with blur effect
@ -140,7 +140,7 @@ export function HeroSection() {
currentUsage: error.detail.currentUsage as number | undefined,
limit: error.detail.limit as number | undefined,
subscription: error.detail.subscription || {
price_id: "price_1RGJ9GG6l1KZGqIroxSqgphC", // Default Free
price_id: config.SUBSCRIPTION_TIERS.FREE.priceId, // Default Free
plan_name: "Free"
}
});

View File

@ -24,7 +24,7 @@ export function OpenSourceSection() {
<div className="flex flex-col gap-6">
<div className="flex items-center gap-2 text-primary font-medium">
<Github className="h-5 w-5" />
<span>Kortix/Suna</span>
<span>kortix-ai/suna</span>
</div>
<div className="relative">
<h3 className="text-2xl font-semibold tracking-tight">

View File

@ -1,20 +1,72 @@
"use client";
import { SectionHeader } from "@/components/home/section-header";
import type { PricingTier } from "@/lib/home";
import { siteConfig } from "@/lib/home";
import { cn } from "@/lib/utils";
import { motion } from "motion/react";
import { useState } from "react";
import { Github, GitFork, File, Terminal } from "lucide-react";
import { useState, useEffect, useRef } from "react";
import { CheckIcon } from "lucide-react";
import Link from "next/link";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { Button } from "@/components/ui/button";
import { getSubscription, createCheckoutSession, SubscriptionStatus, CreateCheckoutSessionResponse } from "@/lib/api";
import { toast } from "sonner";
import { isLocalMode } from "@/lib/config";
interface TabsProps {
// Constants
const DEFAULT_SELECTED_PLAN = "6 hours";
export const SUBSCRIPTION_PLANS = {
FREE: 'free',
PRO: 'base',
ENTERPRISE: 'extra',
};
// Types
type ButtonVariant = "default" | "secondary" | "ghost" | "outline" | "link" | null;
interface PricingTabsProps {
activeTab: "cloud" | "self-hosted";
setActiveTab: (tab: "cloud" | "self-hosted") => void;
className?: string;
}
function PricingTabs({ activeTab, setActiveTab, className }: TabsProps) {
interface PriceDisplayProps {
price: string;
isCompact?: boolean;
}
interface CustomPriceDisplayProps {
price: string;
}
interface UpgradePlan {
hours: string;
price: string;
stripePriceId: string;
}
interface PricingTierProps {
tier: PricingTier;
isCompact?: boolean;
currentSubscription: SubscriptionStatus | null;
isLoading: Record<string, boolean>;
isFetchingPlan: boolean;
selectedPlan?: string;
onPlanSelect?: (planId: string) => void;
onSubscriptionUpdate?: () => void;
isAuthenticated?: boolean;
returnUrl: string;
}
// Components
function PricingTabs({ activeTab, setActiveTab, className }: PricingTabsProps) {
return (
<div
className={cn(
@ -60,21 +112,460 @@ function PricingTabs({ activeTab, setActiveTab, className }: TabsProps) {
);
}
export function PricingSection() {
const [deploymentType, setDeploymentType] = useState<"cloud" | "self-hosted">(
"cloud",
function PriceDisplay({ price, isCompact }: PriceDisplayProps) {
return (
<motion.span
key={price}
className={isCompact ? "text-xl font-semibold" : "text-4xl font-semibold"}
initial={{
opacity: 0,
x: 10,
filter: "blur(5px)",
}}
animate={{ opacity: 1, x: 0, filter: "blur(0px)" }}
transition={{ duration: 0.25, ease: [0.4, 0, 0.2, 1] }}
>
{price}
</motion.span>
);
}
function CustomPriceDisplay({ price }: CustomPriceDisplayProps) {
return (
<motion.span
key={price}
className="text-4xl font-semibold"
initial={{
opacity: 0,
x: 10,
filter: "blur(5px)",
}}
animate={{ opacity: 1, x: 0, filter: "blur(0px)" }}
transition={{ duration: 0.25, ease: [0.4, 0, 0.2, 1] }}
>
{price}
</motion.span>
);
}
function PricingTier({
tier,
isCompact = false,
currentSubscription,
isLoading,
isFetchingPlan,
selectedPlan,
onPlanSelect,
onSubscriptionUpdate,
isAuthenticated = false,
returnUrl
}: PricingTierProps) {
const [localSelectedPlan, setLocalSelectedPlan] = useState(selectedPlan || DEFAULT_SELECTED_PLAN);
const hasInitialized = useRef(false);
// Auto-select the correct plan only on initial load
useEffect(() => {
if (!hasInitialized.current && tier.name === "Custom" && tier.upgradePlans && currentSubscription?.price_id) {
const matchingPlan = tier.upgradePlans.find(plan => plan.stripePriceId === currentSubscription.price_id);
if (matchingPlan) {
setLocalSelectedPlan(matchingPlan.hours);
}
hasInitialized.current = true;
}
}, [currentSubscription, tier.name, tier.upgradePlans]);
// Only refetch when plan is selected
const handlePlanSelect = (value: string) => {
setLocalSelectedPlan(value);
if (tier.name === "Custom" && onSubscriptionUpdate) {
onSubscriptionUpdate();
}
};
const handleSubscribe = async (planStripePriceId: string) => {
if (!isAuthenticated) {
window.location.href = '/auth';
return;
}
if (isLoading[planStripePriceId]) {
return;
}
try {
// For custom tier, get the selected plan's stripePriceId
let finalPriceId = planStripePriceId;
if (tier.name === "Custom" && tier.upgradePlans) {
const selectedPlan = tier.upgradePlans.find(plan => plan.hours === localSelectedPlan);
if (selectedPlan?.stripePriceId) {
finalPriceId = selectedPlan.stripePriceId;
}
}
onPlanSelect?.(finalPriceId);
const response: CreateCheckoutSessionResponse = await createCheckoutSession({
price_id: finalPriceId,
success_url: returnUrl,
cancel_url: returnUrl
});
console.log('Subscription action response:', response);
switch (response.status) {
case 'new':
case 'checkout_created':
if (response.url) {
window.location.href = response.url;
} else {
console.error("Error: Received status 'checkout_created' but no checkout URL.");
toast.error('Failed to initiate subscription. Please try again.');
}
break;
case 'upgraded':
case 'updated':
const upgradeMessage = response.details?.is_upgrade
? `Subscription upgraded from $${response.details.current_price} to $${response.details.new_price}`
: 'Subscription updated successfully';
toast.success(upgradeMessage);
if (onSubscriptionUpdate) onSubscriptionUpdate();
break;
case 'downgrade_scheduled':
case 'scheduled':
const effectiveDate = response.effective_date
? new Date(response.effective_date).toLocaleDateString()
: 'the end of your billing period';
const downgradeMessage = response.details?.is_upgrade === false
? `Subscription downgrade scheduled from $${response.details.current_price} to $${response.details.new_price}`
: 'Subscription change scheduled';
toast.success(
<div>
<p>{downgradeMessage}</p>
<p className="text-sm mt-1">
Your plan will change on {effectiveDate}.
</p>
</div>
);
if (onSubscriptionUpdate) onSubscriptionUpdate();
break;
case 'no_change':
toast.info(response.message || 'You are already on this plan.');
break;
default:
console.warn('Received unexpected status from createCheckoutSession:', response.status);
toast.error('An unexpected error occurred. Please try again.');
}
} catch (error: any) {
console.error('Error processing subscription:', error);
const errorMessage = error?.response?.data?.detail || error?.message || 'Failed to process subscription. Please try again.';
toast.error(errorMessage);
}
};
const getPriceValue = (tier: typeof siteConfig.cloudPricingItems[0], selectedHours?: string): string => {
if (tier.upgradePlans && selectedHours) {
const plan = tier.upgradePlans.find(plan => plan.hours === selectedHours);
if (plan) {
return plan.price;
}
}
return tier.price;
};
const getDisplayedHours = (tier: typeof siteConfig.cloudPricingItems[0]) => {
if (tier.name === "Custom" && localSelectedPlan) {
return localSelectedPlan;
}
return tier.hours;
};
const getSelectedPlanPriceId = (tier: typeof siteConfig.cloudPricingItems[0]): string => {
if (tier.name === "Custom" && tier.upgradePlans) {
const selectedPlan = tier.upgradePlans.find(plan => plan.hours === localSelectedPlan);
return selectedPlan?.stripePriceId || tier.stripePriceId;
}
return tier.stripePriceId;
};
const getSelectedPlanPrice = (tier: typeof siteConfig.cloudPricingItems[0]): string => {
if (tier.name === "Custom" && tier.upgradePlans) {
const selectedPlan = tier.upgradePlans.find(plan => plan.hours === localSelectedPlan);
return selectedPlan?.price || tier.price;
}
return tier.price;
};
const tierPriceId = getSelectedPlanPriceId(tier);
const isCurrentActivePlan = isAuthenticated && (
// For custom tier, check if the selected plan matches the current subscription
tier.name === "Custom"
? tier.upgradePlans?.some(plan =>
plan.hours === localSelectedPlan &&
plan.stripePriceId === currentSubscription?.price_id
)
: currentSubscription?.price_id === tierPriceId
);
const isScheduled = isAuthenticated && currentSubscription?.has_schedule;
const isScheduledTargetPlan = isScheduled && (
// For custom tier, check if the selected plan matches the scheduled subscription
tier.name === "Custom"
? tier.upgradePlans?.some(plan =>
plan.hours === localSelectedPlan &&
plan.stripePriceId === currentSubscription?.scheduled_price_id
)
: currentSubscription?.scheduled_price_id === tierPriceId
);
const isPlanLoading = isLoading[tierPriceId];
let buttonText = isAuthenticated ? "Select Plan" : "Hire Suna";
let buttonDisabled = isPlanLoading;
let buttonVariant: ButtonVariant = null;
let ringClass = "";
let statusBadge = null;
let buttonClassName = "";
if (isAuthenticated) {
if (isCurrentActivePlan) {
buttonText = "Current Plan";
buttonDisabled = true;
buttonVariant = "secondary";
ringClass = isCompact ? "ring-1 ring-primary" : "ring-2 ring-primary";
buttonClassName = "bg-primary/5 hover:bg-primary/10 text-primary";
statusBadge = (
<span className="bg-primary/10 text-primary text-[10px] font-medium px-1.5 py-0.5 rounded-full">
Current
</span>
);
} else if (isScheduledTargetPlan) {
buttonText = "Scheduled";
buttonDisabled = true;
buttonVariant = "outline";
ringClass = isCompact ? "ring-1 ring-yellow-500" : "ring-2 ring-yellow-500";
buttonClassName = "bg-yellow-500/5 hover:bg-yellow-500/10 text-yellow-600 border-yellow-500/20";
statusBadge = (
<span className="bg-yellow-500/10 text-yellow-600 text-[10px] font-medium px-1.5 py-0.5 rounded-full">
Scheduled
</span>
);
} else if (isScheduled && currentSubscription?.price_id === tierPriceId) {
buttonText = "Change Scheduled";
buttonVariant = "secondary";
ringClass = isCompact ? "ring-1 ring-primary" : "ring-2 ring-primary";
buttonClassName = "bg-primary/5 hover:bg-primary/10 text-primary";
statusBadge = (
<span className="bg-yellow-500/10 text-yellow-600 text-[10px] font-medium px-1.5 py-0.5 rounded-full">
Downgrade Pending
</span>
);
} else {
// For custom tier, find the current plan in upgradePlans
const currentTier = tier.name === "Custom" && tier.upgradePlans
? tier.upgradePlans.find(p => p.stripePriceId === currentSubscription?.price_id)
: siteConfig.cloudPricingItems.find(p => p.stripePriceId === currentSubscription?.price_id);
// Find the highest active plan from upgradePlans
const highestActivePlan = siteConfig.cloudPricingItems.reduce((highest, item) => {
if (item.upgradePlans) {
const activePlan = item.upgradePlans.find(p => p.stripePriceId === currentSubscription?.price_id);
if (activePlan) {
const activeAmount = parseFloat(activePlan.price.replace(/[^\d.]/g, '') || '0') * 100;
const highestAmount = parseFloat(highest?.price?.replace(/[^\d.]/g, '') || '0') * 100;
return activeAmount > highestAmount ? activePlan : highest;
}
}
return highest;
}, null as { price: string; hours: string; stripePriceId: string } | null);
const currentPriceString = currentSubscription ? (highestActivePlan?.price || currentTier?.price || '$0') : '$0';
const selectedPriceString = getSelectedPlanPrice(tier);
const currentAmount = currentPriceString === '$0' ? 0 : parseFloat(currentPriceString.replace(/[^\d.]/g, '') || '0') * 100;
const targetAmount = selectedPriceString === '$0' ? 0 : parseFloat(selectedPriceString.replace(/[^\d.]/g, '') || '0') * 100;
if (currentAmount === 0 && targetAmount === 0 && currentSubscription?.status !== 'no_subscription') {
buttonText = "Select Plan";
buttonDisabled = true;
buttonVariant = "secondary";
buttonClassName = "bg-primary/5 hover:bg-primary/10 text-primary";
} else {
buttonText = targetAmount > currentAmount ? "Upgrade" : "Downgrade";
buttonVariant = tier.buttonColor as ButtonVariant;
buttonClassName = targetAmount > currentAmount
? "bg-primary hover:bg-primary/90 text-primary-foreground"
: "bg-primary/5 hover:bg-primary/10 text-primary";
}
}
if (isPlanLoading) {
buttonText = "Loading...";
buttonClassName = "opacity-70 cursor-not-allowed";
}
} else {
// Non-authenticated state styling
buttonVariant = tier.buttonColor as ButtonVariant;
buttonClassName = tier.buttonColor === "default"
? "bg-primary hover:bg-primary/90 text-white"
: "bg-secondary hover:bg-secondary/90 text-white";
}
return (
<div
className={cn(
"rounded-xl flex flex-col relative h-fit min-h-[400px] min-[650px]:h-full min-[900px]:h-fit",
tier.isPopular
? "md:shadow-[0px_61px_24px_-10px_rgba(0,0,0,0.01),0px_34px_20px_-8px_rgba(0,0,0,0.05),0px_15px_15px_-6px_rgba(0,0,0,0.09),0px_4px_8px_-2px_rgba(0,0,0,0.10),0px_0px_0px_1px_rgba(0,0,0,0.08)] bg-accent"
: "bg-[#F3F4F6] dark:bg-[#F9FAFB]/[0.02] border border-border",
ringClass
)}
>
<div className="flex flex-col gap-4 p-4">
<p className="text-sm flex items-center gap-2">
{tier.name}
{tier.isPopular && (
<span className="bg-gradient-to-b from-secondary/50 from-[1.92%] to-secondary to-[100%] text-white h-6 inline-flex w-fit items-center justify-center px-2 rounded-full text-sm shadow-[0px_6px_6px_-3px_rgba(0,0,0,0.08),0px_3px_3px_-1.5px_rgba(0,0,0,0.08),0px_1px_1px_-0.5px_rgba(0,0,0,0.08),0px_0px_0px_1px_rgba(255,255,255,0.12)_inset,0px_1px_0px_0px_rgba(255,255,255,0.12)_inset]">
Popular
</span>
)}
{isAuthenticated && statusBadge}
</p>
<div className="flex items-baseline mt-2">
{tier.name === "Custom" ? (
<CustomPriceDisplay price={getPriceValue(tier, localSelectedPlan)} />
) : (
<PriceDisplay price={tier.price} />
)}
<span className="ml-2">
{tier.price !== "$0" ? "/month" : ""}
</span>
</div>
<p className="text-sm mt-2">{tier.description}</p>
{tier.name === "Custom" && tier.upgradePlans ? (
<div className="w-full space-y-2">
<p className="text-xs font-medium text-muted-foreground">Customize your monthly usage</p>
<Select
value={localSelectedPlan}
onValueChange={handlePlanSelect}
>
<SelectTrigger className="w-full bg-white dark:bg-background">
<SelectValue placeholder="Select a plan" />
</SelectTrigger>
<SelectContent>
{tier.upgradePlans.map((plan) => (
<SelectItem
key={plan.hours}
value={plan.hours}
className={localSelectedPlan === plan.hours ? "font-medium bg-primary/5" : ""}
>
{plan.hours} - {plan.price}
</SelectItem>
))}
</SelectContent>
</Select>
<div className="inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-semibold bg-primary/10 border-primary/20 text-primary w-fit">
{localSelectedPlan}/month
</div>
</div>
) : (
<div className="inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-semibold bg-primary/10 border-primary/20 text-primary w-fit">
{getDisplayedHours(tier)}/month
</div>
)}
</div>
<div className="p-4 flex-grow">
{tier.features && tier.features.length > 0 && (
<ul className="space-y-3">
{tier.features.map((feature) => (
<li key={feature} className="flex items-center gap-2">
<div className="size-5 rounded-full border border-primary/20 flex items-center justify-center">
<CheckIcon className="size-3 text-primary" />
</div>
<span className="text-sm">{feature}</span>
</li>
))}
</ul>
)}
</div>
<div className="mt-auto p-4">
<Button
onClick={() => handleSubscribe(tierPriceId)}
disabled={buttonDisabled}
variant={buttonVariant || "default"}
className={cn(
"w-full font-medium transition-all duration-200",
isCompact
? "h-7 rounded-md text-xs"
: "h-10 rounded-full text-sm",
buttonClassName,
isPlanLoading && "animate-pulse"
)}
>
{buttonText}
</Button>
</div>
</div>
);
}
interface PricingSectionProps {
returnUrl?: string;
showTitleAndTabs?: boolean;
}
export function PricingSection({
returnUrl = typeof window !== 'undefined' ? window.location.href : '/',
showTitleAndTabs = true
}: PricingSectionProps) {
const [deploymentType, setDeploymentType] = useState<"cloud" | "self-hosted">("cloud");
const [currentSubscription, setCurrentSubscription] = useState<SubscriptionStatus | null>(null);
const [isLoading, setIsLoading] = useState<Record<string, boolean>>({});
const [isFetchingPlan, setIsFetchingPlan] = useState(true);
const [isAuthenticated, setIsAuthenticated] = useState(false);
const fetchCurrentPlan = async () => {
setIsFetchingPlan(true);
try {
const subscriptionData = await getSubscription();
console.log("Fetched Subscription Status:", subscriptionData);
setCurrentSubscription(subscriptionData);
setIsAuthenticated(true);
} catch (error) {
console.error('Error fetching subscription:', error);
setCurrentSubscription(null);
setIsAuthenticated(false);
} finally {
setIsFetchingPlan(false);
}
};
const handlePlanSelect = (planId: string) => {
setIsLoading(prev => ({ ...prev, [planId]: true }));
};
const handleSubscriptionUpdate = () => {
fetchCurrentPlan();
setTimeout(() => {
setIsLoading({});
}, 1000);
};
useEffect(() => {
fetchCurrentPlan();
}, []);
// Handle tab change
const handleTabChange = (tab: "cloud" | "self-hosted") => {
if (tab === "self-hosted") {
// Scroll to the open-source section when self-hosted tab is clicked
const openSourceSection = document.getElementById("open-source");
if (openSourceSection) {
// Get the position of the section and scroll to a position slightly above it
const rect = openSourceSection.getBoundingClientRect();
const scrollTop = window.pageYOffset || document.documentElement.scrollTop;
const offsetPosition = scrollTop + rect.top - 100; // 100px offset from the top
const offsetPosition = scrollTop + rect.top - 100;
window.scrollTo({
top: offsetPosition,
@ -82,312 +573,64 @@ export function PricingSection() {
});
}
} else {
// Set the deployment type to cloud for cloud tab
setDeploymentType(tab);
}
};
// Update price animation
const PriceDisplay = ({
tier,
}: {
tier: typeof siteConfig.cloudPricingItems[0];
}) => {
const price = tier.price;
if (isLocalMode()) {
return (
<motion.span
key={price}
className="text-4xl font-semibold"
initial={{
opacity: 0,
x: 10,
filter: "blur(5px)",
}}
animate={{ opacity: 1, x: 0, filter: "blur(0px)" }}
transition={{ duration: 0.25, ease: [0.4, 0, 0.2, 1] }}
>
{price}
</motion.span>
);
};
const SelfHostedContent = () => (
<div className="rounded-xl bg-[#F3F4F6] dark:bg-[#F9FAFB]/[0.02] border border-border p-8 w-full max-w-6xl mx-auto">
<div className="flex flex-col gap-6">
<div className="inline-flex h-10 w-fit items-center justify-center gap-2 rounded-full bg-secondary/10 text-secondary px-4">
<Github className="h-5 w-5" />
<span className="text-sm font-medium">100% Open Source</span>
</div>
<div className="space-y-2">
<h3 className="text-2xl font-semibold tracking-tight">
Self-Hosted Version
</h3>
<p className="text-muted-foreground">
Set up and run the platform on your own infrastructure with complete control over your data and deployment.
</p>
</div>
<div className="grid grid-cols-1 md:grid-cols-3 gap-6 mt-6">
<div className="rounded-xl border border-border bg-background p-6 flex flex-col gap-4">
<div className="size-10 flex items-center justify-center rounded-full bg-primary/10 text-primary">
<GitFork className="h-5 w-5" />
</div>
<h4 className="text-lg font-medium">Fork & Setup</h4>
<p className="text-sm text-muted-foreground">
Fork the repository and follow our step-by-step setup guide to deploy on your own infrastructure.
</p>
<Link
href="https://github.com/Kortix-ai/Suna"
target="_blank"
className="text-sm text-primary flex items-center gap-1 mt-auto"
>
Fork on GitHub
<svg width="12" height="12" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M7 17L17 7M17 7H8M17 7V16" stroke="currentColor" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round"/>
</svg>
</Link>
</div>
<div className="rounded-xl border border-border bg-background p-6 flex flex-col gap-4">
<div className="size-10 flex items-center justify-center rounded-full bg-primary/10 text-primary">
<File className="h-5 w-5" />
</div>
<h4 className="text-lg font-medium">Documentation</h4>
<p className="text-sm text-muted-foreground">
Comprehensive documentation with detailed instructions for installation, configuration, and customization.
</p>
<Link
href="#"
className="text-sm text-primary flex items-center gap-1 mt-auto"
>
Read the docs
<svg width="12" height="12" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M7 17L17 7M17 7H8M17 7V16" stroke="currentColor" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round"/>
</svg>
</Link>
</div>
<div className="rounded-xl border border-border bg-background p-6 flex flex-col gap-4">
<div className="size-10 flex items-center justify-center rounded-full bg-primary/10 text-primary">
<Terminal className="h-5 w-5" />
</div>
<h4 className="text-lg font-medium">Custom APIs</h4>
<p className="text-sm text-muted-foreground">
Connect to your preferred language models and customize the platform to fit your specific requirements.
</p>
<Link
href="#"
className="text-sm text-primary flex items-center gap-1 mt-auto"
>
API Reference
<svg width="12" height="12" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M7 17L17 7M17 7H8M17 7V16" stroke="currentColor" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round"/>
</svg>
</Link>
</div>
</div>
<div className="border-t border-border pt-6 mt-4">
<h4 className="font-medium mb-4">Key benefits of self-hosting</h4>
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
<ul className="space-y-3">
{["Complete data privacy", "Full control over infrastructure", "Custom model integration", "Unlimited usage"].map((feature) => (
<li key={feature} className="flex items-center gap-2">
<div className="size-5 rounded-full border border-primary/20 flex items-center justify-center bg-muted-foreground/10">
<div className="size-3 flex items-center justify-center">
<svg width="8" height="7" viewBox="0 0 8 7" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M1.5 3.48828L3.375 5.36328L6.5 0.988281" stroke="currentColor" strokeWidth="1.5" strokeLinecap="round" strokeLinejoin="round"/>
</svg>
</div>
</div>
<span className="text-sm">{feature}</span>
</li>
))}
</ul>
<ul className="space-y-3">
{["No usage fees", "Apache 2.0 license", "Community support", "Security customization"].map((feature) => (
<li key={feature} className="flex items-center gap-2">
<div className="size-5 rounded-full border border-primary/20 flex items-center justify-center bg-muted-foreground/10">
<div className="size-3 flex items-center justify-center">
<svg width="8" height="7" viewBox="0 0 8 7" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M1.5 3.48828L3.375 5.36328L6.5 0.988281" stroke="currentColor" strokeWidth="1.5" strokeLinecap="round" strokeLinejoin="round"/>
</svg>
</div>
</div>
<span className="text-sm">{feature}</span>
</li>
))}
</ul>
</div>
</div>
<div className="flex flex-col sm:flex-row gap-4 mt-4">
<Link
href="https://github.com/Kortix-ai/Suna"
target="_blank"
className="inline-flex h-11 items-center justify-center gap-2 text-sm font-medium tracking-wide rounded-full bg-primary text-white px-6 shadow-md hover:bg-primary/90 transition-all"
>
<Github className="h-4 w-4" />
View on GitHub
</Link>
<Link
href="#"
className="inline-flex h-11 items-center justify-center gap-2 text-sm font-medium tracking-wide rounded-full bg-secondary/10 text-secondary px-6 border border-secondary/20 hover:bg-secondary/20 transition-all"
>
Read Documentation
</Link>
</div>
<div className="p-4 bg-muted/30 border border-border rounded-lg text-center">
<p className="text-sm text-muted-foreground">
Running in local development mode - billing features are disabled
</p>
</div>
</div>
);
);
}
return (
<section
id="pricing"
className="flex flex-col items-center justify-center gap-10 pb-20 w-full relative"
>
<SectionHeader>
<h2 className="text-3xl md:text-4xl font-medium tracking-tighter text-center text-balance">
General Intelligence available today
</h2>
<p className="text-muted-foreground text-center text-balance font-medium">
You can self-host Suna or use our cloud for managed service.
</p>
</SectionHeader>
<div className="relative w-full h-full">
<div className="absolute -top-14 left-1/2 -translate-x-1/2">
<PricingTabs
activeTab={deploymentType}
setActiveTab={handleTabChange}
className="mx-auto"
/>
</div>
{deploymentType === "cloud" && (
<div className="grid min-[650px]:grid-cols-2 min-[900px]:grid-cols-3 gap-4 w-full max-w-6xl mx-auto px-6">
{siteConfig.cloudPricingItems.map((tier) => (
<div
key={tier.name}
className={cn(
"rounded-xl grid grid-rows-[180px_auto_1fr] relative h-fit min-[650px]:h-full min-[900px]:h-fit",
tier.isPopular
? "md:shadow-[0px_61px_24px_-10px_rgba(0,0,0,0.01),0px_34px_20px_-8px_rgba(0,0,0,0.05),0px_15px_15px_-6px_rgba(0,0,0,0.09),0px_4px_8px_-2px_rgba(0,0,0,0.10),0px_0px_0px_1px_rgba(0,0,0,0.08)] bg-accent"
: "bg-[#F3F4F6] dark:bg-[#F9FAFB]/[0.02] border border-border",
)}
>
<div className="flex flex-col gap-4 p-4">
<p className="text-sm flex items-center">
{tier.name}
{tier.isPopular && (
<span className="bg-gradient-to-b from-secondary/50 from-[1.92%] to-secondary to-[100%] text-white h-6 inline-flex w-fit items-center justify-center px-2 rounded-full text-sm ml-2 shadow-[0px_6px_6px_-3px_rgba(0,0,0,0.08),0px_3px_3px_-1.5px_rgba(0,0,0,0.08),0px_1px_1px_-0.5px_rgba(0,0,0,0.08),0px_0px_0px_1px_rgba(255,255,255,0.12)_inset,0px_1px_0px_0px_rgba(255,255,255,0.12)_inset]">
Popular
</span>
)}
</p>
<div className="flex items-baseline mt-2">
<PriceDisplay tier={tier} />
<span className="ml-2">
{tier.price !== "$0" ? "/month" : ""}
</span>
</div>
<p className="text-sm mt-2">{tier.description}</p>
<div className="inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-semibold bg-primary/10 border-primary/20 text-primary w-fit">
{tier.hours}/month
</div>
</div>
<div className="flex flex-col gap-2 p-4">
{tier.buttonText === "Hire Suna" ? (
<Link
href="/auth"
className={`h-10 w-full flex items-center justify-center text-sm font-normal tracking-wide rounded-full px-4 cursor-pointer transition-all ease-out active:scale-95 ${
tier.isPopular
? `${tier.buttonColor} shadow-[inset_0_1px_2px_rgba(255,255,255,0.25),0_3px_3px_-1.5px_rgba(16,24,40,0.06),0_1px_1px_rgba(16,24,40,0.08)]`
: `${tier.buttonColor} shadow-[0px_1px_2px_0px_rgba(255,255,255,0.16)_inset,0px_3px_3px_-1.5px_rgba(16,24,40,0.24),0px_1px_1px_-0.5px_rgba(16,24,40,0.20)]`
}`}
>
{tier.buttonText}
</Link>
) : (
<button
className={`h-10 w-full flex items-center justify-center text-sm font-normal tracking-wide rounded-full px-4 cursor-pointer transition-all ease-out active:scale-95 ${
tier.isPopular
? `${tier.buttonColor} shadow-[inset_0_1px_2px_rgba(255,255,255,0.25),0_3px_3px_-1.5px_rgba(16,24,40,0.06),0_1px_1px_rgba(16,24,40,0.08)]`
: `${tier.buttonColor} shadow-[0px_1px_2px_0px_rgba(255,255,255,0.16)_inset,0px_3px_3px_-1.5px_rgba(16,24,40,0.24),0px_1px_1px_-0.5px_rgba(16,24,40,0.20)]`
}`}
>
{tier.buttonText}
</button>
)}
</div>
{/* <hr className="border-border dark:border-white/20" /> */}
<div className="p-4">
{/*
{tier.name !== "Free" && (
<p className="text-sm mb-4">
Everything in {tier.name === "Pro" ? "Free" : "Pro"} +
</p>
)}
<ul className="space-y-3">
{tier.features.filter(feature => !feature.startsWith('//'))
.map((feature) => (
<li key={feature} className="flex items-center gap-2">
<div
className={cn(
"size-5 rounded-full border border-primary/20 flex items-center justify-center",
tier.isPopular &&
"bg-muted-foreground/40 border-border",
)}
>
<div className="size-3 flex items-center justify-center">
<svg
width="8"
height="7"
viewBox="0 0 8 7"
fill="none"
xmlns="http://www.w3.org/2000/svg"
className="block dark:hidden"
>
<path
d="M1.5 3.48828L3.375 5.36328L6.5 0.988281"
stroke="#101828"
strokeWidth="1.5"
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
<svg
width="8"
height="7"
viewBox="0 0 8 7"
fill="none"
xmlns="http://www.w3.org/2000/svg"
className="hidden dark:block"
>
<path
d="M1.5 3.48828L3.375 5.36328L6.5 0.988281"
stroke="#FAFAFA"
strokeWidth="1.5"
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
</div>
</div>
<span className="text-sm">{feature}</span>
</li>
))}
</ul>
*/}
</div>
</div>
))}
{showTitleAndTabs && (
<>
<SectionHeader>
<h2 className="text-3xl md:text-4xl font-medium tracking-tighter text-center text-balance">
Choose the right plan for your needs
</h2>
<p className="text-muted-foreground text-center text-balance font-medium">
Start with our free plan or upgrade to a premium plan for more usage hours
</p>
</SectionHeader>
<div className="relative w-full h-full">
<div className="absolute -top-14 left-1/2 -translate-x-1/2">
<PricingTabs
activeTab={deploymentType}
setActiveTab={handleTabChange}
className="mx-auto"
/>
</div>
</div>
)}
</div>
</>
)}
{deploymentType === "cloud" && (
<div className="grid min-[650px]:grid-cols-2 min-[900px]:grid-cols-3 gap-4 w-full max-w-6xl mx-auto px-6">
{siteConfig.cloudPricingItems.map((tier) => (
<PricingTier
key={tier.name}
tier={tier}
currentSubscription={currentSubscription}
isLoading={isLoading}
isFetchingPlan={isFetchingPlan}
onPlanSelect={handlePlanSelect}
onSubscriptionUpdate={handleSubscriptionUpdate}
isAuthenticated={isAuthenticated}
returnUrl={returnUrl}
/>
))}
</div>
)}
</section>
);
}
}

View File

@ -1,55 +0,0 @@
"use server";
import { redirect } from "next/navigation";
import { createClient } from "../supabase/server";
import handleEdgeFunctionError from "../supabase/handle-edge-error";
export async function setupNewSubscription(prevState: any, formData: FormData) {
const accountId = formData.get("accountId") as string;
const returnUrl = formData.get("returnUrl") as string;
const planId = formData.get("planId") as string;
const supabaseClient = await createClient();
const { data, error } = await supabaseClient.functions.invoke('billing-functions', {
body: {
action: "get_new_subscription_url",
args: {
account_id: accountId,
success_url: returnUrl,
cancel_url: returnUrl,
plan_id: planId
}
}
});
if (error) {
return await handleEdgeFunctionError(error);
}
redirect(data.url);
};
export async function manageSubscription(prevState: any, formData: FormData) {
const accountId = formData.get("accountId") as string;
const returnUrl = formData.get("returnUrl") as string;
const supabaseClient = await createClient();
const { data, error } = await supabaseClient.functions.invoke('billing-functions', {
body: {
action: "get_billing_portal_url",
args: {
account_id: accountId,
return_url: returnUrl
}
}
});
console.log(data);
if (error) {
console.error(error);
return await handleEdgeFunctionError(error);
}
redirect(data.url);
};

View File

@ -1177,3 +1177,202 @@ export const checkApiHealth = async (): Promise<HealthCheckResponse> => {
}
};
// Billing API Types
export interface CreateCheckoutSessionRequest {
price_id: string;
success_url: string;
cancel_url: string;
}
export interface CreatePortalSessionRequest {
return_url: string;
}
export interface SubscriptionStatus {
status: string; // Includes 'active', 'trialing', 'past_due', 'scheduled_downgrade', 'no_subscription'
plan_name?: string;
price_id?: string; // Added
current_period_end?: string; // ISO Date string
cancel_at_period_end: boolean;
trial_end?: string; // ISO Date string
minutes_limit?: number;
current_usage?: number;
// Fields for scheduled changes
has_schedule: boolean;
scheduled_plan_name?: string;
scheduled_price_id?: string; // Added
scheduled_change_date?: string; // ISO Date string - Deprecate? Check backend usage
schedule_effective_date?: string; // ISO Date string - Added for consistency
}
export interface BillingStatusResponse {
can_run: boolean;
message: string;
subscription: {
price_id: string;
plan_name: string;
minutes_limit?: number;
};
}
export interface CreateCheckoutSessionResponse {
status: 'upgraded' | 'downgrade_scheduled' | 'checkout_created' | 'no_change' | 'new' | 'updated' | 'scheduled';
subscription_id?: string;
schedule_id?: string;
session_id?: string;
url?: string;
effective_date?: string;
message?: string;
details?: {
is_upgrade?: boolean;
effective_date?: string;
current_price?: number;
new_price?: number;
invoice?: {
id: string;
status: string;
amount_due: number;
amount_paid: number;
};
};
}
// Billing API Functions
export const createCheckoutSession = async (request: CreateCheckoutSessionRequest): Promise<CreateCheckoutSessionResponse> => {
try {
const supabase = createClient();
const { data: { session } } = await supabase.auth.getSession();
if (!session?.access_token) {
throw new Error('No access token available');
}
const response = await fetch(`${API_URL}/billing/create-checkout-session`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': `Bearer ${session.access_token}`,
},
body: JSON.stringify(request),
});
if (!response.ok) {
const errorText = await response.text().catch(() => 'No error details available');
console.error(`Error creating checkout session: ${response.status} ${response.statusText}`, errorText);
throw new Error(`Error creating checkout session: ${response.statusText} (${response.status})`);
}
const data = await response.json();
console.log('Checkout session response:', data);
// Handle all possible statuses
switch (data.status) {
case 'upgraded':
case 'updated':
case 'downgrade_scheduled':
case 'scheduled':
case 'no_change':
return data;
case 'new':
case 'checkout_created':
if (!data.url) {
throw new Error('No checkout URL provided');
}
return data;
default:
console.warn('Unexpected status from createCheckoutSession:', data.status);
return data;
}
} catch (error) {
console.error('Failed to create checkout session:', error);
throw error;
}
};
export const createPortalSession = async (request: CreatePortalSessionRequest): Promise<{ url: string }> => {
try {
const supabase = createClient();
const { data: { session } } = await supabase.auth.getSession();
if (!session?.access_token) {
throw new Error('No access token available');
}
const response = await fetch(`${API_URL}/billing/create-portal-session`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': `Bearer ${session.access_token}`,
},
body: JSON.stringify(request),
});
if (!response.ok) {
const errorText = await response.text().catch(() => 'No error details available');
console.error(`Error creating portal session: ${response.status} ${response.statusText}`, errorText);
throw new Error(`Error creating portal session: ${response.statusText} (${response.status})`);
}
return response.json();
} catch (error) {
console.error('Failed to create portal session:', error);
throw error;
}
};
export const getSubscription = async (): Promise<SubscriptionStatus> => {
try {
const supabase = createClient();
const { data: { session } } = await supabase.auth.getSession();
if (!session?.access_token) {
throw new Error('No access token available');
}
const response = await fetch(`${API_URL}/billing/subscription`, {
headers: {
'Authorization': `Bearer ${session.access_token}`,
},
});
if (!response.ok) {
const errorText = await response.text().catch(() => 'No error details available');
console.error(`Error getting subscription: ${response.status} ${response.statusText}`, errorText);
throw new Error(`Error getting subscription: ${response.statusText} (${response.status})`);
}
return response.json();
} catch (error) {
console.error('Failed to get subscription:', error);
throw error;
}
};
export const checkBillingStatus = async (): Promise<BillingStatusResponse> => {
try {
const supabase = createClient();
const { data: { session } } = await supabase.auth.getSession();
if (!session?.access_token) {
throw new Error('No access token available');
}
const response = await fetch(`${API_URL}/billing/check-status`, {
headers: {
'Authorization': `Bearer ${session.access_token}`,
},
});
if (!response.ok) {
const errorText = await response.text().catch(() => 'No error details available');
console.error(`Error checking billing status: ${response.status} ${response.statusText}`, errorText);
throw new Error(`Error checking billing status: ${response.statusText} (${response.status})`);
}
return response.json();
} catch (error) {
console.error('Failed to check billing status:', error);
throw error;
}
};

View File

@ -5,12 +5,103 @@ export enum EnvMode {
PRODUCTION = 'production',
}
// Subscription tier structure
export interface SubscriptionTierData {
priceId: string;
name: string;
}
// Subscription tiers structure
export interface SubscriptionTiers {
FREE: SubscriptionTierData;
TIER_2_20: SubscriptionTierData;
TIER_6_50: SubscriptionTierData;
TIER_12_100: SubscriptionTierData;
TIER_25_200: SubscriptionTierData;
TIER_50_400: SubscriptionTierData;
TIER_125_800: SubscriptionTierData;
TIER_200_1000: SubscriptionTierData;
}
// Configuration object
interface Config {
ENV_MODE: EnvMode;
IS_LOCAL: boolean;
SUBSCRIPTION_TIERS: SubscriptionTiers;
}
// Production tier IDs
const PROD_TIERS: SubscriptionTiers = {
FREE: {
priceId: 'price_1RILb4G6l1KZGqIrK4QLrx9i',
name: 'Free',
},
TIER_2_20: {
priceId: 'price_1RILb4G6l1KZGqIrhomjgDnO',
name: '2h/$20',
},
TIER_6_50: {
priceId: 'price_1RILb4G6l1KZGqIr5q0sybWn',
name: '6h/$50',
},
TIER_12_100: {
priceId: 'price_1RILb4G6l1KZGqIr5Y20ZLHm',
name: '12h/$100',
},
TIER_25_200: {
priceId: 'price_1RILb4G6l1KZGqIrGAD8rNjb',
name: '25h/$200',
},
TIER_50_400: {
priceId: 'price_1RILb4G6l1KZGqIruNBUMTF1',
name: '50h/$400',
},
TIER_125_800: {
priceId: 'price_1RILb3G6l1KZGqIrbJA766tN',
name: '125h/$800',
},
TIER_200_1000: {
priceId: 'price_1RILb3G6l1KZGqIrmauYPOiN',
name: '200h/$1000',
}
} as const;
// Staging tier IDs
const STAGING_TIERS: SubscriptionTiers = {
FREE: {
priceId: 'price_1RIGvuG6l1KZGqIrw14abxeL',
name: 'Free',
},
TIER_2_20: {
priceId: 'price_1RIGvuG6l1KZGqIrCRu0E4Gi',
name: '2h/$20',
},
TIER_6_50: {
priceId: 'price_1RIGvuG6l1KZGqIrvjlz5p5V',
name: '6h/$50',
},
TIER_12_100: {
priceId: 'price_1RIGvuG6l1KZGqIrT6UfgblC',
name: '12h/$100',
},
TIER_25_200: {
priceId: 'price_1RIGvuG6l1KZGqIrOVLKlOMj',
name: '25h/$200',
},
TIER_50_400: {
priceId: 'price_1RIKNgG6l1KZGqIrvsat5PW7',
name: '50h/$400',
},
TIER_125_800: {
priceId: 'price_1RIKNrG6l1KZGqIrjKT0yGvI',
name: '125h/$800',
},
TIER_200_1000: {
priceId: 'price_1RIKQ2G6l1KZGqIrum9n8SI7',
name: '200h/$1000',
}
} as const;
// Determine the environment mode from environment variables
const getEnvironmentMode = (): EnvMode => {
// Get the environment mode from the environment variable, if set
@ -47,9 +138,13 @@ const currentEnvMode = getEnvironmentMode();
export const config: Config = {
ENV_MODE: currentEnvMode,
IS_LOCAL: currentEnvMode === EnvMode.LOCAL,
SUBSCRIPTION_TIERS: currentEnvMode === EnvMode.STAGING ? STAGING_TIERS : PROD_TIERS,
};
// Helper function to check if we're in local mode (for component conditionals)
export const isLocalMode = (): boolean => {
return config.IS_LOCAL;
};
};
// Export subscription tier type for typing elsewhere
export type SubscriptionTier = keyof typeof PROD_TIERS;

View File

@ -6,6 +6,7 @@ import { FlickeringGrid } from "@/components/home/ui/flickering-grid";
import { Globe } from "@/components/home/ui/globe";
import { cn } from "@/lib/utils";
import { motion } from "motion/react";
import { config } from '@/lib/config';
export const Highlight = ({
children,
@ -28,6 +29,25 @@ export const Highlight = ({
export const BLUR_FADE_DELAY = 0.15;
interface UpgradePlan {
hours: string;
price: string;
stripePriceId: string;
}
export interface PricingTier {
name: string;
price: string;
description: string;
buttonText: string;
buttonColor: string;
isPopular: boolean;
hours: string;
features: string[];
stripePriceId: string;
upgradePlans: UpgradePlan[];
}
export const siteConfig = {
name: "Kortix Suna",
description: "The Generalist AI Agent that can act on your behalf.",
@ -79,54 +99,53 @@ export const siteConfig = {
{
name: "Free",
price: "$0",
description: "For individual use and exploration",
description: "Get started with",
buttonText: "Hire Suna",
buttonColor: "bg-secondary text-white",
isPopular: false,
hours: "10 min",
features: [
"10 minutes",
// "Community support",
// "Single user",
// "Standard response time",
"Public Projects",
],
stripePriceId: 'price_1RGJ9GG6l1KZGqIroxSqgphC',
stripePriceId: config.SUBSCRIPTION_TIERS.FREE.priceId,
upgradePlans: [],
},
{
name: "Pro",
price: "$29",
description: "For professionals and small teams",
price: "$20",
description: "Everything in Free, plus:",
buttonText: "Hire Suna",
buttonColor: "bg-primary text-white dark:text-black",
isPopular: true,
hours: "4 hours",
hours: "2 hours",
features: [
"4 hours usage per month",
// "Priority support",
// "Advanced features",
// "5 team members",
// "Custom integrations",
"2 hours",
"Private projects",
"Team functionality (coming soon)",
],
stripePriceId: 'price_1RGJ9LG6l1KZGqIrd9pwzeNW',
stripePriceId: config.SUBSCRIPTION_TIERS.TIER_2_20.priceId,
upgradePlans: [],
},
{
name: "Enterprise",
price: "$199",
description: "For organizations with complex needs",
name: "Custom",
price: "$50",
description: "Everything in Pro, plus:",
buttonText: "Hire Suna",
buttonColor: "bg-secondary text-white",
isPopular: false,
hours: "40 hours",
hours: "6 hours",
features: [
"40 hours usage per month",
// "Dedicated support",
// "SSO & advanced security",
// "Unlimited team members",
// "Service level agreement",
// "Custom AI model training",
"Unlimited seats",
],
showContactSales: true,
stripePriceId: 'price_1RGJ9JG6l1KZGqIrVUU4ZRv6',
upgradePlans: [
{ hours: "6 hours", price: "$50", stripePriceId: config.SUBSCRIPTION_TIERS.TIER_6_50.priceId },
{ hours: "12 hours", price: "$100", stripePriceId: config.SUBSCRIPTION_TIERS.TIER_12_100.priceId },
{ hours: "25 hours", price: "$200", stripePriceId: config.SUBSCRIPTION_TIERS.TIER_25_200.priceId },
{ hours: "50 hours", price: "$400", stripePriceId: config.SUBSCRIPTION_TIERS.TIER_50_400.priceId },
{ hours: "125 hours", price: "$800", stripePriceId: config.SUBSCRIPTION_TIERS.TIER_125_800.priceId },
{ hours: "200 hours", price: "$1000", stripePriceId: config.SUBSCRIPTION_TIERS.TIER_200_1000.priceId },
],
stripePriceId: config.SUBSCRIPTION_TIERS.TIER_6_50.priceId,
},
],
companyShowcase: {

View File

@ -13,8 +13,8 @@ export const createClient = async () => {
supabaseUrl = `http://${supabaseUrl}`;
}
console.log('[SERVER] Supabase URL:', supabaseUrl);
console.log('[SERVER] Supabase Anon Key:', supabaseAnonKey);
// console.log('[SERVER] Supabase URL:', supabaseUrl);
// console.log('[SERVER] Supabase Anon Key:', supabaseAnonKey);
return createServerClient(
supabaseUrl,