mirror of https://github.com/kortix-ai/suna.git
Fix billing service to get customer data from Stripe if it's missing from DB
- Updated `get_stripe_customer_id` function to accept a Supabase client and handle missing user_id metadata in Stripe customers. - Added logic to create or update records in the `billing_customers` table based on Stripe customer data. - Improved logging for customer metadata updates and billing record changes.
This commit is contained in:
parent
14093acc75
commit
584a4192d7
|
@ -8,6 +8,8 @@ from fastapi import APIRouter, HTTPException, Depends, Request
|
||||||
from typing import Optional, Dict, Tuple
|
from typing import Optional, Dict, Tuple
|
||||||
import stripe
|
import stripe
|
||||||
from datetime import datetime, timezone, timedelta
|
from datetime import datetime, timezone, timedelta
|
||||||
|
|
||||||
|
from supabase import Client as SupabaseClient
|
||||||
from utils.logger import logger
|
from utils.logger import logger
|
||||||
from utils.config import config, EnvMode
|
from utils.config import config, EnvMode
|
||||||
from services.supabase import DBConnection
|
from services.supabase import DBConnection
|
||||||
|
@ -162,7 +164,7 @@ class SubscriptionStatus(BaseModel):
|
||||||
subscription: Optional[Dict] = None
|
subscription: Optional[Dict] = None
|
||||||
|
|
||||||
# Helper functions
|
# Helper functions
|
||||||
async def get_stripe_customer_id(client, user_id: str) -> Optional[str]:
|
async def get_stripe_customer_id(client: SupabaseClient, user_id: str) -> Optional[str]:
|
||||||
"""Get the Stripe customer ID for a user."""
|
"""Get the Stripe customer ID for a user."""
|
||||||
result = await client.schema('basejump').from_('billing_customers') \
|
result = await client.schema('basejump').from_('billing_customers') \
|
||||||
.select('id') \
|
.select('id') \
|
||||||
|
@ -171,6 +173,42 @@ async def get_stripe_customer_id(client, user_id: str) -> Optional[str]:
|
||||||
|
|
||||||
if result.data and len(result.data) > 0:
|
if result.data and len(result.data) > 0:
|
||||||
return result.data[0]['id']
|
return result.data[0]['id']
|
||||||
|
|
||||||
|
customer_result = await stripe.Customer.search_async(
|
||||||
|
query=f"metadata['user_id']:'{user_id}' OR metadata['basejump_account_id']:'{user_id}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
if customer_result.data and len(customer_result.data) > 0:
|
||||||
|
customer = customer_result.data[0]
|
||||||
|
# If the customer does not have 'user_id' in metadata, add it now
|
||||||
|
if not customer.get('metadata', {}).get('user_id'):
|
||||||
|
try:
|
||||||
|
await stripe.Customer.modify_async(
|
||||||
|
customer['id'],
|
||||||
|
metadata={**customer.get('metadata', {}), 'user_id': user_id}
|
||||||
|
)
|
||||||
|
logger.info(f"Added missing user_id metadata to Stripe customer {customer['id']}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to add user_id metadata to Stripe customer {customer['id']}: {str(e)}")
|
||||||
|
|
||||||
|
has_active = len((await stripe.Subscription.list_async(
|
||||||
|
customer=customer['id'],
|
||||||
|
status='active',
|
||||||
|
limit=1
|
||||||
|
)).get('data', [])) > 0
|
||||||
|
|
||||||
|
# Create or update record in billing_customers table
|
||||||
|
await client.schema('basejump').from_('billing_customers').upsert({
|
||||||
|
'id': customer['id'],
|
||||||
|
'account_id': user_id,
|
||||||
|
'email': customer.get('email'),
|
||||||
|
'provider': 'stripe',
|
||||||
|
'active': has_active
|
||||||
|
}).execute()
|
||||||
|
logger.info(f"Updated billing_customers record for customer {customer['id']} and user {user_id}")
|
||||||
|
|
||||||
|
return customer['id']
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def create_stripe_customer(client, user_id: str, email: str) -> str:
|
async def create_stripe_customer(client, user_id: str, email: str) -> str:
|
||||||
|
@ -676,7 +714,7 @@ async def create_checkout_session(
|
||||||
# Get or create Stripe customer
|
# Get or create Stripe customer
|
||||||
customer_id = await get_stripe_customer_id(client, current_user_id)
|
customer_id = await get_stripe_customer_id(client, current_user_id)
|
||||||
if not customer_id: customer_id = await create_stripe_customer(client, current_user_id, email)
|
if not customer_id: customer_id = await create_stripe_customer(client, current_user_id, email)
|
||||||
|
|
||||||
# Get the target price and product ID
|
# Get the target price and product ID
|
||||||
try:
|
try:
|
||||||
price = await stripe.Price.retrieve_async(request.price_id, expand=['product'])
|
price = await stripe.Price.retrieve_async(request.price_id, expand=['product'])
|
||||||
|
|
Loading…
Reference in New Issue