feat(billing): implement hardcoded model pricing and enhance cost calculation logic

This commit is contained in:
sharath 2025-06-26 15:18:37 +00:00
parent bec4494084
commit 79b71db250
No known key found for this signature in database
3 changed files with 138 additions and 72 deletions

View File

@ -23,6 +23,49 @@ stripe.api_key = config.STRIPE_SECRET_KEY
# Initialize router # Initialize router
router = APIRouter(prefix="/billing", tags=["billing"]) router = APIRouter(prefix="/billing", tags=["billing"])
# Hardcoded pricing for specific models (prices per million tokens)
HARDCODED_MODEL_PRICES = {
"openrouter/deepseek/deepseek-chat": {
"input_cost_per_million_tokens": 0.38,
"output_cost_per_million_tokens": 0.89
},
"deepseek/deepseek-chat": {
"input_cost_per_million_tokens": 0.38,
"output_cost_per_million_tokens": 0.89
},
"qwen/qwen3-235b-a22b": {
"input_cost_per_million_tokens": 0.13,
"output_cost_per_million_tokens": 0.60
},
"openrouter/qwen/qwen3-235b-a22b": {
"input_cost_per_million_tokens": 0.13,
"output_cost_per_million_tokens": 0.60
},
"google/gemini-2.5-flash-preview-05-20": {
"input_cost_per_million_tokens": 0.15,
"output_cost_per_million_tokens": 0.60
},
"openrouter/google/gemini-2.5-flash-preview-05-20": {
"input_cost_per_million_tokens": 0.15,
"output_cost_per_million_tokens": 0.60
}
}
def get_model_pricing(model: str) -> tuple[float, float] | None:
"""
Get pricing for a model. Returns (input_cost_per_million, output_cost_per_million) or None.
Args:
model: The model name to get pricing for
Returns:
Tuple of (input_cost_per_million_tokens, output_cost_per_million_tokens) or None if not found
"""
if model in HARDCODED_MODEL_PRICES:
pricing = HARDCODED_MODEL_PRICES[model]
return pricing["input_cost_per_million_tokens"], pricing["output_cost_per_million_tokens"]
return None
SUBSCRIPTION_TIERS = { SUBSCRIPTION_TIERS = {
config.STRIPE_FREE_TIER_ID: {'name': 'free', 'minutes': 60, 'cost': 5}, config.STRIPE_FREE_TIER_ID: {'name': 'free', 'minutes': 60, 'cost': 5},
@ -209,20 +252,34 @@ async def calculate_monthly_usage(client, user_id: str) -> float:
if not token_messages.data: if not token_messages.data:
return 0.0 return 0.0
# Calculate total minutes # Calculate total cost per message (to handle different models correctly)
total_prompt_tokens = 0 total_cost = 0.0
total_completion_tokens = 0
for run in token_messages.data: for run in token_messages.data:
prompt_tokens = run['content']['usage']['prompt_tokens'] prompt_tokens = run['content']['usage']['prompt_tokens']
completion_tokens = run['content']['usage']['completion_tokens'] completion_tokens = run['content']['usage']['completion_tokens']
model = run['content']['model'] model = run['content']['model']
total_prompt_tokens += prompt_tokens # Check if we have hardcoded pricing for this model
total_completion_tokens += completion_tokens hardcoded_pricing = get_model_pricing(model)
if hardcoded_pricing:
input_cost_per_million, output_cost_per_million = hardcoded_pricing
input_cost = (prompt_tokens / 1_000_000) * input_cost_per_million
output_cost = (completion_tokens / 1_000_000) * output_cost_per_million
message_cost = input_cost + output_cost
else:
# Use litellm pricing as fallback
try:
prompt_token_cost, completion_token_cost = cost_per_token(model, int(prompt_tokens), int(completion_tokens))
message_cost = prompt_token_cost + completion_token_cost
except Exception as e:
logger.warning(f"Could not get pricing for model {model}: {str(e)}, skipping message")
continue
total_cost += message_cost
prompt_token_cost, completion_token_cost = cost_per_token(model, int(total_prompt_tokens), int(total_completion_tokens)) # Return total cost * 2 (as per original logic)
total_cost = (prompt_token_cost + completion_token_cost) * 2 # Return total cost * 2 total_cost = total_cost * 2
logger.info(f"Total cost for user {user_id}: {total_cost}") logger.info(f"Total cost for user {user_id}: {total_cost}")
return total_cost return total_cost
@ -973,72 +1030,83 @@ async def get_available_models(
# Check if model is available with current subscription # Check if model is available with current subscription
is_available = model in allowed_models is_available = model in allowed_models
# Get pricing information from litellm using cost_per_token # Get pricing information - check hardcoded prices first, then litellm
pricing_info = {} pricing_info = {}
try:
# Try to get pricing using cost_per_token function # Check if we have hardcoded pricing for this model
models_to_try = [] hardcoded_pricing = get_model_pricing(model)
if hardcoded_pricing:
# Add the original model name input_cost_per_million, output_cost_per_million = hardcoded_pricing
models_to_try.append(model) pricing_info = {
"input_cost_per_million_tokens": input_cost_per_million,
# Try to resolve the model name using MODEL_NAME_ALIASES "output_cost_per_million_tokens": output_cost_per_million,
if model in MODEL_NAME_ALIASES: "max_tokens": None
resolved_model = MODEL_NAME_ALIASES[model] }
models_to_try.append(resolved_model) else:
# Also try without provider prefix if it has one try:
if '/' in resolved_model: # Try to get pricing using cost_per_token function
models_to_try.append(resolved_model.split('/', 1)[1]) models_to_try = []
# If model is a value in aliases, try to find a matching key # Add the original model name
for alias_key, alias_value in MODEL_NAME_ALIASES.items(): models_to_try.append(model)
if alias_value == model:
models_to_try.append(alias_key) # Try to resolve the model name using MODEL_NAME_ALIASES
break if model in MODEL_NAME_ALIASES:
resolved_model = MODEL_NAME_ALIASES[model]
# Also try without provider prefix for the original model models_to_try.append(resolved_model)
if '/' in model: # Also try without provider prefix if it has one
models_to_try.append(model.split('/', 1)[1]) if '/' in resolved_model:
models_to_try.append(resolved_model.split('/', 1)[1])
# Special handling for Google models accessed via OpenRouter
if model.startswith('openrouter/google/'): # If model is a value in aliases, try to find a matching key
google_model_name = model.replace('openrouter/', '') for alias_key, alias_value in MODEL_NAME_ALIASES.items():
models_to_try.append(google_model_name) if alias_value == model:
models_to_try.append(alias_key)
# Try each model name variation until we find one that works
input_cost_per_token = None
output_cost_per_token = None
for model_name in models_to_try:
try:
# Use cost_per_token with sample token counts to get the per-token costs
input_cost, output_cost = cost_per_token(model_name, 1000000, 1000000)
if input_cost is not None and output_cost is not None:
input_cost_per_token = input_cost
output_cost_per_token = output_cost
break break
except Exception:
continue # Also try without provider prefix for the original model
if '/' in model:
if input_cost_per_token is not None and output_cost_per_token is not None: models_to_try.append(model.split('/', 1)[1])
pricing_info = {
"input_cost_per_million_tokens": round(input_cost_per_token * 2, 2), # Special handling for Google models accessed via OpenRouter
"output_cost_per_million_tokens": round(output_cost_per_token * 2, 2), if model.startswith('openrouter/google/'):
"max_tokens": None # cost_per_token doesn't provide max_tokens info google_model_name = model.replace('openrouter/', '')
} models_to_try.append(google_model_name)
else:
# Try each model name variation until we find one that works
input_cost_per_token = None
output_cost_per_token = None
for model_name in models_to_try:
try:
# Use cost_per_token with sample token counts to get the per-token costs
input_cost, output_cost = cost_per_token(model_name, 1000000, 1000000)
if input_cost is not None and output_cost is not None:
input_cost_per_token = input_cost
output_cost_per_token = output_cost
break
except Exception:
continue
if input_cost_per_token is not None and output_cost_per_token is not None:
pricing_info = {
"input_cost_per_million_tokens": round(input_cost_per_token * 2, 2),
"output_cost_per_million_tokens": round(output_cost_per_token * 2, 2),
"max_tokens": None # cost_per_token doesn't provide max_tokens info
}
else:
pricing_info = {
"input_cost_per_million_tokens": None,
"output_cost_per_million_tokens": None,
"max_tokens": None
}
except Exception as e:
logger.warning(f"Could not get pricing for model {model}: {str(e)}")
pricing_info = { pricing_info = {
"input_cost_per_million_tokens": None, "input_cost_per_million_tokens": None,
"output_cost_per_million_tokens": None, "output_cost_per_million_tokens": None,
"max_tokens": None "max_tokens": None
} }
except Exception as e:
logger.warning(f"Could not get pricing for model {model}: {str(e)}")
pricing_info = {
"input_cost_per_million_tokens": None,
"output_cost_per_million_tokens": None,
"max_tokens": None
}
model_info.append({ model_info.append({
"id": model, "id": model,

View File

@ -491,7 +491,7 @@ export default function UsageLogs({ accountId }: Props) {
<TableHead>Time</TableHead> <TableHead>Time</TableHead>
<TableHead>Model</TableHead> <TableHead>Model</TableHead>
<TableHead className="text-right"> <TableHead className="text-right">
Total Tokens
</TableHead> </TableHead>
<TableHead className="text-right">Cost</TableHead> <TableHead className="text-right">Cost</TableHead>
<TableHead className="text-center"> <TableHead className="text-center">
@ -508,10 +508,7 @@ export default function UsageLogs({ accountId }: Props) {
).toLocaleTimeString()} ).toLocaleTimeString()}
</TableCell> </TableCell>
<TableCell> <TableCell>
<Badge <Badge className="font-mono text-xs">
variant="secondary"
className="font-mono text-xs"
>
{log.content.model} {log.content.model}
</Badge> </Badge>
</TableCell> </TableCell>

View File

@ -548,7 +548,7 @@ export function PricingSection({
const [deploymentType, setDeploymentType] = useState<'cloud' | 'self-hosted'>( const [deploymentType, setDeploymentType] = useState<'cloud' | 'self-hosted'>(
'cloud', 'cloud',
); );
const { data: subscriptionData, isLoading: isFetchingPlan, error: subscriptionQueryError } = useSubscription(); const { data: subscriptionData, isLoading: isFetchingPlan, error: subscriptionQueryError, refetch: refetchSubscription } = useSubscription();
// Derive authentication and subscription status from the hook data // Derive authentication and subscription status from the hook data
const isAuthenticated = !!subscriptionData && subscriptionQueryError === null; const isAuthenticated = !!subscriptionData && subscriptionQueryError === null;
@ -592,6 +592,7 @@ export function PricingSection({
}; };
const handleSubscriptionUpdate = () => { const handleSubscriptionUpdate = () => {
refetchSubscription();
// The useSubscription hook will automatically refetch, so we just need to clear loading states // The useSubscription hook will automatically refetch, so we just need to clear loading states
setTimeout(() => { setTimeout(() => {
setPlanLoadingStates({}); setPlanLoadingStates({});