mirror of https://github.com/kortix-ai/suna.git
feat(billing): implement hardcoded model pricing and enhance cost calculation logic
This commit is contained in:
parent
bec4494084
commit
79b71db250
|
@ -23,6 +23,49 @@ stripe.api_key = config.STRIPE_SECRET_KEY
|
|||
# Initialize router
|
||||
router = APIRouter(prefix="/billing", tags=["billing"])
|
||||
|
||||
# Hardcoded pricing for specific models (prices per million tokens)
|
||||
HARDCODED_MODEL_PRICES = {
|
||||
"openrouter/deepseek/deepseek-chat": {
|
||||
"input_cost_per_million_tokens": 0.38,
|
||||
"output_cost_per_million_tokens": 0.89
|
||||
},
|
||||
"deepseek/deepseek-chat": {
|
||||
"input_cost_per_million_tokens": 0.38,
|
||||
"output_cost_per_million_tokens": 0.89
|
||||
},
|
||||
"qwen/qwen3-235b-a22b": {
|
||||
"input_cost_per_million_tokens": 0.13,
|
||||
"output_cost_per_million_tokens": 0.60
|
||||
},
|
||||
"openrouter/qwen/qwen3-235b-a22b": {
|
||||
"input_cost_per_million_tokens": 0.13,
|
||||
"output_cost_per_million_tokens": 0.60
|
||||
},
|
||||
"google/gemini-2.5-flash-preview-05-20": {
|
||||
"input_cost_per_million_tokens": 0.15,
|
||||
"output_cost_per_million_tokens": 0.60
|
||||
},
|
||||
"openrouter/google/gemini-2.5-flash-preview-05-20": {
|
||||
"input_cost_per_million_tokens": 0.15,
|
||||
"output_cost_per_million_tokens": 0.60
|
||||
}
|
||||
}
|
||||
|
||||
def get_model_pricing(model: str) -> tuple[float, float] | None:
|
||||
"""
|
||||
Get pricing for a model. Returns (input_cost_per_million, output_cost_per_million) or None.
|
||||
|
||||
Args:
|
||||
model: The model name to get pricing for
|
||||
|
||||
Returns:
|
||||
Tuple of (input_cost_per_million_tokens, output_cost_per_million_tokens) or None if not found
|
||||
"""
|
||||
if model in HARDCODED_MODEL_PRICES:
|
||||
pricing = HARDCODED_MODEL_PRICES[model]
|
||||
return pricing["input_cost_per_million_tokens"], pricing["output_cost_per_million_tokens"]
|
||||
return None
|
||||
|
||||
|
||||
SUBSCRIPTION_TIERS = {
|
||||
config.STRIPE_FREE_TIER_ID: {'name': 'free', 'minutes': 60, 'cost': 5},
|
||||
|
@ -209,20 +252,34 @@ async def calculate_monthly_usage(client, user_id: str) -> float:
|
|||
if not token_messages.data:
|
||||
return 0.0
|
||||
|
||||
# Calculate total minutes
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
# Calculate total cost per message (to handle different models correctly)
|
||||
total_cost = 0.0
|
||||
|
||||
for run in token_messages.data:
|
||||
prompt_tokens = run['content']['usage']['prompt_tokens']
|
||||
completion_tokens = run['content']['usage']['completion_tokens']
|
||||
model = run['content']['model']
|
||||
|
||||
total_prompt_tokens += prompt_tokens
|
||||
total_completion_tokens += completion_tokens
|
||||
# Check if we have hardcoded pricing for this model
|
||||
hardcoded_pricing = get_model_pricing(model)
|
||||
if hardcoded_pricing:
|
||||
input_cost_per_million, output_cost_per_million = hardcoded_pricing
|
||||
input_cost = (prompt_tokens / 1_000_000) * input_cost_per_million
|
||||
output_cost = (completion_tokens / 1_000_000) * output_cost_per_million
|
||||
message_cost = input_cost + output_cost
|
||||
else:
|
||||
# Use litellm pricing as fallback
|
||||
try:
|
||||
prompt_token_cost, completion_token_cost = cost_per_token(model, int(prompt_tokens), int(completion_tokens))
|
||||
message_cost = prompt_token_cost + completion_token_cost
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not get pricing for model {model}: {str(e)}, skipping message")
|
||||
continue
|
||||
|
||||
total_cost += message_cost
|
||||
|
||||
prompt_token_cost, completion_token_cost = cost_per_token(model, int(total_prompt_tokens), int(total_completion_tokens))
|
||||
total_cost = (prompt_token_cost + completion_token_cost) * 2 # Return total cost * 2
|
||||
# Return total cost * 2 (as per original logic)
|
||||
total_cost = total_cost * 2
|
||||
logger.info(f"Total cost for user {user_id}: {total_cost}")
|
||||
|
||||
return total_cost
|
||||
|
@ -973,72 +1030,83 @@ async def get_available_models(
|
|||
# Check if model is available with current subscription
|
||||
is_available = model in allowed_models
|
||||
|
||||
# Get pricing information from litellm using cost_per_token
|
||||
# Get pricing information - check hardcoded prices first, then litellm
|
||||
pricing_info = {}
|
||||
try:
|
||||
# Try to get pricing using cost_per_token function
|
||||
models_to_try = []
|
||||
|
||||
# Add the original model name
|
||||
models_to_try.append(model)
|
||||
|
||||
# Try to resolve the model name using MODEL_NAME_ALIASES
|
||||
if model in MODEL_NAME_ALIASES:
|
||||
resolved_model = MODEL_NAME_ALIASES[model]
|
||||
models_to_try.append(resolved_model)
|
||||
# Also try without provider prefix if it has one
|
||||
if '/' in resolved_model:
|
||||
models_to_try.append(resolved_model.split('/', 1)[1])
|
||||
|
||||
# If model is a value in aliases, try to find a matching key
|
||||
for alias_key, alias_value in MODEL_NAME_ALIASES.items():
|
||||
if alias_value == model:
|
||||
models_to_try.append(alias_key)
|
||||
break
|
||||
|
||||
# Also try without provider prefix for the original model
|
||||
if '/' in model:
|
||||
models_to_try.append(model.split('/', 1)[1])
|
||||
|
||||
# Special handling for Google models accessed via OpenRouter
|
||||
if model.startswith('openrouter/google/'):
|
||||
google_model_name = model.replace('openrouter/', '')
|
||||
models_to_try.append(google_model_name)
|
||||
|
||||
# Try each model name variation until we find one that works
|
||||
input_cost_per_token = None
|
||||
output_cost_per_token = None
|
||||
|
||||
for model_name in models_to_try:
|
||||
try:
|
||||
# Use cost_per_token with sample token counts to get the per-token costs
|
||||
input_cost, output_cost = cost_per_token(model_name, 1000000, 1000000)
|
||||
if input_cost is not None and output_cost is not None:
|
||||
input_cost_per_token = input_cost
|
||||
output_cost_per_token = output_cost
|
||||
|
||||
# Check if we have hardcoded pricing for this model
|
||||
hardcoded_pricing = get_model_pricing(model)
|
||||
if hardcoded_pricing:
|
||||
input_cost_per_million, output_cost_per_million = hardcoded_pricing
|
||||
pricing_info = {
|
||||
"input_cost_per_million_tokens": input_cost_per_million,
|
||||
"output_cost_per_million_tokens": output_cost_per_million,
|
||||
"max_tokens": None
|
||||
}
|
||||
else:
|
||||
try:
|
||||
# Try to get pricing using cost_per_token function
|
||||
models_to_try = []
|
||||
|
||||
# Add the original model name
|
||||
models_to_try.append(model)
|
||||
|
||||
# Try to resolve the model name using MODEL_NAME_ALIASES
|
||||
if model in MODEL_NAME_ALIASES:
|
||||
resolved_model = MODEL_NAME_ALIASES[model]
|
||||
models_to_try.append(resolved_model)
|
||||
# Also try without provider prefix if it has one
|
||||
if '/' in resolved_model:
|
||||
models_to_try.append(resolved_model.split('/', 1)[1])
|
||||
|
||||
# If model is a value in aliases, try to find a matching key
|
||||
for alias_key, alias_value in MODEL_NAME_ALIASES.items():
|
||||
if alias_value == model:
|
||||
models_to_try.append(alias_key)
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if input_cost_per_token is not None and output_cost_per_token is not None:
|
||||
pricing_info = {
|
||||
"input_cost_per_million_tokens": round(input_cost_per_token * 2, 2),
|
||||
"output_cost_per_million_tokens": round(output_cost_per_token * 2, 2),
|
||||
"max_tokens": None # cost_per_token doesn't provide max_tokens info
|
||||
}
|
||||
else:
|
||||
|
||||
# Also try without provider prefix for the original model
|
||||
if '/' in model:
|
||||
models_to_try.append(model.split('/', 1)[1])
|
||||
|
||||
# Special handling for Google models accessed via OpenRouter
|
||||
if model.startswith('openrouter/google/'):
|
||||
google_model_name = model.replace('openrouter/', '')
|
||||
models_to_try.append(google_model_name)
|
||||
|
||||
# Try each model name variation until we find one that works
|
||||
input_cost_per_token = None
|
||||
output_cost_per_token = None
|
||||
|
||||
for model_name in models_to_try:
|
||||
try:
|
||||
# Use cost_per_token with sample token counts to get the per-token costs
|
||||
input_cost, output_cost = cost_per_token(model_name, 1000000, 1000000)
|
||||
if input_cost is not None and output_cost is not None:
|
||||
input_cost_per_token = input_cost
|
||||
output_cost_per_token = output_cost
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if input_cost_per_token is not None and output_cost_per_token is not None:
|
||||
pricing_info = {
|
||||
"input_cost_per_million_tokens": round(input_cost_per_token * 2, 2),
|
||||
"output_cost_per_million_tokens": round(output_cost_per_token * 2, 2),
|
||||
"max_tokens": None # cost_per_token doesn't provide max_tokens info
|
||||
}
|
||||
else:
|
||||
pricing_info = {
|
||||
"input_cost_per_million_tokens": None,
|
||||
"output_cost_per_million_tokens": None,
|
||||
"max_tokens": None
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not get pricing for model {model}: {str(e)}")
|
||||
pricing_info = {
|
||||
"input_cost_per_million_tokens": None,
|
||||
"output_cost_per_million_tokens": None,
|
||||
"max_tokens": None
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not get pricing for model {model}: {str(e)}")
|
||||
pricing_info = {
|
||||
"input_cost_per_million_tokens": None,
|
||||
"output_cost_per_million_tokens": None,
|
||||
"max_tokens": None
|
||||
}
|
||||
|
||||
model_info.append({
|
||||
"id": model,
|
||||
|
|
|
@ -491,7 +491,7 @@ export default function UsageLogs({ accountId }: Props) {
|
|||
<TableHead>Time</TableHead>
|
||||
<TableHead>Model</TableHead>
|
||||
<TableHead className="text-right">
|
||||
Total
|
||||
Tokens
|
||||
</TableHead>
|
||||
<TableHead className="text-right">Cost</TableHead>
|
||||
<TableHead className="text-center">
|
||||
|
@ -508,10 +508,7 @@ export default function UsageLogs({ accountId }: Props) {
|
|||
).toLocaleTimeString()}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<Badge
|
||||
variant="secondary"
|
||||
className="font-mono text-xs"
|
||||
>
|
||||
<Badge className="font-mono text-xs">
|
||||
{log.content.model}
|
||||
</Badge>
|
||||
</TableCell>
|
||||
|
|
|
@ -548,7 +548,7 @@ export function PricingSection({
|
|||
const [deploymentType, setDeploymentType] = useState<'cloud' | 'self-hosted'>(
|
||||
'cloud',
|
||||
);
|
||||
const { data: subscriptionData, isLoading: isFetchingPlan, error: subscriptionQueryError } = useSubscription();
|
||||
const { data: subscriptionData, isLoading: isFetchingPlan, error: subscriptionQueryError, refetch: refetchSubscription } = useSubscription();
|
||||
|
||||
// Derive authentication and subscription status from the hook data
|
||||
const isAuthenticated = !!subscriptionData && subscriptionQueryError === null;
|
||||
|
@ -592,6 +592,7 @@ export function PricingSection({
|
|||
};
|
||||
|
||||
const handleSubscriptionUpdate = () => {
|
||||
refetchSubscription();
|
||||
// The useSubscription hook will automatically refetch, so we just need to clear loading states
|
||||
setTimeout(() => {
|
||||
setPlanLoadingStates({});
|
||||
|
|
Loading…
Reference in New Issue