mirror of https://github.com/kortix-ai/suna.git
Compare commits
6 Commits
45d238aa62
...
578a8e4a0a
Author | SHA1 | Date |
---|---|---|
|
578a8e4a0a | |
|
a6caaf42a3 | |
|
56b46672f7 | |
|
27c523d1de | |
|
7e10d736b5 | |
|
e0ad5cf2cd |
|
@ -323,8 +323,9 @@ async def start_agent(
|
|||
model_name = body.model_name
|
||||
logger.debug(f"Original model_name from request: {model_name}")
|
||||
|
||||
# Log the model name after alias resolution
|
||||
resolved_model = MODEL_NAME_ALIASES.get(model_name, model_name)
|
||||
# Log the model name after alias resolution using new model manager
|
||||
from models import model_manager
|
||||
resolved_model = model_manager.resolve_model_id(model_name)
|
||||
logger.debug(f"Resolved model name: {resolved_model}")
|
||||
|
||||
# Update model_name to use the resolved version
|
||||
|
@ -974,8 +975,9 @@ async def initiate_agent_with_files(
|
|||
model_name = "openai/gpt-5-mini"
|
||||
logger.debug(f"Using default model: {model_name}")
|
||||
|
||||
# Log the model name after alias resolution
|
||||
resolved_model = MODEL_NAME_ALIASES.get(model_name, model_name)
|
||||
from models import model_manager
|
||||
# Log the model name after alias resolution using new model manager
|
||||
resolved_model = model_manager.resolve_model_id(model_name)
|
||||
logger.debug(f"Resolved model name: {resolved_model}")
|
||||
|
||||
# Update model_name to use the resolved version
|
||||
|
@ -1935,15 +1937,19 @@ async def create_agent(
|
|||
version_service = await _get_version_service()
|
||||
from agent.suna_config import SUNA_CONFIG
|
||||
from agent.config_helper import _get_default_agentpress_tools
|
||||
from models import model_manager
|
||||
|
||||
system_prompt = SUNA_CONFIG["system_prompt"]
|
||||
|
||||
# Use default tools if none specified, ensuring builder tools are included
|
||||
agentpress_tools = agent_data.agentpress_tools if agent_data.agentpress_tools else _get_default_agentpress_tools()
|
||||
|
||||
default_model = await model_manager.get_default_model_for_user(client, user_id)
|
||||
|
||||
version = await version_service.create_version(
|
||||
agent_id=agent['agent_id'],
|
||||
user_id=user_id,
|
||||
system_prompt=system_prompt,
|
||||
model=default_model,
|
||||
configured_mcps=agent_data.configured_mcps or [],
|
||||
custom_mcps=agent_data.custom_mcps or [],
|
||||
agentpress_tools=agentpress_tools,
|
||||
|
|
|
@ -254,8 +254,8 @@ You have the abilixwty to execute operations using both Python and CLI tools:
|
|||
* After image generation/editing, ALWAYS display the result using the ask tool with the image attached
|
||||
* The tool automatically saves images to the workspace with unique filenames
|
||||
* **REMEMBER THE LAST IMAGE:** Always use the most recently generated image filename for follow-up edits
|
||||
* **SHARE PERMANENTLY:** Use `upload_file` to upload generated images to cloud storage for permanent URLs
|
||||
* **CLOUD WORKFLOW:** Generate/Edit → Save to workspace → Upload to "file-uploads" bucket → Share public URL with user
|
||||
* **OPTIONAL CLOUD SHARING:** Ask user if they want to upload images: "Would you like me to upload this image to secure cloud storage for sharing?"
|
||||
* **CLOUD WORKFLOW (if requested):** Generate/Edit → Save to workspace → Ask user → Upload to "file-uploads" bucket if requested → Share public URL with user
|
||||
|
||||
### 2.3.9 DATA PROVIDERS
|
||||
- You have access to a variety of data providers that you can use to get data for your tasks.
|
||||
|
@ -280,10 +280,13 @@ You have the abilixwty to execute operations using both Python and CLI tools:
|
|||
* **Security:** Files stored in user-isolated folders, private bucket, signed URL access only
|
||||
|
||||
**WHEN TO USE upload_file:**
|
||||
* User asks to share files or make them accessible via secure URL
|
||||
* Need to persist files beyond the sandbox session with access control
|
||||
* Need to export generated content (reports, images, data) for controlled external access
|
||||
* Want to create secure, time-limited sharing links for deliverables
|
||||
* **ONLY when user explicitly requests file sharing** or asks for permanent URLs
|
||||
* **ONLY when user asks for files to be accessible externally** or beyond the sandbox session
|
||||
* **ASK USER FIRST** before uploading in most cases: "Would you like me to upload this file to secure cloud storage for sharing?"
|
||||
* User specifically requests file sharing or external access
|
||||
* User asks for permanent or persistent file access
|
||||
* User requests deliverables that need to be shared with others
|
||||
* **DO NOT automatically upload** files unless explicitly requested by the user
|
||||
|
||||
**UPLOAD PARAMETERS:**
|
||||
* `file_path`: Path relative to /workspace (e.g., "report.pdf", "data/results.csv")
|
||||
|
@ -291,18 +294,20 @@ You have the abilixwty to execute operations using both Python and CLI tools:
|
|||
* `custom_filename`: Optional custom name for the uploaded file
|
||||
|
||||
**STORAGE BUCKETS:**
|
||||
* "file-uploads" (default): Secure private storage with user isolation, signed URL access, 24-hour expiration
|
||||
* "browser-screenshots": Public bucket ONLY for actual browser screenshots captured during browser automation
|
||||
* "file-uploads" (default): Secure private storage with user isolation, signed URL access, 24-hour expiration - USE ONLY WHEN REQUESTED
|
||||
* "browser-screenshots": Public bucket ONLY for actual browser screenshots captured during browser automation - CONTINUES NORMAL BEHAVIOR
|
||||
|
||||
**UPLOAD WORKFLOW EXAMPLES:**
|
||||
* Basic secure upload:
|
||||
* Ask before uploading:
|
||||
"I've created the report. Would you like me to upload it to secure cloud storage for sharing?"
|
||||
If user says yes:
|
||||
<function_calls>
|
||||
<invoke name="upload_file">
|
||||
<parameter name="file_path">output/report.pdf</parameter>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
|
||||
* Upload with custom naming:
|
||||
* Upload with custom naming (only after user request):
|
||||
<function_calls>
|
||||
<invoke name="upload_file">
|
||||
<parameter name="file_path">generated_image.png</parameter>
|
||||
|
@ -311,18 +316,22 @@ You have the abilixwty to execute operations using both Python and CLI tools:
|
|||
</function_calls>
|
||||
|
||||
**UPLOAD BEST PRACTICES:**
|
||||
* Always upload important deliverables to provide secure, time-limited URLs
|
||||
* Use default "file-uploads" bucket for all general content (reports, images, presentations, data files)
|
||||
* Use "browser-screenshots" ONLY for actual browser automation screenshots
|
||||
* **ASK FIRST**: "Would you like me to upload this file for sharing or permanent access?"
|
||||
* **EXPLAIN PURPOSE**: Tell users why upload might be useful ("for sharing with others", "for permanent access")
|
||||
* **RESPECT USER CHOICE**: If user says no, don't upload
|
||||
* **DEFAULT TO LOCAL**: Keep files local unless user specifically needs external access
|
||||
* Use default "file-uploads" bucket ONLY when user requests uploads
|
||||
* Use "browser-screenshots" ONLY for actual browser automation screenshots (unchanged behavior)
|
||||
* Provide the secure URL to users but explain it expires in 24 hours
|
||||
* Upload before marking tasks as complete
|
||||
* **BROWSER SCREENSHOTS EXCEPTION**: Browser screenshots continue normal upload behavior without asking
|
||||
* Files are stored with user isolation for security (each user can only access their own files)
|
||||
|
||||
**INTEGRATED WORKFLOW WITH OTHER TOOLS:**
|
||||
* Create file with sb_files_tool → Upload with upload_file → Share secure URL with user
|
||||
* Generate image → Upload to secure cloud → Provide time-limited access link
|
||||
* Scrape data → Save to file → Upload for secure sharing
|
||||
* Create report → Upload with secure access
|
||||
* Create file with tools → **ASK USER** if they want to upload → Upload only if requested → Share secure URL if uploaded
|
||||
* Generate image → **ASK USER** if they need cloud storage → Upload only if requested
|
||||
* Scrape data → Save to file → **ASK USER** about uploading for sharing
|
||||
* Create report → **ASK USER** before uploading
|
||||
* **BROWSER SCREENSHOTS**: Continue automatic upload behavior (no changes)
|
||||
|
||||
# 3. TOOLKIT & METHODOLOGY
|
||||
|
||||
|
@ -643,8 +652,8 @@ IMPORTANT: Use the `cat` command to view contents of small files (100 kb or less
|
|||
3. Parse content using appropriate tools based on content type
|
||||
4. Respect web content limitations - not all content may be accessible
|
||||
5. Extract only the relevant portions of web content
|
||||
6. **Upload scraped data:** Use `upload_file` to share extracted content via permanent URLs
|
||||
7. **Research deliverables:** Scrape → Process → Save → Upload → Share URL for analysis results
|
||||
6. **ASK BEFORE UPLOADING:** Ask users if they want scraped data uploaded: "Would you like me to upload the extracted content for sharing?"
|
||||
7. **CONDITIONAL RESEARCH DELIVERABLES:** Scrape → Process → Save → Ask user about upload → Share URL only if requested
|
||||
|
||||
- Data Freshness:
|
||||
1. Always check publication dates of search results
|
||||
|
@ -970,11 +979,12 @@ When executing a workflow, adopt this mindset:
|
|||
- Use descriptive keywords for better image relevance
|
||||
- Test image URLs before downloading to ensure they work
|
||||
|
||||
4. **UPLOAD FOR SHARING:**
|
||||
- After creating the presentation, use `upload_file` to upload the HTML preview and/or exported PPTX
|
||||
- Upload to "file-uploads" bucket for all presentation content
|
||||
- Share the public URL with users for easy access and distribution
|
||||
- Example: `upload_file` with `file_path="presentations/my-presentation/presentation.html"`
|
||||
4. **ASK ABOUT UPLOAD FOR SHARING:**
|
||||
- After creating the presentation, ask: "Would you like me to upload this presentation to secure cloud storage for sharing?"
|
||||
- Only use `upload_file` to upload the HTML preview and/or exported PPTX if user requests it
|
||||
- Upload to "file-uploads" bucket for all presentation content only when requested
|
||||
- Share the public URL with users for easy access and distribution only if uploaded
|
||||
- Example: `upload_file` with `file_path="presentations/my-presentation/presentation.html"` only after user confirms
|
||||
|
||||
**NEVER create a presentation without downloading images first. This is a MANDATORY step for professional presentations.**
|
||||
|
||||
|
@ -1001,19 +1011,20 @@ For large outputs and complex content, use files instead of long responses:
|
|||
- Make files easily editable and shareable
|
||||
- Attach files when sharing with users via 'ask' tool
|
||||
- Use files as persistent artifacts that users can reference and modify
|
||||
- **UPLOAD FOR SHARING:** After creating important files, use the 'upload_file' tool to get a permanent shareable URL
|
||||
- **CLOUD PERSISTENCE:** Upload deliverables to ensure they persist beyond the sandbox session
|
||||
- **ASK BEFORE UPLOADING:** Ask users if they want files uploaded: "Would you like me to upload this file to secure cloud storage for sharing?"
|
||||
- **CONDITIONAL CLOUD PERSISTENCE:** Upload deliverables only when specifically requested for sharing or external access
|
||||
|
||||
**FILE SHARING WORKFLOW:**
|
||||
1. Create comprehensive file with all content
|
||||
2. Edit and refine the file as needed
|
||||
3. **Upload to secure cloud storage using 'upload_file' for controlled access**
|
||||
4. Share the secure signed URL with the user (note: expires in 24 hours)
|
||||
3. **ASK USER:** "Would you like me to upload this file to secure cloud storage for sharing?"
|
||||
4. **Upload only if requested** using 'upload_file' for controlled access
|
||||
5. Share the secure signed URL with the user (note: expires in 24 hours) - only if uploaded
|
||||
|
||||
**EXAMPLE FILE USAGE:**
|
||||
- Single request → `travel_plan.md` (contains itinerary, accommodation, packing list, etc.) → Upload → Share secure URL (24hr expiry)
|
||||
- Single request → `research_report.md` (contains all findings, analysis, conclusions) → Upload → Share secure URL (24hr expiry)
|
||||
- Single request → `project_guide.md` (contains setup, implementation, testing, documentation) → Upload → Share secure URL (24hr expiry)
|
||||
- Single request → `travel_plan.md` (contains itinerary, accommodation, packing list, etc.) → Ask user about upload → Upload only if requested → Share secure URL (24hr expiry) if uploaded
|
||||
- Single request → `research_report.md` (contains all findings, analysis, conclusions) → Ask user about upload → Upload only if requested → Share secure URL (24hr expiry) if uploaded
|
||||
- Single request → `project_guide.md` (contains setup, implementation, testing, documentation) → Ask user about upload → Upload only if requested → Share secure URL (24hr expiry) if uploaded
|
||||
|
||||
## 6.2 DESIGN GUIDELINES
|
||||
|
||||
|
@ -1197,8 +1208,8 @@ To make conversations feel natural and human-like:
|
|||
* When creating data analysis results, charts must be attached, not just described
|
||||
* Remember: If the user should SEE it, you must ATTACH it with the 'ask' tool
|
||||
* Verify that ALL visual outputs have been attached before proceeding
|
||||
* **SECURE UPLOAD INTEGRATION:** When you've uploaded files using 'upload_file', include the secure signed URL in your message (note: expires in 24 hours)
|
||||
* **DUAL SHARING:** Attach local files AND provide secure signed URLs when available for controlled access
|
||||
* **CONDITIONAL SECURE UPLOAD INTEGRATION:** IF you've uploaded files using 'upload_file' (only when user requested), include the secure signed URL in your message (note: expires in 24 hours)
|
||||
* **DUAL SHARING:** Attach local files AND provide secure signed URLs only when user has requested uploads for controlled access
|
||||
|
||||
- **Attachment Checklist:**
|
||||
* Data visualizations (charts, graphs, plots)
|
||||
|
@ -1210,7 +1221,7 @@ To make conversations feel natural and human-like:
|
|||
* Analysis results with visual components
|
||||
* UI designs and mockups
|
||||
* Any file intended for user viewing or interaction
|
||||
* **Secure signed URLs** (when using upload_file tool - note 24hr expiry)
|
||||
* **Secure signed URLs** (only when user requested upload_file tool usage - note 24hr expiry)
|
||||
|
||||
|
||||
# 9. COMPLETION PROTOCOLS
|
||||
|
|
|
@ -515,6 +515,7 @@ class AgentRunner:
|
|||
return await mcp_manager.register_mcp_tools(self.config.agent_config)
|
||||
|
||||
def get_max_tokens(self) -> Optional[int]:
|
||||
logger.debug(f"get_max_tokens called with: '{self.config.model_name}' (type: {type(self.config.model_name)})")
|
||||
if "sonnet" in self.config.model_name.lower():
|
||||
return 8192
|
||||
elif "gpt-4" in self.config.model_name.lower():
|
||||
|
@ -535,7 +536,7 @@ class AgentRunner:
|
|||
self.config.thread_id,
|
||||
mcp_wrapper_instance, self.client
|
||||
)
|
||||
|
||||
logger.debug(f"model_name received: {self.config.model_name}")
|
||||
iteration_count = 0
|
||||
continue_execution = True
|
||||
|
||||
|
@ -572,7 +573,7 @@ class AgentRunner:
|
|||
|
||||
temporary_message = await message_manager.build_temporary_message()
|
||||
max_tokens = self.get_max_tokens()
|
||||
|
||||
logger.debug(f"max_tokens: {max_tokens}")
|
||||
generation = self.config.trace.generation(name="thread_manager.run_thread") if self.config.trace else None
|
||||
try:
|
||||
response = await self.thread_manager.run_thread(
|
||||
|
@ -720,13 +721,15 @@ async def run_agent(
|
|||
trace: Optional[StatefulTraceClient] = None
|
||||
):
|
||||
effective_model = model_name
|
||||
if model_name == "openai/gpt-5-mini" and agent_config and agent_config.get('model'):
|
||||
is_tier_default = model_name in ["Kimi K2", "Claude Sonnet 4", "openai/gpt-5-mini"]
|
||||
|
||||
if is_tier_default and agent_config and agent_config.get('model'):
|
||||
effective_model = agent_config['model']
|
||||
logger.debug(f"Using model from agent config: {effective_model} (no user selection)")
|
||||
elif model_name != "openai/gpt-5-mini":
|
||||
logger.debug(f"Using model from agent config: {effective_model} (tier default was {model_name})")
|
||||
elif not is_tier_default:
|
||||
logger.debug(f"Using user-selected model: {effective_model}")
|
||||
else:
|
||||
logger.debug(f"Using default model: {effective_model}")
|
||||
logger.debug(f"Using tier default model: {effective_model}")
|
||||
|
||||
config = AgentConfig(
|
||||
thread_id=thread_id,
|
||||
|
|
|
@ -320,14 +320,25 @@ class VersionService:
|
|||
|
||||
client = await self._get_client()
|
||||
|
||||
result = await client.table('agent_versions').select('*').eq(
|
||||
'agent_id', agent_id
|
||||
).eq('is_active', True).execute()
|
||||
|
||||
if not result.data:
|
||||
agent_result = await client.table('agents').select('current_version_id').eq('agent_id', agent_id).execute()
|
||||
if not agent_result.data or not agent_result.data[0].get('current_version_id'):
|
||||
logger.warning(f"No current_version_id found for agent {agent_id}")
|
||||
return None
|
||||
|
||||
return self._version_from_db_row(result.data[0])
|
||||
current_version_id = agent_result.data[0]['current_version_id']
|
||||
logger.debug(f"Agent {agent_id} current_version_id: {current_version_id}")
|
||||
|
||||
result = await client.table('agent_versions').select('*').eq(
|
||||
'version_id', current_version_id
|
||||
).eq('agent_id', agent_id).execute()
|
||||
|
||||
if not result.data:
|
||||
logger.warning(f"Current version {current_version_id} not found for agent {agent_id}")
|
||||
return None
|
||||
|
||||
version = self._version_from_db_row(result.data[0])
|
||||
logger.debug(f"Retrieved active version for agent {agent_id}: model='{version.model}', version_name='{version.version_name}'")
|
||||
return version
|
||||
|
||||
async def get_all_versions(self, agent_id: str, user_id: str) -> List[AgentVersion]:
|
||||
is_owner, is_public = await self._verify_agent_access(agent_id, user_id)
|
||||
|
|
|
@ -692,20 +692,10 @@ async def create_composio_trigger(req: CreateComposioTriggerRequest, current_use
|
|||
|
||||
body = {
|
||||
"user_id": composio_user_id,
|
||||
"userId": composio_user_id,
|
||||
"trigger_config": coerced_config,
|
||||
"triggerConfig": coerced_config,
|
||||
"webhook": {
|
||||
"url": req.webhook_url or f"{base_url}/api/composio/webhook",
|
||||
"headers": webhook_headers,
|
||||
"method": "POST",
|
||||
},
|
||||
}
|
||||
if req.connected_account_id:
|
||||
body["connectedAccountId"] = req.connected_account_id
|
||||
body["connected_account_id"] = req.connected_account_id
|
||||
body["connectedAccountIds"] = [req.connected_account_id]
|
||||
body["connected_account_ids"] = [req.connected_account_id]
|
||||
|
||||
async with httpx.AsyncClient(timeout=20) as http_client:
|
||||
resp = await http_client.post(url, headers=headers, json=body)
|
||||
|
@ -838,6 +828,8 @@ async def composio_webhook(request: Request):
|
|||
logger.info("Composio webhook body read failed", error=str(e))
|
||||
body_str = ""
|
||||
|
||||
# Get webhook ID early for logging
|
||||
wid = request.headers.get("webhook-id", "")
|
||||
|
||||
# Minimal request diagnostics (no secrets)
|
||||
try:
|
||||
|
@ -859,15 +851,6 @@ async def composio_webhook(request: Request):
|
|||
}
|
||||
except Exception:
|
||||
payload_preview = {"keys": []}
|
||||
logger.debug(
|
||||
"Composio webhook incoming",
|
||||
client_ip=client_ip,
|
||||
header_names=header_names,
|
||||
has_authorization=has_auth,
|
||||
has_x_composio_secret=has_x_secret,
|
||||
has_x_trigger_secret=has_x_trigger,
|
||||
payload_meta=payload_preview,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
@ -882,10 +865,10 @@ async def composio_webhook(request: Request):
|
|||
# Parse payload for processing
|
||||
try:
|
||||
payload = json.loads(body_str) if body_str else {}
|
||||
except Exception:
|
||||
except Exception as parse_error:
|
||||
logger.error(f"Failed to parse webhook payload: {parse_error}", payload_raw=body_str)
|
||||
payload = {}
|
||||
|
||||
wid = request.headers.get("webhook-id", "")
|
||||
# Look for trigger_nano_id in data.trigger_nano_id (the actual Composio trigger instance ID)
|
||||
composio_trigger_id = (
|
||||
(payload.get("data", {}) or {}).get("trigger_nano_id")
|
||||
|
@ -906,11 +889,11 @@ async def composio_webhook(request: Request):
|
|||
|
||||
# Basic parsed-field logging (no secrets)
|
||||
try:
|
||||
logger.debug(
|
||||
logger.info(
|
||||
"Composio parsed fields",
|
||||
webhook_id=wid,
|
||||
trigger_slug=trigger_slug,
|
||||
payload_id=composio_trigger_id,
|
||||
composio_trigger_id=composio_trigger_id,
|
||||
provider_event_id=provider_event_id,
|
||||
payload_keys=list(payload.keys()) if isinstance(payload, dict) else [],
|
||||
)
|
||||
|
@ -933,14 +916,6 @@ async def composio_webhook(request: Request):
|
|||
rows = []
|
||||
|
||||
matched = []
|
||||
try:
|
||||
logger.debug(
|
||||
"Composio matching begin",
|
||||
have_id=bool(composio_trigger_id),
|
||||
payload_id=composio_trigger_id,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for row in rows:
|
||||
cfg = row.get("config") or {}
|
||||
|
@ -948,43 +923,31 @@ async def composio_webhook(request: Request):
|
|||
continue
|
||||
prov = cfg.get("provider_id") or row.get("provider_id")
|
||||
if prov != "composio":
|
||||
try:
|
||||
logger.debug("Composio skip non-provider", trigger_id=row.get("trigger_id"), provider_id=prov)
|
||||
except Exception:
|
||||
pass
|
||||
logger.debug("Composio skip non-provider", trigger_id=row.get("trigger_id"), provider_id=prov)
|
||||
continue
|
||||
|
||||
# ONLY match by exact composio_trigger_id - no slug fallback
|
||||
cfg_tid = cfg.get("composio_trigger_id")
|
||||
if composio_trigger_id and cfg_tid == composio_trigger_id:
|
||||
logger.debug(
|
||||
logger.info(
|
||||
"Composio EXACT ID MATCH",
|
||||
trigger_id=row.get("trigger_id"),
|
||||
cfg_id=cfg_tid,
|
||||
payload_id=composio_trigger_id
|
||||
cfg_composio_trigger_id=cfg_tid,
|
||||
payload_composio_trigger_id=composio_trigger_id,
|
||||
is_active=row.get("is_active")
|
||||
)
|
||||
matched.append(row)
|
||||
continue
|
||||
else:
|
||||
logger.debug(
|
||||
logger.info(
|
||||
"Composio ID mismatch",
|
||||
trigger_id=row.get("trigger_id"),
|
||||
cfg_id=cfg_tid,
|
||||
payload_id=composio_trigger_id,
|
||||
match_found=False
|
||||
cfg_composio_trigger_id=cfg_tid,
|
||||
payload_composio_trigger_id=composio_trigger_id,
|
||||
match_found=False,
|
||||
is_active=row.get("is_active")
|
||||
)
|
||||
|
||||
try:
|
||||
logger.debug(
|
||||
"Composio matching result",
|
||||
total=len(rows),
|
||||
matched=len(matched),
|
||||
have_id=bool(composio_trigger_id),
|
||||
payload_id=composio_trigger_id,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not matched:
|
||||
logger.error(
|
||||
f"No exact ID match found for Composio trigger {composio_trigger_id}",
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
from .registry import ModelRegistry, registry
|
||||
from .models import Model, ModelProvider, ModelCapability
|
||||
from .manager import ModelManager, model_manager
|
||||
|
||||
__all__ = [
|
||||
'ModelRegistry',
|
||||
'registry',
|
||||
'Model',
|
||||
'ModelProvider',
|
||||
'ModelCapability',
|
||||
'ModelManager',
|
||||
'model_manager',
|
||||
]
|
|
@ -0,0 +1,224 @@
|
|||
from typing import Optional, List, Dict, Any, Tuple
|
||||
from .registry import registry
|
||||
from .models import Model, ModelCapability
|
||||
from utils.logger import logger
|
||||
from .registry import DEFAULT_PREMIUM_MODEL, DEFAULT_FREE_MODEL
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self):
|
||||
self.registry = registry
|
||||
|
||||
def get_model(self, model_id: str) -> Optional[Model]:
|
||||
return self.registry.get(model_id)
|
||||
|
||||
def resolve_model_id(self, model_id: str) -> str:
|
||||
logger.debug(f"resolve_model_id called with: '{model_id}' (type: {type(model_id)})")
|
||||
|
||||
resolved = self.registry.resolve_model_id(model_id)
|
||||
if resolved:
|
||||
logger.debug(f"Resolved model '{model_id}' to '{resolved}'")
|
||||
return resolved
|
||||
|
||||
all_aliases = list(self.registry._aliases.keys())
|
||||
logger.warning(f"Could not resolve model ID: '{model_id}'. Available aliases: {all_aliases[:10]}...")
|
||||
return model_id
|
||||
|
||||
def validate_model(self, model_id: str) -> Tuple[bool, str]:
|
||||
model = self.get_model(model_id)
|
||||
|
||||
if not model:
|
||||
return False, f"Model '{model_id}' not found"
|
||||
|
||||
if not model.enabled:
|
||||
return False, f"Model '{model.name}' is currently disabled"
|
||||
|
||||
return True, ""
|
||||
|
||||
def calculate_cost(
|
||||
self,
|
||||
model_id: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int
|
||||
) -> Optional[float]:
|
||||
model = self.get_model(model_id)
|
||||
if not model or not model.pricing:
|
||||
logger.warning(f"No pricing available for model: {model_id}")
|
||||
return None
|
||||
|
||||
input_cost = input_tokens * model.pricing.input_cost_per_token
|
||||
output_cost = output_tokens * model.pricing.output_cost_per_token
|
||||
total_cost = input_cost + output_cost
|
||||
|
||||
logger.debug(
|
||||
f"Cost calculation for {model.name}: "
|
||||
f"{input_tokens} input tokens (${input_cost:.6f}) + "
|
||||
f"{output_tokens} output tokens (${output_cost:.6f}) = "
|
||||
f"${total_cost:.6f}"
|
||||
)
|
||||
|
||||
return total_cost
|
||||
|
||||
def get_models_for_tier(self, tier: str) -> List[Model]:
|
||||
return self.registry.get_by_tier(tier, enabled_only=True)
|
||||
|
||||
def get_models_with_capability(self, capability: ModelCapability) -> List[Model]:
|
||||
return self.registry.get_by_capability(capability, enabled_only=True)
|
||||
|
||||
def select_best_model(
|
||||
self,
|
||||
tier: str,
|
||||
required_capabilities: Optional[List[ModelCapability]] = None,
|
||||
min_context_window: Optional[int] = None,
|
||||
prefer_cheaper: bool = False
|
||||
) -> Optional[Model]:
|
||||
models = self.get_models_for_tier(tier)
|
||||
|
||||
if required_capabilities:
|
||||
models = [
|
||||
m for m in models
|
||||
if all(cap in m.capabilities for cap in required_capabilities)
|
||||
]
|
||||
|
||||
if min_context_window:
|
||||
models = [m for m in models if m.context_window >= min_context_window]
|
||||
|
||||
if not models:
|
||||
return None
|
||||
|
||||
if prefer_cheaper and any(m.pricing for m in models):
|
||||
models_with_pricing = [m for m in models if m.pricing]
|
||||
if models_with_pricing:
|
||||
models = sorted(
|
||||
models_with_pricing,
|
||||
key=lambda m: m.pricing.input_cost_per_million_tokens
|
||||
)
|
||||
else:
|
||||
models = sorted(
|
||||
models,
|
||||
key=lambda m: (-m.priority, not m.recommended)
|
||||
)
|
||||
|
||||
return models[0] if models else None
|
||||
|
||||
def get_default_model(self, tier: str = "free") -> Optional[Model]:
|
||||
models = self.get_models_for_tier(tier)
|
||||
|
||||
recommended = [m for m in models if m.recommended]
|
||||
if recommended:
|
||||
recommended = sorted(recommended, key=lambda m: -m.priority)
|
||||
return recommended[0]
|
||||
|
||||
if models:
|
||||
models = sorted(models, key=lambda m: -m.priority)
|
||||
return models[0]
|
||||
|
||||
return None
|
||||
|
||||
def get_context_window(self, model_id: str, default: int = 31_000) -> int:
|
||||
return self.registry.get_context_window(model_id, default)
|
||||
|
||||
def check_token_limit(
|
||||
self,
|
||||
model_id: str,
|
||||
token_count: int,
|
||||
is_input: bool = True
|
||||
) -> Tuple[bool, int]:
|
||||
model = self.get_model(model_id)
|
||||
if not model:
|
||||
return False, 0
|
||||
|
||||
if is_input:
|
||||
max_allowed = model.context_window
|
||||
else:
|
||||
max_allowed = model.max_output_tokens or model.context_window
|
||||
|
||||
return token_count <= max_allowed, max_allowed
|
||||
|
||||
def format_model_info(self, model_id: str) -> Dict[str, Any]:
|
||||
model = self.get_model(model_id)
|
||||
if not model:
|
||||
return {"error": f"Model '{model_id}' not found"}
|
||||
|
||||
return {
|
||||
"id": model.id,
|
||||
"name": model.name,
|
||||
"provider": model.provider.value,
|
||||
"context_window": model.context_window,
|
||||
"max_output_tokens": model.max_output_tokens,
|
||||
"capabilities": [cap.value for cap in model.capabilities],
|
||||
"pricing": {
|
||||
"input_per_million": model.pricing.input_cost_per_million_tokens,
|
||||
"output_per_million": model.pricing.output_cost_per_million_tokens,
|
||||
} if model.pricing else None,
|
||||
"enabled": model.enabled,
|
||||
"beta": model.beta,
|
||||
"tier_availability": model.tier_availability,
|
||||
"priority": model.priority,
|
||||
"recommended": model.recommended,
|
||||
}
|
||||
|
||||
def list_available_models(
|
||||
self,
|
||||
tier: Optional[str] = None,
|
||||
include_disabled: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
logger.debug(f"list_available_models called with tier='{tier}', include_disabled={include_disabled}")
|
||||
|
||||
if tier:
|
||||
models = self.registry.get_by_tier(tier, enabled_only=not include_disabled)
|
||||
logger.debug(f"Found {len(models)} models for tier '{tier}'")
|
||||
else:
|
||||
models = self.registry.get_all(enabled_only=not include_disabled)
|
||||
logger.debug(f"Found {len(models)} total models")
|
||||
|
||||
if models:
|
||||
model_names = [m.name for m in models]
|
||||
logger.debug(f"Models: {model_names}")
|
||||
else:
|
||||
logger.warning(f"No models found for tier '{tier}' - this might indicate a configuration issue")
|
||||
|
||||
models = sorted(
|
||||
models,
|
||||
key=lambda m: (not m.is_free_tier, -m.priority, m.name)
|
||||
)
|
||||
|
||||
return [self.format_model_info(m.id) for m in models]
|
||||
|
||||
def get_legacy_constants(self) -> Dict:
|
||||
return self.registry.to_legacy_format()
|
||||
|
||||
async def get_default_model_for_user(self, client, user_id: str) -> str:
|
||||
try:
|
||||
from utils.config import config, EnvMode
|
||||
if config.ENV_MODE == EnvMode.LOCAL:
|
||||
return DEFAULT_PREMIUM_MODEL
|
||||
|
||||
from services.billing import get_user_subscription, SUBSCRIPTION_TIERS
|
||||
|
||||
subscription = await get_user_subscription(user_id)
|
||||
|
||||
is_paid_tier = False
|
||||
if subscription:
|
||||
price_id = None
|
||||
if subscription.get('items') and subscription['items'].get('data') and len(subscription['items']['data']) > 0:
|
||||
price_id = subscription['items']['data'][0]['price']['id']
|
||||
else:
|
||||
price_id = subscription.get('price_id')
|
||||
|
||||
tier_info = SUBSCRIPTION_TIERS.get(price_id)
|
||||
if tier_info and tier_info['name'] != 'free':
|
||||
is_paid_tier = True
|
||||
|
||||
if is_paid_tier:
|
||||
logger.debug(f"Setting Claude Sonnet 4 as default for paid user {user_id}")
|
||||
return DEFAULT_PREMIUM_MODEL
|
||||
else:
|
||||
logger.debug(f"Setting Kimi K2 as default for free user {user_id}")
|
||||
return DEFAULT_FREE_MODEL
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to determine user tier for {user_id}: {e}")
|
||||
return DEFAULT_FREE_MODEL
|
||||
|
||||
|
||||
model_manager = ModelManager()
|
|
@ -0,0 +1,104 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Dict, Any
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ModelProvider(Enum):
|
||||
OPENAI = "openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
OPENROUTER = "openrouter"
|
||||
GOOGLE = "google"
|
||||
GEMINI = "gemini"
|
||||
XAI = "xai"
|
||||
MOONSHOTAI = "moonshotai"
|
||||
|
||||
class ModelCapability(Enum):
|
||||
CHAT = "chat"
|
||||
FUNCTION_CALLING = "function_calling"
|
||||
VISION = "vision"
|
||||
CODE_INTERPRETER = "code_interpreter"
|
||||
WEB_SEARCH = "web_search"
|
||||
THINKING = "thinking"
|
||||
STRUCTURED_OUTPUT = "structured_output"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelPricing:
|
||||
input_cost_per_million_tokens: float
|
||||
output_cost_per_million_tokens: float
|
||||
|
||||
@property
|
||||
def input_cost_per_token(self) -> float:
|
||||
return self.input_cost_per_million_tokens / 1_000_000
|
||||
|
||||
@property
|
||||
def output_cost_per_token(self) -> float:
|
||||
return self.output_cost_per_million_tokens / 1_000_000
|
||||
|
||||
|
||||
@dataclass
|
||||
class Model:
|
||||
id: str
|
||||
name: str
|
||||
provider: ModelProvider
|
||||
aliases: List[str] = field(default_factory=list)
|
||||
context_window: int = 128_000
|
||||
max_output_tokens: Optional[int] = None
|
||||
capabilities: List[ModelCapability] = field(default_factory=list)
|
||||
pricing: Optional[ModelPricing] = None
|
||||
enabled: bool = True
|
||||
beta: bool = False
|
||||
tier_availability: List[str] = field(default_factory=lambda: ["paid"])
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
priority: int = 0
|
||||
recommended: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.max_output_tokens is None:
|
||||
self.max_output_tokens = min(self.context_window // 4, 32_000)
|
||||
|
||||
if ModelCapability.CHAT not in self.capabilities:
|
||||
self.capabilities.insert(0, ModelCapability.CHAT)
|
||||
|
||||
@property
|
||||
def full_id(self) -> str:
|
||||
if "/" in self.id:
|
||||
return self.id
|
||||
return f"{self.provider.value}/{self.id}"
|
||||
|
||||
@property
|
||||
def supports_thinking(self) -> bool:
|
||||
return ModelCapability.THINKING in self.capabilities
|
||||
|
||||
@property
|
||||
def supports_functions(self) -> bool:
|
||||
return ModelCapability.FUNCTION_CALLING in self.capabilities
|
||||
|
||||
@property
|
||||
def supports_vision(self) -> bool:
|
||||
return ModelCapability.VISION in self.capabilities
|
||||
|
||||
@property
|
||||
def is_free_tier(self) -> bool:
|
||||
return "free" in self.tier_availability
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"provider": self.provider.value,
|
||||
"aliases": self.aliases,
|
||||
"context_window": self.context_window,
|
||||
"max_output_tokens": self.max_output_tokens,
|
||||
"capabilities": [cap.value for cap in self.capabilities],
|
||||
"pricing": {
|
||||
"input_cost_per_million_tokens": self.pricing.input_cost_per_million_tokens,
|
||||
"output_cost_per_million_tokens": self.pricing.output_cost_per_million_tokens,
|
||||
} if self.pricing else None,
|
||||
"enabled": self.enabled,
|
||||
"beta": self.beta,
|
||||
"tier_availability": self.tier_availability,
|
||||
"metadata": self.metadata,
|
||||
"priority": self.priority,
|
||||
"recommended": self.recommended,
|
||||
}
|
|
@ -0,0 +1,327 @@
|
|||
from typing import Dict, List, Optional, Set
|
||||
from .models import Model, ModelProvider, ModelCapability, ModelPricing
|
||||
|
||||
DEFAULT_FREE_MODEL = "Kimi K2"
|
||||
DEFAULT_PREMIUM_MODEL = "Claude Sonnet 4"
|
||||
|
||||
class ModelRegistry:
|
||||
def __init__(self):
|
||||
self._models: Dict[str, Model] = {}
|
||||
self._aliases: Dict[str, str] = {}
|
||||
self._initialize_models()
|
||||
|
||||
def _initialize_models(self):
|
||||
self.register(Model(
|
||||
id="anthropic/claude-sonnet-4-20250514",
|
||||
name="Claude Sonnet 4",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
aliases=["claude-sonnet-4", "anthropic/claude-sonnet-4", "Claude Sonnet 4", "claude-sonnet-4-20250514"],
|
||||
context_window=200_000,
|
||||
capabilities=[
|
||||
ModelCapability.CHAT,
|
||||
ModelCapability.FUNCTION_CALLING,
|
||||
ModelCapability.VISION,
|
||||
ModelCapability.THINKING,
|
||||
],
|
||||
pricing=ModelPricing(
|
||||
input_cost_per_million_tokens=3.00,
|
||||
output_cost_per_million_tokens=15.00
|
||||
),
|
||||
tier_availability=["paid"],
|
||||
priority=100,
|
||||
recommended=True,
|
||||
enabled=True
|
||||
))
|
||||
|
||||
self.register(Model(
|
||||
id="anthropic/claude-3-7-sonnet-latest",
|
||||
name="Claude 3.7 Sonnet",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
aliases=["sonnet-3.7", "claude-3.7", "Claude 3.7 Sonnet", "claude-3-7-sonnet-latest"],
|
||||
context_window=200_000,
|
||||
capabilities=[
|
||||
ModelCapability.CHAT,
|
||||
ModelCapability.FUNCTION_CALLING,
|
||||
ModelCapability.VISION,
|
||||
],
|
||||
pricing=ModelPricing(
|
||||
input_cost_per_million_tokens=3.00,
|
||||
output_cost_per_million_tokens=15.00
|
||||
),
|
||||
tier_availability=["paid"],
|
||||
priority=93,
|
||||
enabled=True
|
||||
))
|
||||
|
||||
self.register(Model(
|
||||
id="anthropic/claude-3-5-sonnet-latest",
|
||||
name="Claude 3.5 Sonnet",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
aliases=["sonnet-3.5", "claude-3.5", "Claude 3.5 Sonnet", "claude-3-5-sonnet-latest"],
|
||||
context_window=200_000,
|
||||
capabilities=[
|
||||
ModelCapability.CHAT,
|
||||
ModelCapability.FUNCTION_CALLING,
|
||||
ModelCapability.VISION,
|
||||
],
|
||||
pricing=ModelPricing(
|
||||
input_cost_per_million_tokens=3.00,
|
||||
output_cost_per_million_tokens=15.00
|
||||
),
|
||||
tier_availability=["paid"],
|
||||
priority=90,
|
||||
enabled=True
|
||||
))
|
||||
|
||||
self.register(Model(
|
||||
id="openai/gpt-5",
|
||||
name="GPT-5",
|
||||
provider=ModelProvider.OPENAI,
|
||||
aliases=["gpt-5", "GPT-5"],
|
||||
context_window=400_000,
|
||||
capabilities=[
|
||||
ModelCapability.CHAT,
|
||||
ModelCapability.FUNCTION_CALLING,
|
||||
ModelCapability.VISION,
|
||||
ModelCapability.STRUCTURED_OUTPUT,
|
||||
],
|
||||
pricing=ModelPricing(
|
||||
input_cost_per_million_tokens=1.25,
|
||||
output_cost_per_million_tokens=10.00
|
||||
),
|
||||
tier_availability=["paid"],
|
||||
priority=99,
|
||||
enabled=True
|
||||
))
|
||||
|
||||
self.register(Model(
|
||||
id="openai/gpt-5-mini",
|
||||
name="GPT-5 Mini",
|
||||
provider=ModelProvider.OPENAI,
|
||||
aliases=["gpt-5-mini", "GPT-5 Mini"],
|
||||
context_window=400_000,
|
||||
capabilities=[
|
||||
ModelCapability.CHAT,
|
||||
ModelCapability.FUNCTION_CALLING,
|
||||
ModelCapability.STRUCTURED_OUTPUT,
|
||||
],
|
||||
pricing=ModelPricing(
|
||||
input_cost_per_million_tokens=0.25,
|
||||
output_cost_per_million_tokens=2.00
|
||||
),
|
||||
tier_availability=["free", "paid"],
|
||||
priority=85,
|
||||
enabled=True
|
||||
))
|
||||
|
||||
self.register(Model(
|
||||
id="gemini/gemini-2.5-pro",
|
||||
name="Gemini 2.5 Pro",
|
||||
provider=ModelProvider.GEMINI,
|
||||
aliases=["google/gemini-2.5-pro", "gemini-2.5-pro", "Gemini 2.5 Pro"],
|
||||
context_window=2_000_000,
|
||||
capabilities=[
|
||||
ModelCapability.CHAT,
|
||||
ModelCapability.FUNCTION_CALLING,
|
||||
ModelCapability.VISION,
|
||||
ModelCapability.STRUCTURED_OUTPUT,
|
||||
],
|
||||
pricing=ModelPricing(
|
||||
input_cost_per_million_tokens=1.25,
|
||||
output_cost_per_million_tokens=10.00
|
||||
),
|
||||
tier_availability=["paid"],
|
||||
priority=96,
|
||||
enabled=True
|
||||
))
|
||||
|
||||
self.register(Model(
|
||||
id="xai/grok-4",
|
||||
name="Grok 4",
|
||||
provider=ModelProvider.XAI,
|
||||
aliases=["grok-4", "x-ai/grok-4", "openrouter/x-ai/grok-4", "Grok 4"],
|
||||
context_window=128_000,
|
||||
capabilities=[
|
||||
ModelCapability.CHAT,
|
||||
ModelCapability.FUNCTION_CALLING,
|
||||
],
|
||||
pricing=ModelPricing(
|
||||
input_cost_per_million_tokens=5.00,
|
||||
output_cost_per_million_tokens=15.00
|
||||
),
|
||||
tier_availability=["paid"],
|
||||
priority=94,
|
||||
enabled=True
|
||||
))
|
||||
|
||||
self.register(Model(
|
||||
id="openrouter/moonshotai/kimi-k2",
|
||||
name="Kimi K2",
|
||||
provider=ModelProvider.MOONSHOTAI,
|
||||
aliases=["moonshotai/kimi-k2", "kimi-k2", "Kimi K2"],
|
||||
context_window=200_000,
|
||||
capabilities=[
|
||||
ModelCapability.CHAT,
|
||||
ModelCapability.FUNCTION_CALLING,
|
||||
],
|
||||
pricing=ModelPricing(
|
||||
input_cost_per_million_tokens=1.00,
|
||||
output_cost_per_million_tokens=3.00
|
||||
),
|
||||
tier_availability=["free", "paid"],
|
||||
priority=100,
|
||||
enabled=True
|
||||
))
|
||||
|
||||
"""
|
||||
# DeepSeek Models
|
||||
self.register(Model(
|
||||
id="openrouter/deepseek/deepseek-chat",
|
||||
name="DeepSeek Chat",
|
||||
provider=ModelProvider.OPENROUTER,
|
||||
aliases=["deepseek", "deepseek-chat"],
|
||||
context_window=128_000,
|
||||
capabilities=[
|
||||
ModelCapability.CHAT,
|
||||
ModelCapability.FUNCTION_CALLING
|
||||
],
|
||||
pricing=ModelPricing(
|
||||
input_cost_per_million_tokens=0.38,
|
||||
output_cost_per_million_tokens=0.89
|
||||
),
|
||||
tier_availability=["free", "paid"],
|
||||
priority=95,
|
||||
enabled=False # Currently disabled
|
||||
))
|
||||
|
||||
# Qwen Models
|
||||
self.register(Model(
|
||||
id="openrouter/qwen/qwen3-235b-a22b",
|
||||
name="Qwen3 235B",
|
||||
provider=ModelProvider.OPENROUTER,
|
||||
aliases=["qwen3", "qwen-3"],
|
||||
context_window=128_000,
|
||||
capabilities=[
|
||||
ModelCapability.CHAT,
|
||||
ModelCapability.FUNCTION_CALLING
|
||||
],
|
||||
pricing=ModelPricing(
|
||||
input_cost_per_million_tokens=0.13,
|
||||
output_cost_per_million_tokens=0.60
|
||||
),
|
||||
tier_availability=["free", "paid"],
|
||||
priority=90,
|
||||
enabled=False # Currently disabled
|
||||
))
|
||||
"""
|
||||
|
||||
def register(self, model: Model) -> None:
|
||||
self._models[model.id] = model
|
||||
for alias in model.aliases:
|
||||
self._aliases[alias] = model.id
|
||||
|
||||
def get(self, model_id: str) -> Optional[Model]:
|
||||
if model_id in self._models:
|
||||
return self._models[model_id]
|
||||
|
||||
if model_id in self._aliases:
|
||||
actual_id = self._aliases[model_id]
|
||||
return self._models.get(actual_id)
|
||||
|
||||
return None
|
||||
|
||||
def get_all(self, enabled_only: bool = True) -> List[Model]:
|
||||
models = list(self._models.values())
|
||||
if enabled_only:
|
||||
models = [m for m in models if m.enabled]
|
||||
return models
|
||||
|
||||
def get_by_tier(self, tier: str, enabled_only: bool = True) -> List[Model]:
|
||||
models = self.get_all(enabled_only)
|
||||
return [m for m in models if tier in m.tier_availability]
|
||||
|
||||
def get_by_provider(self, provider: ModelProvider, enabled_only: bool = True) -> List[Model]:
|
||||
models = self.get_all(enabled_only)
|
||||
return [m for m in models if m.provider == provider]
|
||||
|
||||
def get_by_capability(self, capability: ModelCapability, enabled_only: bool = True) -> List[Model]:
|
||||
models = self.get_all(enabled_only)
|
||||
return [m for m in models if capability in m.capabilities]
|
||||
|
||||
def resolve_model_id(self, model_id: str) -> Optional[str]:
|
||||
model = self.get(model_id)
|
||||
return model.id if model else None
|
||||
|
||||
def get_aliases(self, model_id: str) -> List[str]:
|
||||
model = self.get(model_id)
|
||||
return model.aliases if model else []
|
||||
|
||||
def enable_model(self, model_id: str) -> bool:
|
||||
model = self.get(model_id)
|
||||
if model:
|
||||
model.enabled = True
|
||||
return True
|
||||
return False
|
||||
|
||||
def disable_model(self, model_id: str) -> bool:
|
||||
model = self.get(model_id)
|
||||
if model:
|
||||
model.enabled = False
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_context_window(self, model_id: str, default: int = 31_000) -> int:
|
||||
model = self.get(model_id)
|
||||
return model.context_window if model else default
|
||||
|
||||
def get_pricing(self, model_id: str) -> Optional[ModelPricing]:
|
||||
model = self.get(model_id)
|
||||
return model.pricing if model else None
|
||||
|
||||
def to_legacy_format(self) -> Dict:
|
||||
models_dict = {}
|
||||
aliases_dict = {}
|
||||
pricing_dict = {}
|
||||
context_windows_dict = {}
|
||||
|
||||
for model in self.get_all(enabled_only=True):
|
||||
models_dict[model.id] = {
|
||||
"aliases": model.aliases,
|
||||
"pricing": {
|
||||
"input_cost_per_million_tokens": model.pricing.input_cost_per_million_tokens,
|
||||
"output_cost_per_million_tokens": model.pricing.output_cost_per_million_tokens,
|
||||
} if model.pricing else None,
|
||||
"context_window": model.context_window,
|
||||
"tier_availability": model.tier_availability,
|
||||
}
|
||||
|
||||
for alias in model.aliases:
|
||||
aliases_dict[alias] = model.id
|
||||
|
||||
if model.pricing:
|
||||
pricing_dict[model.id] = {
|
||||
"input_cost_per_million_tokens": model.pricing.input_cost_per_million_tokens,
|
||||
"output_cost_per_million_tokens": model.pricing.output_cost_per_million_tokens,
|
||||
}
|
||||
|
||||
context_windows_dict[model.id] = model.context_window
|
||||
|
||||
free_models = [m.id for m in self.get_by_tier("free")]
|
||||
paid_models = [m.id for m in self.get_by_tier("paid")]
|
||||
|
||||
# Debug logging
|
||||
from utils.logger import logger
|
||||
logger.debug(f"Legacy format generation: {len(free_models)} free models, {len(paid_models)} paid models")
|
||||
logger.debug(f"Free models: {free_models}")
|
||||
logger.debug(f"Paid models: {paid_models}")
|
||||
|
||||
return {
|
||||
"MODELS": models_dict,
|
||||
"MODEL_NAME_ALIASES": aliases_dict,
|
||||
"HARDCODED_MODEL_PRICES": pricing_dict,
|
||||
"MODEL_CONTEXT_WINDOWS": context_windows_dict,
|
||||
"FREE_TIER_MODELS": free_models,
|
||||
"PAID_TIER_MODELS": paid_models,
|
||||
}
|
||||
|
||||
registry = ModelRegistry()
|
|
@ -111,20 +111,19 @@ async def run_agent_background(
|
|||
"agent_config": agent_config,
|
||||
})
|
||||
|
||||
effective_model = model_name
|
||||
if model_name == "openai/gpt-5-mini" and agent_config and agent_config.get('model'):
|
||||
from models import model_manager
|
||||
is_tier_default = model_name in ["Kimi K2", "Claude Sonnet 4", "openai/gpt-5-mini"]
|
||||
|
||||
if is_tier_default and agent_config and agent_config.get('model'):
|
||||
agent_model = agent_config['model']
|
||||
from utils.constants import MODEL_NAME_ALIASES
|
||||
resolved_agent_model = MODEL_NAME_ALIASES.get(agent_model, agent_model)
|
||||
effective_model = resolved_agent_model
|
||||
logger.debug(f"Using model from agent config: {agent_model} -> {effective_model} (no user selection)")
|
||||
effective_model = model_manager.resolve_model_id(agent_model)
|
||||
logger.debug(f"Using model from agent config: {agent_model} -> {effective_model} (tier default was {model_name})")
|
||||
else:
|
||||
from utils.constants import MODEL_NAME_ALIASES
|
||||
effective_model = MODEL_NAME_ALIASES.get(model_name, model_name)
|
||||
if model_name != "openai/gpt-5-mini":
|
||||
effective_model = model_manager.resolve_model_id(model_name)
|
||||
if not is_tier_default:
|
||||
logger.debug(f"Using user-selected model: {model_name} -> {effective_model}")
|
||||
else:
|
||||
logger.debug(f"Using default model: {effective_model}")
|
||||
logger.debug(f"Using tier default model: {model_name} -> {effective_model}")
|
||||
|
||||
logger.debug(f"🚀 Using model: {effective_model} (thinking: {enable_thinking}, reasoning_effort: {reasoning_effort})")
|
||||
if agent_config:
|
||||
|
|
|
@ -110,14 +110,22 @@ def get_model_pricing(model: str) -> tuple[float, float] | None:
|
|||
Get pricing for a model. Returns (input_cost_per_million, output_cost_per_million) or None.
|
||||
|
||||
Args:
|
||||
model: The model name to get pricing for
|
||||
model: The model name to get pricing for (can be display name or model ID)
|
||||
|
||||
Returns:
|
||||
Tuple of (input_cost_per_million_tokens, output_cost_per_million_tokens) or None if not found
|
||||
"""
|
||||
# Try direct lookup first
|
||||
if model in HARDCODED_MODEL_PRICES:
|
||||
pricing = HARDCODED_MODEL_PRICES[model]
|
||||
return pricing["input_cost_per_million_tokens"], pricing["output_cost_per_million_tokens"]
|
||||
|
||||
from models import model_manager
|
||||
resolved_model = model_manager.resolve_model_id(model)
|
||||
if resolved_model != model and resolved_model in HARDCODED_MODEL_PRICES:
|
||||
pricing = HARDCODED_MODEL_PRICES[resolved_model]
|
||||
return pricing["input_cost_per_million_tokens"], pricing["output_cost_per_million_tokens"]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
@ -570,8 +578,9 @@ def calculate_token_cost(prompt_tokens: int, completion_tokens: int, model: str)
|
|||
prompt_tokens = int(prompt_tokens) if prompt_tokens is not None else 0
|
||||
completion_tokens = int(completion_tokens) if completion_tokens is not None else 0
|
||||
|
||||
# Try to resolve the model name using MODEL_NAME_ALIASES first
|
||||
resolved_model = MODEL_NAME_ALIASES.get(model, model)
|
||||
# Try to resolve the model name using new model manager first
|
||||
from models import model_manager
|
||||
resolved_model = model_manager.resolve_model_id(model)
|
||||
|
||||
# Check if we have hardcoded pricing for this model (try both original and resolved)
|
||||
hardcoded_pricing = get_model_pricing(model) or get_model_pricing(resolved_model)
|
||||
|
@ -672,7 +681,8 @@ async def can_use_model(client, user_id: str, model_name: str):
|
|||
}
|
||||
|
||||
allowed_models = await get_allowed_models_for_user(client, user_id)
|
||||
resolved_model = MODEL_NAME_ALIASES.get(model_name, model_name)
|
||||
from models import model_manager
|
||||
resolved_model = model_manager.resolve_model_id(model_name)
|
||||
if resolved_model in allowed_models:
|
||||
return True, "Model access allowed", allowed_models
|
||||
|
||||
|
@ -1751,6 +1761,9 @@ async def get_available_models(
|
|||
):
|
||||
"""Get the list of models available to the user based on their subscription tier."""
|
||||
try:
|
||||
# Import the new model manager
|
||||
from models import model_manager
|
||||
|
||||
# Get Supabase client
|
||||
db = DBConnection()
|
||||
client = await db.client
|
||||
|
@ -1759,18 +1772,23 @@ async def get_available_models(
|
|||
if config.ENV_MODE == EnvMode.LOCAL:
|
||||
logger.debug("Running in local development mode - billing checks are disabled")
|
||||
|
||||
# In local mode, return all models from MODEL_NAME_ALIASES
|
||||
# In local mode, return all enabled models
|
||||
all_models = model_manager.list_available_models(include_disabled=False)
|
||||
model_info = []
|
||||
for short_name, full_name in MODEL_NAME_ALIASES.items():
|
||||
# Skip entries where the key is a full name to avoid duplicates
|
||||
# if short_name == full_name or '/' in short_name:
|
||||
# continue
|
||||
|
||||
for model_data in all_models:
|
||||
# Create clean model info for frontend
|
||||
model_info.append({
|
||||
"id": full_name,
|
||||
"display_name": short_name,
|
||||
"short_name": short_name,
|
||||
"requires_subscription": False # Always false in local dev mode
|
||||
"id": model_data["id"],
|
||||
"display_name": model_data["name"],
|
||||
"short_name": model_data.get("aliases", [model_data["name"]])[0] if model_data.get("aliases") else model_data["name"],
|
||||
"requires_subscription": False, # Always false in local dev mode
|
||||
"input_cost_per_million_tokens": model_data["pricing"]["input_per_million"] if model_data["pricing"] else None,
|
||||
"output_cost_per_million_tokens": model_data["pricing"]["output_per_million"] if model_data["pricing"] else None,
|
||||
"context_window": model_data["context_window"],
|
||||
"capabilities": model_data["capabilities"],
|
||||
"recommended": model_data["recommended"],
|
||||
"priority": model_data["priority"]
|
||||
})
|
||||
|
||||
return {
|
||||
|
@ -1780,10 +1798,7 @@ async def get_available_models(
|
|||
}
|
||||
|
||||
|
||||
# For non-local mode, get list of allowed models for this user
|
||||
allowed_models = await get_allowed_models_for_user(client, current_user_id)
|
||||
free_tier_models = MODEL_ACCESS_TIERS.get('free', [])
|
||||
|
||||
# For non-local mode, use new model manager system
|
||||
# Get subscription info for context
|
||||
subscription = await get_user_subscription(current_user_id)
|
||||
|
||||
|
@ -1801,122 +1816,56 @@ async def get_available_models(
|
|||
if tier_info:
|
||||
tier_name = tier_info['name']
|
||||
|
||||
# Get all unique full model names from MODEL_NAME_ALIASES
|
||||
all_models = set()
|
||||
model_aliases = {}
|
||||
# Get ALL enabled models for preview UI (don't filter by tier here)
|
||||
all_models = model_manager.list_available_models(tier=None, include_disabled=False)
|
||||
logger.debug(f"Found {len(all_models)} total models available")
|
||||
|
||||
for short_name, full_name in MODEL_NAME_ALIASES.items():
|
||||
# Add all unique full model names
|
||||
all_models.add(full_name)
|
||||
# Get allowed models for this specific user (for access checking)
|
||||
allowed_models = await get_allowed_models_for_user(client, current_user_id)
|
||||
logger.debug(f"User {current_user_id} allowed models: {allowed_models}")
|
||||
logger.debug(f"User tier: {tier_name}")
|
||||
|
||||
# Only include short names that don't match their full names for aliases
|
||||
if short_name != full_name and not short_name.startswith("openai/") and not short_name.startswith("anthropic/") and not short_name.startswith("openrouter/") and not short_name.startswith("xai/"):
|
||||
if full_name not in model_aliases:
|
||||
model_aliases[full_name] = short_name
|
||||
|
||||
# Create model info with display names for ALL models
|
||||
# Create clean model info for frontend
|
||||
model_info = []
|
||||
for model in all_models:
|
||||
display_name = model_aliases.get(model, model.split('/')[-1] if '/' in model else model)
|
||||
|
||||
# Check if model requires subscription (not in free tier)
|
||||
requires_sub = model not in free_tier_models
|
||||
for model_data in all_models:
|
||||
model_id = model_data["id"]
|
||||
|
||||
# Check if model is available with current subscription
|
||||
is_available = model in allowed_models
|
||||
is_available = model_id in allowed_models
|
||||
|
||||
# Get pricing information - check hardcoded prices first, then litellm
|
||||
# Get pricing with multiplier applied
|
||||
pricing_info = {}
|
||||
|
||||
# Check if we have hardcoded pricing for this model
|
||||
hardcoded_pricing = get_model_pricing(model)
|
||||
if hardcoded_pricing:
|
||||
input_cost_per_million, output_cost_per_million = hardcoded_pricing
|
||||
if model_data["pricing"]:
|
||||
pricing_info = {
|
||||
"input_cost_per_million_tokens": input_cost_per_million * TOKEN_PRICE_MULTIPLIER,
|
||||
"output_cost_per_million_tokens": output_cost_per_million * TOKEN_PRICE_MULTIPLIER,
|
||||
"max_tokens": None
|
||||
"input_cost_per_million_tokens": model_data["pricing"]["input_per_million"] * TOKEN_PRICE_MULTIPLIER,
|
||||
"output_cost_per_million_tokens": model_data["pricing"]["output_per_million"] * TOKEN_PRICE_MULTIPLIER,
|
||||
"max_tokens": model_data["max_output_tokens"]
|
||||
}
|
||||
else:
|
||||
try:
|
||||
# Try to get pricing using cost_per_token function
|
||||
models_to_try = []
|
||||
|
||||
# Add the original model name
|
||||
models_to_try.append(model)
|
||||
|
||||
# Try to resolve the model name using MODEL_NAME_ALIASES
|
||||
if model in MODEL_NAME_ALIASES:
|
||||
resolved_model = MODEL_NAME_ALIASES[model]
|
||||
models_to_try.append(resolved_model)
|
||||
# Also try without provider prefix if it has one
|
||||
if '/' in resolved_model:
|
||||
models_to_try.append(resolved_model.split('/', 1)[1])
|
||||
|
||||
# If model is a value in aliases, try to find a matching key
|
||||
for alias_key, alias_value in MODEL_NAME_ALIASES.items():
|
||||
if alias_value == model:
|
||||
models_to_try.append(alias_key)
|
||||
break
|
||||
|
||||
# Also try without provider prefix for the original model
|
||||
if '/' in model:
|
||||
models_to_try.append(model.split('/', 1)[1])
|
||||
|
||||
# Special handling for Google models accessed via Google API
|
||||
if model.startswith('gemini/'):
|
||||
google_model_name = model.replace('gemini/', '')
|
||||
models_to_try.append(google_model_name)
|
||||
|
||||
# Special handling for Google models accessed via Google API
|
||||
if model.startswith('gemini/'):
|
||||
google_model_name = model.replace('gemini/', '')
|
||||
models_to_try.append(google_model_name)
|
||||
|
||||
# Try each model name variation until we find one that works
|
||||
input_cost_per_token = None
|
||||
output_cost_per_token = None
|
||||
|
||||
for model_name in models_to_try:
|
||||
try:
|
||||
# Use cost_per_token with sample token counts to get the per-token costs
|
||||
input_cost, output_cost = cost_per_token(model_name, 1000000, 1000000)
|
||||
if input_cost is not None and output_cost is not None:
|
||||
input_cost_per_token = input_cost
|
||||
output_cost_per_token = output_cost
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if input_cost_per_token is not None and output_cost_per_token is not None:
|
||||
pricing_info = {
|
||||
"input_cost_per_million_tokens": input_cost_per_token * TOKEN_PRICE_MULTIPLIER,
|
||||
"output_cost_per_million_tokens": output_cost_per_million * TOKEN_PRICE_MULTIPLIER,
|
||||
"max_tokens": None # cost_per_token doesn't provide max_tokens info
|
||||
}
|
||||
else:
|
||||
pricing_info = {
|
||||
"input_cost_per_million_tokens": None,
|
||||
"output_cost_per_million_tokens": None,
|
||||
"max_tokens": None
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not get pricing for model {model}: {str(e)}")
|
||||
pricing_info = {
|
||||
"input_cost_per_million_tokens": None,
|
||||
"output_cost_per_million_tokens": None,
|
||||
"max_tokens": None
|
||||
}
|
||||
pricing_info = {
|
||||
"input_cost_per_million_tokens": None,
|
||||
"output_cost_per_million_tokens": None,
|
||||
"max_tokens": None
|
||||
}
|
||||
|
||||
model_info.append({
|
||||
"id": model,
|
||||
"display_name": display_name,
|
||||
"short_name": model_aliases.get(model),
|
||||
"requires_subscription": requires_sub,
|
||||
"id": model_id,
|
||||
"display_name": model_data["name"],
|
||||
"short_name": model_data.get("aliases", [model_data["name"]])[0] if model_data.get("aliases") else model_data["name"],
|
||||
"requires_subscription": not model_data.get("tier_availability", []) or "free" not in model_data["tier_availability"],
|
||||
"is_available": is_available,
|
||||
"context_window": model_data["context_window"],
|
||||
"capabilities": model_data["capabilities"],
|
||||
"recommended": model_data["recommended"],
|
||||
"priority": model_data["priority"],
|
||||
**pricing_info
|
||||
})
|
||||
|
||||
logger.debug(f"Returning {len(model_info)} models to user {current_user_id} (tier: {tier_name})")
|
||||
if model_info:
|
||||
model_names = [m["display_name"] for m in model_info]
|
||||
logger.debug(f"Model names: {model_names}")
|
||||
|
||||
return {
|
||||
"models": model_info,
|
||||
"subscription_tier": tier_name,
|
||||
|
|
|
@ -257,9 +257,12 @@ def prepare_params(
|
|||
enable_thinking: Optional[bool] = False,
|
||||
reasoning_effort: Optional[str] = 'low'
|
||||
) -> Dict[str, Any]:
|
||||
"""Prepare parameters for the API call."""
|
||||
from models import model_manager
|
||||
resolved_model_name = model_manager.resolve_model_id(model_name)
|
||||
logger.debug(f"Model resolution: '{model_name}' -> '{resolved_model_name}'")
|
||||
|
||||
params = {
|
||||
"model": model_name,
|
||||
"model": resolved_model_name,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"response_format": response_format,
|
||||
|
@ -276,22 +279,22 @@ def prepare_params(
|
|||
params["model_id"] = model_id
|
||||
|
||||
# Handle token limits
|
||||
_configure_token_limits(params, model_name, max_tokens)
|
||||
_configure_token_limits(params, resolved_model_name, max_tokens)
|
||||
# Add tools if provided
|
||||
_add_tools_config(params, tools, tool_choice)
|
||||
# Add Anthropic-specific parameters
|
||||
_configure_anthopic(params, model_name, params["messages"])
|
||||
_configure_anthopic(params, resolved_model_name, params["messages"])
|
||||
# Add OpenRouter-specific parameters
|
||||
_configure_openrouter(params, model_name)
|
||||
_configure_openrouter(params, resolved_model_name)
|
||||
# Add Bedrock-specific parameters
|
||||
_configure_bedrock(params, model_name, model_id)
|
||||
_configure_bedrock(params, resolved_model_name, model_id)
|
||||
|
||||
_add_fallback_model(params, model_name, messages)
|
||||
_add_fallback_model(params, resolved_model_name, messages)
|
||||
# Add OpenAI GPT-5 specific parameters
|
||||
_configure_openai_gpt5(params, model_name)
|
||||
_configure_openai_gpt5(params, resolved_model_name)
|
||||
# Add Kimi K2-specific parameters
|
||||
_configure_kimi_k2(params, model_name)
|
||||
_configure_thinking(params, model_name, enable_thinking, reasoning_effort)
|
||||
_configure_kimi_k2(params, resolved_model_name)
|
||||
_configure_thinking(params, resolved_model_name, enable_thinking, reasoning_effort)
|
||||
|
||||
return params
|
||||
|
||||
|
|
|
@ -121,6 +121,7 @@ class WorkflowUpdateRequest(BaseModel):
|
|||
|
||||
class WorkflowExecuteRequest(BaseModel):
|
||||
input_data: Optional[Dict[str, Any]] = None
|
||||
model_name: Optional[str] = None
|
||||
|
||||
|
||||
WorkflowStepRequest.model_rebuild()
|
||||
|
@ -761,6 +762,7 @@ async def execute_agent_workflow(
|
|||
execution_data: WorkflowExecuteRequest,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
):
|
||||
print("DEBUG: Executing workflow", workflow_id, "for agent", agent_id)
|
||||
await verify_agent_access(agent_id, user_id)
|
||||
|
||||
client = await db.client
|
||||
|
@ -786,7 +788,9 @@ async def execute_agent_workflow(
|
|||
if active_version and active_version.model:
|
||||
model_name = active_version.model
|
||||
else:
|
||||
model_name = "openai/gpt-5-mini"
|
||||
from models import model_manager
|
||||
model_name = await model_manager.get_default_model_for_user(client, account_id)
|
||||
print("DEBUG: Using tier-based default model:", model_name)
|
||||
|
||||
can_use, model_message, allowed_models = await can_use_model(client, account_id, model_name)
|
||||
if not can_use:
|
||||
|
@ -807,7 +811,8 @@ async def execute_agent_workflow(
|
|||
'triggered_by': 'manual',
|
||||
'execution_timestamp': datetime.now(timezone.utc).isoformat(),
|
||||
'user_id': user_id,
|
||||
'execution_source': 'workflow_api'
|
||||
'execution_source': 'workflow_api',
|
||||
'model_name': model_name # Use the model from agent version config
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
@ -360,7 +360,20 @@ class AgentExecutor:
|
|||
trigger_variables: Dict[str, Any]
|
||||
) -> str:
|
||||
client = await self._db.client
|
||||
model_name = agent_config.get('model') or "openai/gpt-5-mini"
|
||||
|
||||
# Debug: Log the agent config to see what model is set
|
||||
logger.debug(f"Agent config for trigger execution: model='{agent_config.get('model')}', keys={list(agent_config.keys())}")
|
||||
|
||||
model_name = agent_config.get('model')
|
||||
logger.debug(f"Model from agent config: '{model_name}' (type: {type(model_name)})")
|
||||
|
||||
if not model_name:
|
||||
account_id = agent_config.get('account_id')
|
||||
if account_id:
|
||||
from models import model_manager
|
||||
model_name = await model_manager.get_default_model_for_user(client, account_id)
|
||||
else:
|
||||
model_name = "Kimi K2"
|
||||
|
||||
account_id = agent_config.get('account_id')
|
||||
if not account_id:
|
||||
|
@ -439,7 +452,7 @@ class WorkflowExecutor:
|
|||
agent_config, account_id = await self._get_agent_data(agent_id)
|
||||
|
||||
enhanced_agent_config = await self._enhance_agent_config_for_workflow(
|
||||
agent_config, workflow_config, steps_json, workflow_input, account_id
|
||||
agent_config, workflow_config, steps_json, workflow_input, account_id, trigger_result
|
||||
)
|
||||
|
||||
thread_id, project_id = await self._session_manager.create_workflow_session(
|
||||
|
@ -528,7 +541,8 @@ class WorkflowExecutor:
|
|||
workflow_config: Dict[str, Any],
|
||||
steps_json: list,
|
||||
workflow_input: Dict[str, Any],
|
||||
account_id: str = None
|
||||
account_id: str = None,
|
||||
trigger_result: Optional[TriggerResult] = None
|
||||
) -> Dict[str, Any]:
|
||||
available_tools = self._get_available_tools(agent_config)
|
||||
workflow_prompt = format_workflow_for_llm(
|
||||
|
@ -547,6 +561,13 @@ class WorkflowExecutor:
|
|||
if account_id:
|
||||
enhanced_config['account_id'] = account_id
|
||||
|
||||
# Check for user-specified model in trigger execution variables
|
||||
if trigger_result and hasattr(trigger_result, 'execution_variables'):
|
||||
user_model = trigger_result.execution_variables.get('model_name')
|
||||
if user_model:
|
||||
enhanced_config['model'] = user_model
|
||||
logger.debug(f"Using user-specified model for workflow: {user_model}")
|
||||
|
||||
return enhanced_config
|
||||
|
||||
def _get_available_tools(self, agent_config: Dict[str, Any]) -> list:
|
||||
|
@ -589,7 +610,8 @@ class WorkflowExecutor:
|
|||
from services.billing import check_billing_status, can_use_model
|
||||
|
||||
client = await self._db.client
|
||||
model_name = "openai/gpt-5-mini"
|
||||
from models import model_manager
|
||||
model_name = await model_manager.get_default_model_for_user(client, account_id)
|
||||
|
||||
can_use, model_message, _ = await can_use_model(client, account_id, model_name)
|
||||
if not can_use:
|
||||
|
@ -636,7 +658,25 @@ class WorkflowExecutor:
|
|||
agent_config: Dict[str, Any]
|
||||
) -> str:
|
||||
client = await self._db.client
|
||||
model_name = agent_config.get('model') or "openai/gpt-5-mini"
|
||||
|
||||
# Debug: Log the agent config to see what model is set
|
||||
logger.debug(f"Agent config for workflow execution: model='{agent_config.get('model')}', keys={list(agent_config.keys())}")
|
||||
|
||||
model_name = agent_config.get('model')
|
||||
logger.debug(f"Model from agent config: '{model_name}' (type: {type(model_name)})")
|
||||
|
||||
if not model_name:
|
||||
account_id = agent_config.get('account_id')
|
||||
if not account_id:
|
||||
thread_result = await client.table('threads').select('account_id').eq('thread_id', thread_id).execute()
|
||||
if thread_result.data:
|
||||
account_id = thread_result.data[0]['account_id']
|
||||
|
||||
if account_id:
|
||||
from models import model_manager
|
||||
model_name = await model_manager.get_default_model_for_user(client, account_id)
|
||||
else:
|
||||
model_name = "Kimi K2"
|
||||
|
||||
account_id = agent_config.get('account_id')
|
||||
if not account_id:
|
||||
|
|
|
@ -289,7 +289,9 @@ class ProviderService:
|
|||
def _initialize_providers(self):
|
||||
self._providers["schedule"] = ScheduleProvider()
|
||||
self._providers["webhook"] = WebhookProvider()
|
||||
self._providers["composio"] = ComposioEventProvider()
|
||||
composio_provider = ComposioEventProvider()
|
||||
composio_provider.set_db(self._db)
|
||||
self._providers["composio"] = composio_provider
|
||||
|
||||
async def get_available_providers(self) -> List[Dict[str, Any]]:
|
||||
providers = []
|
||||
|
@ -446,6 +448,45 @@ class ComposioEventProvider(TriggerProvider):
|
|||
super().__init__("composio", TriggerType.WEBHOOK)
|
||||
self._api_base = os.getenv("COMPOSIO_API_BASE", "https://backend.composio.dev")
|
||||
self._api_key = os.getenv("COMPOSIO_API_KEY", "")
|
||||
self._db: Optional[DBConnection] = None
|
||||
|
||||
def set_db(self, db: DBConnection):
|
||||
"""Set database connection for provider"""
|
||||
self._db = db
|
||||
|
||||
async def _count_triggers_with_composio_id(self, composio_trigger_id: str, exclude_trigger_id: Optional[str] = None) -> int:
|
||||
"""Count how many triggers use the same composio_trigger_id (excluding specified trigger)"""
|
||||
if not self._db:
|
||||
return 0
|
||||
client = await self._db.client
|
||||
|
||||
# Use PostgreSQL JSON operator for exact match
|
||||
query = client.table('agent_triggers').select('trigger_id', count='exact').eq('trigger_type', 'webhook').eq('config->>composio_trigger_id', composio_trigger_id)
|
||||
|
||||
if exclude_trigger_id:
|
||||
query = query.neq('trigger_id', exclude_trigger_id)
|
||||
|
||||
result = await query.execute()
|
||||
count = result.count or 0
|
||||
|
||||
return count
|
||||
|
||||
async def _count_active_triggers_with_composio_id(self, composio_trigger_id: str, exclude_trigger_id: Optional[str] = None) -> int:
|
||||
"""Count how many ACTIVE triggers use the same composio_trigger_id (excluding specified trigger)"""
|
||||
if not self._db:
|
||||
return 0
|
||||
client = await self._db.client
|
||||
|
||||
# Use PostgreSQL JSON operator for exact match
|
||||
query = client.table('agent_triggers').select('trigger_id', count='exact').eq('trigger_type', 'webhook').eq('is_active', True).eq('config->>composio_trigger_id', composio_trigger_id)
|
||||
|
||||
if exclude_trigger_id:
|
||||
query = query.neq('trigger_id', exclude_trigger_id)
|
||||
|
||||
result = await query.execute()
|
||||
count = result.count or 0
|
||||
|
||||
return count
|
||||
|
||||
def _headers(self) -> Dict[str, str]:
|
||||
return {"x-api-key": self._api_key, "Content-Type": "application/json"}
|
||||
|
@ -482,14 +523,23 @@ class ComposioEventProvider(TriggerProvider):
|
|||
return config
|
||||
|
||||
async def setup_trigger(self, trigger: Trigger) -> bool:
|
||||
# Re-enable the Composio trigger instance if present
|
||||
# Enable in Composio only if this will be the first active trigger with this composio_trigger_id
|
||||
try:
|
||||
trigger_id = trigger.config.get("composio_trigger_id")
|
||||
if not trigger_id:
|
||||
composio_trigger_id = trigger.config.get("composio_trigger_id")
|
||||
if not composio_trigger_id or not self._api_key:
|
||||
return True
|
||||
if not self._api_key:
|
||||
|
||||
# Check if other ACTIVE triggers are using this composio_trigger_id
|
||||
other_active_count = await self._count_active_triggers_with_composio_id(composio_trigger_id, trigger.trigger_id)
|
||||
logger.debug(f"Setup trigger {trigger.trigger_id}: other_active_count={other_active_count} for composio_id={composio_trigger_id}")
|
||||
|
||||
if other_active_count > 0:
|
||||
# Other active triggers exist, don't touch Composio - just mark our trigger as active locally
|
||||
logger.debug(f"Skipping Composio enable - {other_active_count} other active triggers exist")
|
||||
return True
|
||||
# Use canonical payload first per Composio API; include tolerant fallbacks
|
||||
|
||||
# We're the first/only active trigger, enable in Composio
|
||||
logger.debug(f"Enabling trigger in Composio - first active trigger for {composio_trigger_id}")
|
||||
payload_candidates: List[Dict[str, Any]] = [
|
||||
{"status": "enable"},
|
||||
{"status": "enabled"},
|
||||
|
@ -497,11 +547,12 @@ class ComposioEventProvider(TriggerProvider):
|
|||
]
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
for api_base in self._api_bases():
|
||||
url = f"{api_base}/api/v3/trigger_instances/manage/{trigger_id}"
|
||||
url = f"{api_base}/api/v3/trigger_instances/manage/{composio_trigger_id}"
|
||||
for body in payload_candidates:
|
||||
try:
|
||||
resp = await client.patch(url, headers=self._headers(), json=body)
|
||||
if resp.status_code in (200, 204):
|
||||
logger.debug(f"Successfully enabled trigger in Composio: {composio_trigger_id}")
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
|
@ -510,14 +561,23 @@ class ComposioEventProvider(TriggerProvider):
|
|||
return True
|
||||
|
||||
async def teardown_trigger(self, trigger: Trigger) -> bool:
|
||||
# Disable the Composio trigger instance so it stops sending webhooks
|
||||
# Disable in Composio only if this was the last active trigger with this composio_trigger_id
|
||||
try:
|
||||
trigger_id = trigger.config.get("composio_trigger_id")
|
||||
if not trigger_id:
|
||||
composio_trigger_id = trigger.config.get("composio_trigger_id")
|
||||
|
||||
if not composio_trigger_id or not self._api_key:
|
||||
logger.info(f"TEARDOWN: Skipping - no composio_id or api_key")
|
||||
return True
|
||||
if not self._api_key:
|
||||
|
||||
# Check if other ACTIVE triggers are using this composio_trigger_id
|
||||
other_active_count = await self._count_active_triggers_with_composio_id(composio_trigger_id, trigger.trigger_id)
|
||||
|
||||
if other_active_count > 0:
|
||||
# Other active triggers exist, don't touch Composio - just mark our trigger as inactive locally
|
||||
logger.info(f"TEARDOWN: Skipping Composio disable - {other_active_count} other active triggers exist")
|
||||
return True
|
||||
# Use canonical payload first per Composio API; include tolerant fallbacks
|
||||
|
||||
# We're the last active trigger, disable in Composio
|
||||
payload_candidates: List[Dict[str, Any]] = [
|
||||
{"status": "disable"},
|
||||
{"status": "disabled"},
|
||||
|
@ -525,29 +585,38 @@ class ComposioEventProvider(TriggerProvider):
|
|||
]
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
for api_base in self._api_bases():
|
||||
url = f"{api_base}/api/v3/trigger_instances/manage/{trigger_id}"
|
||||
url = f"{api_base}/api/v3/trigger_instances/manage/{composio_trigger_id}"
|
||||
for body in payload_candidates:
|
||||
try:
|
||||
resp = await client.patch(url, headers=self._headers(), json=body)
|
||||
if resp.status_code in (200, 204):
|
||||
return True
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.warning(f"TEARDOWN: Failed to disable with body {body}: {e}")
|
||||
continue
|
||||
logger.warning(f"TEARDOWN: Failed to disable trigger in Composio: {composio_trigger_id}")
|
||||
return True
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.error(f"TEARDOWN: Exception in teardown_trigger: {e}")
|
||||
return True
|
||||
|
||||
async def delete_remote_trigger(self, trigger: Trigger) -> bool:
|
||||
# Permanently remove the remote Composio trigger instance
|
||||
# Only permanently remove the remote Composio trigger if this is the last trigger using it
|
||||
try:
|
||||
trigger_id = trigger.config.get("composio_trigger_id")
|
||||
if not trigger_id:
|
||||
composio_trigger_id = trigger.config.get("composio_trigger_id")
|
||||
if not composio_trigger_id or not self._api_key:
|
||||
return True
|
||||
if not self._api_key:
|
||||
|
||||
# Check if other triggers are using this composio_trigger_id
|
||||
other_count = await self._count_triggers_with_composio_id(composio_trigger_id, trigger.trigger_id)
|
||||
if other_count > 0:
|
||||
# Other triggers exist, don't delete from Composio - just remove our local trigger
|
||||
return True
|
||||
|
||||
# We're the last trigger, permanently delete from Composio
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
for api_base in self._api_bases():
|
||||
url = f"{api_base}/api/v3/trigger_instances/manage/{trigger_id}"
|
||||
url = f"{api_base}/api/v3/trigger_instances/manage/{composio_trigger_id}"
|
||||
try:
|
||||
resp = await client.delete(url, headers=self._headers())
|
||||
if resp.status_code in (200, 204):
|
||||
|
|
|
@ -145,6 +145,10 @@ class TriggerService:
|
|||
config_changed = config is not None
|
||||
activation_toggled = (is_active is not None) and (previous_is_active != trigger.is_active)
|
||||
|
||||
|
||||
# UPDATE DATABASE FIRST so provider methods see correct state
|
||||
await self._update_trigger(trigger)
|
||||
|
||||
if config_changed or activation_toggled:
|
||||
from .provider_service import get_provider_service
|
||||
provider_service = get_provider_service(self._db)
|
||||
|
@ -156,8 +160,8 @@ class TriggerService:
|
|||
setup_success = await provider_service.setup_trigger(trigger)
|
||||
if not setup_success:
|
||||
raise ValueError(f"Failed to update trigger setup: {trigger_id}")
|
||||
else:
|
||||
# Only activation toggled; call the minimal required action
|
||||
elif activation_toggled:
|
||||
# Only activation toggled; call the appropriate action
|
||||
if trigger.is_active:
|
||||
setup_success = await provider_service.setup_trigger(trigger)
|
||||
if not setup_success:
|
||||
|
@ -165,8 +169,6 @@ class TriggerService:
|
|||
else:
|
||||
await provider_service.teardown_trigger(trigger)
|
||||
|
||||
await self._update_trigger(trigger)
|
||||
|
||||
logger.debug(f"Updated trigger {trigger_id}")
|
||||
return trigger
|
||||
|
||||
|
@ -175,9 +177,17 @@ class TriggerService:
|
|||
if not trigger:
|
||||
return False
|
||||
|
||||
# DELETE FROM DATABASE FIRST so provider methods see correct state
|
||||
client = await self._db.client
|
||||
result = await client.table('agent_triggers').delete().eq('trigger_id', trigger_id).execute()
|
||||
|
||||
success = len(result.data) > 0
|
||||
if not success:
|
||||
return False
|
||||
|
||||
from .provider_service import get_provider_service
|
||||
provider_service = get_provider_service(self._db)
|
||||
# First disable remotely so webhooks stop quickly
|
||||
# Now disable remotely so webhooks stop quickly
|
||||
try:
|
||||
await provider_service.teardown_trigger(trigger)
|
||||
except Exception:
|
||||
|
@ -188,13 +198,6 @@ class TriggerService:
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
client = await self._db.client
|
||||
result = await client.table('agent_triggers').delete().eq('trigger_id', trigger_id).execute()
|
||||
|
||||
success = len(result.data) > 0
|
||||
if success:
|
||||
logger.debug(f"Deleted trigger {trigger_id}")
|
||||
|
||||
return success
|
||||
|
||||
async def process_trigger_event(self, trigger_id: str, raw_data: Dict[str, Any]) -> TriggerResult:
|
||||
|
|
|
@ -1,209 +1,13 @@
|
|||
# Master model configuration - single source of truth
|
||||
MODELS = {
|
||||
# Free tier models
|
||||
from models import model_manager
|
||||
|
||||
"anthropic/claude-sonnet-4-20250514": {
|
||||
"aliases": ["claude-sonnet-4"],
|
||||
"pricing": {
|
||||
"input_cost_per_million_tokens": 3.00,
|
||||
"output_cost_per_million_tokens": 15.00
|
||||
},
|
||||
"context_window": 200_000, # 200k tokens
|
||||
"tier_availability": ["paid"]
|
||||
},
|
||||
# "openrouter/deepseek/deepseek-chat": {
|
||||
# "aliases": ["deepseek"],
|
||||
# "pricing": {
|
||||
# "input_cost_per_million_tokens": 0.38,
|
||||
# "output_cost_per_million_tokens": 0.89
|
||||
# },
|
||||
# "context_window": 128_000, # 128k tokens
|
||||
# "tier_availability": ["free", "paid"]
|
||||
# },
|
||||
# "openrouter/qwen/qwen3-235b-a22b": {
|
||||
# "aliases": ["qwen3"],
|
||||
# "pricing": {
|
||||
# "input_cost_per_million_tokens": 0.13,
|
||||
# "output_cost_per_million_tokens": 0.60
|
||||
# },
|
||||
# "context_window": 128_000, # 128k tokens
|
||||
# "tier_availability": ["free", "paid"]
|
||||
# },
|
||||
# "openrouter/google/gemini-2.5-flash-preview-05-20": {
|
||||
# "aliases": ["gemini-flash-2.5"],
|
||||
# "pricing": {
|
||||
# "input_cost_per_million_tokens": 0.15,
|
||||
# "output_cost_per_million_tokens": 0.60
|
||||
# },
|
||||
# "tier_availability": ["free", "paid"]
|
||||
# },
|
||||
# "openrouter/deepseek/deepseek-chat-v3-0324": {
|
||||
# "aliases": ["deepseek/deepseek-chat-v3-0324"],
|
||||
# "pricing": {
|
||||
# "input_cost_per_million_tokens": 0.38,
|
||||
# "output_cost_per_million_tokens": 0.89
|
||||
# },
|
||||
# "tier_availability": ["free", "paid"]
|
||||
# },
|
||||
"openrouter/moonshotai/kimi-k2": {
|
||||
"aliases": ["moonshotai/kimi-k2"],
|
||||
"pricing": {
|
||||
"input_cost_per_million_tokens": 1.00,
|
||||
"output_cost_per_million_tokens": 3.00
|
||||
},
|
||||
"context_window": 200_000, # 200k tokens
|
||||
"tier_availability": ["free", "paid"]
|
||||
},
|
||||
"xai/grok-4": {
|
||||
"aliases": ["grok-4", "x-ai/grok-4"],
|
||||
"pricing": {
|
||||
"input_cost_per_million_tokens": 5.00,
|
||||
"output_cost_per_million_tokens": 15.00
|
||||
},
|
||||
"context_window": 128_000, # 128k tokens
|
||||
"tier_availability": ["paid"]
|
||||
},
|
||||
_legacy_data = model_manager.get_legacy_constants()
|
||||
|
||||
# Paid tier only models
|
||||
"gemini/gemini-2.5-pro": {
|
||||
"aliases": ["google/gemini-2.5-pro"],
|
||||
"pricing": {
|
||||
"input_cost_per_million_tokens": 1.25,
|
||||
"output_cost_per_million_tokens": 10.00
|
||||
},
|
||||
"context_window": 2_000_000, # 2M tokens
|
||||
"tier_availability": ["paid"]
|
||||
},
|
||||
# "openai/gpt-4o": {
|
||||
# "aliases": ["gpt-4o"],
|
||||
# "pricing": {
|
||||
# "input_cost_per_million_tokens": 2.50,
|
||||
# "output_cost_per_million_tokens": 10.00
|
||||
# },
|
||||
# "tier_availability": ["paid"]
|
||||
# },
|
||||
# "openai/gpt-4.1": {
|
||||
# "aliases": ["gpt-4.1"],
|
||||
# "pricing": {
|
||||
# "input_cost_per_million_tokens": 15.00,
|
||||
# "output_cost_per_million_tokens": 60.00
|
||||
# },
|
||||
# "tier_availability": ["paid"]
|
||||
# },
|
||||
"openai/gpt-5": {
|
||||
"aliases": ["gpt-5"],
|
||||
"pricing": {
|
||||
"input_cost_per_million_tokens": 1.25,
|
||||
"output_cost_per_million_tokens": 10.00
|
||||
},
|
||||
"context_window": 400_000, # 400k tokens
|
||||
"tier_availability": ["paid"]
|
||||
},
|
||||
"openai/gpt-5-mini": {
|
||||
"aliases": ["gpt-5-mini"],
|
||||
"pricing": {
|
||||
"input_cost_per_million_tokens": 0.25,
|
||||
"output_cost_per_million_tokens": 2.00
|
||||
},
|
||||
"context_window": 400_000, # 400k tokens
|
||||
"tier_availability": ["free", "paid"]
|
||||
},
|
||||
# "openai/gpt-4.1-mini": {
|
||||
# "aliases": ["gpt-4.1-mini"],
|
||||
# "pricing": {
|
||||
# "input_cost_per_million_tokens": 1.50,
|
||||
# "output_cost_per_million_tokens": 6.00
|
||||
# },
|
||||
# "tier_availability": ["paid"]
|
||||
# },
|
||||
"anthropic/claude-3-7-sonnet-latest": {
|
||||
"aliases": ["sonnet-3.7"],
|
||||
"pricing": {
|
||||
"input_cost_per_million_tokens": 3.00,
|
||||
"output_cost_per_million_tokens": 15.00
|
||||
},
|
||||
"context_window": 200_000, # 200k tokens
|
||||
"tier_availability": ["paid"]
|
||||
},
|
||||
"anthropic/claude-3-5-sonnet-latest": {
|
||||
"aliases": ["sonnet-3.5"],
|
||||
"pricing": {
|
||||
"input_cost_per_million_tokens": 3.00,
|
||||
"output_cost_per_million_tokens": 15.00
|
||||
},
|
||||
"context_window": 200_000, # 200k tokens
|
||||
"tier_availability": ["paid"]
|
||||
},
|
||||
}
|
||||
|
||||
# Derived structures (auto-generated from MODELS)
|
||||
def _generate_model_structures():
|
||||
"""Generate all model structures from the master MODELS dictionary."""
|
||||
|
||||
# Generate tier lists
|
||||
free_models = []
|
||||
paid_models = []
|
||||
|
||||
# Generate aliases
|
||||
aliases = {}
|
||||
|
||||
# Generate pricing
|
||||
pricing = {}
|
||||
|
||||
# Generate context window limits
|
||||
context_windows = {}
|
||||
|
||||
for model_name, config in MODELS.items():
|
||||
# Add to tier lists
|
||||
if "free" in config["tier_availability"]:
|
||||
free_models.append(model_name)
|
||||
if "paid" in config["tier_availability"]:
|
||||
paid_models.append(model_name)
|
||||
|
||||
# Add aliases
|
||||
for alias in config["aliases"]:
|
||||
aliases[alias] = model_name
|
||||
|
||||
# Add pricing
|
||||
pricing[model_name] = config["pricing"]
|
||||
|
||||
# Add context window limits
|
||||
if "context_window" in config:
|
||||
context_windows[model_name] = config["context_window"]
|
||||
|
||||
# Also add pricing and context windows for legacy model name variations
|
||||
if model_name.startswith("openrouter/deepseek/"):
|
||||
legacy_name = model_name.replace("openrouter/", "")
|
||||
pricing[legacy_name] = config["pricing"]
|
||||
if "context_window" in config:
|
||||
context_windows[legacy_name] = config["context_window"]
|
||||
elif model_name.startswith("openrouter/qwen/"):
|
||||
legacy_name = model_name.replace("openrouter/", "")
|
||||
pricing[legacy_name] = config["pricing"]
|
||||
if "context_window" in config:
|
||||
context_windows[legacy_name] = config["context_window"]
|
||||
elif model_name.startswith("gemini/"):
|
||||
legacy_name = model_name.replace("gemini/", "")
|
||||
pricing[legacy_name] = config["pricing"]
|
||||
if "context_window" in config:
|
||||
context_windows[legacy_name] = config["context_window"]
|
||||
elif model_name.startswith("anthropic/"):
|
||||
# Add anthropic/claude-sonnet-4 alias for claude-sonnet-4-20250514
|
||||
if "claude-sonnet-4-20250514" in model_name:
|
||||
pricing["anthropic/claude-sonnet-4"] = config["pricing"]
|
||||
if "context_window" in config:
|
||||
context_windows["anthropic/claude-sonnet-4"] = config["context_window"]
|
||||
elif model_name.startswith("xai/"):
|
||||
# Add pricing for OpenRouter x-ai models
|
||||
openrouter_name = model_name.replace("xai/", "openrouter/x-ai/")
|
||||
pricing[openrouter_name] = config["pricing"]
|
||||
if "context_window" in config:
|
||||
context_windows[openrouter_name] = config["context_window"]
|
||||
|
||||
return free_models, paid_models, aliases, pricing, context_windows
|
||||
|
||||
# Generate all structures
|
||||
FREE_TIER_MODELS, PAID_TIER_MODELS, MODEL_NAME_ALIASES, HARDCODED_MODEL_PRICES, MODEL_CONTEXT_WINDOWS = _generate_model_structures()
|
||||
MODELS = _legacy_data["MODELS"]
|
||||
MODEL_NAME_ALIASES = _legacy_data["MODEL_NAME_ALIASES"]
|
||||
HARDCODED_MODEL_PRICES = _legacy_data["HARDCODED_MODEL_PRICES"]
|
||||
MODEL_CONTEXT_WINDOWS = _legacy_data["MODEL_CONTEXT_WINDOWS"]
|
||||
FREE_TIER_MODELS = _legacy_data["FREE_TIER_MODELS"]
|
||||
PAID_TIER_MODELS = _legacy_data["PAID_TIER_MODELS"]
|
||||
|
||||
MODEL_ACCESS_TIERS = {
|
||||
"free": FREE_TIER_MODELS,
|
||||
|
@ -220,38 +24,4 @@ MODEL_ACCESS_TIERS = {
|
|||
}
|
||||
|
||||
def get_model_context_window(model_name: str, default: int = 31_000) -> int:
|
||||
"""
|
||||
Get the context window size for a given model.
|
||||
|
||||
Args:
|
||||
model_name: The model name or alias
|
||||
default: Default context window if model not found
|
||||
|
||||
Returns:
|
||||
Context window size in tokens
|
||||
"""
|
||||
# Check direct model name first
|
||||
if model_name in MODEL_CONTEXT_WINDOWS:
|
||||
return MODEL_CONTEXT_WINDOWS[model_name]
|
||||
|
||||
# Check if it's an alias
|
||||
if model_name in MODEL_NAME_ALIASES:
|
||||
canonical_name = MODEL_NAME_ALIASES[model_name]
|
||||
if canonical_name in MODEL_CONTEXT_WINDOWS:
|
||||
return MODEL_CONTEXT_WINDOWS[canonical_name]
|
||||
|
||||
# Fallback patterns for common model naming variations
|
||||
if 'sonnet' in model_name.lower():
|
||||
return 200_000 # Claude Sonnet models
|
||||
elif 'gpt-5' in model_name.lower():
|
||||
return 400_000 # GPT-5 models
|
||||
elif 'gemini' in model_name.lower():
|
||||
return 2_000_000 # Gemini models
|
||||
elif 'grok' in model_name.lower():
|
||||
return 128_000 # Grok models
|
||||
elif 'gpt' in model_name.lower():
|
||||
return 128_000 # GPT-4 and variants
|
||||
elif 'deepseek' in model_name.lower():
|
||||
return 128_000 # DeepSeek models
|
||||
|
||||
return default
|
||||
return model_manager.get_context_window(model_name, default)
|
||||
|
|
|
@ -19,12 +19,10 @@ import { cn } from '@/lib/utils';
|
|||
import {
|
||||
useModelSelection,
|
||||
MODELS,
|
||||
STORAGE_KEY_CUSTOM_MODELS,
|
||||
formatModelName,
|
||||
getCustomModels,
|
||||
DEFAULT_FREE_MODEL_ID,
|
||||
DEFAULT_PREMIUM_MODEL_ID
|
||||
} from '@/components/thread/chat-input/_use-model-selection';
|
||||
import { formatModelName, getPrefixedModelId } from '@/lib/stores/model-store';
|
||||
import { useAvailableModels } from '@/hooks/react-query/subscriptions/use-billing';
|
||||
import { isLocalMode } from '@/lib/config';
|
||||
import { CustomModelDialog, CustomModelFormData } from '@/components/thread/chat-input/custom-model-dialog';
|
||||
|
@ -52,36 +50,33 @@ export function AgentModelSelector({
|
|||
variant = 'default',
|
||||
className,
|
||||
}: AgentModelSelectorProps) {
|
||||
const { allModels, canAccessModel, subscriptionStatus } = useModelSelection();
|
||||
const {
|
||||
allModels,
|
||||
canAccessModel,
|
||||
subscriptionStatus,
|
||||
selectedModel: storeSelectedModel,
|
||||
handleModelChange: storeHandleModelChange,
|
||||
customModels: storeCustomModels,
|
||||
addCustomModel: storeAddCustomModel,
|
||||
updateCustomModel: storeUpdateCustomModel,
|
||||
removeCustomModel: storeRemoveCustomModel
|
||||
} = useModelSelection();
|
||||
const { data: modelsData } = useAvailableModels();
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const [searchQuery, setSearchQuery] = useState('');
|
||||
const [highlightedIndex, setHighlightedIndex] = useState<number>(-1);
|
||||
const searchInputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
// Paywall and billing states
|
||||
const [paywallOpen, setPaywallOpen] = useState(false);
|
||||
const [lockedModel, setLockedModel] = useState<string | null>(null);
|
||||
const [billingModalOpen, setBillingModalOpen] = useState(false);
|
||||
|
||||
// Custom model states for local mode
|
||||
const [customModels, setCustomModels] = useState<CustomModel[]>([]);
|
||||
const [isCustomModelDialogOpen, setIsCustomModelDialogOpen] = useState(false);
|
||||
const [dialogInitialData, setDialogInitialData] = useState<CustomModelFormData>({ id: '', label: '' });
|
||||
const [dialogMode, setDialogMode] = useState<'add' | 'edit'>('add');
|
||||
const [editingModelId, setEditingModelId] = useState<string | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (isLocalMode()) {
|
||||
setCustomModels(getCustomModels());
|
||||
}
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (isLocalMode() && customModels.length > 0) {
|
||||
localStorage.setItem(STORAGE_KEY_CUSTOM_MODELS, JSON.stringify(customModels));
|
||||
}
|
||||
}, [customModels]);
|
||||
const customModels = storeCustomModels;
|
||||
|
||||
const normalizeModelId = (modelId?: string): string => {
|
||||
if (!modelId) return subscriptionStatus === 'active' ? DEFAULT_PREMIUM_MODEL_ID : DEFAULT_FREE_MODEL_ID;
|
||||
|
@ -103,17 +98,45 @@ export function AgentModelSelector({
|
|||
return modelId;
|
||||
};
|
||||
|
||||
const selectedModel = normalizeModelId(value);
|
||||
const normalizedValue = normalizeModelId(value);
|
||||
|
||||
useEffect(() => {
|
||||
if (normalizedValue && normalizedValue !== storeSelectedModel) {
|
||||
storeHandleModelChange(normalizedValue);
|
||||
}
|
||||
}, [normalizedValue, storeSelectedModel, storeHandleModelChange]);
|
||||
|
||||
const selectedModel = storeSelectedModel;
|
||||
|
||||
const enhancedModelOptions = useMemo(() => {
|
||||
const modelMap = new Map();
|
||||
|
||||
allModels.forEach(model => {
|
||||
modelMap.set(model.id, {
|
||||
...model,
|
||||
isCustom: false
|
||||
if (modelsData?.models) {
|
||||
modelsData.models.forEach(model => {
|
||||
const shortName = model.short_name || model.id;
|
||||
const displayName = model.display_name || shortName;
|
||||
|
||||
modelMap.set(shortName, {
|
||||
id: shortName,
|
||||
label: displayName,
|
||||
requiresSubscription: model.requires_subscription || false,
|
||||
priority: model.priority || 0,
|
||||
recommended: model.recommended || false,
|
||||
top: (model.priority || 0) >= 90,
|
||||
capabilities: model.capabilities || [],
|
||||
contextWindow: model.context_window || 128000,
|
||||
isCustom: false
|
||||
});
|
||||
});
|
||||
});
|
||||
} else {
|
||||
// Fallback to allModels if API data not available
|
||||
allModels.forEach(model => {
|
||||
modelMap.set(model.id, {
|
||||
...model,
|
||||
isCustom: false
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
if (isLocalMode()) {
|
||||
customModels.forEach(model => {
|
||||
|
@ -136,7 +159,7 @@ export function AgentModelSelector({
|
|||
}
|
||||
|
||||
return Array.from(modelMap.values());
|
||||
}, [allModels, customModels]);
|
||||
}, [modelsData?.models, allModels, customModels]);
|
||||
|
||||
const selectedModelDisplay = useMemo(() => {
|
||||
const model = enhancedModelOptions.find(m => m.id === selectedModel);
|
||||
|
@ -162,7 +185,7 @@ export function AgentModelSelector({
|
|||
const freeModels = sortedModels.filter(m => !m.requiresSubscription);
|
||||
const premiumModels = sortedModels.filter(m => m.requiresSubscription);
|
||||
|
||||
const shouldDisplayAll = (!isLocalMode() && subscriptionStatus === 'no_subscription') && premiumModels.length > 0;
|
||||
const shouldDisplayAll = !isLocalMode() && premiumModels.length > 0;
|
||||
|
||||
useEffect(() => {
|
||||
if (isOpen && searchInputRef.current) {
|
||||
|
@ -176,20 +199,18 @@ export function AgentModelSelector({
|
|||
}, [isOpen]);
|
||||
|
||||
const handleSelect = (modelId: string) => {
|
||||
console.log('🔧 AgentModelSelector: Selecting model:', modelId);
|
||||
console.log('🔧 AgentModelSelector: Current selectedModel (normalized):', selectedModel);
|
||||
console.log('🔧 AgentModelSelector: Current value prop:', value);
|
||||
|
||||
const isCustomModel = customModels.some(model => model.id === modelId);
|
||||
|
||||
if (isCustomModel && isLocalMode()) {
|
||||
storeHandleModelChange(modelId);
|
||||
onChange(modelId);
|
||||
setIsOpen(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (isLocalMode() || canAccessModel(modelId)) {
|
||||
// Don't transform the modelId - pass it as-is to match what useModelSelection expects
|
||||
const hasAccess = isLocalMode() || canAccessModel(modelId);
|
||||
if (hasAccess) {
|
||||
storeHandleModelChange(modelId);
|
||||
onChange(modelId);
|
||||
setIsOpen(false);
|
||||
} else {
|
||||
|
@ -262,22 +283,16 @@ export function AgentModelSelector({
|
|||
closeCustomModelDialog();
|
||||
const newModel = { id: modelId, label: modelLabel };
|
||||
|
||||
const updatedModels = dialogMode === 'add'
|
||||
? [...customModels, newModel]
|
||||
: customModels.map(model => model.id === editingModelId ? newModel : model);
|
||||
|
||||
try {
|
||||
localStorage.setItem(STORAGE_KEY_CUSTOM_MODELS, JSON.stringify(updatedModels));
|
||||
} catch (error) {
|
||||
console.error('Failed to save custom models to localStorage:', error);
|
||||
}
|
||||
|
||||
setCustomModels(updatedModels);
|
||||
|
||||
if (dialogMode === 'add') {
|
||||
storeAddCustomModel(newModel);
|
||||
storeHandleModelChange(modelId);
|
||||
onChange(modelId);
|
||||
} else if (selectedModel === editingModelId) {
|
||||
onChange(modelId);
|
||||
} else {
|
||||
storeUpdateCustomModel(editingModelId!, newModel);
|
||||
if (selectedModel === editingModelId) {
|
||||
storeHandleModelChange(modelId);
|
||||
onChange(modelId);
|
||||
}
|
||||
}
|
||||
|
||||
setIsOpen(false);
|
||||
|
@ -293,20 +308,11 @@ export function AgentModelSelector({
|
|||
e?.stopPropagation();
|
||||
e?.preventDefault();
|
||||
|
||||
const updatedCustomModels = customModels.filter(model => model.id !== modelId);
|
||||
|
||||
if (isLocalMode() && typeof window !== 'undefined') {
|
||||
try {
|
||||
localStorage.setItem(STORAGE_KEY_CUSTOM_MODELS, JSON.stringify(updatedCustomModels));
|
||||
} catch (error) {
|
||||
console.error('Failed to update custom models in localStorage:', error);
|
||||
}
|
||||
}
|
||||
|
||||
setCustomModels(updatedCustomModels);
|
||||
storeRemoveCustomModel(modelId);
|
||||
|
||||
if (selectedModel === modelId) {
|
||||
const defaultModel = subscriptionStatus === 'active' ? DEFAULT_PREMIUM_MODEL_ID : DEFAULT_FREE_MODEL_ID;
|
||||
storeHandleModelChange(defaultModel);
|
||||
onChange(defaultModel);
|
||||
}
|
||||
};
|
||||
|
@ -461,7 +467,7 @@ export function AgentModelSelector({
|
|||
</TooltipProvider>
|
||||
<DropdownMenuContent
|
||||
align={variant === 'menu-item' ? 'end' : 'start'}
|
||||
className="w-72 p-0 overflow-hidden"
|
||||
className="w-76 p-0 overflow-hidden"
|
||||
sideOffset={variant === 'menu-item' ? 8 : 4}
|
||||
>
|
||||
<div className="max-h-[400px] overflow-y-auto scrollbar-thin scrollbar-thumb-zinc-300 dark:scrollbar-thumb-zinc-700 scrollbar-track-transparent w-full">
|
||||
|
@ -535,58 +541,77 @@ export function AgentModelSelector({
|
|||
<div className="mt-4 border-t border-border pt-2">
|
||||
<div className="px-3 py-1.5 text-xs font-medium text-blue-500 flex items-center">
|
||||
<Crown className="h-3.5 w-3.5 mr-1.5" />
|
||||
Additional Models
|
||||
{subscriptionStatus === 'active' ? 'Premium Models' : 'Additional Models'}
|
||||
</div>
|
||||
<div className="relative h-40 overflow-hidden px-2">
|
||||
{premiumModels.slice(0, 3).map((model, index) => (
|
||||
<TooltipProvider key={`premium-${model.id}-${index}`}>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<div className='w-full'>
|
||||
<DropdownMenuItem
|
||||
className="text-sm px-3 rounded-lg py-2 mx-2 my-0.5 flex items-center justify-between opacity-70 cursor-pointer pointer-events-none"
|
||||
>
|
||||
<div className="flex items-center">
|
||||
<span className="font-medium">{model.label}</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
{MODELS[model.id]?.recommended && (
|
||||
<span className="text-xs px-1.5 py-0.5 rounded-sm bg-blue-100 dark:bg-blue-900 text-blue-600 dark:text-blue-300 font-medium whitespace-nowrap">
|
||||
Recommended
|
||||
</span>
|
||||
<div className="relative overflow-hidden" style={{ maxHeight: subscriptionStatus === 'active' ? 'none' : '160px' }}>
|
||||
{(subscriptionStatus === 'active' ? premiumModels : premiumModels.slice(0, 3)).map((model, index) => {
|
||||
const canAccess = isLocalMode() || canAccessModel(model.id);
|
||||
const isRecommended = model.recommended;
|
||||
|
||||
return (
|
||||
<TooltipProvider key={`premium-${model.id}-${index}`}>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<div className='w-full'>
|
||||
<DropdownMenuItem
|
||||
className={cn(
|
||||
"text-sm px-3 rounded-lg py-2 mx-2 my-0.5 flex items-center justify-between cursor-pointer",
|
||||
!canAccess && "opacity-70"
|
||||
)}
|
||||
<Crown className="h-3.5 w-3.5 text-blue-500" />
|
||||
</div>
|
||||
</DropdownMenuItem>
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="left" className="text-xs max-w-xs">
|
||||
<p>Requires subscription to access premium model</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
))}
|
||||
<div className="absolute inset-0 bg-gradient-to-t from-background via-background/95 to-transparent flex items-end justify-center">
|
||||
<div className="w-full p-3">
|
||||
<div className="rounded-xl bg-gradient-to-br from-blue-50/80 to-blue-200/70 dark:from-blue-950/40 dark:to-blue-900/30 shadow-sm border border-blue-200/50 dark:border-blue-800/50 p-3">
|
||||
<div className="flex flex-col space-y-2">
|
||||
<div className="flex items-center">
|
||||
<Crown className="h-4 w-4 text-blue-500 mr-2 flex-shrink-0" />
|
||||
<div>
|
||||
<p className="text-sm font-medium">Unlock all models + higher limits</p>
|
||||
onClick={() => handleSelect(model.id)}
|
||||
>
|
||||
<div className="flex items-center">
|
||||
<span className="font-medium">{model.label}</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
{isRecommended && (
|
||||
<span className="text-xs px-1.5 py-0.5 rounded-sm bg-blue-100 dark:bg-blue-900 text-blue-600 dark:text-blue-300 font-medium whitespace-nowrap">
|
||||
Recommended
|
||||
</span>
|
||||
)}
|
||||
{!canAccess && <Crown className="h-3.5 w-3.5 text-blue-500" />}
|
||||
{selectedModel === model.id && (
|
||||
<Check className="h-4 w-4 text-blue-500" />
|
||||
)}
|
||||
</div>
|
||||
</DropdownMenuItem>
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="left" className="text-xs max-w-xs">
|
||||
<p>
|
||||
{canAccess
|
||||
? (isRecommended ? 'Recommended for optimal performance' : 'Premium model')
|
||||
: 'Requires subscription to access premium model'
|
||||
}
|
||||
</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
);
|
||||
})}
|
||||
{subscriptionStatus !== 'active' && (
|
||||
<div className="absolute inset-0 bg-gradient-to-t from-background via-background/95 to-transparent flex items-end justify-center">
|
||||
<div className="w-full p-3">
|
||||
<div className="rounded-xl bg-gradient-to-br from-blue-50/80 to-blue-200/70 dark:from-blue-950/40 dark:to-blue-900/30 shadow-sm border border-blue-200/50 dark:border-blue-800/50 p-3">
|
||||
<div className="flex flex-col space-y-2">
|
||||
<div className="flex items-center">
|
||||
<Crown className="h-4 w-4 text-blue-500 mr-2 flex-shrink-0" />
|
||||
<div>
|
||||
<p className="text-sm font-medium">Unlock all models + higher limits</p>
|
||||
</div>
|
||||
</div>
|
||||
<Button
|
||||
size="sm"
|
||||
className="w-full h-8 font-medium"
|
||||
onClick={handleUpgradeClick}
|
||||
>
|
||||
Upgrade now
|
||||
</Button>
|
||||
</div>
|
||||
<Button
|
||||
size="sm"
|
||||
className="w-full h-8 font-medium"
|
||||
onClick={handleUpgradeClick}
|
||||
>
|
||||
Upgrade now
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
|
|
|
@ -0,0 +1,230 @@
|
|||
'use client';
|
||||
|
||||
import { useSubscriptionData } from '@/contexts/SubscriptionContext';
|
||||
import { useEffect, useMemo } from 'react';
|
||||
import { isLocalMode } from '@/lib/config';
|
||||
import { useAvailableModels } from '@/hooks/react-query/subscriptions/use-model';
|
||||
import {
|
||||
useModelStore,
|
||||
canAccessModel,
|
||||
formatModelName,
|
||||
getPrefixedModelId,
|
||||
type SubscriptionStatus,
|
||||
type ModelOption,
|
||||
type CustomModel
|
||||
} from '@/lib/stores/model-store';
|
||||
|
||||
export const useModelSelection = () => {
|
||||
const { data: subscriptionData } = useSubscriptionData();
|
||||
const { data: modelsData, isLoading: isLoadingModels } = useAvailableModels({
|
||||
refetchOnMount: false,
|
||||
});
|
||||
|
||||
const {
|
||||
selectedModel,
|
||||
customModels,
|
||||
hasHydrated,
|
||||
setSelectedModel,
|
||||
addCustomModel,
|
||||
updateCustomModel,
|
||||
removeCustomModel,
|
||||
setCustomModels,
|
||||
setHasHydrated,
|
||||
getDefaultModel,
|
||||
resetToDefault,
|
||||
} = useModelStore();
|
||||
|
||||
const subscriptionStatus: SubscriptionStatus = (subscriptionData?.status === 'active' || subscriptionData?.status === 'trialing')
|
||||
? 'active'
|
||||
: 'no_subscription';
|
||||
|
||||
// Load custom models from localStorage for local mode
|
||||
useEffect(() => {
|
||||
if (isLocalMode() && hasHydrated && typeof window !== 'undefined') {
|
||||
try {
|
||||
const storedModels = localStorage.getItem('customModels');
|
||||
if (storedModels) {
|
||||
const parsedModels = JSON.parse(storedModels);
|
||||
if (Array.isArray(parsedModels)) {
|
||||
const validModels = parsedModels.filter((model: any) =>
|
||||
model && typeof model === 'object' &&
|
||||
typeof model.id === 'string' &&
|
||||
typeof model.label === 'string'
|
||||
);
|
||||
setCustomModels(validModels);
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Error loading custom models:', e);
|
||||
}
|
||||
}
|
||||
}, [isLocalMode, hasHydrated, setCustomModels]);
|
||||
|
||||
const MODEL_OPTIONS = useMemo(() => {
|
||||
let models: ModelOption[] = [];
|
||||
if (!modelsData?.models || isLoadingModels) {
|
||||
models = [
|
||||
{
|
||||
id: 'moonshotai/kimi-k2',
|
||||
label: 'Kimi K2',
|
||||
requiresSubscription: false,
|
||||
priority: 100,
|
||||
recommended: true
|
||||
},
|
||||
{
|
||||
id: 'claude-sonnet-4',
|
||||
label: 'Claude Sonnet 4',
|
||||
requiresSubscription: true,
|
||||
priority: 100,
|
||||
recommended: true
|
||||
},
|
||||
];
|
||||
} else {
|
||||
models = modelsData.models.map(model => {
|
||||
const shortName = model.short_name || model.id;
|
||||
const displayName = model.display_name || shortName;
|
||||
|
||||
return {
|
||||
id: shortName,
|
||||
label: displayName,
|
||||
requiresSubscription: model.requires_subscription || false,
|
||||
priority: model.priority || 0,
|
||||
recommended: model.recommended || false,
|
||||
top: (model.priority || 0) >= 90,
|
||||
capabilities: model.capabilities || [],
|
||||
contextWindow: model.context_window || 128000
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
if (isLocalMode() && customModels.length > 0) {
|
||||
const customModelOptions = customModels.map(model => ({
|
||||
id: model.id,
|
||||
label: model.label || formatModelName(model.id),
|
||||
requiresSubscription: false,
|
||||
top: false,
|
||||
isCustom: true,
|
||||
priority: 30,
|
||||
}));
|
||||
|
||||
models = [...models, ...customModelOptions];
|
||||
}
|
||||
|
||||
const sortedModels = models.sort((a, b) => {
|
||||
if (a.recommended !== b.recommended) {
|
||||
return a.recommended ? -1 : 1;
|
||||
}
|
||||
|
||||
if (a.priority !== b.priority) {
|
||||
return (b.priority || 0) - (a.priority || 0);
|
||||
}
|
||||
|
||||
return a.label.localeCompare(b.label);
|
||||
});
|
||||
|
||||
return sortedModels;
|
||||
}, [modelsData, isLoadingModels, customModels]);
|
||||
|
||||
const availableModels = useMemo(() => {
|
||||
return isLocalMode()
|
||||
? MODEL_OPTIONS
|
||||
: MODEL_OPTIONS.filter(model =>
|
||||
canAccessModel(subscriptionStatus, model.requiresSubscription)
|
||||
);
|
||||
}, [MODEL_OPTIONS, subscriptionStatus]);
|
||||
|
||||
// Validate model selection only after hydration and when subscription status changes
|
||||
useEffect(() => {
|
||||
// Skip validation until hydrated and models are loaded
|
||||
if (!hasHydrated || isLoadingModels || typeof window === 'undefined') {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if the selected model is still valid
|
||||
const isValidModel = MODEL_OPTIONS.some(model => model.id === selectedModel) ||
|
||||
(isLocalMode() && customModels.some(model => model.id === selectedModel));
|
||||
|
||||
if (!isValidModel) {
|
||||
console.log('🔧 ModelSelection: Invalid model detected, resetting to default');
|
||||
resetToDefault(subscriptionStatus);
|
||||
return;
|
||||
}
|
||||
|
||||
// For non-local mode, check if user still has access to the selected model
|
||||
if (!isLocalMode()) {
|
||||
const modelOption = MODEL_OPTIONS.find(m => m.id === selectedModel);
|
||||
if (modelOption && !canAccessModel(subscriptionStatus, modelOption.requiresSubscription)) {
|
||||
console.log('🔧 ModelSelection: User lost access to model, resetting to default');
|
||||
resetToDefault(subscriptionStatus);
|
||||
}
|
||||
}
|
||||
}, [hasHydrated, selectedModel, subscriptionStatus, MODEL_OPTIONS, customModels, isLoadingModels, resetToDefault]);
|
||||
|
||||
const handleModelChange = (modelId: string) => {
|
||||
const isCustomModel = isLocalMode() && customModels.some(model => model.id === modelId);
|
||||
|
||||
const modelOption = MODEL_OPTIONS.find(option => option.id === modelId);
|
||||
|
||||
if (!modelOption && !isCustomModel) {
|
||||
resetToDefault(subscriptionStatus);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!isCustomModel && !isLocalMode() &&
|
||||
!canAccessModel(subscriptionStatus, modelOption?.requiresSubscription ?? false)) {
|
||||
return;
|
||||
}
|
||||
|
||||
setSelectedModel(modelId);
|
||||
};
|
||||
|
||||
const getActualModelId = (modelId: string): string => {
|
||||
const isCustomModel = isLocalMode() && customModels.some(model => model.id === modelId);
|
||||
return isCustomModel ? getPrefixedModelId(modelId, true) : modelId;
|
||||
};
|
||||
|
||||
// Function to refresh custom models from localStorage
|
||||
const refreshCustomModels = () => {
|
||||
if (isLocalMode() && typeof window !== 'undefined') {
|
||||
try {
|
||||
const storedModels = localStorage.getItem('customModels');
|
||||
if (storedModels) {
|
||||
const parsedModels = JSON.parse(storedModels);
|
||||
if (Array.isArray(parsedModels)) {
|
||||
const validModels = parsedModels.filter((model: any) =>
|
||||
model && typeof model === 'object' &&
|
||||
typeof model.id === 'string' &&
|
||||
typeof model.label === 'string'
|
||||
);
|
||||
setCustomModels(validModels);
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Error loading custom models:', e);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
selectedModel,
|
||||
handleModelChange,
|
||||
setSelectedModel: handleModelChange, // Alias for backward compatibility
|
||||
availableModels,
|
||||
allModels: MODEL_OPTIONS,
|
||||
customModels,
|
||||
addCustomModel,
|
||||
updateCustomModel,
|
||||
removeCustomModel,
|
||||
refreshCustomModels,
|
||||
getActualModelId,
|
||||
canAccessModel: (modelId: string) => {
|
||||
if (isLocalMode()) return true;
|
||||
const model = MODEL_OPTIONS.find(m => m.id === modelId);
|
||||
return model ? canAccessModel(subscriptionStatus, model.requiresSubscription) : false;
|
||||
},
|
||||
isSubscriptionRequired: (modelId: string) => {
|
||||
return MODEL_OPTIONS.find(m => m.id === modelId)?.requiresSubscription || false;
|
||||
},
|
||||
subscriptionStatus,
|
||||
};
|
||||
};
|
|
@ -169,7 +169,7 @@ const saveModelPreference = (modelId: string): void => {
|
|||
}
|
||||
};
|
||||
|
||||
export const useModelSelection = () => {
|
||||
export const useModelSelectionOld = () => {
|
||||
const [selectedModel, setSelectedModel] = useState(DEFAULT_FREE_MODEL_ID);
|
||||
const [customModels, setCustomModels] = useState<CustomModel[]>([]);
|
||||
const [hasInitialized, setHasInitialized] = useState(false);
|
||||
|
@ -200,52 +200,41 @@ export const useModelSelection = () => {
|
|||
const MODEL_OPTIONS = useMemo(() => {
|
||||
let models = [];
|
||||
|
||||
// Default models if API data not available
|
||||
if (!modelsData?.models || isLoadingModels) {
|
||||
models = [
|
||||
{
|
||||
id: DEFAULT_FREE_MODEL_ID,
|
||||
label: 'KIMI K2',
|
||||
requiresSubscription: false,
|
||||
priority: MODELS[DEFAULT_FREE_MODEL_ID]?.priority || 100
|
||||
},
|
||||
{
|
||||
id: DEFAULT_PREMIUM_MODEL_ID,
|
||||
label: 'Claude Sonnet 4',
|
||||
requiresSubscription: true,
|
||||
priority: MODELS[DEFAULT_PREMIUM_MODEL_ID]?.priority || 100
|
||||
},
|
||||
];
|
||||
// Default models if API data not available
|
||||
if (!modelsData?.models || isLoadingModels) {
|
||||
models = [
|
||||
{
|
||||
id: DEFAULT_FREE_MODEL_ID,
|
||||
label: 'KIMI K2',
|
||||
requiresSubscription: false,
|
||||
priority: 100,
|
||||
recommended: true
|
||||
},
|
||||
{
|
||||
id: DEFAULT_PREMIUM_MODEL_ID,
|
||||
label: 'Claude Sonnet 4',
|
||||
requiresSubscription: true,
|
||||
priority: 100,
|
||||
recommended: true
|
||||
},
|
||||
];
|
||||
} else {
|
||||
// Process API-provided models
|
||||
// Process API-provided models - use clean data from new backend system
|
||||
models = modelsData.models.map(model => {
|
||||
// Use the clean data directly from the API (no more duplicates!)
|
||||
const shortName = model.short_name || model.id;
|
||||
const displayName = model.display_name || shortName;
|
||||
|
||||
// Format the display label
|
||||
let cleanLabel = displayName;
|
||||
if (cleanLabel.includes('/')) {
|
||||
cleanLabel = cleanLabel.split('/').pop() || cleanLabel;
|
||||
}
|
||||
|
||||
cleanLabel = cleanLabel
|
||||
.replace(/-/g, ' ')
|
||||
.split(' ')
|
||||
.map(word => word.charAt(0).toUpperCase() + word.slice(1))
|
||||
.join(' ');
|
||||
|
||||
// Get model data from our central MODELS constant
|
||||
const modelData = MODELS[shortName] || {};
|
||||
const isPremium = model?.requires_subscription || modelData.tier === 'premium' || false;
|
||||
|
||||
return {
|
||||
id: shortName,
|
||||
label: cleanLabel,
|
||||
requiresSubscription: isPremium,
|
||||
top: modelData.priority >= 90, // Mark high-priority models as "top"
|
||||
priority: modelData.priority || 0,
|
||||
lowQuality: modelData.lowQuality || false,
|
||||
recommended: modelData.recommended || false
|
||||
label: displayName,
|
||||
requiresSubscription: model.requires_subscription || false,
|
||||
priority: model.priority || 0,
|
||||
recommended: model.recommended || false,
|
||||
top: (model.priority || 0) >= 90, // Mark high-priority models as "top"
|
||||
lowQuality: false, // All models in new system are quality controlled
|
||||
capabilities: model.capabilities || [],
|
||||
contextWindow: model.context_window || 128000
|
||||
};
|
||||
});
|
||||
}
|
||||
|
@ -516,4 +505,7 @@ export const useModelSelection = () => {
|
|||
};
|
||||
};
|
||||
|
||||
// Export the new model selection hook
|
||||
export { useModelSelection } from './_use-model-selection-new';
|
||||
|
||||
// Export the hook but not any sorting logic - sorting is handled internally
|
|
@ -14,7 +14,7 @@ import { Card, CardContent } from '@/components/ui/card';
|
|||
import { handleFiles } from './file-upload-handler';
|
||||
import { MessageInput } from './message-input';
|
||||
import { AttachmentGroup } from '../attachment-group';
|
||||
import { useModelSelection } from './_use-model-selection';
|
||||
import { useModelSelection } from './_use-model-selection-new';
|
||||
import { useFileDelete } from '@/hooks/react-query/files';
|
||||
import { useQueryClient } from '@tanstack/react-query';
|
||||
import { ToolCallInput } from './floating-tool-preview';
|
||||
|
|
|
@ -452,7 +452,8 @@ export function FileAttachment({
|
|||
"group relative w-full",
|
||||
"rounded-xl border bg-card overflow-hidden pt-10", // Consistent card styling with header space
|
||||
isPdf ? "!min-h-[200px] sm:min-h-0 sm:h-[400px] max-h-[500px] sm:!min-w-[300px]" :
|
||||
standalone ? "min-h-[300px] h-auto" : "h-[300px]", // Better height handling for standalone
|
||||
isHtmlOrMd ? "!min-h-[200px] sm:min-h-0 sm:h-[400px] max-h-[600px] sm:!min-w-[300px]" :
|
||||
standalone ? "min-h-[300px] h-auto" : "h-[300px]", // Better height handling for standalone
|
||||
className
|
||||
)}
|
||||
style={{
|
||||
|
@ -469,8 +470,8 @@ export function FileAttachment({
|
|||
style={{
|
||||
minWidth: 0,
|
||||
width: '100%',
|
||||
containIntrinsicSize: isPdf ? '100% 500px' : undefined,
|
||||
contain: isPdf ? 'layout size' : undefined
|
||||
containIntrinsicSize: (isPdf || isHtmlOrMd) ? '100% 500px' : undefined,
|
||||
contain: (isPdf || isHtmlOrMd) ? 'layout size' : undefined
|
||||
}}
|
||||
>
|
||||
{/* Render PDF or text-based previews */}
|
||||
|
|
|
@ -3,29 +3,68 @@ import {
|
|||
FileDiff,
|
||||
CheckCircle,
|
||||
AlertTriangle,
|
||||
ExternalLink,
|
||||
Loader2,
|
||||
Code,
|
||||
Eye,
|
||||
File,
|
||||
ChevronDown,
|
||||
ChevronUp,
|
||||
Copy,
|
||||
Check,
|
||||
Minus,
|
||||
Plus,
|
||||
} from 'lucide-react';
|
||||
import {
|
||||
extractFilePath,
|
||||
extractFileContent,
|
||||
extractStreamingFileContent,
|
||||
formatTimestamp,
|
||||
getToolTitle,
|
||||
normalizeContentToString,
|
||||
extractToolData,
|
||||
} from '../utils';
|
||||
import {
|
||||
MarkdownRenderer,
|
||||
processUnicodeContent,
|
||||
} from '@/components/file-renderers/markdown-renderer';
|
||||
import { CsvRenderer } from '@/components/file-renderers/csv-renderer';
|
||||
import { XlsxRenderer } from '@/components/file-renderers/xlsx-renderer';
|
||||
import { cn } from '@/lib/utils';
|
||||
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
|
||||
import { useTheme } from 'next-themes';
|
||||
import { CodeBlockCode } from '@/components/ui/code-block';
|
||||
import { constructHtmlPreviewUrl } from '@/lib/utils/url';
|
||||
import {
|
||||
Card,
|
||||
CardContent,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
} from '@/components/ui/card';
|
||||
import { Badge } from '@/components/ui/badge';
|
||||
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip";
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
||||
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||
import {
|
||||
extractFileEditData,
|
||||
generateLineDiff,
|
||||
calculateDiffStats,
|
||||
LineDiff,
|
||||
DiffStats
|
||||
DiffStats,
|
||||
getLanguageFromFileName,
|
||||
getOperationType,
|
||||
getOperationConfigs,
|
||||
getFileIcon,
|
||||
processFilePath,
|
||||
getFileName,
|
||||
getFileExtension,
|
||||
isFileType,
|
||||
hasLanguageHighlighting,
|
||||
splitContentIntoLines,
|
||||
type FileOperation,
|
||||
type OperationConfig,
|
||||
} from './_utils';
|
||||
import { formatTimestamp, getToolTitle } from '../utils';
|
||||
import { ToolViewProps } from '../types';
|
||||
import { GenericToolView } from '../GenericToolView';
|
||||
import { LoadingState } from '../shared/LoadingState';
|
||||
import { toast } from 'sonner';
|
||||
import ReactDiffViewer from 'react-diff-viewer-continued';
|
||||
|
||||
const UnifiedDiffView: React.FC<{ oldCode: string; newCode: string }> = ({ oldCode, newCode }) => (
|
||||
|
@ -60,40 +99,6 @@ const UnifiedDiffView: React.FC<{ oldCode: string; newCode: string }> = ({ oldCo
|
|||
/>
|
||||
);
|
||||
|
||||
const SplitDiffView: React.FC<{ oldCode: string; newCode: string }> = ({ oldCode, newCode }) => (
|
||||
<ReactDiffViewer
|
||||
oldValue={oldCode}
|
||||
newValue={newCode}
|
||||
splitView={true}
|
||||
useDarkTheme={document.documentElement.classList.contains('dark')}
|
||||
styles={{
|
||||
variables: {
|
||||
dark: {
|
||||
diffViewerColor: '#e2e8f0',
|
||||
diffViewerBackground: '#09090b',
|
||||
addedBackground: '#104a32',
|
||||
addedColor: '#6ee7b7',
|
||||
removedBackground: '#5c1a2e',
|
||||
removedColor: '#fca5a5',
|
||||
},
|
||||
},
|
||||
diffContainer: {
|
||||
backgroundColor: 'var(--card)',
|
||||
border: 'none',
|
||||
},
|
||||
gutter: {
|
||||
backgroundColor: 'var(--muted)',
|
||||
'&:hover': {
|
||||
backgroundColor: 'var(--accent)',
|
||||
},
|
||||
},
|
||||
line: {
|
||||
fontFamily: 'monospace',
|
||||
},
|
||||
}}
|
||||
/>
|
||||
);
|
||||
|
||||
const ErrorState: React.FC<{ message?: string }> = ({ message }) => (
|
||||
<div className="flex flex-col items-center justify-center h-full py-12 px-6 bg-gradient-to-b from-white to-zinc-50 dark:from-zinc-950 dark:to-zinc-900">
|
||||
<div className="text-center w-full max-w-xs">
|
||||
|
@ -116,8 +121,42 @@ export function FileEditToolView({
|
|||
toolTimestamp,
|
||||
isSuccess = true,
|
||||
isStreaming = false,
|
||||
project,
|
||||
}: ToolViewProps): JSX.Element {
|
||||
const [viewMode, setViewMode] = useState<'unified' | 'split'>('unified');
|
||||
const { resolvedTheme } = useTheme();
|
||||
const isDarkTheme = resolvedTheme === 'dark';
|
||||
|
||||
// Add copy functionality state
|
||||
const [isCopyingContent, setIsCopyingContent] = useState(false);
|
||||
|
||||
// Copy functions
|
||||
const copyToClipboard = async (text: string) => {
|
||||
try {
|
||||
await navigator.clipboard.writeText(text);
|
||||
return true;
|
||||
} catch (err) {
|
||||
console.error('Failed to copy text: ', err);
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
const handleCopyContent = async () => {
|
||||
if (!updatedContent) return;
|
||||
|
||||
setIsCopyingContent(true);
|
||||
const success = await copyToClipboard(updatedContent);
|
||||
if (success) {
|
||||
toast.success('File content copied to clipboard');
|
||||
} else {
|
||||
toast.error('Failed to copy file content');
|
||||
}
|
||||
setTimeout(() => setIsCopyingContent(false), 500);
|
||||
};
|
||||
|
||||
const operation = getOperationType(name, assistantContent);
|
||||
const configs = getOperationConfigs();
|
||||
const config = configs[operation] || configs['edit']; // fallback to edit config
|
||||
const Icon = FileDiff; // Always use FileDiff for edit operations
|
||||
|
||||
const {
|
||||
filePath,
|
||||
|
@ -135,103 +174,275 @@ export function FileEditToolView({
|
|||
);
|
||||
|
||||
const toolTitle = getToolTitle(name);
|
||||
const processedFilePath = processFilePath(filePath);
|
||||
const fileName = getFileName(processedFilePath);
|
||||
const fileExtension = getFileExtension(fileName);
|
||||
|
||||
const isMarkdown = isFileType.markdown(fileExtension);
|
||||
const isHtml = isFileType.html(fileExtension);
|
||||
const isCsv = isFileType.csv(fileExtension);
|
||||
const isXlsx = isFileType.xlsx(fileExtension);
|
||||
|
||||
const language = getLanguageFromFileName(fileName);
|
||||
const hasHighlighting = hasLanguageHighlighting(language);
|
||||
const contentLines = splitContentIntoLines(updatedContent);
|
||||
|
||||
const htmlPreviewUrl =
|
||||
isHtml && project?.sandbox?.sandbox_url && processedFilePath
|
||||
? constructHtmlPreviewUrl(project.sandbox.sandbox_url, processedFilePath)
|
||||
: undefined;
|
||||
|
||||
const FileIcon = getFileIcon(fileName);
|
||||
|
||||
const lineDiff = originalContent && updatedContent ? generateLineDiff(originalContent, updatedContent) : [];
|
||||
const stats: DiffStats = calculateDiffStats(lineDiff);
|
||||
|
||||
const shouldShowError = !isStreaming && (!actualIsSuccess || (actualIsSuccess && (originalContent === null || updatedContent === null)));
|
||||
|
||||
return (
|
||||
<Card className="gap-0 flex border shadow-none border-t border-b-0 border-x-0 p-0 rounded-none flex-col h-full bg-card">
|
||||
<CardHeader className="h-14 bg-zinc-50/80 dark:bg-zinc-900/80 backdrop-blur-sm border-b p-2 px-4 space-y-2">
|
||||
<div className="flex flex-row items-center justify-between">
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="relative p-2 rounded-lg bg-gradient-to-br from-blue-500/20 to-blue-600/10 border border-blue-500/20">
|
||||
<FileDiff className="w-5 h-5 text-blue-500 dark:text-blue-400" />
|
||||
</div>
|
||||
<CardTitle className="text-base font-medium text-zinc-900 dark:text-zinc-100">
|
||||
{toolTitle}
|
||||
</CardTitle>
|
||||
</div>
|
||||
if (!isStreaming && !processedFilePath && !updatedContent) {
|
||||
return (
|
||||
<GenericToolView
|
||||
name={name || 'edit-file'}
|
||||
assistantContent={assistantContent}
|
||||
toolContent={toolContent}
|
||||
assistantTimestamp={assistantTimestamp}
|
||||
toolTimestamp={toolTimestamp}
|
||||
isSuccess={isSuccess}
|
||||
isStreaming={isStreaming}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
{!isStreaming && (
|
||||
<Badge
|
||||
variant="secondary"
|
||||
className={
|
||||
actualIsSuccess
|
||||
? "bg-gradient-to-b from-emerald-200 to-emerald-100 text-emerald-700 dark:from-emerald-800/50 dark:to-emerald-900/60 dark:text-emerald-300"
|
||||
: "bg-gradient-to-b from-rose-200 to-rose-100 text-rose-700 dark:from-rose-800/50 dark:to-rose-900/60 dark:text-rose-300"
|
||||
}
|
||||
>
|
||||
{actualIsSuccess ? (
|
||||
<CheckCircle className="h-3.5 w-3.5 mr-1" />
|
||||
) : (
|
||||
<AlertTriangle className="h-3.5 w-3.5 mr-1" />
|
||||
)}
|
||||
{actualIsSuccess ? 'Edit applied' : 'Edit failed'}
|
||||
</Badge>
|
||||
)}
|
||||
const renderFilePreview = () => {
|
||||
if (!updatedContent) {
|
||||
return (
|
||||
<div className="flex items-center justify-center h-full p-12">
|
||||
<div className="text-center">
|
||||
<FileIcon className="h-12 w-12 mx-auto mb-4 text-zinc-400" />
|
||||
<p className="text-sm text-zinc-500 dark:text-zinc-400">No content to preview</p>
|
||||
</div>
|
||||
</div>
|
||||
</CardHeader>
|
||||
);
|
||||
}
|
||||
|
||||
<CardContent className="p-0 flex-1 flex flex-col min-h-0">
|
||||
{isStreaming ? (
|
||||
<LoadingState
|
||||
icon={FileDiff}
|
||||
iconColor="text-blue-500 dark:text-blue-400"
|
||||
bgColor="bg-gradient-to-b from-blue-100 to-blue-50 shadow-inner dark:from-blue-800/40 dark:to-blue-900/60 dark:shadow-blue-950/20"
|
||||
title="Applying File Edit"
|
||||
filePath={filePath || 'Processing file...'}
|
||||
progressText="Analyzing changes"
|
||||
subtitle="Please wait while the file is being modified"
|
||||
if (isHtml && htmlPreviewUrl) {
|
||||
return (
|
||||
<div className="flex flex-col h-[calc(100vh-16rem)]">
|
||||
<iframe
|
||||
src={htmlPreviewUrl}
|
||||
title={`HTML Preview of ${fileName}`}
|
||||
className="flex-grow border-0"
|
||||
sandbox="allow-same-origin allow-scripts"
|
||||
/>
|
||||
) : shouldShowError ? (
|
||||
<ErrorState message={errorMessage} />
|
||||
) : (
|
||||
<div className="flex-1 flex flex-col min-h-0">
|
||||
<div className="shrink-0 p-3 border-b border-zinc-200 dark:border-zinc-800 bg-accent flex items-center justify-between">
|
||||
<div className="flex items-center">
|
||||
<File className="h-4 w-4 mr-2 text-zinc-500 dark:text-zinc-400" />
|
||||
<code className="text-sm font-mono text-zinc-700 dark:text-zinc-300">
|
||||
{filePath || 'Unknown file'}
|
||||
</code>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="flex items-center text-xs text-zinc-500 dark:text-zinc-400 gap-3">
|
||||
{stats.additions === 0 && stats.deletions === 0 ? (
|
||||
<Badge variant="outline" className="text-xs font-normal">No changes</Badge>
|
||||
) : (
|
||||
<>
|
||||
<div className="flex items-center">
|
||||
<Plus className="h-3.5 w-3.5 text-emerald-500 mr-1" />
|
||||
<span>{stats.additions}</span>
|
||||
</div>
|
||||
<div className="flex items-center">
|
||||
<Minus className="h-3.5 w-3.5 text-red-500 mr-1" />
|
||||
<span>{stats.deletions}</span>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
<Tabs value={viewMode} onValueChange={(v) => setViewMode(v as 'unified' | 'split')} className="w-auto">
|
||||
<TabsList className="h-7 p-0.5">
|
||||
<TabsTrigger value="unified" className="text-xs h-6 px-2">Unified</TabsTrigger>
|
||||
<TabsTrigger value="split" className="text-xs h-6 px-2">Split</TabsTrigger>
|
||||
</TabsList>
|
||||
</Tabs>
|
||||
if (isMarkdown) {
|
||||
return (
|
||||
<div className="p-1 py-0 prose dark:prose-invert prose-zinc max-w-none">
|
||||
<MarkdownRenderer
|
||||
content={processUnicodeContent(updatedContent)}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (isCsv) {
|
||||
return (
|
||||
<div className="h-full w-full p-4">
|
||||
<div className="h-[calc(100vh-17rem)] w-full bg-muted/20 border rounded-xl overflow-auto">
|
||||
<CsvRenderer content={processUnicodeContent(updatedContent)} />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (isXlsx) {
|
||||
return (
|
||||
<div className="h-full w-full p-4">
|
||||
<div className="h-[calc(100vh-17rem)] w-full bg-muted/20 border rounded-xl overflow-auto">
|
||||
<XlsxRenderer
|
||||
content={updatedContent}
|
||||
filePath={processedFilePath}
|
||||
fileName={fileName}
|
||||
project={project}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="p-4">
|
||||
<div className='w-full h-full bg-muted/20 border rounded-xl px-4 py-2 pb-6'>
|
||||
<pre className="text-sm font-mono text-zinc-800 dark:text-zinc-300 whitespace-pre-wrap break-words">
|
||||
{processUnicodeContent(updatedContent)}
|
||||
</pre>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const renderSourceCode = () => {
|
||||
if (!originalContent || !updatedContent) {
|
||||
return (
|
||||
<div className="flex items-center justify-center h-full p-12">
|
||||
<div className="text-center">
|
||||
<FileIcon className="h-12 w-12 mx-auto mb-4 text-zinc-400" />
|
||||
<p className="text-sm text-zinc-500 dark:text-zinc-400">No diff to display</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Show unified diff view in source tab
|
||||
return (
|
||||
<div className="flex-1 overflow-auto min-h-0 text-xs">
|
||||
<UnifiedDiffView oldCode={originalContent} newCode={updatedContent} />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<Card className="flex border shadow-none border-t border-b-0 border-x-0 p-0 rounded-none flex-col h-full overflow-hidden bg-card">
|
||||
<Tabs defaultValue={isMarkdown || isHtml || isCsv || isXlsx ? 'preview' : 'code'} className="w-full h-full">
|
||||
<CardHeader className="h-14 bg-zinc-50/80 dark:bg-zinc-900/80 backdrop-blur-sm border-b p-2 px-4 space-y-2 mb-0">
|
||||
<div className="flex flex-row items-center justify-between">
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="relative p-2 rounded-lg bg-gradient-to-br from-blue-500/20 to-blue-600/10 border border-blue-500/20">
|
||||
<FileDiff className="w-5 h-5 text-blue-500 dark:text-blue-400" />
|
||||
</div>
|
||||
<div>
|
||||
<CardTitle className="text-base font-medium text-zinc-900 dark:text-zinc-100">
|
||||
{toolTitle}
|
||||
</CardTitle>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex-1 overflow-auto min-h-0 text-xs">
|
||||
{viewMode === 'unified' ? (
|
||||
<UnifiedDiffView oldCode={originalContent!} newCode={updatedContent!} />
|
||||
) : (
|
||||
<SplitDiffView oldCode={originalContent!} newCode={updatedContent!} />
|
||||
<div className='flex items-center gap-2'>
|
||||
{isHtml && htmlPreviewUrl && !isStreaming && (
|
||||
<Button variant="outline" size="sm" className="h-8 text-xs bg-white dark:bg-muted/50 hover:bg-zinc-100 dark:hover:bg-zinc-800 shadow-none" asChild>
|
||||
<a href={htmlPreviewUrl} target="_blank" rel="noopener noreferrer">
|
||||
<ExternalLink className="h-3.5 w-3.5 mr-1.5" />
|
||||
Open in Browser
|
||||
</a>
|
||||
</Button>
|
||||
)}
|
||||
{/* Copy button - only show when there's file content */}
|
||||
{updatedContent && !isStreaming && (
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={handleCopyContent}
|
||||
disabled={isCopyingContent}
|
||||
className="h-8 text-xs bg-white dark:bg-muted/50 hover:bg-zinc-100 dark:hover:bg-zinc-800 shadow-none"
|
||||
title="Copy file content"
|
||||
>
|
||||
{isCopyingContent ? (
|
||||
<Check className="h-3.5 w-3.5 mr-1.5" />
|
||||
) : (
|
||||
<Copy className="h-3.5 w-3.5 mr-1.5" />
|
||||
)}
|
||||
<span className="hidden sm:inline">Copy</span>
|
||||
</Button>
|
||||
)}
|
||||
{/* Diff mode selector for source tab */}
|
||||
{originalContent && updatedContent && (
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="flex items-center text-xs text-zinc-500 dark:text-zinc-400 gap-3">
|
||||
{stats.additions === 0 && stats.deletions === 0 && (
|
||||
<Badge variant="outline" className="text-xs font-normal">No changes</Badge>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
<TabsList className="h-8 bg-muted/50 border border-border/50 p-0.5 gap-1">
|
||||
<TabsTrigger
|
||||
value="code"
|
||||
className="flex items-center gap-1.5 px-4 py-2 text-xs font-medium transition-all [&[data-state=active]]:bg-white [&[data-state=active]]:dark:bg-primary/10 [&[data-state=active]]:text-foreground hover:bg-background/50 text-muted-foreground shadow-none"
|
||||
>
|
||||
<Code className="h-3.5 w-3.5" />
|
||||
Source
|
||||
</TabsTrigger>
|
||||
<TabsTrigger
|
||||
value="preview"
|
||||
className="flex items-center gap-1.5 px-4 py-2 text-xs font-medium transition-all [&[data-state=active]]:bg-white [&[data-state=active]]:dark:bg-primary/10 [&[data-state=active]]:text-foreground hover:bg-background/50 text-muted-foreground shadow-none"
|
||||
>
|
||||
<Eye className="h-3.5 w-3.5" />
|
||||
Preview
|
||||
</TabsTrigger>
|
||||
</TabsList>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</CardContent>
|
||||
</CardHeader>
|
||||
|
||||
<CardContent className="p-0 -my-2 h-full flex-1 overflow-hidden relative">
|
||||
<TabsContent value="code" className="flex-1 h-full mt-0 p-0 overflow-hidden">
|
||||
<ScrollArea className="h-screen w-full min-h-0">
|
||||
{isStreaming && !updatedContent ? (
|
||||
<LoadingState
|
||||
icon={FileDiff}
|
||||
iconColor="text-blue-500 dark:text-blue-400"
|
||||
bgColor="bg-gradient-to-b from-blue-100 to-blue-50 shadow-inner dark:from-blue-800/40 dark:to-blue-900/60 dark:shadow-blue-950/20"
|
||||
title="Applying File Edit"
|
||||
filePath={processedFilePath || 'Processing file...'}
|
||||
subtitle="Please wait while the file is being modified"
|
||||
showProgress={false}
|
||||
/>
|
||||
) : shouldShowError ? (
|
||||
<ErrorState message={errorMessage} />
|
||||
) : (
|
||||
renderSourceCode()
|
||||
)}
|
||||
</ScrollArea>
|
||||
</TabsContent>
|
||||
|
||||
<TabsContent value="preview" className="w-full flex-1 h-full mt-0 p-0 overflow-hidden">
|
||||
<ScrollArea className="h-full w-full min-h-0">
|
||||
{isStreaming && !updatedContent ? (
|
||||
<LoadingState
|
||||
icon={FileDiff}
|
||||
iconColor="text-blue-500 dark:text-blue-400"
|
||||
bgColor="bg-gradient-to-b from-blue-100 to-blue-50 shadow-inner dark:from-blue-800/40 dark:to-blue-900/60 dark:shadow-blue-950/20"
|
||||
title="Applying File Edit"
|
||||
filePath={processedFilePath || 'Processing file...'}
|
||||
subtitle="Please wait while the file is being modified"
|
||||
showProgress={false}
|
||||
/>
|
||||
) : shouldShowError ? (
|
||||
<ErrorState message={errorMessage} />
|
||||
) : (
|
||||
renderFilePreview()
|
||||
)}
|
||||
{isStreaming && updatedContent && (
|
||||
<div className="sticky bottom-4 right-4 float-right mr-4 mb-4">
|
||||
<Badge className="bg-blue-500/90 text-white border-none shadow-lg animate-pulse">
|
||||
<Loader2 className="h-3 w-3 animate-spin mr-1" />
|
||||
Streaming...
|
||||
</Badge>
|
||||
</div>
|
||||
)}
|
||||
</ScrollArea>
|
||||
</TabsContent>
|
||||
</CardContent>
|
||||
|
||||
<div className="px-4 py-2 h-10 bg-gradient-to-r from-zinc-50/90 to-zinc-100/90 dark:from-zinc-900/90 dark:to-zinc-800/90 backdrop-blur-sm border-t border-zinc-200 dark:border-zinc-800 flex justify-between items-center gap-4">
|
||||
<div className="h-full flex items-center gap-2 text-sm text-zinc-500 dark:text-zinc-400">
|
||||
<Badge variant="outline" className="py-0.5 h-6">
|
||||
<FileIcon className="h-3 w-3" />
|
||||
{hasHighlighting ? language.toUpperCase() : fileExtension.toUpperCase() || 'TEXT'}
|
||||
</Badge>
|
||||
</div>
|
||||
|
||||
<div className="text-xs text-zinc-500 dark:text-zinc-400">
|
||||
{actualToolTimestamp && !isStreaming
|
||||
? formatTimestamp(actualToolTimestamp)
|
||||
: assistantTimestamp
|
||||
? formatTimestamp(assistantTimestamp)
|
||||
: ''}
|
||||
</div>
|
||||
</div>
|
||||
</Tabs>
|
||||
</Card>
|
||||
);
|
||||
}
|
|
@ -1695,6 +1695,10 @@ export interface Model {
|
|||
input_cost_per_million_tokens?: number | null;
|
||||
output_cost_per_million_tokens?: number | null;
|
||||
max_tokens?: number | null;
|
||||
context_window?: number;
|
||||
capabilities?: string[];
|
||||
recommended?: boolean;
|
||||
priority?: number;
|
||||
}
|
||||
|
||||
export interface AvailableModelsResponse {
|
||||
|
|
|
@ -0,0 +1,141 @@
|
|||
import { create } from 'zustand';
|
||||
import { persist } from 'zustand/middleware';
|
||||
import { isLocalMode } from '@/lib/config';
|
||||
|
||||
export interface CustomModel {
|
||||
id: string;
|
||||
label: string;
|
||||
}
|
||||
|
||||
export interface ModelOption {
|
||||
id: string;
|
||||
label: string;
|
||||
requiresSubscription: boolean;
|
||||
description?: string;
|
||||
top?: boolean;
|
||||
isCustom?: boolean;
|
||||
priority?: number;
|
||||
recommended?: boolean;
|
||||
capabilities?: string[];
|
||||
contextWindow?: number;
|
||||
}
|
||||
|
||||
export type SubscriptionStatus = 'no_subscription' | 'active';
|
||||
|
||||
interface ModelStore {
|
||||
selectedModel: string;
|
||||
customModels: CustomModel[];
|
||||
hasHydrated: boolean;
|
||||
|
||||
setSelectedModel: (model: string) => void;
|
||||
addCustomModel: (model: CustomModel) => void;
|
||||
updateCustomModel: (id: string, model: CustomModel) => void;
|
||||
removeCustomModel: (id: string) => void;
|
||||
setCustomModels: (models: CustomModel[]) => void;
|
||||
setHasHydrated: (hydrated: boolean) => void;
|
||||
|
||||
getDefaultModel: (subscriptionStatus: SubscriptionStatus) => string;
|
||||
resetToDefault: (subscriptionStatus: SubscriptionStatus) => void;
|
||||
}
|
||||
|
||||
const DEFAULT_FREE_MODEL_ID = 'moonshotai/kimi-k2';
|
||||
const DEFAULT_PREMIUM_MODEL_ID = 'claude-sonnet-4';
|
||||
|
||||
export const useModelStore = create<ModelStore>()(
|
||||
persist(
|
||||
(set, get) => ({
|
||||
selectedModel: DEFAULT_FREE_MODEL_ID,
|
||||
customModels: [],
|
||||
hasHydrated: false,
|
||||
|
||||
setSelectedModel: (model: string) => {
|
||||
set({ selectedModel: model });
|
||||
},
|
||||
|
||||
addCustomModel: (model: CustomModel) => {
|
||||
const { customModels } = get();
|
||||
if (customModels.some(existing => existing.id === model.id)) {
|
||||
return;
|
||||
}
|
||||
const newCustomModels = [...customModels, model];
|
||||
set({ customModels: newCustomModels });
|
||||
},
|
||||
|
||||
updateCustomModel: (id: string, model: CustomModel) => {
|
||||
const { customModels } = get();
|
||||
const newCustomModels = customModels.map(existing =>
|
||||
existing.id === id ? model : existing
|
||||
);
|
||||
set({ customModels: newCustomModels });
|
||||
},
|
||||
|
||||
removeCustomModel: (id: string) => {
|
||||
const { customModels, selectedModel } = get();
|
||||
const newCustomModels = customModels.filter(model => model.id !== id);
|
||||
|
||||
const updates: Partial<ModelStore> = { customModels: newCustomModels };
|
||||
if (selectedModel === id) {
|
||||
updates.selectedModel = DEFAULT_FREE_MODEL_ID;
|
||||
}
|
||||
|
||||
set(updates);
|
||||
},
|
||||
|
||||
setCustomModels: (models: CustomModel[]) => {
|
||||
set({ customModels: models });
|
||||
},
|
||||
|
||||
setHasHydrated: (hydrated: boolean) => {
|
||||
set({ hasHydrated: hydrated });
|
||||
},
|
||||
|
||||
getDefaultModel: (subscriptionStatus: SubscriptionStatus) => {
|
||||
if (isLocalMode()) {
|
||||
return DEFAULT_PREMIUM_MODEL_ID;
|
||||
}
|
||||
return subscriptionStatus === 'active' ? DEFAULT_PREMIUM_MODEL_ID : DEFAULT_FREE_MODEL_ID;
|
||||
},
|
||||
|
||||
resetToDefault: (subscriptionStatus: SubscriptionStatus) => {
|
||||
const defaultModel = get().getDefaultModel(subscriptionStatus);
|
||||
set({ selectedModel: defaultModel });
|
||||
},
|
||||
}),
|
||||
{
|
||||
name: 'suna-model-selection-v2',
|
||||
partialize: (state) => ({
|
||||
selectedModel: state.selectedModel,
|
||||
customModels: state.customModels,
|
||||
}),
|
||||
onRehydrateStorage: () => (state) => {
|
||||
if (state) {
|
||||
state.setHasHydrated(true);
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
export const canAccessModel = (
|
||||
subscriptionStatus: SubscriptionStatus,
|
||||
requiresSubscription: boolean,
|
||||
): boolean => {
|
||||
if (isLocalMode()) {
|
||||
return true;
|
||||
}
|
||||
return subscriptionStatus === 'active' || !requiresSubscription;
|
||||
};
|
||||
|
||||
export const formatModelName = (name: string): string => {
|
||||
return name
|
||||
.split('-')
|
||||
.map(word => word.charAt(0).toUpperCase() + word.slice(1))
|
||||
.join(' ');
|
||||
};
|
||||
|
||||
export const getPrefixedModelId = (modelId: string, isCustom: boolean): string => {
|
||||
if (isCustom && !modelId.startsWith('openrouter/')) {
|
||||
return `openrouter/${modelId}`;
|
||||
}
|
||||
return modelId;
|
||||
};
|
Loading…
Reference in New Issue