mirror of https://github.com/kortix-ai/suna.git
feat(billing): implement hardcoded model pricing and enhance cost calculation logic
This commit is contained in:
parent
bec4494084
commit
79b71db250
|
@ -23,6 +23,49 @@ stripe.api_key = config.STRIPE_SECRET_KEY
|
||||||
# Initialize router
|
# Initialize router
|
||||||
router = APIRouter(prefix="/billing", tags=["billing"])
|
router = APIRouter(prefix="/billing", tags=["billing"])
|
||||||
|
|
||||||
|
# Hardcoded pricing for specific models (prices per million tokens)
|
||||||
|
HARDCODED_MODEL_PRICES = {
|
||||||
|
"openrouter/deepseek/deepseek-chat": {
|
||||||
|
"input_cost_per_million_tokens": 0.38,
|
||||||
|
"output_cost_per_million_tokens": 0.89
|
||||||
|
},
|
||||||
|
"deepseek/deepseek-chat": {
|
||||||
|
"input_cost_per_million_tokens": 0.38,
|
||||||
|
"output_cost_per_million_tokens": 0.89
|
||||||
|
},
|
||||||
|
"qwen/qwen3-235b-a22b": {
|
||||||
|
"input_cost_per_million_tokens": 0.13,
|
||||||
|
"output_cost_per_million_tokens": 0.60
|
||||||
|
},
|
||||||
|
"openrouter/qwen/qwen3-235b-a22b": {
|
||||||
|
"input_cost_per_million_tokens": 0.13,
|
||||||
|
"output_cost_per_million_tokens": 0.60
|
||||||
|
},
|
||||||
|
"google/gemini-2.5-flash-preview-05-20": {
|
||||||
|
"input_cost_per_million_tokens": 0.15,
|
||||||
|
"output_cost_per_million_tokens": 0.60
|
||||||
|
},
|
||||||
|
"openrouter/google/gemini-2.5-flash-preview-05-20": {
|
||||||
|
"input_cost_per_million_tokens": 0.15,
|
||||||
|
"output_cost_per_million_tokens": 0.60
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model_pricing(model: str) -> tuple[float, float] | None:
|
||||||
|
"""
|
||||||
|
Get pricing for a model. Returns (input_cost_per_million, output_cost_per_million) or None.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model name to get pricing for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (input_cost_per_million_tokens, output_cost_per_million_tokens) or None if not found
|
||||||
|
"""
|
||||||
|
if model in HARDCODED_MODEL_PRICES:
|
||||||
|
pricing = HARDCODED_MODEL_PRICES[model]
|
||||||
|
return pricing["input_cost_per_million_tokens"], pricing["output_cost_per_million_tokens"]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
SUBSCRIPTION_TIERS = {
|
SUBSCRIPTION_TIERS = {
|
||||||
config.STRIPE_FREE_TIER_ID: {'name': 'free', 'minutes': 60, 'cost': 5},
|
config.STRIPE_FREE_TIER_ID: {'name': 'free', 'minutes': 60, 'cost': 5},
|
||||||
|
@ -209,20 +252,34 @@ async def calculate_monthly_usage(client, user_id: str) -> float:
|
||||||
if not token_messages.data:
|
if not token_messages.data:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
# Calculate total minutes
|
# Calculate total cost per message (to handle different models correctly)
|
||||||
total_prompt_tokens = 0
|
total_cost = 0.0
|
||||||
total_completion_tokens = 0
|
|
||||||
|
|
||||||
for run in token_messages.data:
|
for run in token_messages.data:
|
||||||
prompt_tokens = run['content']['usage']['prompt_tokens']
|
prompt_tokens = run['content']['usage']['prompt_tokens']
|
||||||
completion_tokens = run['content']['usage']['completion_tokens']
|
completion_tokens = run['content']['usage']['completion_tokens']
|
||||||
model = run['content']['model']
|
model = run['content']['model']
|
||||||
|
|
||||||
total_prompt_tokens += prompt_tokens
|
# Check if we have hardcoded pricing for this model
|
||||||
total_completion_tokens += completion_tokens
|
hardcoded_pricing = get_model_pricing(model)
|
||||||
|
if hardcoded_pricing:
|
||||||
|
input_cost_per_million, output_cost_per_million = hardcoded_pricing
|
||||||
|
input_cost = (prompt_tokens / 1_000_000) * input_cost_per_million
|
||||||
|
output_cost = (completion_tokens / 1_000_000) * output_cost_per_million
|
||||||
|
message_cost = input_cost + output_cost
|
||||||
|
else:
|
||||||
|
# Use litellm pricing as fallback
|
||||||
|
try:
|
||||||
|
prompt_token_cost, completion_token_cost = cost_per_token(model, int(prompt_tokens), int(completion_tokens))
|
||||||
|
message_cost = prompt_token_cost + completion_token_cost
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not get pricing for model {model}: {str(e)}, skipping message")
|
||||||
|
continue
|
||||||
|
|
||||||
|
total_cost += message_cost
|
||||||
|
|
||||||
prompt_token_cost, completion_token_cost = cost_per_token(model, int(total_prompt_tokens), int(total_completion_tokens))
|
# Return total cost * 2 (as per original logic)
|
||||||
total_cost = (prompt_token_cost + completion_token_cost) * 2 # Return total cost * 2
|
total_cost = total_cost * 2
|
||||||
logger.info(f"Total cost for user {user_id}: {total_cost}")
|
logger.info(f"Total cost for user {user_id}: {total_cost}")
|
||||||
|
|
||||||
return total_cost
|
return total_cost
|
||||||
|
@ -973,72 +1030,83 @@ async def get_available_models(
|
||||||
# Check if model is available with current subscription
|
# Check if model is available with current subscription
|
||||||
is_available = model in allowed_models
|
is_available = model in allowed_models
|
||||||
|
|
||||||
# Get pricing information from litellm using cost_per_token
|
# Get pricing information - check hardcoded prices first, then litellm
|
||||||
pricing_info = {}
|
pricing_info = {}
|
||||||
try:
|
|
||||||
# Try to get pricing using cost_per_token function
|
# Check if we have hardcoded pricing for this model
|
||||||
models_to_try = []
|
hardcoded_pricing = get_model_pricing(model)
|
||||||
|
if hardcoded_pricing:
|
||||||
# Add the original model name
|
input_cost_per_million, output_cost_per_million = hardcoded_pricing
|
||||||
models_to_try.append(model)
|
pricing_info = {
|
||||||
|
"input_cost_per_million_tokens": input_cost_per_million,
|
||||||
# Try to resolve the model name using MODEL_NAME_ALIASES
|
"output_cost_per_million_tokens": output_cost_per_million,
|
||||||
if model in MODEL_NAME_ALIASES:
|
"max_tokens": None
|
||||||
resolved_model = MODEL_NAME_ALIASES[model]
|
}
|
||||||
models_to_try.append(resolved_model)
|
else:
|
||||||
# Also try without provider prefix if it has one
|
try:
|
||||||
if '/' in resolved_model:
|
# Try to get pricing using cost_per_token function
|
||||||
models_to_try.append(resolved_model.split('/', 1)[1])
|
models_to_try = []
|
||||||
|
|
||||||
# If model is a value in aliases, try to find a matching key
|
# Add the original model name
|
||||||
for alias_key, alias_value in MODEL_NAME_ALIASES.items():
|
models_to_try.append(model)
|
||||||
if alias_value == model:
|
|
||||||
models_to_try.append(alias_key)
|
# Try to resolve the model name using MODEL_NAME_ALIASES
|
||||||
break
|
if model in MODEL_NAME_ALIASES:
|
||||||
|
resolved_model = MODEL_NAME_ALIASES[model]
|
||||||
# Also try without provider prefix for the original model
|
models_to_try.append(resolved_model)
|
||||||
if '/' in model:
|
# Also try without provider prefix if it has one
|
||||||
models_to_try.append(model.split('/', 1)[1])
|
if '/' in resolved_model:
|
||||||
|
models_to_try.append(resolved_model.split('/', 1)[1])
|
||||||
# Special handling for Google models accessed via OpenRouter
|
|
||||||
if model.startswith('openrouter/google/'):
|
# If model is a value in aliases, try to find a matching key
|
||||||
google_model_name = model.replace('openrouter/', '')
|
for alias_key, alias_value in MODEL_NAME_ALIASES.items():
|
||||||
models_to_try.append(google_model_name)
|
if alias_value == model:
|
||||||
|
models_to_try.append(alias_key)
|
||||||
# Try each model name variation until we find one that works
|
|
||||||
input_cost_per_token = None
|
|
||||||
output_cost_per_token = None
|
|
||||||
|
|
||||||
for model_name in models_to_try:
|
|
||||||
try:
|
|
||||||
# Use cost_per_token with sample token counts to get the per-token costs
|
|
||||||
input_cost, output_cost = cost_per_token(model_name, 1000000, 1000000)
|
|
||||||
if input_cost is not None and output_cost is not None:
|
|
||||||
input_cost_per_token = input_cost
|
|
||||||
output_cost_per_token = output_cost
|
|
||||||
break
|
break
|
||||||
except Exception:
|
|
||||||
continue
|
# Also try without provider prefix for the original model
|
||||||
|
if '/' in model:
|
||||||
if input_cost_per_token is not None and output_cost_per_token is not None:
|
models_to_try.append(model.split('/', 1)[1])
|
||||||
pricing_info = {
|
|
||||||
"input_cost_per_million_tokens": round(input_cost_per_token * 2, 2),
|
# Special handling for Google models accessed via OpenRouter
|
||||||
"output_cost_per_million_tokens": round(output_cost_per_token * 2, 2),
|
if model.startswith('openrouter/google/'):
|
||||||
"max_tokens": None # cost_per_token doesn't provide max_tokens info
|
google_model_name = model.replace('openrouter/', '')
|
||||||
}
|
models_to_try.append(google_model_name)
|
||||||
else:
|
|
||||||
|
# Try each model name variation until we find one that works
|
||||||
|
input_cost_per_token = None
|
||||||
|
output_cost_per_token = None
|
||||||
|
|
||||||
|
for model_name in models_to_try:
|
||||||
|
try:
|
||||||
|
# Use cost_per_token with sample token counts to get the per-token costs
|
||||||
|
input_cost, output_cost = cost_per_token(model_name, 1000000, 1000000)
|
||||||
|
if input_cost is not None and output_cost is not None:
|
||||||
|
input_cost_per_token = input_cost
|
||||||
|
output_cost_per_token = output_cost
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if input_cost_per_token is not None and output_cost_per_token is not None:
|
||||||
|
pricing_info = {
|
||||||
|
"input_cost_per_million_tokens": round(input_cost_per_token * 2, 2),
|
||||||
|
"output_cost_per_million_tokens": round(output_cost_per_token * 2, 2),
|
||||||
|
"max_tokens": None # cost_per_token doesn't provide max_tokens info
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
pricing_info = {
|
||||||
|
"input_cost_per_million_tokens": None,
|
||||||
|
"output_cost_per_million_tokens": None,
|
||||||
|
"max_tokens": None
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not get pricing for model {model}: {str(e)}")
|
||||||
pricing_info = {
|
pricing_info = {
|
||||||
"input_cost_per_million_tokens": None,
|
"input_cost_per_million_tokens": None,
|
||||||
"output_cost_per_million_tokens": None,
|
"output_cost_per_million_tokens": None,
|
||||||
"max_tokens": None
|
"max_tokens": None
|
||||||
}
|
}
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Could not get pricing for model {model}: {str(e)}")
|
|
||||||
pricing_info = {
|
|
||||||
"input_cost_per_million_tokens": None,
|
|
||||||
"output_cost_per_million_tokens": None,
|
|
||||||
"max_tokens": None
|
|
||||||
}
|
|
||||||
|
|
||||||
model_info.append({
|
model_info.append({
|
||||||
"id": model,
|
"id": model,
|
||||||
|
|
|
@ -491,7 +491,7 @@ export default function UsageLogs({ accountId }: Props) {
|
||||||
<TableHead>Time</TableHead>
|
<TableHead>Time</TableHead>
|
||||||
<TableHead>Model</TableHead>
|
<TableHead>Model</TableHead>
|
||||||
<TableHead className="text-right">
|
<TableHead className="text-right">
|
||||||
Total
|
Tokens
|
||||||
</TableHead>
|
</TableHead>
|
||||||
<TableHead className="text-right">Cost</TableHead>
|
<TableHead className="text-right">Cost</TableHead>
|
||||||
<TableHead className="text-center">
|
<TableHead className="text-center">
|
||||||
|
@ -508,10 +508,7 @@ export default function UsageLogs({ accountId }: Props) {
|
||||||
).toLocaleTimeString()}
|
).toLocaleTimeString()}
|
||||||
</TableCell>
|
</TableCell>
|
||||||
<TableCell>
|
<TableCell>
|
||||||
<Badge
|
<Badge className="font-mono text-xs">
|
||||||
variant="secondary"
|
|
||||||
className="font-mono text-xs"
|
|
||||||
>
|
|
||||||
{log.content.model}
|
{log.content.model}
|
||||||
</Badge>
|
</Badge>
|
||||||
</TableCell>
|
</TableCell>
|
||||||
|
|
|
@ -548,7 +548,7 @@ export function PricingSection({
|
||||||
const [deploymentType, setDeploymentType] = useState<'cloud' | 'self-hosted'>(
|
const [deploymentType, setDeploymentType] = useState<'cloud' | 'self-hosted'>(
|
||||||
'cloud',
|
'cloud',
|
||||||
);
|
);
|
||||||
const { data: subscriptionData, isLoading: isFetchingPlan, error: subscriptionQueryError } = useSubscription();
|
const { data: subscriptionData, isLoading: isFetchingPlan, error: subscriptionQueryError, refetch: refetchSubscription } = useSubscription();
|
||||||
|
|
||||||
// Derive authentication and subscription status from the hook data
|
// Derive authentication and subscription status from the hook data
|
||||||
const isAuthenticated = !!subscriptionData && subscriptionQueryError === null;
|
const isAuthenticated = !!subscriptionData && subscriptionQueryError === null;
|
||||||
|
@ -592,6 +592,7 @@ export function PricingSection({
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleSubscriptionUpdate = () => {
|
const handleSubscriptionUpdate = () => {
|
||||||
|
refetchSubscription();
|
||||||
// The useSubscription hook will automatically refetch, so we just need to clear loading states
|
// The useSubscription hook will automatically refetch, so we just need to clear loading states
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
setPlanLoadingStates({});
|
setPlanLoadingStates({});
|
||||||
|
|
Loading…
Reference in New Issue