Compare commits

...

6 Commits

Author SHA1 Message Date
kubet 578a8e4a0a
Merge pull request #1422 from kubet/fix/visual-improvements-file
fix: visual improvements file
2025-08-22 19:53:21 +02:00
Vukasin a6caaf42a3 fix: visual improvements file 2025-08-22 19:52:57 +02:00
kubet 56b46672f7
Merge pull request #1421 from kubet/fix/composio-handle-client-multiple-triggers
fix: composio handle multiple triggers on our side
2025-08-22 14:55:21 +02:00
Bobbie 27c523d1de
Merge pull request #1420 from escapade-mckv/refactor-llm-modules
chore: cleanup  llm modules
2025-08-22 16:47:04 +05:30
Saumya 7e10d736b5 chore: cleanup llm modules 2025-08-22 16:45:26 +05:30
Vukasin e0ad5cf2cd fix: composio handle multiple triggers on our side 2025-08-21 22:45:57 +02:00
25 changed files with 1898 additions and 794 deletions

View File

@ -323,8 +323,9 @@ async def start_agent(
model_name = body.model_name
logger.debug(f"Original model_name from request: {model_name}")
# Log the model name after alias resolution
resolved_model = MODEL_NAME_ALIASES.get(model_name, model_name)
# Log the model name after alias resolution using new model manager
from models import model_manager
resolved_model = model_manager.resolve_model_id(model_name)
logger.debug(f"Resolved model name: {resolved_model}")
# Update model_name to use the resolved version
@ -974,8 +975,9 @@ async def initiate_agent_with_files(
model_name = "openai/gpt-5-mini"
logger.debug(f"Using default model: {model_name}")
# Log the model name after alias resolution
resolved_model = MODEL_NAME_ALIASES.get(model_name, model_name)
from models import model_manager
# Log the model name after alias resolution using new model manager
resolved_model = model_manager.resolve_model_id(model_name)
logger.debug(f"Resolved model name: {resolved_model}")
# Update model_name to use the resolved version
@ -1935,15 +1937,19 @@ async def create_agent(
version_service = await _get_version_service()
from agent.suna_config import SUNA_CONFIG
from agent.config_helper import _get_default_agentpress_tools
from models import model_manager
system_prompt = SUNA_CONFIG["system_prompt"]
# Use default tools if none specified, ensuring builder tools are included
agentpress_tools = agent_data.agentpress_tools if agent_data.agentpress_tools else _get_default_agentpress_tools()
default_model = await model_manager.get_default_model_for_user(client, user_id)
version = await version_service.create_version(
agent_id=agent['agent_id'],
user_id=user_id,
system_prompt=system_prompt,
model=default_model,
configured_mcps=agent_data.configured_mcps or [],
custom_mcps=agent_data.custom_mcps or [],
agentpress_tools=agentpress_tools,

View File

@ -254,8 +254,8 @@ You have the abilixwty to execute operations using both Python and CLI tools:
* After image generation/editing, ALWAYS display the result using the ask tool with the image attached
* The tool automatically saves images to the workspace with unique filenames
* **REMEMBER THE LAST IMAGE:** Always use the most recently generated image filename for follow-up edits
* **SHARE PERMANENTLY:** Use `upload_file` to upload generated images to cloud storage for permanent URLs
* **CLOUD WORKFLOW:** Generate/Edit Save to workspace Upload to "file-uploads" bucket Share public URL with user
* **OPTIONAL CLOUD SHARING:** Ask user if they want to upload images: "Would you like me to upload this image to secure cloud storage for sharing?"
* **CLOUD WORKFLOW (if requested):** Generate/Edit Save to workspace Ask user Upload to "file-uploads" bucket if requested Share public URL with user
### 2.3.9 DATA PROVIDERS
- You have access to a variety of data providers that you can use to get data for your tasks.
@ -280,10 +280,13 @@ You have the abilixwty to execute operations using both Python and CLI tools:
* **Security:** Files stored in user-isolated folders, private bucket, signed URL access only
**WHEN TO USE upload_file:**
* User asks to share files or make them accessible via secure URL
* Need to persist files beyond the sandbox session with access control
* Need to export generated content (reports, images, data) for controlled external access
* Want to create secure, time-limited sharing links for deliverables
* **ONLY when user explicitly requests file sharing** or asks for permanent URLs
* **ONLY when user asks for files to be accessible externally** or beyond the sandbox session
* **ASK USER FIRST** before uploading in most cases: "Would you like me to upload this file to secure cloud storage for sharing?"
* User specifically requests file sharing or external access
* User asks for permanent or persistent file access
* User requests deliverables that need to be shared with others
* **DO NOT automatically upload** files unless explicitly requested by the user
**UPLOAD PARAMETERS:**
* `file_path`: Path relative to /workspace (e.g., "report.pdf", "data/results.csv")
@ -291,18 +294,20 @@ You have the abilixwty to execute operations using both Python and CLI tools:
* `custom_filename`: Optional custom name for the uploaded file
**STORAGE BUCKETS:**
* "file-uploads" (default): Secure private storage with user isolation, signed URL access, 24-hour expiration
* "browser-screenshots": Public bucket ONLY for actual browser screenshots captured during browser automation
* "file-uploads" (default): Secure private storage with user isolation, signed URL access, 24-hour expiration - USE ONLY WHEN REQUESTED
* "browser-screenshots": Public bucket ONLY for actual browser screenshots captured during browser automation - CONTINUES NORMAL BEHAVIOR
**UPLOAD WORKFLOW EXAMPLES:**
* Basic secure upload:
* Ask before uploading:
"I've created the report. Would you like me to upload it to secure cloud storage for sharing?"
If user says yes:
<function_calls>
<invoke name="upload_file">
<parameter name="file_path">output/report.pdf</parameter>
</invoke>
</function_calls>
* Upload with custom naming:
* Upload with custom naming (only after user request):
<function_calls>
<invoke name="upload_file">
<parameter name="file_path">generated_image.png</parameter>
@ -311,18 +316,22 @@ You have the abilixwty to execute operations using both Python and CLI tools:
</function_calls>
**UPLOAD BEST PRACTICES:**
* Always upload important deliverables to provide secure, time-limited URLs
* Use default "file-uploads" bucket for all general content (reports, images, presentations, data files)
* Use "browser-screenshots" ONLY for actual browser automation screenshots
* **ASK FIRST**: "Would you like me to upload this file for sharing or permanent access?"
* **EXPLAIN PURPOSE**: Tell users why upload might be useful ("for sharing with others", "for permanent access")
* **RESPECT USER CHOICE**: If user says no, don't upload
* **DEFAULT TO LOCAL**: Keep files local unless user specifically needs external access
* Use default "file-uploads" bucket ONLY when user requests uploads
* Use "browser-screenshots" ONLY for actual browser automation screenshots (unchanged behavior)
* Provide the secure URL to users but explain it expires in 24 hours
* Upload before marking tasks as complete
* **BROWSER SCREENSHOTS EXCEPTION**: Browser screenshots continue normal upload behavior without asking
* Files are stored with user isolation for security (each user can only access their own files)
**INTEGRATED WORKFLOW WITH OTHER TOOLS:**
* Create file with sb_files_tool Upload with upload_file Share secure URL with user
* Generate image Upload to secure cloud Provide time-limited access link
* Scrape data Save to file Upload for secure sharing
* Create report Upload with secure access
* Create file with tools **ASK USER** if they want to upload Upload only if requested Share secure URL if uploaded
* Generate image **ASK USER** if they need cloud storage Upload only if requested
* Scrape data Save to file **ASK USER** about uploading for sharing
* Create report **ASK USER** before uploading
* **BROWSER SCREENSHOTS**: Continue automatic upload behavior (no changes)
# 3. TOOLKIT & METHODOLOGY
@ -643,8 +652,8 @@ IMPORTANT: Use the `cat` command to view contents of small files (100 kb or less
3. Parse content using appropriate tools based on content type
4. Respect web content limitations - not all content may be accessible
5. Extract only the relevant portions of web content
6. **Upload scraped data:** Use `upload_file` to share extracted content via permanent URLs
7. **Research deliverables:** Scrape Process Save Upload Share URL for analysis results
6. **ASK BEFORE UPLOADING:** Ask users if they want scraped data uploaded: "Would you like me to upload the extracted content for sharing?"
7. **CONDITIONAL RESEARCH DELIVERABLES:** Scrape Process Save Ask user about upload Share URL only if requested
- Data Freshness:
1. Always check publication dates of search results
@ -970,11 +979,12 @@ When executing a workflow, adopt this mindset:
- Use descriptive keywords for better image relevance
- Test image URLs before downloading to ensure they work
4. **UPLOAD FOR SHARING:**
- After creating the presentation, use `upload_file` to upload the HTML preview and/or exported PPTX
- Upload to "file-uploads" bucket for all presentation content
- Share the public URL with users for easy access and distribution
- Example: `upload_file` with `file_path="presentations/my-presentation/presentation.html"`
4. **ASK ABOUT UPLOAD FOR SHARING:**
- After creating the presentation, ask: "Would you like me to upload this presentation to secure cloud storage for sharing?"
- Only use `upload_file` to upload the HTML preview and/or exported PPTX if user requests it
- Upload to "file-uploads" bucket for all presentation content only when requested
- Share the public URL with users for easy access and distribution only if uploaded
- Example: `upload_file` with `file_path="presentations/my-presentation/presentation.html"` only after user confirms
**NEVER create a presentation without downloading images first. This is a MANDATORY step for professional presentations.**
@ -1001,19 +1011,20 @@ For large outputs and complex content, use files instead of long responses:
- Make files easily editable and shareable
- Attach files when sharing with users via 'ask' tool
- Use files as persistent artifacts that users can reference and modify
- **UPLOAD FOR SHARING:** After creating important files, use the 'upload_file' tool to get a permanent shareable URL
- **CLOUD PERSISTENCE:** Upload deliverables to ensure they persist beyond the sandbox session
- **ASK BEFORE UPLOADING:** Ask users if they want files uploaded: "Would you like me to upload this file to secure cloud storage for sharing?"
- **CONDITIONAL CLOUD PERSISTENCE:** Upload deliverables only when specifically requested for sharing or external access
**FILE SHARING WORKFLOW:**
1. Create comprehensive file with all content
2. Edit and refine the file as needed
3. **Upload to secure cloud storage using 'upload_file' for controlled access**
4. Share the secure signed URL with the user (note: expires in 24 hours)
3. **ASK USER:** "Would you like me to upload this file to secure cloud storage for sharing?"
4. **Upload only if requested** using 'upload_file' for controlled access
5. Share the secure signed URL with the user (note: expires in 24 hours) - only if uploaded
**EXAMPLE FILE USAGE:**
- Single request `travel_plan.md` (contains itinerary, accommodation, packing list, etc.) Upload Share secure URL (24hr expiry)
- Single request `research_report.md` (contains all findings, analysis, conclusions) Upload Share secure URL (24hr expiry)
- Single request `project_guide.md` (contains setup, implementation, testing, documentation) Upload Share secure URL (24hr expiry)
- Single request `travel_plan.md` (contains itinerary, accommodation, packing list, etc.) Ask user about upload Upload only if requested Share secure URL (24hr expiry) if uploaded
- Single request `research_report.md` (contains all findings, analysis, conclusions) Ask user about upload Upload only if requested Share secure URL (24hr expiry) if uploaded
- Single request `project_guide.md` (contains setup, implementation, testing, documentation) Ask user about upload Upload only if requested Share secure URL (24hr expiry) if uploaded
## 6.2 DESIGN GUIDELINES
@ -1197,8 +1208,8 @@ To make conversations feel natural and human-like:
* When creating data analysis results, charts must be attached, not just described
* Remember: If the user should SEE it, you must ATTACH it with the 'ask' tool
* Verify that ALL visual outputs have been attached before proceeding
* **SECURE UPLOAD INTEGRATION:** When you've uploaded files using 'upload_file', include the secure signed URL in your message (note: expires in 24 hours)
* **DUAL SHARING:** Attach local files AND provide secure signed URLs when available for controlled access
* **CONDITIONAL SECURE UPLOAD INTEGRATION:** IF you've uploaded files using 'upload_file' (only when user requested), include the secure signed URL in your message (note: expires in 24 hours)
* **DUAL SHARING:** Attach local files AND provide secure signed URLs only when user has requested uploads for controlled access
- **Attachment Checklist:**
* Data visualizations (charts, graphs, plots)
@ -1210,7 +1221,7 @@ To make conversations feel natural and human-like:
* Analysis results with visual components
* UI designs and mockups
* Any file intended for user viewing or interaction
* **Secure signed URLs** (when using upload_file tool - note 24hr expiry)
* **Secure signed URLs** (only when user requested upload_file tool usage - note 24hr expiry)
# 9. COMPLETION PROTOCOLS

View File

@ -515,6 +515,7 @@ class AgentRunner:
return await mcp_manager.register_mcp_tools(self.config.agent_config)
def get_max_tokens(self) -> Optional[int]:
logger.debug(f"get_max_tokens called with: '{self.config.model_name}' (type: {type(self.config.model_name)})")
if "sonnet" in self.config.model_name.lower():
return 8192
elif "gpt-4" in self.config.model_name.lower():
@ -535,7 +536,7 @@ class AgentRunner:
self.config.thread_id,
mcp_wrapper_instance, self.client
)
logger.debug(f"model_name received: {self.config.model_name}")
iteration_count = 0
continue_execution = True
@ -572,7 +573,7 @@ class AgentRunner:
temporary_message = await message_manager.build_temporary_message()
max_tokens = self.get_max_tokens()
logger.debug(f"max_tokens: {max_tokens}")
generation = self.config.trace.generation(name="thread_manager.run_thread") if self.config.trace else None
try:
response = await self.thread_manager.run_thread(
@ -720,13 +721,15 @@ async def run_agent(
trace: Optional[StatefulTraceClient] = None
):
effective_model = model_name
if model_name == "openai/gpt-5-mini" and agent_config and agent_config.get('model'):
is_tier_default = model_name in ["Kimi K2", "Claude Sonnet 4", "openai/gpt-5-mini"]
if is_tier_default and agent_config and agent_config.get('model'):
effective_model = agent_config['model']
logger.debug(f"Using model from agent config: {effective_model} (no user selection)")
elif model_name != "openai/gpt-5-mini":
logger.debug(f"Using model from agent config: {effective_model} (tier default was {model_name})")
elif not is_tier_default:
logger.debug(f"Using user-selected model: {effective_model}")
else:
logger.debug(f"Using default model: {effective_model}")
logger.debug(f"Using tier default model: {effective_model}")
config = AgentConfig(
thread_id=thread_id,

View File

@ -320,14 +320,25 @@ class VersionService:
client = await self._get_client()
result = await client.table('agent_versions').select('*').eq(
'agent_id', agent_id
).eq('is_active', True).execute()
if not result.data:
agent_result = await client.table('agents').select('current_version_id').eq('agent_id', agent_id).execute()
if not agent_result.data or not agent_result.data[0].get('current_version_id'):
logger.warning(f"No current_version_id found for agent {agent_id}")
return None
return self._version_from_db_row(result.data[0])
current_version_id = agent_result.data[0]['current_version_id']
logger.debug(f"Agent {agent_id} current_version_id: {current_version_id}")
result = await client.table('agent_versions').select('*').eq(
'version_id', current_version_id
).eq('agent_id', agent_id).execute()
if not result.data:
logger.warning(f"Current version {current_version_id} not found for agent {agent_id}")
return None
version = self._version_from_db_row(result.data[0])
logger.debug(f"Retrieved active version for agent {agent_id}: model='{version.model}', version_name='{version.version_name}'")
return version
async def get_all_versions(self, agent_id: str, user_id: str) -> List[AgentVersion]:
is_owner, is_public = await self._verify_agent_access(agent_id, user_id)

View File

@ -692,20 +692,10 @@ async def create_composio_trigger(req: CreateComposioTriggerRequest, current_use
body = {
"user_id": composio_user_id,
"userId": composio_user_id,
"trigger_config": coerced_config,
"triggerConfig": coerced_config,
"webhook": {
"url": req.webhook_url or f"{base_url}/api/composio/webhook",
"headers": webhook_headers,
"method": "POST",
},
}
if req.connected_account_id:
body["connectedAccountId"] = req.connected_account_id
body["connected_account_id"] = req.connected_account_id
body["connectedAccountIds"] = [req.connected_account_id]
body["connected_account_ids"] = [req.connected_account_id]
async with httpx.AsyncClient(timeout=20) as http_client:
resp = await http_client.post(url, headers=headers, json=body)
@ -838,6 +828,8 @@ async def composio_webhook(request: Request):
logger.info("Composio webhook body read failed", error=str(e))
body_str = ""
# Get webhook ID early for logging
wid = request.headers.get("webhook-id", "")
# Minimal request diagnostics (no secrets)
try:
@ -859,15 +851,6 @@ async def composio_webhook(request: Request):
}
except Exception:
payload_preview = {"keys": []}
logger.debug(
"Composio webhook incoming",
client_ip=client_ip,
header_names=header_names,
has_authorization=has_auth,
has_x_composio_secret=has_x_secret,
has_x_trigger_secret=has_x_trigger,
payload_meta=payload_preview,
)
except Exception:
pass
@ -882,10 +865,10 @@ async def composio_webhook(request: Request):
# Parse payload for processing
try:
payload = json.loads(body_str) if body_str else {}
except Exception:
except Exception as parse_error:
logger.error(f"Failed to parse webhook payload: {parse_error}", payload_raw=body_str)
payload = {}
wid = request.headers.get("webhook-id", "")
# Look for trigger_nano_id in data.trigger_nano_id (the actual Composio trigger instance ID)
composio_trigger_id = (
(payload.get("data", {}) or {}).get("trigger_nano_id")
@ -906,11 +889,11 @@ async def composio_webhook(request: Request):
# Basic parsed-field logging (no secrets)
try:
logger.debug(
logger.info(
"Composio parsed fields",
webhook_id=wid,
trigger_slug=trigger_slug,
payload_id=composio_trigger_id,
composio_trigger_id=composio_trigger_id,
provider_event_id=provider_event_id,
payload_keys=list(payload.keys()) if isinstance(payload, dict) else [],
)
@ -933,14 +916,6 @@ async def composio_webhook(request: Request):
rows = []
matched = []
try:
logger.debug(
"Composio matching begin",
have_id=bool(composio_trigger_id),
payload_id=composio_trigger_id,
)
except Exception:
pass
for row in rows:
cfg = row.get("config") or {}
@ -948,43 +923,31 @@ async def composio_webhook(request: Request):
continue
prov = cfg.get("provider_id") or row.get("provider_id")
if prov != "composio":
try:
logger.debug("Composio skip non-provider", trigger_id=row.get("trigger_id"), provider_id=prov)
except Exception:
pass
logger.debug("Composio skip non-provider", trigger_id=row.get("trigger_id"), provider_id=prov)
continue
# ONLY match by exact composio_trigger_id - no slug fallback
cfg_tid = cfg.get("composio_trigger_id")
if composio_trigger_id and cfg_tid == composio_trigger_id:
logger.debug(
logger.info(
"Composio EXACT ID MATCH",
trigger_id=row.get("trigger_id"),
cfg_id=cfg_tid,
payload_id=composio_trigger_id
cfg_composio_trigger_id=cfg_tid,
payload_composio_trigger_id=composio_trigger_id,
is_active=row.get("is_active")
)
matched.append(row)
continue
else:
logger.debug(
logger.info(
"Composio ID mismatch",
trigger_id=row.get("trigger_id"),
cfg_id=cfg_tid,
payload_id=composio_trigger_id,
match_found=False
cfg_composio_trigger_id=cfg_tid,
payload_composio_trigger_id=composio_trigger_id,
match_found=False,
is_active=row.get("is_active")
)
try:
logger.debug(
"Composio matching result",
total=len(rows),
matched=len(matched),
have_id=bool(composio_trigger_id),
payload_id=composio_trigger_id,
)
except Exception:
pass
if not matched:
logger.error(
f"No exact ID match found for Composio trigger {composio_trigger_id}",

View File

@ -0,0 +1,13 @@
from .registry import ModelRegistry, registry
from .models import Model, ModelProvider, ModelCapability
from .manager import ModelManager, model_manager
__all__ = [
'ModelRegistry',
'registry',
'Model',
'ModelProvider',
'ModelCapability',
'ModelManager',
'model_manager',
]

224
backend/models/manager.py Normal file
View File

@ -0,0 +1,224 @@
from typing import Optional, List, Dict, Any, Tuple
from .registry import registry
from .models import Model, ModelCapability
from utils.logger import logger
from .registry import DEFAULT_PREMIUM_MODEL, DEFAULT_FREE_MODEL
class ModelManager:
def __init__(self):
self.registry = registry
def get_model(self, model_id: str) -> Optional[Model]:
return self.registry.get(model_id)
def resolve_model_id(self, model_id: str) -> str:
logger.debug(f"resolve_model_id called with: '{model_id}' (type: {type(model_id)})")
resolved = self.registry.resolve_model_id(model_id)
if resolved:
logger.debug(f"Resolved model '{model_id}' to '{resolved}'")
return resolved
all_aliases = list(self.registry._aliases.keys())
logger.warning(f"Could not resolve model ID: '{model_id}'. Available aliases: {all_aliases[:10]}...")
return model_id
def validate_model(self, model_id: str) -> Tuple[bool, str]:
model = self.get_model(model_id)
if not model:
return False, f"Model '{model_id}' not found"
if not model.enabled:
return False, f"Model '{model.name}' is currently disabled"
return True, ""
def calculate_cost(
self,
model_id: str,
input_tokens: int,
output_tokens: int
) -> Optional[float]:
model = self.get_model(model_id)
if not model or not model.pricing:
logger.warning(f"No pricing available for model: {model_id}")
return None
input_cost = input_tokens * model.pricing.input_cost_per_token
output_cost = output_tokens * model.pricing.output_cost_per_token
total_cost = input_cost + output_cost
logger.debug(
f"Cost calculation for {model.name}: "
f"{input_tokens} input tokens (${input_cost:.6f}) + "
f"{output_tokens} output tokens (${output_cost:.6f}) = "
f"${total_cost:.6f}"
)
return total_cost
def get_models_for_tier(self, tier: str) -> List[Model]:
return self.registry.get_by_tier(tier, enabled_only=True)
def get_models_with_capability(self, capability: ModelCapability) -> List[Model]:
return self.registry.get_by_capability(capability, enabled_only=True)
def select_best_model(
self,
tier: str,
required_capabilities: Optional[List[ModelCapability]] = None,
min_context_window: Optional[int] = None,
prefer_cheaper: bool = False
) -> Optional[Model]:
models = self.get_models_for_tier(tier)
if required_capabilities:
models = [
m for m in models
if all(cap in m.capabilities for cap in required_capabilities)
]
if min_context_window:
models = [m for m in models if m.context_window >= min_context_window]
if not models:
return None
if prefer_cheaper and any(m.pricing for m in models):
models_with_pricing = [m for m in models if m.pricing]
if models_with_pricing:
models = sorted(
models_with_pricing,
key=lambda m: m.pricing.input_cost_per_million_tokens
)
else:
models = sorted(
models,
key=lambda m: (-m.priority, not m.recommended)
)
return models[0] if models else None
def get_default_model(self, tier: str = "free") -> Optional[Model]:
models = self.get_models_for_tier(tier)
recommended = [m for m in models if m.recommended]
if recommended:
recommended = sorted(recommended, key=lambda m: -m.priority)
return recommended[0]
if models:
models = sorted(models, key=lambda m: -m.priority)
return models[0]
return None
def get_context_window(self, model_id: str, default: int = 31_000) -> int:
return self.registry.get_context_window(model_id, default)
def check_token_limit(
self,
model_id: str,
token_count: int,
is_input: bool = True
) -> Tuple[bool, int]:
model = self.get_model(model_id)
if not model:
return False, 0
if is_input:
max_allowed = model.context_window
else:
max_allowed = model.max_output_tokens or model.context_window
return token_count <= max_allowed, max_allowed
def format_model_info(self, model_id: str) -> Dict[str, Any]:
model = self.get_model(model_id)
if not model:
return {"error": f"Model '{model_id}' not found"}
return {
"id": model.id,
"name": model.name,
"provider": model.provider.value,
"context_window": model.context_window,
"max_output_tokens": model.max_output_tokens,
"capabilities": [cap.value for cap in model.capabilities],
"pricing": {
"input_per_million": model.pricing.input_cost_per_million_tokens,
"output_per_million": model.pricing.output_cost_per_million_tokens,
} if model.pricing else None,
"enabled": model.enabled,
"beta": model.beta,
"tier_availability": model.tier_availability,
"priority": model.priority,
"recommended": model.recommended,
}
def list_available_models(
self,
tier: Optional[str] = None,
include_disabled: bool = False
) -> List[Dict[str, Any]]:
logger.debug(f"list_available_models called with tier='{tier}', include_disabled={include_disabled}")
if tier:
models = self.registry.get_by_tier(tier, enabled_only=not include_disabled)
logger.debug(f"Found {len(models)} models for tier '{tier}'")
else:
models = self.registry.get_all(enabled_only=not include_disabled)
logger.debug(f"Found {len(models)} total models")
if models:
model_names = [m.name for m in models]
logger.debug(f"Models: {model_names}")
else:
logger.warning(f"No models found for tier '{tier}' - this might indicate a configuration issue")
models = sorted(
models,
key=lambda m: (not m.is_free_tier, -m.priority, m.name)
)
return [self.format_model_info(m.id) for m in models]
def get_legacy_constants(self) -> Dict:
return self.registry.to_legacy_format()
async def get_default_model_for_user(self, client, user_id: str) -> str:
try:
from utils.config import config, EnvMode
if config.ENV_MODE == EnvMode.LOCAL:
return DEFAULT_PREMIUM_MODEL
from services.billing import get_user_subscription, SUBSCRIPTION_TIERS
subscription = await get_user_subscription(user_id)
is_paid_tier = False
if subscription:
price_id = None
if subscription.get('items') and subscription['items'].get('data') and len(subscription['items']['data']) > 0:
price_id = subscription['items']['data'][0]['price']['id']
else:
price_id = subscription.get('price_id')
tier_info = SUBSCRIPTION_TIERS.get(price_id)
if tier_info and tier_info['name'] != 'free':
is_paid_tier = True
if is_paid_tier:
logger.debug(f"Setting Claude Sonnet 4 as default for paid user {user_id}")
return DEFAULT_PREMIUM_MODEL
else:
logger.debug(f"Setting Kimi K2 as default for free user {user_id}")
return DEFAULT_FREE_MODEL
except Exception as e:
logger.warning(f"Failed to determine user tier for {user_id}: {e}")
return DEFAULT_FREE_MODEL
model_manager = ModelManager()

104
backend/models/models.py Normal file
View File

@ -0,0 +1,104 @@
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Any
from enum import Enum
class ModelProvider(Enum):
OPENAI = "openai"
ANTHROPIC = "anthropic"
OPENROUTER = "openrouter"
GOOGLE = "google"
GEMINI = "gemini"
XAI = "xai"
MOONSHOTAI = "moonshotai"
class ModelCapability(Enum):
CHAT = "chat"
FUNCTION_CALLING = "function_calling"
VISION = "vision"
CODE_INTERPRETER = "code_interpreter"
WEB_SEARCH = "web_search"
THINKING = "thinking"
STRUCTURED_OUTPUT = "structured_output"
@dataclass
class ModelPricing:
input_cost_per_million_tokens: float
output_cost_per_million_tokens: float
@property
def input_cost_per_token(self) -> float:
return self.input_cost_per_million_tokens / 1_000_000
@property
def output_cost_per_token(self) -> float:
return self.output_cost_per_million_tokens / 1_000_000
@dataclass
class Model:
id: str
name: str
provider: ModelProvider
aliases: List[str] = field(default_factory=list)
context_window: int = 128_000
max_output_tokens: Optional[int] = None
capabilities: List[ModelCapability] = field(default_factory=list)
pricing: Optional[ModelPricing] = None
enabled: bool = True
beta: bool = False
tier_availability: List[str] = field(default_factory=lambda: ["paid"])
metadata: Dict[str, Any] = field(default_factory=dict)
priority: int = 0
recommended: bool = False
def __post_init__(self):
if self.max_output_tokens is None:
self.max_output_tokens = min(self.context_window // 4, 32_000)
if ModelCapability.CHAT not in self.capabilities:
self.capabilities.insert(0, ModelCapability.CHAT)
@property
def full_id(self) -> str:
if "/" in self.id:
return self.id
return f"{self.provider.value}/{self.id}"
@property
def supports_thinking(self) -> bool:
return ModelCapability.THINKING in self.capabilities
@property
def supports_functions(self) -> bool:
return ModelCapability.FUNCTION_CALLING in self.capabilities
@property
def supports_vision(self) -> bool:
return ModelCapability.VISION in self.capabilities
@property
def is_free_tier(self) -> bool:
return "free" in self.tier_availability
def to_dict(self) -> Dict[str, Any]:
return {
"id": self.id,
"name": self.name,
"provider": self.provider.value,
"aliases": self.aliases,
"context_window": self.context_window,
"max_output_tokens": self.max_output_tokens,
"capabilities": [cap.value for cap in self.capabilities],
"pricing": {
"input_cost_per_million_tokens": self.pricing.input_cost_per_million_tokens,
"output_cost_per_million_tokens": self.pricing.output_cost_per_million_tokens,
} if self.pricing else None,
"enabled": self.enabled,
"beta": self.beta,
"tier_availability": self.tier_availability,
"metadata": self.metadata,
"priority": self.priority,
"recommended": self.recommended,
}

327
backend/models/registry.py Normal file
View File

@ -0,0 +1,327 @@
from typing import Dict, List, Optional, Set
from .models import Model, ModelProvider, ModelCapability, ModelPricing
DEFAULT_FREE_MODEL = "Kimi K2"
DEFAULT_PREMIUM_MODEL = "Claude Sonnet 4"
class ModelRegistry:
def __init__(self):
self._models: Dict[str, Model] = {}
self._aliases: Dict[str, str] = {}
self._initialize_models()
def _initialize_models(self):
self.register(Model(
id="anthropic/claude-sonnet-4-20250514",
name="Claude Sonnet 4",
provider=ModelProvider.ANTHROPIC,
aliases=["claude-sonnet-4", "anthropic/claude-sonnet-4", "Claude Sonnet 4", "claude-sonnet-4-20250514"],
context_window=200_000,
capabilities=[
ModelCapability.CHAT,
ModelCapability.FUNCTION_CALLING,
ModelCapability.VISION,
ModelCapability.THINKING,
],
pricing=ModelPricing(
input_cost_per_million_tokens=3.00,
output_cost_per_million_tokens=15.00
),
tier_availability=["paid"],
priority=100,
recommended=True,
enabled=True
))
self.register(Model(
id="anthropic/claude-3-7-sonnet-latest",
name="Claude 3.7 Sonnet",
provider=ModelProvider.ANTHROPIC,
aliases=["sonnet-3.7", "claude-3.7", "Claude 3.7 Sonnet", "claude-3-7-sonnet-latest"],
context_window=200_000,
capabilities=[
ModelCapability.CHAT,
ModelCapability.FUNCTION_CALLING,
ModelCapability.VISION,
],
pricing=ModelPricing(
input_cost_per_million_tokens=3.00,
output_cost_per_million_tokens=15.00
),
tier_availability=["paid"],
priority=93,
enabled=True
))
self.register(Model(
id="anthropic/claude-3-5-sonnet-latest",
name="Claude 3.5 Sonnet",
provider=ModelProvider.ANTHROPIC,
aliases=["sonnet-3.5", "claude-3.5", "Claude 3.5 Sonnet", "claude-3-5-sonnet-latest"],
context_window=200_000,
capabilities=[
ModelCapability.CHAT,
ModelCapability.FUNCTION_CALLING,
ModelCapability.VISION,
],
pricing=ModelPricing(
input_cost_per_million_tokens=3.00,
output_cost_per_million_tokens=15.00
),
tier_availability=["paid"],
priority=90,
enabled=True
))
self.register(Model(
id="openai/gpt-5",
name="GPT-5",
provider=ModelProvider.OPENAI,
aliases=["gpt-5", "GPT-5"],
context_window=400_000,
capabilities=[
ModelCapability.CHAT,
ModelCapability.FUNCTION_CALLING,
ModelCapability.VISION,
ModelCapability.STRUCTURED_OUTPUT,
],
pricing=ModelPricing(
input_cost_per_million_tokens=1.25,
output_cost_per_million_tokens=10.00
),
tier_availability=["paid"],
priority=99,
enabled=True
))
self.register(Model(
id="openai/gpt-5-mini",
name="GPT-5 Mini",
provider=ModelProvider.OPENAI,
aliases=["gpt-5-mini", "GPT-5 Mini"],
context_window=400_000,
capabilities=[
ModelCapability.CHAT,
ModelCapability.FUNCTION_CALLING,
ModelCapability.STRUCTURED_OUTPUT,
],
pricing=ModelPricing(
input_cost_per_million_tokens=0.25,
output_cost_per_million_tokens=2.00
),
tier_availability=["free", "paid"],
priority=85,
enabled=True
))
self.register(Model(
id="gemini/gemini-2.5-pro",
name="Gemini 2.5 Pro",
provider=ModelProvider.GEMINI,
aliases=["google/gemini-2.5-pro", "gemini-2.5-pro", "Gemini 2.5 Pro"],
context_window=2_000_000,
capabilities=[
ModelCapability.CHAT,
ModelCapability.FUNCTION_CALLING,
ModelCapability.VISION,
ModelCapability.STRUCTURED_OUTPUT,
],
pricing=ModelPricing(
input_cost_per_million_tokens=1.25,
output_cost_per_million_tokens=10.00
),
tier_availability=["paid"],
priority=96,
enabled=True
))
self.register(Model(
id="xai/grok-4",
name="Grok 4",
provider=ModelProvider.XAI,
aliases=["grok-4", "x-ai/grok-4", "openrouter/x-ai/grok-4", "Grok 4"],
context_window=128_000,
capabilities=[
ModelCapability.CHAT,
ModelCapability.FUNCTION_CALLING,
],
pricing=ModelPricing(
input_cost_per_million_tokens=5.00,
output_cost_per_million_tokens=15.00
),
tier_availability=["paid"],
priority=94,
enabled=True
))
self.register(Model(
id="openrouter/moonshotai/kimi-k2",
name="Kimi K2",
provider=ModelProvider.MOONSHOTAI,
aliases=["moonshotai/kimi-k2", "kimi-k2", "Kimi K2"],
context_window=200_000,
capabilities=[
ModelCapability.CHAT,
ModelCapability.FUNCTION_CALLING,
],
pricing=ModelPricing(
input_cost_per_million_tokens=1.00,
output_cost_per_million_tokens=3.00
),
tier_availability=["free", "paid"],
priority=100,
enabled=True
))
"""
# DeepSeek Models
self.register(Model(
id="openrouter/deepseek/deepseek-chat",
name="DeepSeek Chat",
provider=ModelProvider.OPENROUTER,
aliases=["deepseek", "deepseek-chat"],
context_window=128_000,
capabilities=[
ModelCapability.CHAT,
ModelCapability.FUNCTION_CALLING
],
pricing=ModelPricing(
input_cost_per_million_tokens=0.38,
output_cost_per_million_tokens=0.89
),
tier_availability=["free", "paid"],
priority=95,
enabled=False # Currently disabled
))
# Qwen Models
self.register(Model(
id="openrouter/qwen/qwen3-235b-a22b",
name="Qwen3 235B",
provider=ModelProvider.OPENROUTER,
aliases=["qwen3", "qwen-3"],
context_window=128_000,
capabilities=[
ModelCapability.CHAT,
ModelCapability.FUNCTION_CALLING
],
pricing=ModelPricing(
input_cost_per_million_tokens=0.13,
output_cost_per_million_tokens=0.60
),
tier_availability=["free", "paid"],
priority=90,
enabled=False # Currently disabled
))
"""
def register(self, model: Model) -> None:
self._models[model.id] = model
for alias in model.aliases:
self._aliases[alias] = model.id
def get(self, model_id: str) -> Optional[Model]:
if model_id in self._models:
return self._models[model_id]
if model_id in self._aliases:
actual_id = self._aliases[model_id]
return self._models.get(actual_id)
return None
def get_all(self, enabled_only: bool = True) -> List[Model]:
models = list(self._models.values())
if enabled_only:
models = [m for m in models if m.enabled]
return models
def get_by_tier(self, tier: str, enabled_only: bool = True) -> List[Model]:
models = self.get_all(enabled_only)
return [m for m in models if tier in m.tier_availability]
def get_by_provider(self, provider: ModelProvider, enabled_only: bool = True) -> List[Model]:
models = self.get_all(enabled_only)
return [m for m in models if m.provider == provider]
def get_by_capability(self, capability: ModelCapability, enabled_only: bool = True) -> List[Model]:
models = self.get_all(enabled_only)
return [m for m in models if capability in m.capabilities]
def resolve_model_id(self, model_id: str) -> Optional[str]:
model = self.get(model_id)
return model.id if model else None
def get_aliases(self, model_id: str) -> List[str]:
model = self.get(model_id)
return model.aliases if model else []
def enable_model(self, model_id: str) -> bool:
model = self.get(model_id)
if model:
model.enabled = True
return True
return False
def disable_model(self, model_id: str) -> bool:
model = self.get(model_id)
if model:
model.enabled = False
return True
return False
def get_context_window(self, model_id: str, default: int = 31_000) -> int:
model = self.get(model_id)
return model.context_window if model else default
def get_pricing(self, model_id: str) -> Optional[ModelPricing]:
model = self.get(model_id)
return model.pricing if model else None
def to_legacy_format(self) -> Dict:
models_dict = {}
aliases_dict = {}
pricing_dict = {}
context_windows_dict = {}
for model in self.get_all(enabled_only=True):
models_dict[model.id] = {
"aliases": model.aliases,
"pricing": {
"input_cost_per_million_tokens": model.pricing.input_cost_per_million_tokens,
"output_cost_per_million_tokens": model.pricing.output_cost_per_million_tokens,
} if model.pricing else None,
"context_window": model.context_window,
"tier_availability": model.tier_availability,
}
for alias in model.aliases:
aliases_dict[alias] = model.id
if model.pricing:
pricing_dict[model.id] = {
"input_cost_per_million_tokens": model.pricing.input_cost_per_million_tokens,
"output_cost_per_million_tokens": model.pricing.output_cost_per_million_tokens,
}
context_windows_dict[model.id] = model.context_window
free_models = [m.id for m in self.get_by_tier("free")]
paid_models = [m.id for m in self.get_by_tier("paid")]
# Debug logging
from utils.logger import logger
logger.debug(f"Legacy format generation: {len(free_models)} free models, {len(paid_models)} paid models")
logger.debug(f"Free models: {free_models}")
logger.debug(f"Paid models: {paid_models}")
return {
"MODELS": models_dict,
"MODEL_NAME_ALIASES": aliases_dict,
"HARDCODED_MODEL_PRICES": pricing_dict,
"MODEL_CONTEXT_WINDOWS": context_windows_dict,
"FREE_TIER_MODELS": free_models,
"PAID_TIER_MODELS": paid_models,
}
registry = ModelRegistry()

View File

@ -111,20 +111,19 @@ async def run_agent_background(
"agent_config": agent_config,
})
effective_model = model_name
if model_name == "openai/gpt-5-mini" and agent_config and agent_config.get('model'):
from models import model_manager
is_tier_default = model_name in ["Kimi K2", "Claude Sonnet 4", "openai/gpt-5-mini"]
if is_tier_default and agent_config and agent_config.get('model'):
agent_model = agent_config['model']
from utils.constants import MODEL_NAME_ALIASES
resolved_agent_model = MODEL_NAME_ALIASES.get(agent_model, agent_model)
effective_model = resolved_agent_model
logger.debug(f"Using model from agent config: {agent_model} -> {effective_model} (no user selection)")
effective_model = model_manager.resolve_model_id(agent_model)
logger.debug(f"Using model from agent config: {agent_model} -> {effective_model} (tier default was {model_name})")
else:
from utils.constants import MODEL_NAME_ALIASES
effective_model = MODEL_NAME_ALIASES.get(model_name, model_name)
if model_name != "openai/gpt-5-mini":
effective_model = model_manager.resolve_model_id(model_name)
if not is_tier_default:
logger.debug(f"Using user-selected model: {model_name} -> {effective_model}")
else:
logger.debug(f"Using default model: {effective_model}")
logger.debug(f"Using tier default model: {model_name} -> {effective_model}")
logger.debug(f"🚀 Using model: {effective_model} (thinking: {enable_thinking}, reasoning_effort: {reasoning_effort})")
if agent_config:

View File

@ -110,14 +110,22 @@ def get_model_pricing(model: str) -> tuple[float, float] | None:
Get pricing for a model. Returns (input_cost_per_million, output_cost_per_million) or None.
Args:
model: The model name to get pricing for
model: The model name to get pricing for (can be display name or model ID)
Returns:
Tuple of (input_cost_per_million_tokens, output_cost_per_million_tokens) or None if not found
"""
# Try direct lookup first
if model in HARDCODED_MODEL_PRICES:
pricing = HARDCODED_MODEL_PRICES[model]
return pricing["input_cost_per_million_tokens"], pricing["output_cost_per_million_tokens"]
from models import model_manager
resolved_model = model_manager.resolve_model_id(model)
if resolved_model != model and resolved_model in HARDCODED_MODEL_PRICES:
pricing = HARDCODED_MODEL_PRICES[resolved_model]
return pricing["input_cost_per_million_tokens"], pricing["output_cost_per_million_tokens"]
return None
@ -570,8 +578,9 @@ def calculate_token_cost(prompt_tokens: int, completion_tokens: int, model: str)
prompt_tokens = int(prompt_tokens) if prompt_tokens is not None else 0
completion_tokens = int(completion_tokens) if completion_tokens is not None else 0
# Try to resolve the model name using MODEL_NAME_ALIASES first
resolved_model = MODEL_NAME_ALIASES.get(model, model)
# Try to resolve the model name using new model manager first
from models import model_manager
resolved_model = model_manager.resolve_model_id(model)
# Check if we have hardcoded pricing for this model (try both original and resolved)
hardcoded_pricing = get_model_pricing(model) or get_model_pricing(resolved_model)
@ -672,7 +681,8 @@ async def can_use_model(client, user_id: str, model_name: str):
}
allowed_models = await get_allowed_models_for_user(client, user_id)
resolved_model = MODEL_NAME_ALIASES.get(model_name, model_name)
from models import model_manager
resolved_model = model_manager.resolve_model_id(model_name)
if resolved_model in allowed_models:
return True, "Model access allowed", allowed_models
@ -1751,6 +1761,9 @@ async def get_available_models(
):
"""Get the list of models available to the user based on their subscription tier."""
try:
# Import the new model manager
from models import model_manager
# Get Supabase client
db = DBConnection()
client = await db.client
@ -1759,18 +1772,23 @@ async def get_available_models(
if config.ENV_MODE == EnvMode.LOCAL:
logger.debug("Running in local development mode - billing checks are disabled")
# In local mode, return all models from MODEL_NAME_ALIASES
# In local mode, return all enabled models
all_models = model_manager.list_available_models(include_disabled=False)
model_info = []
for short_name, full_name in MODEL_NAME_ALIASES.items():
# Skip entries where the key is a full name to avoid duplicates
# if short_name == full_name or '/' in short_name:
# continue
for model_data in all_models:
# Create clean model info for frontend
model_info.append({
"id": full_name,
"display_name": short_name,
"short_name": short_name,
"requires_subscription": False # Always false in local dev mode
"id": model_data["id"],
"display_name": model_data["name"],
"short_name": model_data.get("aliases", [model_data["name"]])[0] if model_data.get("aliases") else model_data["name"],
"requires_subscription": False, # Always false in local dev mode
"input_cost_per_million_tokens": model_data["pricing"]["input_per_million"] if model_data["pricing"] else None,
"output_cost_per_million_tokens": model_data["pricing"]["output_per_million"] if model_data["pricing"] else None,
"context_window": model_data["context_window"],
"capabilities": model_data["capabilities"],
"recommended": model_data["recommended"],
"priority": model_data["priority"]
})
return {
@ -1780,10 +1798,7 @@ async def get_available_models(
}
# For non-local mode, get list of allowed models for this user
allowed_models = await get_allowed_models_for_user(client, current_user_id)
free_tier_models = MODEL_ACCESS_TIERS.get('free', [])
# For non-local mode, use new model manager system
# Get subscription info for context
subscription = await get_user_subscription(current_user_id)
@ -1801,122 +1816,56 @@ async def get_available_models(
if tier_info:
tier_name = tier_info['name']
# Get all unique full model names from MODEL_NAME_ALIASES
all_models = set()
model_aliases = {}
# Get ALL enabled models for preview UI (don't filter by tier here)
all_models = model_manager.list_available_models(tier=None, include_disabled=False)
logger.debug(f"Found {len(all_models)} total models available")
for short_name, full_name in MODEL_NAME_ALIASES.items():
# Add all unique full model names
all_models.add(full_name)
# Get allowed models for this specific user (for access checking)
allowed_models = await get_allowed_models_for_user(client, current_user_id)
logger.debug(f"User {current_user_id} allowed models: {allowed_models}")
logger.debug(f"User tier: {tier_name}")
# Only include short names that don't match their full names for aliases
if short_name != full_name and not short_name.startswith("openai/") and not short_name.startswith("anthropic/") and not short_name.startswith("openrouter/") and not short_name.startswith("xai/"):
if full_name not in model_aliases:
model_aliases[full_name] = short_name
# Create model info with display names for ALL models
# Create clean model info for frontend
model_info = []
for model in all_models:
display_name = model_aliases.get(model, model.split('/')[-1] if '/' in model else model)
# Check if model requires subscription (not in free tier)
requires_sub = model not in free_tier_models
for model_data in all_models:
model_id = model_data["id"]
# Check if model is available with current subscription
is_available = model in allowed_models
is_available = model_id in allowed_models
# Get pricing information - check hardcoded prices first, then litellm
# Get pricing with multiplier applied
pricing_info = {}
# Check if we have hardcoded pricing for this model
hardcoded_pricing = get_model_pricing(model)
if hardcoded_pricing:
input_cost_per_million, output_cost_per_million = hardcoded_pricing
if model_data["pricing"]:
pricing_info = {
"input_cost_per_million_tokens": input_cost_per_million * TOKEN_PRICE_MULTIPLIER,
"output_cost_per_million_tokens": output_cost_per_million * TOKEN_PRICE_MULTIPLIER,
"max_tokens": None
"input_cost_per_million_tokens": model_data["pricing"]["input_per_million"] * TOKEN_PRICE_MULTIPLIER,
"output_cost_per_million_tokens": model_data["pricing"]["output_per_million"] * TOKEN_PRICE_MULTIPLIER,
"max_tokens": model_data["max_output_tokens"]
}
else:
try:
# Try to get pricing using cost_per_token function
models_to_try = []
# Add the original model name
models_to_try.append(model)
# Try to resolve the model name using MODEL_NAME_ALIASES
if model in MODEL_NAME_ALIASES:
resolved_model = MODEL_NAME_ALIASES[model]
models_to_try.append(resolved_model)
# Also try without provider prefix if it has one
if '/' in resolved_model:
models_to_try.append(resolved_model.split('/', 1)[1])
# If model is a value in aliases, try to find a matching key
for alias_key, alias_value in MODEL_NAME_ALIASES.items():
if alias_value == model:
models_to_try.append(alias_key)
break
# Also try without provider prefix for the original model
if '/' in model:
models_to_try.append(model.split('/', 1)[1])
# Special handling for Google models accessed via Google API
if model.startswith('gemini/'):
google_model_name = model.replace('gemini/', '')
models_to_try.append(google_model_name)
# Special handling for Google models accessed via Google API
if model.startswith('gemini/'):
google_model_name = model.replace('gemini/', '')
models_to_try.append(google_model_name)
# Try each model name variation until we find one that works
input_cost_per_token = None
output_cost_per_token = None
for model_name in models_to_try:
try:
# Use cost_per_token with sample token counts to get the per-token costs
input_cost, output_cost = cost_per_token(model_name, 1000000, 1000000)
if input_cost is not None and output_cost is not None:
input_cost_per_token = input_cost
output_cost_per_token = output_cost
break
except Exception:
continue
if input_cost_per_token is not None and output_cost_per_token is not None:
pricing_info = {
"input_cost_per_million_tokens": input_cost_per_token * TOKEN_PRICE_MULTIPLIER,
"output_cost_per_million_tokens": output_cost_per_million * TOKEN_PRICE_MULTIPLIER,
"max_tokens": None # cost_per_token doesn't provide max_tokens info
}
else:
pricing_info = {
"input_cost_per_million_tokens": None,
"output_cost_per_million_tokens": None,
"max_tokens": None
}
except Exception as e:
logger.warning(f"Could not get pricing for model {model}: {str(e)}")
pricing_info = {
"input_cost_per_million_tokens": None,
"output_cost_per_million_tokens": None,
"max_tokens": None
}
pricing_info = {
"input_cost_per_million_tokens": None,
"output_cost_per_million_tokens": None,
"max_tokens": None
}
model_info.append({
"id": model,
"display_name": display_name,
"short_name": model_aliases.get(model),
"requires_subscription": requires_sub,
"id": model_id,
"display_name": model_data["name"],
"short_name": model_data.get("aliases", [model_data["name"]])[0] if model_data.get("aliases") else model_data["name"],
"requires_subscription": not model_data.get("tier_availability", []) or "free" not in model_data["tier_availability"],
"is_available": is_available,
"context_window": model_data["context_window"],
"capabilities": model_data["capabilities"],
"recommended": model_data["recommended"],
"priority": model_data["priority"],
**pricing_info
})
logger.debug(f"Returning {len(model_info)} models to user {current_user_id} (tier: {tier_name})")
if model_info:
model_names = [m["display_name"] for m in model_info]
logger.debug(f"Model names: {model_names}")
return {
"models": model_info,
"subscription_tier": tier_name,

View File

@ -257,9 +257,12 @@ def prepare_params(
enable_thinking: Optional[bool] = False,
reasoning_effort: Optional[str] = 'low'
) -> Dict[str, Any]:
"""Prepare parameters for the API call."""
from models import model_manager
resolved_model_name = model_manager.resolve_model_id(model_name)
logger.debug(f"Model resolution: '{model_name}' -> '{resolved_model_name}'")
params = {
"model": model_name,
"model": resolved_model_name,
"messages": messages,
"temperature": temperature,
"response_format": response_format,
@ -276,22 +279,22 @@ def prepare_params(
params["model_id"] = model_id
# Handle token limits
_configure_token_limits(params, model_name, max_tokens)
_configure_token_limits(params, resolved_model_name, max_tokens)
# Add tools if provided
_add_tools_config(params, tools, tool_choice)
# Add Anthropic-specific parameters
_configure_anthopic(params, model_name, params["messages"])
_configure_anthopic(params, resolved_model_name, params["messages"])
# Add OpenRouter-specific parameters
_configure_openrouter(params, model_name)
_configure_openrouter(params, resolved_model_name)
# Add Bedrock-specific parameters
_configure_bedrock(params, model_name, model_id)
_configure_bedrock(params, resolved_model_name, model_id)
_add_fallback_model(params, model_name, messages)
_add_fallback_model(params, resolved_model_name, messages)
# Add OpenAI GPT-5 specific parameters
_configure_openai_gpt5(params, model_name)
_configure_openai_gpt5(params, resolved_model_name)
# Add Kimi K2-specific parameters
_configure_kimi_k2(params, model_name)
_configure_thinking(params, model_name, enable_thinking, reasoning_effort)
_configure_kimi_k2(params, resolved_model_name)
_configure_thinking(params, resolved_model_name, enable_thinking, reasoning_effort)
return params

View File

@ -121,6 +121,7 @@ class WorkflowUpdateRequest(BaseModel):
class WorkflowExecuteRequest(BaseModel):
input_data: Optional[Dict[str, Any]] = None
model_name: Optional[str] = None
WorkflowStepRequest.model_rebuild()
@ -761,6 +762,7 @@ async def execute_agent_workflow(
execution_data: WorkflowExecuteRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
):
print("DEBUG: Executing workflow", workflow_id, "for agent", agent_id)
await verify_agent_access(agent_id, user_id)
client = await db.client
@ -786,7 +788,9 @@ async def execute_agent_workflow(
if active_version and active_version.model:
model_name = active_version.model
else:
model_name = "openai/gpt-5-mini"
from models import model_manager
model_name = await model_manager.get_default_model_for_user(client, account_id)
print("DEBUG: Using tier-based default model:", model_name)
can_use, model_message, allowed_models = await can_use_model(client, account_id, model_name)
if not can_use:
@ -807,7 +811,8 @@ async def execute_agent_workflow(
'triggered_by': 'manual',
'execution_timestamp': datetime.now(timezone.utc).isoformat(),
'user_id': user_id,
'execution_source': 'workflow_api'
'execution_source': 'workflow_api',
'model_name': model_name # Use the model from agent version config
}
)

View File

@ -360,7 +360,20 @@ class AgentExecutor:
trigger_variables: Dict[str, Any]
) -> str:
client = await self._db.client
model_name = agent_config.get('model') or "openai/gpt-5-mini"
# Debug: Log the agent config to see what model is set
logger.debug(f"Agent config for trigger execution: model='{agent_config.get('model')}', keys={list(agent_config.keys())}")
model_name = agent_config.get('model')
logger.debug(f"Model from agent config: '{model_name}' (type: {type(model_name)})")
if not model_name:
account_id = agent_config.get('account_id')
if account_id:
from models import model_manager
model_name = await model_manager.get_default_model_for_user(client, account_id)
else:
model_name = "Kimi K2"
account_id = agent_config.get('account_id')
if not account_id:
@ -439,7 +452,7 @@ class WorkflowExecutor:
agent_config, account_id = await self._get_agent_data(agent_id)
enhanced_agent_config = await self._enhance_agent_config_for_workflow(
agent_config, workflow_config, steps_json, workflow_input, account_id
agent_config, workflow_config, steps_json, workflow_input, account_id, trigger_result
)
thread_id, project_id = await self._session_manager.create_workflow_session(
@ -528,7 +541,8 @@ class WorkflowExecutor:
workflow_config: Dict[str, Any],
steps_json: list,
workflow_input: Dict[str, Any],
account_id: str = None
account_id: str = None,
trigger_result: Optional[TriggerResult] = None
) -> Dict[str, Any]:
available_tools = self._get_available_tools(agent_config)
workflow_prompt = format_workflow_for_llm(
@ -547,6 +561,13 @@ class WorkflowExecutor:
if account_id:
enhanced_config['account_id'] = account_id
# Check for user-specified model in trigger execution variables
if trigger_result and hasattr(trigger_result, 'execution_variables'):
user_model = trigger_result.execution_variables.get('model_name')
if user_model:
enhanced_config['model'] = user_model
logger.debug(f"Using user-specified model for workflow: {user_model}")
return enhanced_config
def _get_available_tools(self, agent_config: Dict[str, Any]) -> list:
@ -589,7 +610,8 @@ class WorkflowExecutor:
from services.billing import check_billing_status, can_use_model
client = await self._db.client
model_name = "openai/gpt-5-mini"
from models import model_manager
model_name = await model_manager.get_default_model_for_user(client, account_id)
can_use, model_message, _ = await can_use_model(client, account_id, model_name)
if not can_use:
@ -636,7 +658,25 @@ class WorkflowExecutor:
agent_config: Dict[str, Any]
) -> str:
client = await self._db.client
model_name = agent_config.get('model') or "openai/gpt-5-mini"
# Debug: Log the agent config to see what model is set
logger.debug(f"Agent config for workflow execution: model='{agent_config.get('model')}', keys={list(agent_config.keys())}")
model_name = agent_config.get('model')
logger.debug(f"Model from agent config: '{model_name}' (type: {type(model_name)})")
if not model_name:
account_id = agent_config.get('account_id')
if not account_id:
thread_result = await client.table('threads').select('account_id').eq('thread_id', thread_id).execute()
if thread_result.data:
account_id = thread_result.data[0]['account_id']
if account_id:
from models import model_manager
model_name = await model_manager.get_default_model_for_user(client, account_id)
else:
model_name = "Kimi K2"
account_id = agent_config.get('account_id')
if not account_id:

View File

@ -289,7 +289,9 @@ class ProviderService:
def _initialize_providers(self):
self._providers["schedule"] = ScheduleProvider()
self._providers["webhook"] = WebhookProvider()
self._providers["composio"] = ComposioEventProvider()
composio_provider = ComposioEventProvider()
composio_provider.set_db(self._db)
self._providers["composio"] = composio_provider
async def get_available_providers(self) -> List[Dict[str, Any]]:
providers = []
@ -446,6 +448,45 @@ class ComposioEventProvider(TriggerProvider):
super().__init__("composio", TriggerType.WEBHOOK)
self._api_base = os.getenv("COMPOSIO_API_BASE", "https://backend.composio.dev")
self._api_key = os.getenv("COMPOSIO_API_KEY", "")
self._db: Optional[DBConnection] = None
def set_db(self, db: DBConnection):
"""Set database connection for provider"""
self._db = db
async def _count_triggers_with_composio_id(self, composio_trigger_id: str, exclude_trigger_id: Optional[str] = None) -> int:
"""Count how many triggers use the same composio_trigger_id (excluding specified trigger)"""
if not self._db:
return 0
client = await self._db.client
# Use PostgreSQL JSON operator for exact match
query = client.table('agent_triggers').select('trigger_id', count='exact').eq('trigger_type', 'webhook').eq('config->>composio_trigger_id', composio_trigger_id)
if exclude_trigger_id:
query = query.neq('trigger_id', exclude_trigger_id)
result = await query.execute()
count = result.count or 0
return count
async def _count_active_triggers_with_composio_id(self, composio_trigger_id: str, exclude_trigger_id: Optional[str] = None) -> int:
"""Count how many ACTIVE triggers use the same composio_trigger_id (excluding specified trigger)"""
if not self._db:
return 0
client = await self._db.client
# Use PostgreSQL JSON operator for exact match
query = client.table('agent_triggers').select('trigger_id', count='exact').eq('trigger_type', 'webhook').eq('is_active', True).eq('config->>composio_trigger_id', composio_trigger_id)
if exclude_trigger_id:
query = query.neq('trigger_id', exclude_trigger_id)
result = await query.execute()
count = result.count or 0
return count
def _headers(self) -> Dict[str, str]:
return {"x-api-key": self._api_key, "Content-Type": "application/json"}
@ -482,14 +523,23 @@ class ComposioEventProvider(TriggerProvider):
return config
async def setup_trigger(self, trigger: Trigger) -> bool:
# Re-enable the Composio trigger instance if present
# Enable in Composio only if this will be the first active trigger with this composio_trigger_id
try:
trigger_id = trigger.config.get("composio_trigger_id")
if not trigger_id:
composio_trigger_id = trigger.config.get("composio_trigger_id")
if not composio_trigger_id or not self._api_key:
return True
if not self._api_key:
# Check if other ACTIVE triggers are using this composio_trigger_id
other_active_count = await self._count_active_triggers_with_composio_id(composio_trigger_id, trigger.trigger_id)
logger.debug(f"Setup trigger {trigger.trigger_id}: other_active_count={other_active_count} for composio_id={composio_trigger_id}")
if other_active_count > 0:
# Other active triggers exist, don't touch Composio - just mark our trigger as active locally
logger.debug(f"Skipping Composio enable - {other_active_count} other active triggers exist")
return True
# Use canonical payload first per Composio API; include tolerant fallbacks
# We're the first/only active trigger, enable in Composio
logger.debug(f"Enabling trigger in Composio - first active trigger for {composio_trigger_id}")
payload_candidates: List[Dict[str, Any]] = [
{"status": "enable"},
{"status": "enabled"},
@ -497,11 +547,12 @@ class ComposioEventProvider(TriggerProvider):
]
async with httpx.AsyncClient(timeout=10) as client:
for api_base in self._api_bases():
url = f"{api_base}/api/v3/trigger_instances/manage/{trigger_id}"
url = f"{api_base}/api/v3/trigger_instances/manage/{composio_trigger_id}"
for body in payload_candidates:
try:
resp = await client.patch(url, headers=self._headers(), json=body)
if resp.status_code in (200, 204):
logger.debug(f"Successfully enabled trigger in Composio: {composio_trigger_id}")
return True
except Exception:
continue
@ -510,14 +561,23 @@ class ComposioEventProvider(TriggerProvider):
return True
async def teardown_trigger(self, trigger: Trigger) -> bool:
# Disable the Composio trigger instance so it stops sending webhooks
# Disable in Composio only if this was the last active trigger with this composio_trigger_id
try:
trigger_id = trigger.config.get("composio_trigger_id")
if not trigger_id:
composio_trigger_id = trigger.config.get("composio_trigger_id")
if not composio_trigger_id or not self._api_key:
logger.info(f"TEARDOWN: Skipping - no composio_id or api_key")
return True
if not self._api_key:
# Check if other ACTIVE triggers are using this composio_trigger_id
other_active_count = await self._count_active_triggers_with_composio_id(composio_trigger_id, trigger.trigger_id)
if other_active_count > 0:
# Other active triggers exist, don't touch Composio - just mark our trigger as inactive locally
logger.info(f"TEARDOWN: Skipping Composio disable - {other_active_count} other active triggers exist")
return True
# Use canonical payload first per Composio API; include tolerant fallbacks
# We're the last active trigger, disable in Composio
payload_candidates: List[Dict[str, Any]] = [
{"status": "disable"},
{"status": "disabled"},
@ -525,29 +585,38 @@ class ComposioEventProvider(TriggerProvider):
]
async with httpx.AsyncClient(timeout=10) as client:
for api_base in self._api_bases():
url = f"{api_base}/api/v3/trigger_instances/manage/{trigger_id}"
url = f"{api_base}/api/v3/trigger_instances/manage/{composio_trigger_id}"
for body in payload_candidates:
try:
resp = await client.patch(url, headers=self._headers(), json=body)
if resp.status_code in (200, 204):
return True
except Exception:
except Exception as e:
logger.warning(f"TEARDOWN: Failed to disable with body {body}: {e}")
continue
logger.warning(f"TEARDOWN: Failed to disable trigger in Composio: {composio_trigger_id}")
return True
except Exception:
except Exception as e:
logger.error(f"TEARDOWN: Exception in teardown_trigger: {e}")
return True
async def delete_remote_trigger(self, trigger: Trigger) -> bool:
# Permanently remove the remote Composio trigger instance
# Only permanently remove the remote Composio trigger if this is the last trigger using it
try:
trigger_id = trigger.config.get("composio_trigger_id")
if not trigger_id:
composio_trigger_id = trigger.config.get("composio_trigger_id")
if not composio_trigger_id or not self._api_key:
return True
if not self._api_key:
# Check if other triggers are using this composio_trigger_id
other_count = await self._count_triggers_with_composio_id(composio_trigger_id, trigger.trigger_id)
if other_count > 0:
# Other triggers exist, don't delete from Composio - just remove our local trigger
return True
# We're the last trigger, permanently delete from Composio
async with httpx.AsyncClient(timeout=10) as client:
for api_base in self._api_bases():
url = f"{api_base}/api/v3/trigger_instances/manage/{trigger_id}"
url = f"{api_base}/api/v3/trigger_instances/manage/{composio_trigger_id}"
try:
resp = await client.delete(url, headers=self._headers())
if resp.status_code in (200, 204):

View File

@ -145,6 +145,10 @@ class TriggerService:
config_changed = config is not None
activation_toggled = (is_active is not None) and (previous_is_active != trigger.is_active)
# UPDATE DATABASE FIRST so provider methods see correct state
await self._update_trigger(trigger)
if config_changed or activation_toggled:
from .provider_service import get_provider_service
provider_service = get_provider_service(self._db)
@ -156,8 +160,8 @@ class TriggerService:
setup_success = await provider_service.setup_trigger(trigger)
if not setup_success:
raise ValueError(f"Failed to update trigger setup: {trigger_id}")
else:
# Only activation toggled; call the minimal required action
elif activation_toggled:
# Only activation toggled; call the appropriate action
if trigger.is_active:
setup_success = await provider_service.setup_trigger(trigger)
if not setup_success:
@ -165,8 +169,6 @@ class TriggerService:
else:
await provider_service.teardown_trigger(trigger)
await self._update_trigger(trigger)
logger.debug(f"Updated trigger {trigger_id}")
return trigger
@ -175,9 +177,17 @@ class TriggerService:
if not trigger:
return False
# DELETE FROM DATABASE FIRST so provider methods see correct state
client = await self._db.client
result = await client.table('agent_triggers').delete().eq('trigger_id', trigger_id).execute()
success = len(result.data) > 0
if not success:
return False
from .provider_service import get_provider_service
provider_service = get_provider_service(self._db)
# First disable remotely so webhooks stop quickly
# Now disable remotely so webhooks stop quickly
try:
await provider_service.teardown_trigger(trigger)
except Exception:
@ -188,13 +198,6 @@ class TriggerService:
except Exception:
pass
client = await self._db.client
result = await client.table('agent_triggers').delete().eq('trigger_id', trigger_id).execute()
success = len(result.data) > 0
if success:
logger.debug(f"Deleted trigger {trigger_id}")
return success
async def process_trigger_event(self, trigger_id: str, raw_data: Dict[str, Any]) -> TriggerResult:

View File

@ -1,209 +1,13 @@
# Master model configuration - single source of truth
MODELS = {
# Free tier models
from models import model_manager
"anthropic/claude-sonnet-4-20250514": {
"aliases": ["claude-sonnet-4"],
"pricing": {
"input_cost_per_million_tokens": 3.00,
"output_cost_per_million_tokens": 15.00
},
"context_window": 200_000, # 200k tokens
"tier_availability": ["paid"]
},
# "openrouter/deepseek/deepseek-chat": {
# "aliases": ["deepseek"],
# "pricing": {
# "input_cost_per_million_tokens": 0.38,
# "output_cost_per_million_tokens": 0.89
# },
# "context_window": 128_000, # 128k tokens
# "tier_availability": ["free", "paid"]
# },
# "openrouter/qwen/qwen3-235b-a22b": {
# "aliases": ["qwen3"],
# "pricing": {
# "input_cost_per_million_tokens": 0.13,
# "output_cost_per_million_tokens": 0.60
# },
# "context_window": 128_000, # 128k tokens
# "tier_availability": ["free", "paid"]
# },
# "openrouter/google/gemini-2.5-flash-preview-05-20": {
# "aliases": ["gemini-flash-2.5"],
# "pricing": {
# "input_cost_per_million_tokens": 0.15,
# "output_cost_per_million_tokens": 0.60
# },
# "tier_availability": ["free", "paid"]
# },
# "openrouter/deepseek/deepseek-chat-v3-0324": {
# "aliases": ["deepseek/deepseek-chat-v3-0324"],
# "pricing": {
# "input_cost_per_million_tokens": 0.38,
# "output_cost_per_million_tokens": 0.89
# },
# "tier_availability": ["free", "paid"]
# },
"openrouter/moonshotai/kimi-k2": {
"aliases": ["moonshotai/kimi-k2"],
"pricing": {
"input_cost_per_million_tokens": 1.00,
"output_cost_per_million_tokens": 3.00
},
"context_window": 200_000, # 200k tokens
"tier_availability": ["free", "paid"]
},
"xai/grok-4": {
"aliases": ["grok-4", "x-ai/grok-4"],
"pricing": {
"input_cost_per_million_tokens": 5.00,
"output_cost_per_million_tokens": 15.00
},
"context_window": 128_000, # 128k tokens
"tier_availability": ["paid"]
},
_legacy_data = model_manager.get_legacy_constants()
# Paid tier only models
"gemini/gemini-2.5-pro": {
"aliases": ["google/gemini-2.5-pro"],
"pricing": {
"input_cost_per_million_tokens": 1.25,
"output_cost_per_million_tokens": 10.00
},
"context_window": 2_000_000, # 2M tokens
"tier_availability": ["paid"]
},
# "openai/gpt-4o": {
# "aliases": ["gpt-4o"],
# "pricing": {
# "input_cost_per_million_tokens": 2.50,
# "output_cost_per_million_tokens": 10.00
# },
# "tier_availability": ["paid"]
# },
# "openai/gpt-4.1": {
# "aliases": ["gpt-4.1"],
# "pricing": {
# "input_cost_per_million_tokens": 15.00,
# "output_cost_per_million_tokens": 60.00
# },
# "tier_availability": ["paid"]
# },
"openai/gpt-5": {
"aliases": ["gpt-5"],
"pricing": {
"input_cost_per_million_tokens": 1.25,
"output_cost_per_million_tokens": 10.00
},
"context_window": 400_000, # 400k tokens
"tier_availability": ["paid"]
},
"openai/gpt-5-mini": {
"aliases": ["gpt-5-mini"],
"pricing": {
"input_cost_per_million_tokens": 0.25,
"output_cost_per_million_tokens": 2.00
},
"context_window": 400_000, # 400k tokens
"tier_availability": ["free", "paid"]
},
# "openai/gpt-4.1-mini": {
# "aliases": ["gpt-4.1-mini"],
# "pricing": {
# "input_cost_per_million_tokens": 1.50,
# "output_cost_per_million_tokens": 6.00
# },
# "tier_availability": ["paid"]
# },
"anthropic/claude-3-7-sonnet-latest": {
"aliases": ["sonnet-3.7"],
"pricing": {
"input_cost_per_million_tokens": 3.00,
"output_cost_per_million_tokens": 15.00
},
"context_window": 200_000, # 200k tokens
"tier_availability": ["paid"]
},
"anthropic/claude-3-5-sonnet-latest": {
"aliases": ["sonnet-3.5"],
"pricing": {
"input_cost_per_million_tokens": 3.00,
"output_cost_per_million_tokens": 15.00
},
"context_window": 200_000, # 200k tokens
"tier_availability": ["paid"]
},
}
# Derived structures (auto-generated from MODELS)
def _generate_model_structures():
"""Generate all model structures from the master MODELS dictionary."""
# Generate tier lists
free_models = []
paid_models = []
# Generate aliases
aliases = {}
# Generate pricing
pricing = {}
# Generate context window limits
context_windows = {}
for model_name, config in MODELS.items():
# Add to tier lists
if "free" in config["tier_availability"]:
free_models.append(model_name)
if "paid" in config["tier_availability"]:
paid_models.append(model_name)
# Add aliases
for alias in config["aliases"]:
aliases[alias] = model_name
# Add pricing
pricing[model_name] = config["pricing"]
# Add context window limits
if "context_window" in config:
context_windows[model_name] = config["context_window"]
# Also add pricing and context windows for legacy model name variations
if model_name.startswith("openrouter/deepseek/"):
legacy_name = model_name.replace("openrouter/", "")
pricing[legacy_name] = config["pricing"]
if "context_window" in config:
context_windows[legacy_name] = config["context_window"]
elif model_name.startswith("openrouter/qwen/"):
legacy_name = model_name.replace("openrouter/", "")
pricing[legacy_name] = config["pricing"]
if "context_window" in config:
context_windows[legacy_name] = config["context_window"]
elif model_name.startswith("gemini/"):
legacy_name = model_name.replace("gemini/", "")
pricing[legacy_name] = config["pricing"]
if "context_window" in config:
context_windows[legacy_name] = config["context_window"]
elif model_name.startswith("anthropic/"):
# Add anthropic/claude-sonnet-4 alias for claude-sonnet-4-20250514
if "claude-sonnet-4-20250514" in model_name:
pricing["anthropic/claude-sonnet-4"] = config["pricing"]
if "context_window" in config:
context_windows["anthropic/claude-sonnet-4"] = config["context_window"]
elif model_name.startswith("xai/"):
# Add pricing for OpenRouter x-ai models
openrouter_name = model_name.replace("xai/", "openrouter/x-ai/")
pricing[openrouter_name] = config["pricing"]
if "context_window" in config:
context_windows[openrouter_name] = config["context_window"]
return free_models, paid_models, aliases, pricing, context_windows
# Generate all structures
FREE_TIER_MODELS, PAID_TIER_MODELS, MODEL_NAME_ALIASES, HARDCODED_MODEL_PRICES, MODEL_CONTEXT_WINDOWS = _generate_model_structures()
MODELS = _legacy_data["MODELS"]
MODEL_NAME_ALIASES = _legacy_data["MODEL_NAME_ALIASES"]
HARDCODED_MODEL_PRICES = _legacy_data["HARDCODED_MODEL_PRICES"]
MODEL_CONTEXT_WINDOWS = _legacy_data["MODEL_CONTEXT_WINDOWS"]
FREE_TIER_MODELS = _legacy_data["FREE_TIER_MODELS"]
PAID_TIER_MODELS = _legacy_data["PAID_TIER_MODELS"]
MODEL_ACCESS_TIERS = {
"free": FREE_TIER_MODELS,
@ -220,38 +24,4 @@ MODEL_ACCESS_TIERS = {
}
def get_model_context_window(model_name: str, default: int = 31_000) -> int:
"""
Get the context window size for a given model.
Args:
model_name: The model name or alias
default: Default context window if model not found
Returns:
Context window size in tokens
"""
# Check direct model name first
if model_name in MODEL_CONTEXT_WINDOWS:
return MODEL_CONTEXT_WINDOWS[model_name]
# Check if it's an alias
if model_name in MODEL_NAME_ALIASES:
canonical_name = MODEL_NAME_ALIASES[model_name]
if canonical_name in MODEL_CONTEXT_WINDOWS:
return MODEL_CONTEXT_WINDOWS[canonical_name]
# Fallback patterns for common model naming variations
if 'sonnet' in model_name.lower():
return 200_000 # Claude Sonnet models
elif 'gpt-5' in model_name.lower():
return 400_000 # GPT-5 models
elif 'gemini' in model_name.lower():
return 2_000_000 # Gemini models
elif 'grok' in model_name.lower():
return 128_000 # Grok models
elif 'gpt' in model_name.lower():
return 128_000 # GPT-4 and variants
elif 'deepseek' in model_name.lower():
return 128_000 # DeepSeek models
return default
return model_manager.get_context_window(model_name, default)

View File

@ -19,12 +19,10 @@ import { cn } from '@/lib/utils';
import {
useModelSelection,
MODELS,
STORAGE_KEY_CUSTOM_MODELS,
formatModelName,
getCustomModels,
DEFAULT_FREE_MODEL_ID,
DEFAULT_PREMIUM_MODEL_ID
} from '@/components/thread/chat-input/_use-model-selection';
import { formatModelName, getPrefixedModelId } from '@/lib/stores/model-store';
import { useAvailableModels } from '@/hooks/react-query/subscriptions/use-billing';
import { isLocalMode } from '@/lib/config';
import { CustomModelDialog, CustomModelFormData } from '@/components/thread/chat-input/custom-model-dialog';
@ -52,36 +50,33 @@ export function AgentModelSelector({
variant = 'default',
className,
}: AgentModelSelectorProps) {
const { allModels, canAccessModel, subscriptionStatus } = useModelSelection();
const {
allModels,
canAccessModel,
subscriptionStatus,
selectedModel: storeSelectedModel,
handleModelChange: storeHandleModelChange,
customModels: storeCustomModels,
addCustomModel: storeAddCustomModel,
updateCustomModel: storeUpdateCustomModel,
removeCustomModel: storeRemoveCustomModel
} = useModelSelection();
const { data: modelsData } = useAvailableModels();
const [isOpen, setIsOpen] = useState(false);
const [searchQuery, setSearchQuery] = useState('');
const [highlightedIndex, setHighlightedIndex] = useState<number>(-1);
const searchInputRef = useRef<HTMLInputElement>(null);
// Paywall and billing states
const [paywallOpen, setPaywallOpen] = useState(false);
const [lockedModel, setLockedModel] = useState<string | null>(null);
const [billingModalOpen, setBillingModalOpen] = useState(false);
// Custom model states for local mode
const [customModels, setCustomModels] = useState<CustomModel[]>([]);
const [isCustomModelDialogOpen, setIsCustomModelDialogOpen] = useState(false);
const [dialogInitialData, setDialogInitialData] = useState<CustomModelFormData>({ id: '', label: '' });
const [dialogMode, setDialogMode] = useState<'add' | 'edit'>('add');
const [editingModelId, setEditingModelId] = useState<string | null>(null);
useEffect(() => {
if (isLocalMode()) {
setCustomModels(getCustomModels());
}
}, []);
useEffect(() => {
if (isLocalMode() && customModels.length > 0) {
localStorage.setItem(STORAGE_KEY_CUSTOM_MODELS, JSON.stringify(customModels));
}
}, [customModels]);
const customModels = storeCustomModels;
const normalizeModelId = (modelId?: string): string => {
if (!modelId) return subscriptionStatus === 'active' ? DEFAULT_PREMIUM_MODEL_ID : DEFAULT_FREE_MODEL_ID;
@ -103,17 +98,45 @@ export function AgentModelSelector({
return modelId;
};
const selectedModel = normalizeModelId(value);
const normalizedValue = normalizeModelId(value);
useEffect(() => {
if (normalizedValue && normalizedValue !== storeSelectedModel) {
storeHandleModelChange(normalizedValue);
}
}, [normalizedValue, storeSelectedModel, storeHandleModelChange]);
const selectedModel = storeSelectedModel;
const enhancedModelOptions = useMemo(() => {
const modelMap = new Map();
allModels.forEach(model => {
modelMap.set(model.id, {
...model,
isCustom: false
if (modelsData?.models) {
modelsData.models.forEach(model => {
const shortName = model.short_name || model.id;
const displayName = model.display_name || shortName;
modelMap.set(shortName, {
id: shortName,
label: displayName,
requiresSubscription: model.requires_subscription || false,
priority: model.priority || 0,
recommended: model.recommended || false,
top: (model.priority || 0) >= 90,
capabilities: model.capabilities || [],
contextWindow: model.context_window || 128000,
isCustom: false
});
});
});
} else {
// Fallback to allModels if API data not available
allModels.forEach(model => {
modelMap.set(model.id, {
...model,
isCustom: false
});
});
}
if (isLocalMode()) {
customModels.forEach(model => {
@ -136,7 +159,7 @@ export function AgentModelSelector({
}
return Array.from(modelMap.values());
}, [allModels, customModels]);
}, [modelsData?.models, allModels, customModels]);
const selectedModelDisplay = useMemo(() => {
const model = enhancedModelOptions.find(m => m.id === selectedModel);
@ -162,7 +185,7 @@ export function AgentModelSelector({
const freeModels = sortedModels.filter(m => !m.requiresSubscription);
const premiumModels = sortedModels.filter(m => m.requiresSubscription);
const shouldDisplayAll = (!isLocalMode() && subscriptionStatus === 'no_subscription') && premiumModels.length > 0;
const shouldDisplayAll = !isLocalMode() && premiumModels.length > 0;
useEffect(() => {
if (isOpen && searchInputRef.current) {
@ -176,20 +199,18 @@ export function AgentModelSelector({
}, [isOpen]);
const handleSelect = (modelId: string) => {
console.log('🔧 AgentModelSelector: Selecting model:', modelId);
console.log('🔧 AgentModelSelector: Current selectedModel (normalized):', selectedModel);
console.log('🔧 AgentModelSelector: Current value prop:', value);
const isCustomModel = customModels.some(model => model.id === modelId);
if (isCustomModel && isLocalMode()) {
storeHandleModelChange(modelId);
onChange(modelId);
setIsOpen(false);
return;
}
if (isLocalMode() || canAccessModel(modelId)) {
// Don't transform the modelId - pass it as-is to match what useModelSelection expects
const hasAccess = isLocalMode() || canAccessModel(modelId);
if (hasAccess) {
storeHandleModelChange(modelId);
onChange(modelId);
setIsOpen(false);
} else {
@ -262,22 +283,16 @@ export function AgentModelSelector({
closeCustomModelDialog();
const newModel = { id: modelId, label: modelLabel };
const updatedModels = dialogMode === 'add'
? [...customModels, newModel]
: customModels.map(model => model.id === editingModelId ? newModel : model);
try {
localStorage.setItem(STORAGE_KEY_CUSTOM_MODELS, JSON.stringify(updatedModels));
} catch (error) {
console.error('Failed to save custom models to localStorage:', error);
}
setCustomModels(updatedModels);
if (dialogMode === 'add') {
storeAddCustomModel(newModel);
storeHandleModelChange(modelId);
onChange(modelId);
} else if (selectedModel === editingModelId) {
onChange(modelId);
} else {
storeUpdateCustomModel(editingModelId!, newModel);
if (selectedModel === editingModelId) {
storeHandleModelChange(modelId);
onChange(modelId);
}
}
setIsOpen(false);
@ -293,20 +308,11 @@ export function AgentModelSelector({
e?.stopPropagation();
e?.preventDefault();
const updatedCustomModels = customModels.filter(model => model.id !== modelId);
if (isLocalMode() && typeof window !== 'undefined') {
try {
localStorage.setItem(STORAGE_KEY_CUSTOM_MODELS, JSON.stringify(updatedCustomModels));
} catch (error) {
console.error('Failed to update custom models in localStorage:', error);
}
}
setCustomModels(updatedCustomModels);
storeRemoveCustomModel(modelId);
if (selectedModel === modelId) {
const defaultModel = subscriptionStatus === 'active' ? DEFAULT_PREMIUM_MODEL_ID : DEFAULT_FREE_MODEL_ID;
storeHandleModelChange(defaultModel);
onChange(defaultModel);
}
};
@ -461,7 +467,7 @@ export function AgentModelSelector({
</TooltipProvider>
<DropdownMenuContent
align={variant === 'menu-item' ? 'end' : 'start'}
className="w-72 p-0 overflow-hidden"
className="w-76 p-0 overflow-hidden"
sideOffset={variant === 'menu-item' ? 8 : 4}
>
<div className="max-h-[400px] overflow-y-auto scrollbar-thin scrollbar-thumb-zinc-300 dark:scrollbar-thumb-zinc-700 scrollbar-track-transparent w-full">
@ -535,58 +541,77 @@ export function AgentModelSelector({
<div className="mt-4 border-t border-border pt-2">
<div className="px-3 py-1.5 text-xs font-medium text-blue-500 flex items-center">
<Crown className="h-3.5 w-3.5 mr-1.5" />
Additional Models
{subscriptionStatus === 'active' ? 'Premium Models' : 'Additional Models'}
</div>
<div className="relative h-40 overflow-hidden px-2">
{premiumModels.slice(0, 3).map((model, index) => (
<TooltipProvider key={`premium-${model.id}-${index}`}>
<Tooltip>
<TooltipTrigger asChild>
<div className='w-full'>
<DropdownMenuItem
className="text-sm px-3 rounded-lg py-2 mx-2 my-0.5 flex items-center justify-between opacity-70 cursor-pointer pointer-events-none"
>
<div className="flex items-center">
<span className="font-medium">{model.label}</span>
</div>
<div className="flex items-center gap-2">
{MODELS[model.id]?.recommended && (
<span className="text-xs px-1.5 py-0.5 rounded-sm bg-blue-100 dark:bg-blue-900 text-blue-600 dark:text-blue-300 font-medium whitespace-nowrap">
Recommended
</span>
<div className="relative overflow-hidden" style={{ maxHeight: subscriptionStatus === 'active' ? 'none' : '160px' }}>
{(subscriptionStatus === 'active' ? premiumModels : premiumModels.slice(0, 3)).map((model, index) => {
const canAccess = isLocalMode() || canAccessModel(model.id);
const isRecommended = model.recommended;
return (
<TooltipProvider key={`premium-${model.id}-${index}`}>
<Tooltip>
<TooltipTrigger asChild>
<div className='w-full'>
<DropdownMenuItem
className={cn(
"text-sm px-3 rounded-lg py-2 mx-2 my-0.5 flex items-center justify-between cursor-pointer",
!canAccess && "opacity-70"
)}
<Crown className="h-3.5 w-3.5 text-blue-500" />
</div>
</DropdownMenuItem>
</div>
</TooltipTrigger>
<TooltipContent side="left" className="text-xs max-w-xs">
<p>Requires subscription to access premium model</p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
))}
<div className="absolute inset-0 bg-gradient-to-t from-background via-background/95 to-transparent flex items-end justify-center">
<div className="w-full p-3">
<div className="rounded-xl bg-gradient-to-br from-blue-50/80 to-blue-200/70 dark:from-blue-950/40 dark:to-blue-900/30 shadow-sm border border-blue-200/50 dark:border-blue-800/50 p-3">
<div className="flex flex-col space-y-2">
<div className="flex items-center">
<Crown className="h-4 w-4 text-blue-500 mr-2 flex-shrink-0" />
<div>
<p className="text-sm font-medium">Unlock all models + higher limits</p>
onClick={() => handleSelect(model.id)}
>
<div className="flex items-center">
<span className="font-medium">{model.label}</span>
</div>
<div className="flex items-center gap-2">
{isRecommended && (
<span className="text-xs px-1.5 py-0.5 rounded-sm bg-blue-100 dark:bg-blue-900 text-blue-600 dark:text-blue-300 font-medium whitespace-nowrap">
Recommended
</span>
)}
{!canAccess && <Crown className="h-3.5 w-3.5 text-blue-500" />}
{selectedModel === model.id && (
<Check className="h-4 w-4 text-blue-500" />
)}
</div>
</DropdownMenuItem>
</div>
</TooltipTrigger>
<TooltipContent side="left" className="text-xs max-w-xs">
<p>
{canAccess
? (isRecommended ? 'Recommended for optimal performance' : 'Premium model')
: 'Requires subscription to access premium model'
}
</p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
);
})}
{subscriptionStatus !== 'active' && (
<div className="absolute inset-0 bg-gradient-to-t from-background via-background/95 to-transparent flex items-end justify-center">
<div className="w-full p-3">
<div className="rounded-xl bg-gradient-to-br from-blue-50/80 to-blue-200/70 dark:from-blue-950/40 dark:to-blue-900/30 shadow-sm border border-blue-200/50 dark:border-blue-800/50 p-3">
<div className="flex flex-col space-y-2">
<div className="flex items-center">
<Crown className="h-4 w-4 text-blue-500 mr-2 flex-shrink-0" />
<div>
<p className="text-sm font-medium">Unlock all models + higher limits</p>
</div>
</div>
<Button
size="sm"
className="w-full h-8 font-medium"
onClick={handleUpgradeClick}
>
Upgrade now
</Button>
</div>
<Button
size="sm"
className="w-full h-8 font-medium"
onClick={handleUpgradeClick}
>
Upgrade now
</Button>
</div>
</div>
</div>
</div>
)}
</div>
</div>
</>

View File

@ -0,0 +1,230 @@
'use client';
import { useSubscriptionData } from '@/contexts/SubscriptionContext';
import { useEffect, useMemo } from 'react';
import { isLocalMode } from '@/lib/config';
import { useAvailableModels } from '@/hooks/react-query/subscriptions/use-model';
import {
useModelStore,
canAccessModel,
formatModelName,
getPrefixedModelId,
type SubscriptionStatus,
type ModelOption,
type CustomModel
} from '@/lib/stores/model-store';
export const useModelSelection = () => {
const { data: subscriptionData } = useSubscriptionData();
const { data: modelsData, isLoading: isLoadingModels } = useAvailableModels({
refetchOnMount: false,
});
const {
selectedModel,
customModels,
hasHydrated,
setSelectedModel,
addCustomModel,
updateCustomModel,
removeCustomModel,
setCustomModels,
setHasHydrated,
getDefaultModel,
resetToDefault,
} = useModelStore();
const subscriptionStatus: SubscriptionStatus = (subscriptionData?.status === 'active' || subscriptionData?.status === 'trialing')
? 'active'
: 'no_subscription';
// Load custom models from localStorage for local mode
useEffect(() => {
if (isLocalMode() && hasHydrated && typeof window !== 'undefined') {
try {
const storedModels = localStorage.getItem('customModels');
if (storedModels) {
const parsedModels = JSON.parse(storedModels);
if (Array.isArray(parsedModels)) {
const validModels = parsedModels.filter((model: any) =>
model && typeof model === 'object' &&
typeof model.id === 'string' &&
typeof model.label === 'string'
);
setCustomModels(validModels);
}
}
} catch (e) {
console.error('Error loading custom models:', e);
}
}
}, [isLocalMode, hasHydrated, setCustomModels]);
const MODEL_OPTIONS = useMemo(() => {
let models: ModelOption[] = [];
if (!modelsData?.models || isLoadingModels) {
models = [
{
id: 'moonshotai/kimi-k2',
label: 'Kimi K2',
requiresSubscription: false,
priority: 100,
recommended: true
},
{
id: 'claude-sonnet-4',
label: 'Claude Sonnet 4',
requiresSubscription: true,
priority: 100,
recommended: true
},
];
} else {
models = modelsData.models.map(model => {
const shortName = model.short_name || model.id;
const displayName = model.display_name || shortName;
return {
id: shortName,
label: displayName,
requiresSubscription: model.requires_subscription || false,
priority: model.priority || 0,
recommended: model.recommended || false,
top: (model.priority || 0) >= 90,
capabilities: model.capabilities || [],
contextWindow: model.context_window || 128000
};
});
}
if (isLocalMode() && customModels.length > 0) {
const customModelOptions = customModels.map(model => ({
id: model.id,
label: model.label || formatModelName(model.id),
requiresSubscription: false,
top: false,
isCustom: true,
priority: 30,
}));
models = [...models, ...customModelOptions];
}
const sortedModels = models.sort((a, b) => {
if (a.recommended !== b.recommended) {
return a.recommended ? -1 : 1;
}
if (a.priority !== b.priority) {
return (b.priority || 0) - (a.priority || 0);
}
return a.label.localeCompare(b.label);
});
return sortedModels;
}, [modelsData, isLoadingModels, customModels]);
const availableModels = useMemo(() => {
return isLocalMode()
? MODEL_OPTIONS
: MODEL_OPTIONS.filter(model =>
canAccessModel(subscriptionStatus, model.requiresSubscription)
);
}, [MODEL_OPTIONS, subscriptionStatus]);
// Validate model selection only after hydration and when subscription status changes
useEffect(() => {
// Skip validation until hydrated and models are loaded
if (!hasHydrated || isLoadingModels || typeof window === 'undefined') {
return;
}
// Check if the selected model is still valid
const isValidModel = MODEL_OPTIONS.some(model => model.id === selectedModel) ||
(isLocalMode() && customModels.some(model => model.id === selectedModel));
if (!isValidModel) {
console.log('🔧 ModelSelection: Invalid model detected, resetting to default');
resetToDefault(subscriptionStatus);
return;
}
// For non-local mode, check if user still has access to the selected model
if (!isLocalMode()) {
const modelOption = MODEL_OPTIONS.find(m => m.id === selectedModel);
if (modelOption && !canAccessModel(subscriptionStatus, modelOption.requiresSubscription)) {
console.log('🔧 ModelSelection: User lost access to model, resetting to default');
resetToDefault(subscriptionStatus);
}
}
}, [hasHydrated, selectedModel, subscriptionStatus, MODEL_OPTIONS, customModels, isLoadingModels, resetToDefault]);
const handleModelChange = (modelId: string) => {
const isCustomModel = isLocalMode() && customModels.some(model => model.id === modelId);
const modelOption = MODEL_OPTIONS.find(option => option.id === modelId);
if (!modelOption && !isCustomModel) {
resetToDefault(subscriptionStatus);
return;
}
if (!isCustomModel && !isLocalMode() &&
!canAccessModel(subscriptionStatus, modelOption?.requiresSubscription ?? false)) {
return;
}
setSelectedModel(modelId);
};
const getActualModelId = (modelId: string): string => {
const isCustomModel = isLocalMode() && customModels.some(model => model.id === modelId);
return isCustomModel ? getPrefixedModelId(modelId, true) : modelId;
};
// Function to refresh custom models from localStorage
const refreshCustomModels = () => {
if (isLocalMode() && typeof window !== 'undefined') {
try {
const storedModels = localStorage.getItem('customModels');
if (storedModels) {
const parsedModels = JSON.parse(storedModels);
if (Array.isArray(parsedModels)) {
const validModels = parsedModels.filter((model: any) =>
model && typeof model === 'object' &&
typeof model.id === 'string' &&
typeof model.label === 'string'
);
setCustomModels(validModels);
}
}
} catch (e) {
console.error('Error loading custom models:', e);
}
}
};
return {
selectedModel,
handleModelChange,
setSelectedModel: handleModelChange, // Alias for backward compatibility
availableModels,
allModels: MODEL_OPTIONS,
customModels,
addCustomModel,
updateCustomModel,
removeCustomModel,
refreshCustomModels,
getActualModelId,
canAccessModel: (modelId: string) => {
if (isLocalMode()) return true;
const model = MODEL_OPTIONS.find(m => m.id === modelId);
return model ? canAccessModel(subscriptionStatus, model.requiresSubscription) : false;
},
isSubscriptionRequired: (modelId: string) => {
return MODEL_OPTIONS.find(m => m.id === modelId)?.requiresSubscription || false;
},
subscriptionStatus,
};
};

View File

@ -169,7 +169,7 @@ const saveModelPreference = (modelId: string): void => {
}
};
export const useModelSelection = () => {
export const useModelSelectionOld = () => {
const [selectedModel, setSelectedModel] = useState(DEFAULT_FREE_MODEL_ID);
const [customModels, setCustomModels] = useState<CustomModel[]>([]);
const [hasInitialized, setHasInitialized] = useState(false);
@ -200,52 +200,41 @@ export const useModelSelection = () => {
const MODEL_OPTIONS = useMemo(() => {
let models = [];
// Default models if API data not available
if (!modelsData?.models || isLoadingModels) {
models = [
{
id: DEFAULT_FREE_MODEL_ID,
label: 'KIMI K2',
requiresSubscription: false,
priority: MODELS[DEFAULT_FREE_MODEL_ID]?.priority || 100
},
{
id: DEFAULT_PREMIUM_MODEL_ID,
label: 'Claude Sonnet 4',
requiresSubscription: true,
priority: MODELS[DEFAULT_PREMIUM_MODEL_ID]?.priority || 100
},
];
// Default models if API data not available
if (!modelsData?.models || isLoadingModels) {
models = [
{
id: DEFAULT_FREE_MODEL_ID,
label: 'KIMI K2',
requiresSubscription: false,
priority: 100,
recommended: true
},
{
id: DEFAULT_PREMIUM_MODEL_ID,
label: 'Claude Sonnet 4',
requiresSubscription: true,
priority: 100,
recommended: true
},
];
} else {
// Process API-provided models
// Process API-provided models - use clean data from new backend system
models = modelsData.models.map(model => {
// Use the clean data directly from the API (no more duplicates!)
const shortName = model.short_name || model.id;
const displayName = model.display_name || shortName;
// Format the display label
let cleanLabel = displayName;
if (cleanLabel.includes('/')) {
cleanLabel = cleanLabel.split('/').pop() || cleanLabel;
}
cleanLabel = cleanLabel
.replace(/-/g, ' ')
.split(' ')
.map(word => word.charAt(0).toUpperCase() + word.slice(1))
.join(' ');
// Get model data from our central MODELS constant
const modelData = MODELS[shortName] || {};
const isPremium = model?.requires_subscription || modelData.tier === 'premium' || false;
return {
id: shortName,
label: cleanLabel,
requiresSubscription: isPremium,
top: modelData.priority >= 90, // Mark high-priority models as "top"
priority: modelData.priority || 0,
lowQuality: modelData.lowQuality || false,
recommended: modelData.recommended || false
label: displayName,
requiresSubscription: model.requires_subscription || false,
priority: model.priority || 0,
recommended: model.recommended || false,
top: (model.priority || 0) >= 90, // Mark high-priority models as "top"
lowQuality: false, // All models in new system are quality controlled
capabilities: model.capabilities || [],
contextWindow: model.context_window || 128000
};
});
}
@ -516,4 +505,7 @@ export const useModelSelection = () => {
};
};
// Export the new model selection hook
export { useModelSelection } from './_use-model-selection-new';
// Export the hook but not any sorting logic - sorting is handled internally

View File

@ -14,7 +14,7 @@ import { Card, CardContent } from '@/components/ui/card';
import { handleFiles } from './file-upload-handler';
import { MessageInput } from './message-input';
import { AttachmentGroup } from '../attachment-group';
import { useModelSelection } from './_use-model-selection';
import { useModelSelection } from './_use-model-selection-new';
import { useFileDelete } from '@/hooks/react-query/files';
import { useQueryClient } from '@tanstack/react-query';
import { ToolCallInput } from './floating-tool-preview';

View File

@ -452,7 +452,8 @@ export function FileAttachment({
"group relative w-full",
"rounded-xl border bg-card overflow-hidden pt-10", // Consistent card styling with header space
isPdf ? "!min-h-[200px] sm:min-h-0 sm:h-[400px] max-h-[500px] sm:!min-w-[300px]" :
standalone ? "min-h-[300px] h-auto" : "h-[300px]", // Better height handling for standalone
isHtmlOrMd ? "!min-h-[200px] sm:min-h-0 sm:h-[400px] max-h-[600px] sm:!min-w-[300px]" :
standalone ? "min-h-[300px] h-auto" : "h-[300px]", // Better height handling for standalone
className
)}
style={{
@ -469,8 +470,8 @@ export function FileAttachment({
style={{
minWidth: 0,
width: '100%',
containIntrinsicSize: isPdf ? '100% 500px' : undefined,
contain: isPdf ? 'layout size' : undefined
containIntrinsicSize: (isPdf || isHtmlOrMd) ? '100% 500px' : undefined,
contain: (isPdf || isHtmlOrMd) ? 'layout size' : undefined
}}
>
{/* Render PDF or text-based previews */}

View File

@ -3,29 +3,68 @@ import {
FileDiff,
CheckCircle,
AlertTriangle,
ExternalLink,
Loader2,
Code,
Eye,
File,
ChevronDown,
ChevronUp,
Copy,
Check,
Minus,
Plus,
} from 'lucide-react';
import {
extractFilePath,
extractFileContent,
extractStreamingFileContent,
formatTimestamp,
getToolTitle,
normalizeContentToString,
extractToolData,
} from '../utils';
import {
MarkdownRenderer,
processUnicodeContent,
} from '@/components/file-renderers/markdown-renderer';
import { CsvRenderer } from '@/components/file-renderers/csv-renderer';
import { XlsxRenderer } from '@/components/file-renderers/xlsx-renderer';
import { cn } from '@/lib/utils';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { useTheme } from 'next-themes';
import { CodeBlockCode } from '@/components/ui/code-block';
import { constructHtmlPreviewUrl } from '@/lib/utils/url';
import {
Card,
CardContent,
CardHeader,
CardTitle,
} from '@/components/ui/card';
import { Badge } from '@/components/ui/badge';
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip";
import { Button } from '@/components/ui/button';
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
import { ScrollArea } from "@/components/ui/scroll-area";
import {
extractFileEditData,
generateLineDiff,
calculateDiffStats,
LineDiff,
DiffStats
DiffStats,
getLanguageFromFileName,
getOperationType,
getOperationConfigs,
getFileIcon,
processFilePath,
getFileName,
getFileExtension,
isFileType,
hasLanguageHighlighting,
splitContentIntoLines,
type FileOperation,
type OperationConfig,
} from './_utils';
import { formatTimestamp, getToolTitle } from '../utils';
import { ToolViewProps } from '../types';
import { GenericToolView } from '../GenericToolView';
import { LoadingState } from '../shared/LoadingState';
import { toast } from 'sonner';
import ReactDiffViewer from 'react-diff-viewer-continued';
const UnifiedDiffView: React.FC<{ oldCode: string; newCode: string }> = ({ oldCode, newCode }) => (
@ -60,40 +99,6 @@ const UnifiedDiffView: React.FC<{ oldCode: string; newCode: string }> = ({ oldCo
/>
);
const SplitDiffView: React.FC<{ oldCode: string; newCode: string }> = ({ oldCode, newCode }) => (
<ReactDiffViewer
oldValue={oldCode}
newValue={newCode}
splitView={true}
useDarkTheme={document.documentElement.classList.contains('dark')}
styles={{
variables: {
dark: {
diffViewerColor: '#e2e8f0',
diffViewerBackground: '#09090b',
addedBackground: '#104a32',
addedColor: '#6ee7b7',
removedBackground: '#5c1a2e',
removedColor: '#fca5a5',
},
},
diffContainer: {
backgroundColor: 'var(--card)',
border: 'none',
},
gutter: {
backgroundColor: 'var(--muted)',
'&:hover': {
backgroundColor: 'var(--accent)',
},
},
line: {
fontFamily: 'monospace',
},
}}
/>
);
const ErrorState: React.FC<{ message?: string }> = ({ message }) => (
<div className="flex flex-col items-center justify-center h-full py-12 px-6 bg-gradient-to-b from-white to-zinc-50 dark:from-zinc-950 dark:to-zinc-900">
<div className="text-center w-full max-w-xs">
@ -116,8 +121,42 @@ export function FileEditToolView({
toolTimestamp,
isSuccess = true,
isStreaming = false,
project,
}: ToolViewProps): JSX.Element {
const [viewMode, setViewMode] = useState<'unified' | 'split'>('unified');
const { resolvedTheme } = useTheme();
const isDarkTheme = resolvedTheme === 'dark';
// Add copy functionality state
const [isCopyingContent, setIsCopyingContent] = useState(false);
// Copy functions
const copyToClipboard = async (text: string) => {
try {
await navigator.clipboard.writeText(text);
return true;
} catch (err) {
console.error('Failed to copy text: ', err);
return false;
}
};
const handleCopyContent = async () => {
if (!updatedContent) return;
setIsCopyingContent(true);
const success = await copyToClipboard(updatedContent);
if (success) {
toast.success('File content copied to clipboard');
} else {
toast.error('Failed to copy file content');
}
setTimeout(() => setIsCopyingContent(false), 500);
};
const operation = getOperationType(name, assistantContent);
const configs = getOperationConfigs();
const config = configs[operation] || configs['edit']; // fallback to edit config
const Icon = FileDiff; // Always use FileDiff for edit operations
const {
filePath,
@ -135,103 +174,275 @@ export function FileEditToolView({
);
const toolTitle = getToolTitle(name);
const processedFilePath = processFilePath(filePath);
const fileName = getFileName(processedFilePath);
const fileExtension = getFileExtension(fileName);
const isMarkdown = isFileType.markdown(fileExtension);
const isHtml = isFileType.html(fileExtension);
const isCsv = isFileType.csv(fileExtension);
const isXlsx = isFileType.xlsx(fileExtension);
const language = getLanguageFromFileName(fileName);
const hasHighlighting = hasLanguageHighlighting(language);
const contentLines = splitContentIntoLines(updatedContent);
const htmlPreviewUrl =
isHtml && project?.sandbox?.sandbox_url && processedFilePath
? constructHtmlPreviewUrl(project.sandbox.sandbox_url, processedFilePath)
: undefined;
const FileIcon = getFileIcon(fileName);
const lineDiff = originalContent && updatedContent ? generateLineDiff(originalContent, updatedContent) : [];
const stats: DiffStats = calculateDiffStats(lineDiff);
const shouldShowError = !isStreaming && (!actualIsSuccess || (actualIsSuccess && (originalContent === null || updatedContent === null)));
return (
<Card className="gap-0 flex border shadow-none border-t border-b-0 border-x-0 p-0 rounded-none flex-col h-full bg-card">
<CardHeader className="h-14 bg-zinc-50/80 dark:bg-zinc-900/80 backdrop-blur-sm border-b p-2 px-4 space-y-2">
<div className="flex flex-row items-center justify-between">
<div className="flex items-center gap-2">
<div className="relative p-2 rounded-lg bg-gradient-to-br from-blue-500/20 to-blue-600/10 border border-blue-500/20">
<FileDiff className="w-5 h-5 text-blue-500 dark:text-blue-400" />
</div>
<CardTitle className="text-base font-medium text-zinc-900 dark:text-zinc-100">
{toolTitle}
</CardTitle>
</div>
if (!isStreaming && !processedFilePath && !updatedContent) {
return (
<GenericToolView
name={name || 'edit-file'}
assistantContent={assistantContent}
toolContent={toolContent}
assistantTimestamp={assistantTimestamp}
toolTimestamp={toolTimestamp}
isSuccess={isSuccess}
isStreaming={isStreaming}
/>
);
}
{!isStreaming && (
<Badge
variant="secondary"
className={
actualIsSuccess
? "bg-gradient-to-b from-emerald-200 to-emerald-100 text-emerald-700 dark:from-emerald-800/50 dark:to-emerald-900/60 dark:text-emerald-300"
: "bg-gradient-to-b from-rose-200 to-rose-100 text-rose-700 dark:from-rose-800/50 dark:to-rose-900/60 dark:text-rose-300"
}
>
{actualIsSuccess ? (
<CheckCircle className="h-3.5 w-3.5 mr-1" />
) : (
<AlertTriangle className="h-3.5 w-3.5 mr-1" />
)}
{actualIsSuccess ? 'Edit applied' : 'Edit failed'}
</Badge>
)}
const renderFilePreview = () => {
if (!updatedContent) {
return (
<div className="flex items-center justify-center h-full p-12">
<div className="text-center">
<FileIcon className="h-12 w-12 mx-auto mb-4 text-zinc-400" />
<p className="text-sm text-zinc-500 dark:text-zinc-400">No content to preview</p>
</div>
</div>
</CardHeader>
);
}
<CardContent className="p-0 flex-1 flex flex-col min-h-0">
{isStreaming ? (
<LoadingState
icon={FileDiff}
iconColor="text-blue-500 dark:text-blue-400"
bgColor="bg-gradient-to-b from-blue-100 to-blue-50 shadow-inner dark:from-blue-800/40 dark:to-blue-900/60 dark:shadow-blue-950/20"
title="Applying File Edit"
filePath={filePath || 'Processing file...'}
progressText="Analyzing changes"
subtitle="Please wait while the file is being modified"
if (isHtml && htmlPreviewUrl) {
return (
<div className="flex flex-col h-[calc(100vh-16rem)]">
<iframe
src={htmlPreviewUrl}
title={`HTML Preview of ${fileName}`}
className="flex-grow border-0"
sandbox="allow-same-origin allow-scripts"
/>
) : shouldShowError ? (
<ErrorState message={errorMessage} />
) : (
<div className="flex-1 flex flex-col min-h-0">
<div className="shrink-0 p-3 border-b border-zinc-200 dark:border-zinc-800 bg-accent flex items-center justify-between">
<div className="flex items-center">
<File className="h-4 w-4 mr-2 text-zinc-500 dark:text-zinc-400" />
<code className="text-sm font-mono text-zinc-700 dark:text-zinc-300">
{filePath || 'Unknown file'}
</code>
</div>
</div>
);
}
<div className="flex items-center gap-2">
<div className="flex items-center text-xs text-zinc-500 dark:text-zinc-400 gap-3">
{stats.additions === 0 && stats.deletions === 0 ? (
<Badge variant="outline" className="text-xs font-normal">No changes</Badge>
) : (
<>
<div className="flex items-center">
<Plus className="h-3.5 w-3.5 text-emerald-500 mr-1" />
<span>{stats.additions}</span>
</div>
<div className="flex items-center">
<Minus className="h-3.5 w-3.5 text-red-500 mr-1" />
<span>{stats.deletions}</span>
</div>
</>
)}
</div>
<Tabs value={viewMode} onValueChange={(v) => setViewMode(v as 'unified' | 'split')} className="w-auto">
<TabsList className="h-7 p-0.5">
<TabsTrigger value="unified" className="text-xs h-6 px-2">Unified</TabsTrigger>
<TabsTrigger value="split" className="text-xs h-6 px-2">Split</TabsTrigger>
</TabsList>
</Tabs>
if (isMarkdown) {
return (
<div className="p-1 py-0 prose dark:prose-invert prose-zinc max-w-none">
<MarkdownRenderer
content={processUnicodeContent(updatedContent)}
/>
</div>
);
}
if (isCsv) {
return (
<div className="h-full w-full p-4">
<div className="h-[calc(100vh-17rem)] w-full bg-muted/20 border rounded-xl overflow-auto">
<CsvRenderer content={processUnicodeContent(updatedContent)} />
</div>
</div>
);
}
if (isXlsx) {
return (
<div className="h-full w-full p-4">
<div className="h-[calc(100vh-17rem)] w-full bg-muted/20 border rounded-xl overflow-auto">
<XlsxRenderer
content={updatedContent}
filePath={processedFilePath}
fileName={fileName}
project={project}
/>
</div>
</div>
);
}
return (
<div className="p-4">
<div className='w-full h-full bg-muted/20 border rounded-xl px-4 py-2 pb-6'>
<pre className="text-sm font-mono text-zinc-800 dark:text-zinc-300 whitespace-pre-wrap break-words">
{processUnicodeContent(updatedContent)}
</pre>
</div>
</div>
);
};
const renderSourceCode = () => {
if (!originalContent || !updatedContent) {
return (
<div className="flex items-center justify-center h-full p-12">
<div className="text-center">
<FileIcon className="h-12 w-12 mx-auto mb-4 text-zinc-400" />
<p className="text-sm text-zinc-500 dark:text-zinc-400">No diff to display</p>
</div>
</div>
);
}
// Show unified diff view in source tab
return (
<div className="flex-1 overflow-auto min-h-0 text-xs">
<UnifiedDiffView oldCode={originalContent} newCode={updatedContent} />
</div>
);
};
return (
<Card className="flex border shadow-none border-t border-b-0 border-x-0 p-0 rounded-none flex-col h-full overflow-hidden bg-card">
<Tabs defaultValue={isMarkdown || isHtml || isCsv || isXlsx ? 'preview' : 'code'} className="w-full h-full">
<CardHeader className="h-14 bg-zinc-50/80 dark:bg-zinc-900/80 backdrop-blur-sm border-b p-2 px-4 space-y-2 mb-0">
<div className="flex flex-row items-center justify-between">
<div className="flex items-center gap-2">
<div className="relative p-2 rounded-lg bg-gradient-to-br from-blue-500/20 to-blue-600/10 border border-blue-500/20">
<FileDiff className="w-5 h-5 text-blue-500 dark:text-blue-400" />
</div>
<div>
<CardTitle className="text-base font-medium text-zinc-900 dark:text-zinc-100">
{toolTitle}
</CardTitle>
</div>
</div>
<div className="flex-1 overflow-auto min-h-0 text-xs">
{viewMode === 'unified' ? (
<UnifiedDiffView oldCode={originalContent!} newCode={updatedContent!} />
) : (
<SplitDiffView oldCode={originalContent!} newCode={updatedContent!} />
<div className='flex items-center gap-2'>
{isHtml && htmlPreviewUrl && !isStreaming && (
<Button variant="outline" size="sm" className="h-8 text-xs bg-white dark:bg-muted/50 hover:bg-zinc-100 dark:hover:bg-zinc-800 shadow-none" asChild>
<a href={htmlPreviewUrl} target="_blank" rel="noopener noreferrer">
<ExternalLink className="h-3.5 w-3.5 mr-1.5" />
Open in Browser
</a>
</Button>
)}
{/* Copy button - only show when there's file content */}
{updatedContent && !isStreaming && (
<Button
variant="outline"
size="sm"
onClick={handleCopyContent}
disabled={isCopyingContent}
className="h-8 text-xs bg-white dark:bg-muted/50 hover:bg-zinc-100 dark:hover:bg-zinc-800 shadow-none"
title="Copy file content"
>
{isCopyingContent ? (
<Check className="h-3.5 w-3.5 mr-1.5" />
) : (
<Copy className="h-3.5 w-3.5 mr-1.5" />
)}
<span className="hidden sm:inline">Copy</span>
</Button>
)}
{/* Diff mode selector for source tab */}
{originalContent && updatedContent && (
<div className="flex items-center gap-2">
<div className="flex items-center text-xs text-zinc-500 dark:text-zinc-400 gap-3">
{stats.additions === 0 && stats.deletions === 0 && (
<Badge variant="outline" className="text-xs font-normal">No changes</Badge>
)}
</div>
</div>
)}
<TabsList className="h-8 bg-muted/50 border border-border/50 p-0.5 gap-1">
<TabsTrigger
value="code"
className="flex items-center gap-1.5 px-4 py-2 text-xs font-medium transition-all [&[data-state=active]]:bg-white [&[data-state=active]]:dark:bg-primary/10 [&[data-state=active]]:text-foreground hover:bg-background/50 text-muted-foreground shadow-none"
>
<Code className="h-3.5 w-3.5" />
Source
</TabsTrigger>
<TabsTrigger
value="preview"
className="flex items-center gap-1.5 px-4 py-2 text-xs font-medium transition-all [&[data-state=active]]:bg-white [&[data-state=active]]:dark:bg-primary/10 [&[data-state=active]]:text-foreground hover:bg-background/50 text-muted-foreground shadow-none"
>
<Eye className="h-3.5 w-3.5" />
Preview
</TabsTrigger>
</TabsList>
</div>
</div>
)}
</CardContent>
</CardHeader>
<CardContent className="p-0 -my-2 h-full flex-1 overflow-hidden relative">
<TabsContent value="code" className="flex-1 h-full mt-0 p-0 overflow-hidden">
<ScrollArea className="h-screen w-full min-h-0">
{isStreaming && !updatedContent ? (
<LoadingState
icon={FileDiff}
iconColor="text-blue-500 dark:text-blue-400"
bgColor="bg-gradient-to-b from-blue-100 to-blue-50 shadow-inner dark:from-blue-800/40 dark:to-blue-900/60 dark:shadow-blue-950/20"
title="Applying File Edit"
filePath={processedFilePath || 'Processing file...'}
subtitle="Please wait while the file is being modified"
showProgress={false}
/>
) : shouldShowError ? (
<ErrorState message={errorMessage} />
) : (
renderSourceCode()
)}
</ScrollArea>
</TabsContent>
<TabsContent value="preview" className="w-full flex-1 h-full mt-0 p-0 overflow-hidden">
<ScrollArea className="h-full w-full min-h-0">
{isStreaming && !updatedContent ? (
<LoadingState
icon={FileDiff}
iconColor="text-blue-500 dark:text-blue-400"
bgColor="bg-gradient-to-b from-blue-100 to-blue-50 shadow-inner dark:from-blue-800/40 dark:to-blue-900/60 dark:shadow-blue-950/20"
title="Applying File Edit"
filePath={processedFilePath || 'Processing file...'}
subtitle="Please wait while the file is being modified"
showProgress={false}
/>
) : shouldShowError ? (
<ErrorState message={errorMessage} />
) : (
renderFilePreview()
)}
{isStreaming && updatedContent && (
<div className="sticky bottom-4 right-4 float-right mr-4 mb-4">
<Badge className="bg-blue-500/90 text-white border-none shadow-lg animate-pulse">
<Loader2 className="h-3 w-3 animate-spin mr-1" />
Streaming...
</Badge>
</div>
)}
</ScrollArea>
</TabsContent>
</CardContent>
<div className="px-4 py-2 h-10 bg-gradient-to-r from-zinc-50/90 to-zinc-100/90 dark:from-zinc-900/90 dark:to-zinc-800/90 backdrop-blur-sm border-t border-zinc-200 dark:border-zinc-800 flex justify-between items-center gap-4">
<div className="h-full flex items-center gap-2 text-sm text-zinc-500 dark:text-zinc-400">
<Badge variant="outline" className="py-0.5 h-6">
<FileIcon className="h-3 w-3" />
{hasHighlighting ? language.toUpperCase() : fileExtension.toUpperCase() || 'TEXT'}
</Badge>
</div>
<div className="text-xs text-zinc-500 dark:text-zinc-400">
{actualToolTimestamp && !isStreaming
? formatTimestamp(actualToolTimestamp)
: assistantTimestamp
? formatTimestamp(assistantTimestamp)
: ''}
</div>
</div>
</Tabs>
</Card>
);
}

View File

@ -1695,6 +1695,10 @@ export interface Model {
input_cost_per_million_tokens?: number | null;
output_cost_per_million_tokens?: number | null;
max_tokens?: number | null;
context_window?: number;
capabilities?: string[];
recommended?: boolean;
priority?: number;
}
export interface AvailableModelsResponse {

View File

@ -0,0 +1,141 @@
import { create } from 'zustand';
import { persist } from 'zustand/middleware';
import { isLocalMode } from '@/lib/config';
export interface CustomModel {
id: string;
label: string;
}
export interface ModelOption {
id: string;
label: string;
requiresSubscription: boolean;
description?: string;
top?: boolean;
isCustom?: boolean;
priority?: number;
recommended?: boolean;
capabilities?: string[];
contextWindow?: number;
}
export type SubscriptionStatus = 'no_subscription' | 'active';
interface ModelStore {
selectedModel: string;
customModels: CustomModel[];
hasHydrated: boolean;
setSelectedModel: (model: string) => void;
addCustomModel: (model: CustomModel) => void;
updateCustomModel: (id: string, model: CustomModel) => void;
removeCustomModel: (id: string) => void;
setCustomModels: (models: CustomModel[]) => void;
setHasHydrated: (hydrated: boolean) => void;
getDefaultModel: (subscriptionStatus: SubscriptionStatus) => string;
resetToDefault: (subscriptionStatus: SubscriptionStatus) => void;
}
const DEFAULT_FREE_MODEL_ID = 'moonshotai/kimi-k2';
const DEFAULT_PREMIUM_MODEL_ID = 'claude-sonnet-4';
export const useModelStore = create<ModelStore>()(
persist(
(set, get) => ({
selectedModel: DEFAULT_FREE_MODEL_ID,
customModels: [],
hasHydrated: false,
setSelectedModel: (model: string) => {
set({ selectedModel: model });
},
addCustomModel: (model: CustomModel) => {
const { customModels } = get();
if (customModels.some(existing => existing.id === model.id)) {
return;
}
const newCustomModels = [...customModels, model];
set({ customModels: newCustomModels });
},
updateCustomModel: (id: string, model: CustomModel) => {
const { customModels } = get();
const newCustomModels = customModels.map(existing =>
existing.id === id ? model : existing
);
set({ customModels: newCustomModels });
},
removeCustomModel: (id: string) => {
const { customModels, selectedModel } = get();
const newCustomModels = customModels.filter(model => model.id !== id);
const updates: Partial<ModelStore> = { customModels: newCustomModels };
if (selectedModel === id) {
updates.selectedModel = DEFAULT_FREE_MODEL_ID;
}
set(updates);
},
setCustomModels: (models: CustomModel[]) => {
set({ customModels: models });
},
setHasHydrated: (hydrated: boolean) => {
set({ hasHydrated: hydrated });
},
getDefaultModel: (subscriptionStatus: SubscriptionStatus) => {
if (isLocalMode()) {
return DEFAULT_PREMIUM_MODEL_ID;
}
return subscriptionStatus === 'active' ? DEFAULT_PREMIUM_MODEL_ID : DEFAULT_FREE_MODEL_ID;
},
resetToDefault: (subscriptionStatus: SubscriptionStatus) => {
const defaultModel = get().getDefaultModel(subscriptionStatus);
set({ selectedModel: defaultModel });
},
}),
{
name: 'suna-model-selection-v2',
partialize: (state) => ({
selectedModel: state.selectedModel,
customModels: state.customModels,
}),
onRehydrateStorage: () => (state) => {
if (state) {
state.setHasHydrated(true);
}
},
}
)
);
export const canAccessModel = (
subscriptionStatus: SubscriptionStatus,
requiresSubscription: boolean,
): boolean => {
if (isLocalMode()) {
return true;
}
return subscriptionStatus === 'active' || !requiresSubscription;
};
export const formatModelName = (name: string): string => {
return name
.split('-')
.map(word => word.charAt(0).toUpperCase() + word.slice(1))
.join(' ');
};
export const getPrefixedModelId = (modelId: string, isCustom: boolean): string => {
if (isCustom && !modelId.startsWith('openrouter/')) {
return `openrouter/${modelId}`;
}
return modelId;
};