mirror of https://github.com/kortix-ai/suna.git
Compare commits
13 Commits
d3106431b5
...
80b798634c
Author | SHA1 | Date |
---|---|---|
|
80b798634c | |
|
f7aeab90e3 | |
|
07ac985275 | |
|
c92b7d3688 | |
|
05a5fb65a5 | |
|
054e1c6825 | |
|
76619ce3d9 | |
|
674e4d92d0 | |
|
94a3e787b2 | |
|
028b447c25 | |
|
9c5e73ef80 | |
|
9576936bdd | |
|
8be8b5face |
|
@ -6,6 +6,8 @@ on:
|
|||
- main
|
||||
- PRODUCTION
|
||||
workflow_dispatch:
|
||||
repository_dispatch:
|
||||
types: [production-updated]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
@ -16,11 +18,21 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'repository_dispatch' && 'PRODUCTION' || github.ref }}
|
||||
|
||||
- name: Get tag name
|
||||
shell: bash
|
||||
run: |
|
||||
if [[ "${GITHUB_REF#refs/heads/}" == "main" ]]; then
|
||||
echo "Event name: ${{ github.event_name }}"
|
||||
echo "Current ref: ${{ github.ref }}"
|
||||
echo "Branch: ${GITHUB_REF#refs/heads/}"
|
||||
|
||||
if [[ "${{ github.event_name }}" == "repository_dispatch" ]]; then
|
||||
echo "Triggered by repository dispatch - setting prod environment"
|
||||
echo "branch=prod" >> $GITHUB_OUTPUT
|
||||
echo "environment=prod" >> $GITHUB_OUTPUT
|
||||
elif [[ "${GITHUB_REF#refs/heads/}" == "main" ]]; then
|
||||
echo "branch=latest" >> $GITHUB_OUTPUT
|
||||
echo "environment=staging" >> $GITHUB_OUTPUT
|
||||
elif [[ "${GITHUB_REF#refs/heads/}" == "PRODUCTION" ]]; then
|
||||
|
|
|
@ -327,7 +327,7 @@ class BrowserTool(SandboxToolsBase):
|
|||
"type": "function",
|
||||
"function": {
|
||||
"name": "browser_act",
|
||||
"description": "Perform any browser action using natural language description. CRITICAL: This tool automatically provides a screenshot with every action. For data entry actions (filling forms, entering text, selecting options), you MUST review the provided screenshot to verify that displayed values exactly match what was intended. Report mismatches immediately.",
|
||||
"description": "Perform any browser action using natural language description. CRITICAL: This tool automatically provides a screenshot with every action. For data entry actions (filling forms, entering text, selecting options), you MUST review the provided screenshot to verify that displayed values exactly match what was intended. Report mismatches immediately. CRITICAL FILE UPLOAD RULE: ANY action that involves clicking, interacting with, or locating upload buttons, file inputs, resume upload sections, or any element that might trigger a choose file dialog MUST include the filePath parameter with filePath. This includes actions like 'click upload button', 'locate resume section', 'find file input' etc. Always err on the side of caution - if there's any possibility the action might lead to a file dialog, include filePath. This prevents accidental file dialog triggers without proper file handling.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -345,6 +345,10 @@ class BrowserTool(SandboxToolsBase):
|
|||
"type": "boolean",
|
||||
"description": "Whether to include iframe content in the action. Set to true if the target element is inside an iframe.",
|
||||
"default": True
|
||||
},
|
||||
"filePath": {
|
||||
"type": "string",
|
||||
"description": "CRITICAL: REQUIRED for ANY action that might involve file uploads. This includes: clicking upload buttons, locating resume sections, finding file inputs, scrolling to upload areas, or any action that could potentially trigger a file dialog. Always include this parameter when dealing with upload-related elements to prevent accidental file dialog triggers. The tool will automatically handle the file upload after the action is performed.",
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
|
@ -359,11 +363,20 @@ class BrowserTool(SandboxToolsBase):
|
|||
<parameter name="iframes">true</parameter>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
|
||||
<function_calls>
|
||||
<invoke name="browser_act">
|
||||
<parameter name="action">click on upload resume button</parameter>
|
||||
<parameter name="filePath">/workspace/downloads/document.pdf</parameter>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
''')
|
||||
async def browser_act(self, action: str, variables: dict = None, iframes: bool = False) -> ToolResult:
|
||||
async def browser_act(self, action: str, variables: dict = None, iframes: bool = False, filePath: dict = None) -> ToolResult:
|
||||
"""Perform any browser action using Stagehand."""
|
||||
logger.debug(f"Browser acting: {action} (variables={'***' if variables else None}, iframes={iframes})")
|
||||
logger.debug(f"Browser acting: {action} (variables={'***' if variables else None}, iframes={iframes}), filePath={filePath}")
|
||||
params = {"action": action, "iframes": iframes, "variables": variables}
|
||||
if filePath:
|
||||
params["filePath"] = filePath
|
||||
return await self._execute_stagehand_api("act", params)
|
||||
|
||||
@openapi_schema({
|
||||
|
|
|
@ -133,7 +133,20 @@ async def log_requests_middleware(request: Request, call_next):
|
|||
allowed_origins = ["https://www.suna.so", "https://suna.so"]
|
||||
allow_origin_regex = None
|
||||
|
||||
# Add staging-specific origins
|
||||
# Add Claude Code origins for MCP
|
||||
allowed_origins.extend([
|
||||
"https://claude.ai",
|
||||
"https://www.claude.ai",
|
||||
"https://app.claude.ai",
|
||||
"http://localhost",
|
||||
"http://127.0.0.1",
|
||||
"http://192.168.1.1"
|
||||
])
|
||||
|
||||
# Add wildcard for local development and Claude Code CLI
|
||||
allow_origin_regex = r"https://.*\.claude\.ai|http://localhost.*|http://127\.0\.0\.1.*|http://192\.168\..*|http://10\..*"
|
||||
|
||||
# Add local environment origins
|
||||
if config.ENV_MODE == EnvMode.LOCAL:
|
||||
allowed_origins.append("http://localhost:3000")
|
||||
|
||||
|
@ -149,7 +162,7 @@ app.add_middleware(
|
|||
allow_origin_regex=allow_origin_regex,
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
allow_headers=["Content-Type", "Authorization", "X-Project-Id", "X-MCP-URL", "X-MCP-Type", "X-MCP-Headers", "X-Refresh-Token", "X-API-Key"],
|
||||
allow_headers=["Content-Type", "Authorization", "X-Project-Id", "X-MCP-URL", "X-MCP-Type", "X-MCP-Headers", "X-Refresh-Token", "X-API-Key", "Mcp-Session-Id"],
|
||||
)
|
||||
|
||||
# Create a main API router
|
||||
|
@ -191,6 +204,41 @@ api_router.include_router(admin_api.router)
|
|||
from composio_integration import api as composio_api
|
||||
api_router.include_router(composio_api.router)
|
||||
|
||||
# Include MCP Kortix Layer
|
||||
from mcp_kortix_layer import mcp_router
|
||||
api_router.include_router(mcp_router)
|
||||
|
||||
# Add OAuth discovery endpoints at root level for Claude Code MCP
|
||||
@api_router.get("/.well-known/oauth-authorization-server")
|
||||
async def oauth_authorization_server():
|
||||
"""OAuth authorization server metadata for Claude Code MCP"""
|
||||
return {
|
||||
"issuer": "https://api2.restoned.app",
|
||||
"authorization_endpoint": "https://api2.restoned.app/api/mcp/oauth/authorize",
|
||||
"token_endpoint": "https://api2.restoned.app/api/mcp/oauth/token",
|
||||
"registration_endpoint": "https://api2.restoned.app/register",
|
||||
"response_types_supported": ["code"],
|
||||
"grant_types_supported": ["authorization_code"],
|
||||
"token_endpoint_auth_methods_supported": ["none"]
|
||||
}
|
||||
|
||||
@api_router.get("/.well-known/oauth-protected-resource")
|
||||
async def oauth_protected_resource():
|
||||
"""OAuth protected resource metadata for Claude Code MCP"""
|
||||
return {
|
||||
"resource": "https://api2.restoned.app/api/mcp",
|
||||
"authorization_servers": ["https://api2.restoned.app"]
|
||||
}
|
||||
|
||||
@api_router.post("/register")
|
||||
async def oauth_register():
|
||||
"""OAuth client registration for Claude Code MCP"""
|
||||
return {
|
||||
"client_id": "claude-code-mcp-client",
|
||||
"client_secret": "not-required-for-api-key-auth",
|
||||
"message": "AgentPress MCP uses API key authentication - provide your key via Authorization header"
|
||||
}
|
||||
|
||||
@api_router.get("/health")
|
||||
async def health_check():
|
||||
logger.debug("Health check endpoint called")
|
||||
|
|
|
@ -0,0 +1,867 @@
|
|||
"""
|
||||
MCP Layer for AgentPress Agent Invocation
|
||||
|
||||
Allows Claude Code to discover and invoke your custom AgentPress agents and workflows
|
||||
through MCP (Model Context Protocol) instead of using generic capabilities.
|
||||
|
||||
🚀 ADD TO CLAUDE CODE:
|
||||
```bash
|
||||
claude mcp add --transport http "AgentPress" "https://your-backend-domain.com/api/mcp?key=pk_your_key:sk_your_secret"
|
||||
```
|
||||
claude mcp add AgentPress https://api2.restoned.app/api/mcp --header
|
||||
"Authorization=Bearer pk_your_key:sk_your_secret"
|
||||
|
||||
📋 SETUP STEPS:
|
||||
1. Deploy your AgentPress backend with this MCP layer
|
||||
2. Get your API key from your-frontend-domain.com/settings/api-keys
|
||||
3. Replace your-backend-domain.com and API key in the command above
|
||||
4. Run the command in Claude Code
|
||||
|
||||
🎯 BENEFITS:
|
||||
✅ Claude Code uses YOUR specialized agents instead of generic ones
|
||||
✅ Real execution of your custom prompts and workflows
|
||||
✅ Uses existing AgentPress authentication and infrastructure
|
||||
✅ Transforms your agents into Claude Code tools
|
||||
|
||||
📡 TOOLS PROVIDED:
|
||||
- get_agent_list: List your agents
|
||||
- get_agent_workflows: List agent workflows
|
||||
- run_agent: Execute agents with prompts or workflows
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Request, Response
|
||||
from typing import Dict, Any, Union, Optional
|
||||
from pydantic import BaseModel
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
from utils.logger import logger
|
||||
from services.supabase import DBConnection
|
||||
|
||||
|
||||
# Create MCP router that wraps existing endpoints
|
||||
mcp_router = APIRouter(prefix="/mcp", tags=["MCP Kortix Layer"])
|
||||
|
||||
# Initialize database connection
|
||||
db = DBConnection()
|
||||
|
||||
|
||||
class JSONRPCRequest(BaseModel):
|
||||
"""JSON-RPC 2.0 request format"""
|
||||
jsonrpc: str = "2.0"
|
||||
method: str
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
id: Union[str, int, None] = None
|
||||
|
||||
|
||||
class JSONRPCSuccessResponse(BaseModel):
|
||||
"""JSON-RPC 2.0 success response format"""
|
||||
jsonrpc: str = "2.0"
|
||||
result: Any
|
||||
id: Union[str, int, None]
|
||||
|
||||
|
||||
class JSONRPCError(BaseModel):
|
||||
"""JSON-RPC 2.0 error object"""
|
||||
code: int
|
||||
message: str
|
||||
data: Optional[Any] = None
|
||||
|
||||
|
||||
class JSONRPCErrorResponse(BaseModel):
|
||||
"""JSON-RPC 2.0 error response format"""
|
||||
jsonrpc: str = "2.0"
|
||||
error: JSONRPCError
|
||||
id: Union[str, int, None]
|
||||
|
||||
|
||||
# JSON-RPC error codes
|
||||
class JSONRPCErrorCodes:
|
||||
PARSE_ERROR = -32700
|
||||
INVALID_REQUEST = -32600
|
||||
METHOD_NOT_FOUND = -32601
|
||||
INVALID_PARAMS = -32602
|
||||
INTERNAL_ERROR = -32603
|
||||
UNAUTHORIZED = -32001 # Custom error for auth failures
|
||||
|
||||
|
||||
def extract_api_key_from_request(request: Request) -> tuple[str, str]:
|
||||
"""Extract and parse API key from request URL parameters or Authorization header."""
|
||||
# Try Authorization header first (for Claude Code)
|
||||
auth_header = request.headers.get("authorization")
|
||||
if auth_header:
|
||||
if auth_header.startswith("Bearer "):
|
||||
key_param = auth_header[7:] # Remove "Bearer " prefix
|
||||
else:
|
||||
key_param = auth_header
|
||||
|
||||
if ":" in key_param:
|
||||
try:
|
||||
public_key, secret_key = key_param.split(":", 1)
|
||||
if public_key.startswith("pk_") and secret_key.startswith("sk_"):
|
||||
return public_key, secret_key
|
||||
except:
|
||||
pass
|
||||
|
||||
# Fallback to URL parameter (for curl testing)
|
||||
key_param = request.query_params.get("key")
|
||||
if key_param:
|
||||
if ":" not in key_param:
|
||||
raise ValueError("Invalid key format. Expected 'pk_xxx:sk_xxx'")
|
||||
|
||||
try:
|
||||
public_key, secret_key = key_param.split(":", 1)
|
||||
|
||||
if not public_key.startswith("pk_") or not secret_key.startswith("sk_"):
|
||||
raise ValueError("Invalid key format. Expected 'pk_xxx:sk_xxx'")
|
||||
|
||||
return public_key, secret_key
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse API key: {str(e)}")
|
||||
|
||||
# No valid auth found
|
||||
raise ValueError("Missing API key. Provide via Authorization header: 'Bearer pk_xxx:sk_xxx' or URL parameter: '?key=pk_xxx:sk_xxx'")
|
||||
|
||||
|
||||
async def authenticate_api_key(public_key: str, secret_key: str) -> str:
|
||||
"""Authenticate API key and return account_id."""
|
||||
try:
|
||||
# Use the existing API key service for validation
|
||||
from services.api_keys import APIKeyService
|
||||
api_key_service = APIKeyService(db)
|
||||
|
||||
# Validate the API key
|
||||
validation_result = await api_key_service.validate_api_key(public_key, secret_key)
|
||||
|
||||
if not validation_result.is_valid:
|
||||
raise ValueError(validation_result.error_message or "Invalid API key")
|
||||
|
||||
account_id = str(validation_result.account_id)
|
||||
logger.info(f"API key authenticated for account_id: {account_id}")
|
||||
return account_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"API key authentication failed: {str(e)}")
|
||||
raise ValueError(f"Authentication failed: {str(e)}")
|
||||
|
||||
|
||||
def extract_last_message(full_output: str) -> str:
|
||||
"""Extract the last meaningful message from agent output."""
|
||||
if not full_output.strip():
|
||||
return "No output received"
|
||||
|
||||
lines = full_output.strip().split('\n')
|
||||
|
||||
# Look for the last substantial message
|
||||
for line in reversed(lines):
|
||||
if line.strip() and not line.startswith('#') and not line.startswith('```'):
|
||||
try:
|
||||
line_index = lines.index(line)
|
||||
start_index = max(0, line_index - 3)
|
||||
return '\n'.join(lines[start_index:]).strip()
|
||||
except ValueError:
|
||||
return line.strip()
|
||||
|
||||
# Fallback: return last 20% of the output
|
||||
return full_output[-len(full_output)//5:].strip() if len(full_output) > 100 else full_output
|
||||
|
||||
|
||||
def truncate_from_end(text: str, max_tokens: int) -> str:
|
||||
"""Truncate text from the beginning, keeping the end."""
|
||||
if max_tokens <= 0:
|
||||
return ""
|
||||
|
||||
max_chars = max_tokens * 4 # Rough token estimation
|
||||
|
||||
if len(text) <= max_chars:
|
||||
return text
|
||||
|
||||
truncated = text[-max_chars:]
|
||||
return f"...[truncated {len(text) - max_chars} characters]...\n{truncated}"
|
||||
|
||||
|
||||
@mcp_router.post("/")
|
||||
@mcp_router.post("") # Handle requests without trailing slash
|
||||
async def mcp_handler(
|
||||
request: JSONRPCRequest,
|
||||
http_request: Request
|
||||
):
|
||||
"""Main MCP endpoint handling JSON-RPC 2.0 requests."""
|
||||
try:
|
||||
# Authenticate API key from URL parameters
|
||||
try:
|
||||
public_key, secret_key = extract_api_key_from_request(http_request)
|
||||
account_id = await authenticate_api_key(public_key, secret_key)
|
||||
except ValueError as auth_error:
|
||||
logger.warning(f"Authentication failed: {str(auth_error)}")
|
||||
return JSONRPCErrorResponse(
|
||||
error=JSONRPCError(
|
||||
code=JSONRPCErrorCodes.UNAUTHORIZED,
|
||||
message=f"Authentication failed: {str(auth_error)}"
|
||||
),
|
||||
id=request.id
|
||||
)
|
||||
|
||||
# Validate JSON-RPC format
|
||||
if request.jsonrpc != "2.0":
|
||||
return JSONRPCErrorResponse(
|
||||
error=JSONRPCError(
|
||||
code=JSONRPCErrorCodes.INVALID_REQUEST,
|
||||
message="Invalid JSON-RPC version"
|
||||
),
|
||||
id=request.id
|
||||
)
|
||||
|
||||
method = request.method
|
||||
params = request.params or {}
|
||||
|
||||
logger.info(f"MCP JSON-RPC call: {method} for account: {account_id}")
|
||||
|
||||
# Handle different MCP methods
|
||||
if method == "initialize":
|
||||
result = {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {
|
||||
"tools": {}
|
||||
},
|
||||
"serverInfo": {
|
||||
"name": "agentpress",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
}
|
||||
elif method == "tools/list":
|
||||
result = await handle_tools_list()
|
||||
elif method == "tools/call":
|
||||
tool_name = params.get("name")
|
||||
arguments = params.get("arguments", {})
|
||||
|
||||
if not tool_name:
|
||||
return JSONRPCErrorResponse(
|
||||
error=JSONRPCError(
|
||||
code=JSONRPCErrorCodes.INVALID_PARAMS,
|
||||
message="Missing 'name' parameter for tools/call"
|
||||
),
|
||||
id=request.id
|
||||
)
|
||||
|
||||
result = await handle_tool_call(tool_name, arguments, account_id)
|
||||
else:
|
||||
return JSONRPCErrorResponse(
|
||||
error=JSONRPCError(
|
||||
code=JSONRPCErrorCodes.METHOD_NOT_FOUND,
|
||||
message=f"Method '{method}' not found"
|
||||
),
|
||||
id=request.id
|
||||
)
|
||||
|
||||
return JSONRPCSuccessResponse(
|
||||
result=result,
|
||||
id=request.id
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in MCP JSON-RPC handler: {str(e)}")
|
||||
return JSONRPCErrorResponse(
|
||||
error=JSONRPCError(
|
||||
code=JSONRPCErrorCodes.INTERNAL_ERROR,
|
||||
message=f"Internal error: {str(e)}"
|
||||
),
|
||||
id=getattr(request, 'id', None)
|
||||
)
|
||||
|
||||
|
||||
async def handle_tools_list():
|
||||
"""Handle tools/list method."""
|
||||
tools = [
|
||||
{
|
||||
"name": "get_agent_list",
|
||||
"description": "Get a list of all available agents in your account. Always call this tool first.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "get_agent_workflows",
|
||||
"description": "Get a list of available workflows for a specific agent.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"description": "The ID of the agent to get workflows for"
|
||||
}
|
||||
},
|
||||
"required": ["agent_id"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "run_agent",
|
||||
"description": "Run a specific agent with a message and get formatted output.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"description": "The ID of the agent to run"
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "The message/prompt to send to the agent"
|
||||
},
|
||||
"execution_mode": {
|
||||
"type": "string",
|
||||
"enum": ["prompt", "workflow"],
|
||||
"description": "Either 'prompt' for custom prompt execution or 'workflow' for workflow execution"
|
||||
},
|
||||
"workflow_id": {
|
||||
"type": "string",
|
||||
"description": "Required when execution_mode is 'workflow' - the ID of the workflow to run"
|
||||
},
|
||||
"output_mode": {
|
||||
"type": "string",
|
||||
"enum": ["last_message", "full"],
|
||||
"description": "How to format output: 'last_message' (default) or 'full'"
|
||||
},
|
||||
"max_tokens": {
|
||||
"type": "integer",
|
||||
"description": "Maximum tokens in response"
|
||||
},
|
||||
"model_name": {
|
||||
"type": "string",
|
||||
"description": "Model to use for the agent execution. If not specified, uses the agent's configured model or fallback."
|
||||
}
|
||||
},
|
||||
"required": ["agent_id", "message"]
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
return {"tools": tools}
|
||||
|
||||
|
||||
async def handle_tool_call(tool_name: str, arguments: Dict[str, Any], account_id: str):
|
||||
"""Handle tools/call method."""
|
||||
try:
|
||||
if tool_name == "get_agent_list":
|
||||
result = await call_get_agents_endpoint(account_id)
|
||||
elif tool_name == "get_agent_workflows":
|
||||
agent_id = arguments.get("agent_id")
|
||||
if not agent_id:
|
||||
raise ValueError("agent_id is required")
|
||||
result = await call_get_agent_workflows_endpoint(account_id, agent_id)
|
||||
elif tool_name == "run_agent":
|
||||
agent_id = arguments.get("agent_id")
|
||||
message = arguments.get("message")
|
||||
execution_mode = arguments.get("execution_mode", "prompt")
|
||||
workflow_id = arguments.get("workflow_id")
|
||||
output_mode = arguments.get("output_mode", "last_message")
|
||||
max_tokens = arguments.get("max_tokens", 1000)
|
||||
model_name = arguments.get("model_name")
|
||||
|
||||
if not agent_id or not message:
|
||||
raise ValueError("agent_id and message are required")
|
||||
|
||||
if execution_mode == "workflow" and not workflow_id:
|
||||
raise ValueError("workflow_id is required when execution_mode is 'workflow'")
|
||||
|
||||
result = await call_run_agent_endpoint(
|
||||
account_id, agent_id, message, execution_mode, workflow_id, output_mode, max_tokens, model_name
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown tool: {tool_name}")
|
||||
|
||||
# Return MCP-compatible tool call result
|
||||
return {
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": result
|
||||
}]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in tool call {tool_name}: {str(e)}")
|
||||
raise e
|
||||
|
||||
|
||||
async def call_get_agents_endpoint(account_id: str) -> str:
|
||||
"""Call the existing /agents endpoint and format for MCP."""
|
||||
try:
|
||||
# Import the get_agents function from agent.api
|
||||
from agent.api import get_agents
|
||||
|
||||
# Call the existing endpoint
|
||||
response = await get_agents(
|
||||
user_id=account_id,
|
||||
page=1,
|
||||
limit=100, # Get all agents
|
||||
search=None,
|
||||
sort_by="created_at",
|
||||
sort_order="desc",
|
||||
has_default=None,
|
||||
has_mcp_tools=None,
|
||||
has_agentpress_tools=None,
|
||||
tools=None
|
||||
)
|
||||
|
||||
# Handle both dict and object response formats
|
||||
if hasattr(response, 'agents'):
|
||||
agents = response.agents
|
||||
elif isinstance(response, dict) and 'agents' in response:
|
||||
agents = response['agents']
|
||||
else:
|
||||
logger.error(f"Unexpected response format from get_agents: {type(response)}")
|
||||
return f"Error: Unexpected response format from get_agents: {response}"
|
||||
|
||||
if not agents:
|
||||
return "No agents found in your account. Create some agents first in your frontend."
|
||||
|
||||
agent_list = "🤖 Available Agents in Your Account:\n\n"
|
||||
for i, agent in enumerate(agents, 1):
|
||||
# Handle both dict and object formats for individual agents
|
||||
agent_id = agent.agent_id if hasattr(agent, 'agent_id') else agent.get('agent_id')
|
||||
name = agent.name if hasattr(agent, 'name') else agent.get('name')
|
||||
description = agent.description if hasattr(agent, 'description') else agent.get('description')
|
||||
|
||||
agent_list += f"{i}. Agent ID: {agent_id}\n"
|
||||
agent_list += f" Name: {name}\n"
|
||||
if description:
|
||||
agent_list += f" Description: {description}\n"
|
||||
agent_list += "\n"
|
||||
|
||||
agent_list += "📝 Use the 'run_agent' tool with the Agent ID to invoke any of these agents."
|
||||
|
||||
logger.info(f"Listed {len(agents)} agents via MCP")
|
||||
return agent_list
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_agent_list: {str(e)}")
|
||||
return f"Error listing agents: {str(e)}"
|
||||
|
||||
|
||||
async def verify_agent_access(agent_id: str, account_id: str):
|
||||
"""Verify account has access to the agent."""
|
||||
try:
|
||||
client = await db.client
|
||||
result = await client.table('agents').select('agent_id').eq('agent_id', agent_id).eq('account_id', account_id).execute()
|
||||
|
||||
if not result.data:
|
||||
raise ValueError("Agent not found or access denied")
|
||||
except Exception as e:
|
||||
logger.error(f"Database error in verify_agent_access: {str(e)}")
|
||||
raise ValueError("Database connection error")
|
||||
|
||||
|
||||
async def call_get_agent_workflows_endpoint(account_id: str, agent_id: str) -> str:
|
||||
"""Get workflows for a specific agent."""
|
||||
try:
|
||||
# Verify agent access
|
||||
await verify_agent_access(agent_id, account_id)
|
||||
|
||||
# Get workflows from database
|
||||
client = await db.client
|
||||
result = await client.table('agent_workflows').select('*').eq('agent_id', agent_id).order('created_at', desc=True).execute()
|
||||
|
||||
if not result.data:
|
||||
return f"No workflows found for agent {agent_id}. This agent can only be run with custom prompts."
|
||||
|
||||
workflow_list = f"🔄 Available Workflows for Agent {agent_id}:\n\n"
|
||||
for i, workflow in enumerate(result.data, 1):
|
||||
workflow_list += f"{i}. Workflow ID: {workflow['id']}\n"
|
||||
workflow_list += f" Name: {workflow['name']}\n"
|
||||
if workflow.get('description'):
|
||||
workflow_list += f" Description: {workflow['description']}\n"
|
||||
workflow_list += f" Status: {workflow.get('status', 'unknown')}\n"
|
||||
workflow_list += "\n"
|
||||
|
||||
workflow_list += "📝 Use the 'run_agent' tool with execution_mode='workflow' and the Workflow ID to run a workflow."
|
||||
|
||||
logger.info(f"Listed {len(result.data)} workflows for agent {agent_id} via MCP")
|
||||
return workflow_list
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_agent_workflows: {str(e)}")
|
||||
return f"Error listing workflows: {str(e)}"
|
||||
|
||||
|
||||
async def call_run_agent_endpoint(
|
||||
account_id: str,
|
||||
agent_id: str,
|
||||
message: str,
|
||||
execution_mode: str = "prompt",
|
||||
workflow_id: Optional[str] = None,
|
||||
output_mode: str = "last_message",
|
||||
max_tokens: int = 1000,
|
||||
model_name: Optional[str] = None
|
||||
) -> str:
|
||||
"""Call the existing agent run endpoints and format for MCP."""
|
||||
try:
|
||||
# Validate execution mode and workflow parameters
|
||||
if execution_mode not in ["prompt", "workflow"]:
|
||||
return "Error: execution_mode must be either 'prompt' or 'workflow'"
|
||||
|
||||
if execution_mode == "workflow" and not workflow_id:
|
||||
return "Error: workflow_id is required when execution_mode is 'workflow'"
|
||||
|
||||
# Verify agent access
|
||||
await verify_agent_access(agent_id, account_id)
|
||||
|
||||
# Apply model fallback if no model specified
|
||||
if not model_name:
|
||||
# Use a reliable free-tier model as fallback
|
||||
model_name = "openrouter/google/gemini-2.5-flash"
|
||||
logger.info(f"No model specified for agent {agent_id}, using fallback: {model_name}")
|
||||
|
||||
if execution_mode == "workflow":
|
||||
# Execute workflow using the existing workflow execution endpoint
|
||||
result = await execute_agent_workflow_internal(agent_id, workflow_id, message, account_id, model_name)
|
||||
else:
|
||||
# Execute agent with prompt using existing agent endpoints
|
||||
result = await execute_agent_prompt_internal(agent_id, message, account_id, model_name)
|
||||
|
||||
# Process the output based on the requested mode
|
||||
if output_mode == "last_message":
|
||||
processed_output = extract_last_message(result)
|
||||
else:
|
||||
processed_output = result
|
||||
|
||||
# Apply token limiting
|
||||
final_output = truncate_from_end(processed_output, max_tokens)
|
||||
|
||||
logger.info(f"MCP agent run completed for agent {agent_id} in {execution_mode} mode")
|
||||
return final_output
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error running agent {agent_id}: {str(e)}")
|
||||
return f"Error running agent: {str(e)}"
|
||||
|
||||
|
||||
async def execute_agent_workflow_internal(agent_id: str, workflow_id: str, message: str, account_id: str, model_name: Optional[str] = None) -> str:
|
||||
"""Execute an agent workflow."""
|
||||
try:
|
||||
client = await db.client
|
||||
|
||||
# Verify workflow exists and is active
|
||||
workflow_result = await client.table('agent_workflows').select('*').eq('id', workflow_id).eq('agent_id', agent_id).execute()
|
||||
if not workflow_result.data:
|
||||
return f"Error: Workflow {workflow_id} not found for agent {agent_id}"
|
||||
|
||||
workflow = workflow_result.data[0]
|
||||
if workflow.get('status') != 'active':
|
||||
return f"Error: Workflow {workflow['name']} is not active (status: {workflow.get('status')})"
|
||||
|
||||
# Execute workflow through the execution service
|
||||
try:
|
||||
from triggers.execution_service import execute_workflow
|
||||
|
||||
# Execute the workflow with the provided message
|
||||
execution_result = await execute_workflow(
|
||||
workflow_id=workflow_id,
|
||||
agent_id=agent_id,
|
||||
input_data={"message": message},
|
||||
user_id=account_id
|
||||
)
|
||||
|
||||
if execution_result.get('success'):
|
||||
return execution_result.get('output', f"Workflow '{workflow['name']}' executed successfully")
|
||||
else:
|
||||
return f"Workflow execution failed: {execution_result.get('error', 'Unknown error')}"
|
||||
|
||||
except ImportError:
|
||||
logger.warning("Execution service not available, using fallback workflow execution")
|
||||
# Fallback: Create a thread and run the agent with workflow context
|
||||
from agent.api import create_thread, add_message_to_thread, start_agent, AgentStartRequest
|
||||
|
||||
# Create thread with workflow context
|
||||
thread_response = await create_thread(
|
||||
name=f"Workflow: {workflow['name']}",
|
||||
user_id=account_id
|
||||
)
|
||||
thread_id = thread_response.get('thread_id') if isinstance(thread_response, dict) else thread_response.thread_id
|
||||
|
||||
# Add workflow context message
|
||||
workflow_context = f"Executing workflow '{workflow['name']}'"
|
||||
if workflow.get('description'):
|
||||
workflow_context += f": {workflow['description']}"
|
||||
workflow_context += f"\n\nUser message: {message}"
|
||||
|
||||
await add_message_to_thread(
|
||||
thread_id=thread_id,
|
||||
message=workflow_context,
|
||||
user_id=account_id
|
||||
)
|
||||
|
||||
# Start agent with workflow execution
|
||||
agent_request = AgentStartRequest(
|
||||
agent_id=agent_id,
|
||||
enable_thinking=False,
|
||||
stream=False,
|
||||
model_name=model_name
|
||||
)
|
||||
|
||||
await start_agent(
|
||||
thread_id=thread_id,
|
||||
body=agent_request,
|
||||
user_id=account_id
|
||||
)
|
||||
|
||||
# Wait for completion (similar to prompt execution)
|
||||
client = await db.client
|
||||
max_wait = 90 # Longer timeout for workflows
|
||||
poll_interval = 3
|
||||
elapsed = 0
|
||||
|
||||
while elapsed < max_wait:
|
||||
messages_result = await client.table('messages').select('*').eq('thread_id', thread_id).order('created_at', desc=True).limit(5).execute()
|
||||
|
||||
if messages_result.data:
|
||||
for msg in messages_result.data:
|
||||
# Parse JSON content to check role
|
||||
content = msg.get('content')
|
||||
if content:
|
||||
try:
|
||||
import json
|
||||
if isinstance(content, str):
|
||||
parsed_content = json.loads(content)
|
||||
else:
|
||||
parsed_content = content
|
||||
|
||||
if parsed_content.get('role') == 'assistant':
|
||||
return parsed_content.get('content', '')
|
||||
except:
|
||||
# If parsing fails, check if it's a direct assistant message
|
||||
if msg.get('type') == 'assistant':
|
||||
return content
|
||||
|
||||
runs_result = await client.table('agent_runs').select('status, error').eq('thread_id', thread_id).order('created_at', desc=True).limit(1).execute()
|
||||
|
||||
if runs_result.data:
|
||||
run = runs_result.data[0]
|
||||
if run['status'] in ['completed', 'failed', 'cancelled']:
|
||||
if run['status'] == 'failed':
|
||||
return f"Workflow execution failed: {run.get('error', 'Unknown error')}"
|
||||
elif run['status'] == 'cancelled':
|
||||
return "Workflow execution was cancelled"
|
||||
break
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
elapsed += poll_interval
|
||||
|
||||
return f"Workflow '{workflow['name']}' execution timed out after {max_wait}s. Thread ID: {thread_id}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing workflow {workflow_id}: {str(e)}")
|
||||
return f"Error executing workflow: {str(e)}"
|
||||
|
||||
|
||||
async def execute_agent_prompt_internal(agent_id: str, message: str, account_id: str, model_name: Optional[str] = None) -> str:
|
||||
"""Execute an agent with a custom prompt."""
|
||||
try:
|
||||
# Import existing agent execution functions
|
||||
from agent.api import create_thread, add_message_to_thread, start_agent, AgentStartRequest
|
||||
|
||||
# Create a new thread
|
||||
thread_response = await create_thread(name="MCP Agent Run", user_id=account_id)
|
||||
thread_id = thread_response.get('thread_id') if isinstance(thread_response, dict) else thread_response.thread_id
|
||||
|
||||
# Add the message to the thread
|
||||
await add_message_to_thread(
|
||||
thread_id=thread_id,
|
||||
message=message,
|
||||
user_id=account_id
|
||||
)
|
||||
|
||||
# Start the agent
|
||||
agent_request = AgentStartRequest(
|
||||
agent_id=agent_id,
|
||||
enable_thinking=False,
|
||||
stream=False,
|
||||
model_name=model_name
|
||||
)
|
||||
|
||||
# Start the agent
|
||||
await start_agent(
|
||||
thread_id=thread_id,
|
||||
body=agent_request,
|
||||
user_id=account_id
|
||||
)
|
||||
|
||||
# Wait for agent completion and get response
|
||||
client = await db.client
|
||||
|
||||
# Poll for completion (max 60 seconds)
|
||||
max_wait = 60
|
||||
poll_interval = 2
|
||||
elapsed = 0
|
||||
|
||||
while elapsed < max_wait:
|
||||
# Check thread messages for agent response
|
||||
messages_result = await client.table('messages').select('*').eq('thread_id', thread_id).order('created_at', desc=True).limit(5).execute()
|
||||
|
||||
if messages_result.data:
|
||||
# Look for the most recent agent message (not user message)
|
||||
for msg in messages_result.data:
|
||||
# Parse JSON content to check role
|
||||
content = msg.get('content')
|
||||
if content:
|
||||
try:
|
||||
import json
|
||||
if isinstance(content, str):
|
||||
parsed_content = json.loads(content)
|
||||
else:
|
||||
parsed_content = content
|
||||
|
||||
if parsed_content.get('role') == 'assistant':
|
||||
return parsed_content.get('content', '')
|
||||
except:
|
||||
# If parsing fails, check if it's a direct assistant message
|
||||
if msg.get('type') == 'assistant':
|
||||
return content
|
||||
|
||||
# Check if agent run is complete by checking agent_runs table
|
||||
runs_result = await client.table('agent_runs').select('status, error').eq('thread_id', thread_id).order('created_at', desc=True).limit(1).execute()
|
||||
|
||||
if runs_result.data:
|
||||
run = runs_result.data[0]
|
||||
if run['status'] in ['completed', 'failed', 'cancelled']:
|
||||
if run['status'] == 'failed':
|
||||
return f"Agent execution failed: {run.get('error', 'Unknown error')}"
|
||||
elif run['status'] == 'cancelled':
|
||||
return "Agent execution was cancelled"
|
||||
# If completed, continue to check for messages
|
||||
break
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
elapsed += poll_interval
|
||||
|
||||
# Timeout fallback - get latest messages
|
||||
messages_result = await client.table('messages').select('*').eq('thread_id', thread_id).order('created_at', desc=True).limit(10).execute()
|
||||
|
||||
if messages_result.data:
|
||||
# Return the most recent assistant message or fallback message
|
||||
for msg in messages_result.data:
|
||||
# Parse JSON content to check role
|
||||
content = msg.get('content')
|
||||
if content:
|
||||
try:
|
||||
import json
|
||||
if isinstance(content, str):
|
||||
parsed_content = json.loads(content)
|
||||
else:
|
||||
parsed_content = content
|
||||
|
||||
if parsed_content.get('role') == 'assistant':
|
||||
return parsed_content.get('content', '')
|
||||
except:
|
||||
# If parsing fails, check if it's a direct assistant message
|
||||
if msg.get('type') == 'assistant':
|
||||
return content
|
||||
|
||||
return f"Agent execution timed out after {max_wait}s. Thread ID: {thread_id}"
|
||||
|
||||
return f"No response received from agent {agent_id}. Thread ID: {thread_id}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing agent prompt: {str(e)}")
|
||||
return f"Error executing agent: {str(e)}"
|
||||
|
||||
|
||||
|
||||
@mcp_router.get("/health")
|
||||
async def mcp_health_check():
|
||||
"""Health check for MCP layer"""
|
||||
return {"status": "healthy", "service": "mcp-kortix-layer"}
|
||||
|
||||
|
||||
# OAuth 2.0 endpoints for Claude Code compatibility
|
||||
@mcp_router.get("/oauth/authorize")
|
||||
async def oauth_authorize(
|
||||
response_type: str = None,
|
||||
client_id: str = None,
|
||||
redirect_uri: str = None,
|
||||
scope: str = None,
|
||||
state: str = None,
|
||||
code_challenge: str = None,
|
||||
code_challenge_method: str = None
|
||||
):
|
||||
"""OAuth authorization endpoint - redirect with authorization code"""
|
||||
from fastapi.responses import RedirectResponse
|
||||
import secrets
|
||||
|
||||
# Generate a dummy authorization code (since we use API keys)
|
||||
auth_code = f"ac_{secrets.token_urlsafe(32)}"
|
||||
|
||||
# Build redirect URL with authorization code and state
|
||||
redirect_url = f"{redirect_uri}?code={auth_code}"
|
||||
if state:
|
||||
redirect_url += f"&state={state}"
|
||||
|
||||
logger.info(f"OAuth authorize redirecting to: {redirect_url}")
|
||||
return RedirectResponse(url=redirect_url)
|
||||
|
||||
|
||||
@mcp_router.post("/oauth/token")
|
||||
async def oauth_token(
|
||||
grant_type: str = None,
|
||||
code: str = None,
|
||||
redirect_uri: str = None,
|
||||
client_id: str = None,
|
||||
client_secret: str = None,
|
||||
code_verifier: str = None
|
||||
):
|
||||
"""OAuth token endpoint - simplified for API key flow with PKCE support"""
|
||||
return {
|
||||
"access_token": "use_api_key_instead",
|
||||
"token_type": "bearer",
|
||||
"message": "AgentPress MCP Server uses API key authentication",
|
||||
"instructions": "Use your API key as Bearer token: Authorization: Bearer pk_xxx:sk_xxx",
|
||||
"pkce_supported": True
|
||||
}
|
||||
|
||||
|
||||
@mcp_router.options("/")
|
||||
@mcp_router.options("") # Handle OPTIONS without trailing slash
|
||||
async def mcp_options():
|
||||
"""Handle CORS preflight for MCP endpoint"""
|
||||
return Response(
|
||||
status_code=200,
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization, Mcp-Session-Id"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@mcp_router.get("/.well-known/mcp")
|
||||
async def mcp_discovery():
|
||||
"""MCP discovery endpoint for Claude Code"""
|
||||
return {
|
||||
"mcpVersion": "2024-11-05",
|
||||
"capabilities": {
|
||||
"tools": {}
|
||||
},
|
||||
"implementation": {
|
||||
"name": "AgentPress MCP Server",
|
||||
"version": "1.0.0"
|
||||
},
|
||||
"oauth": {
|
||||
"authorization_endpoint": "/api/mcp/oauth/authorize",
|
||||
"token_endpoint": "/api/mcp/oauth/token",
|
||||
"supported_flows": ["authorization_code"]
|
||||
},
|
||||
"instructions": "Use API key authentication via Authorization header: Bearer pk_xxx:sk_xxx"
|
||||
}
|
||||
|
||||
|
||||
@mcp_router.get("/")
|
||||
@mcp_router.get("")
|
||||
async def mcp_health():
|
||||
"""Health check endpoint for MCP server"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "agentpress-mcp-server",
|
||||
"version": "1.0.0",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
|
@ -20,7 +20,7 @@ You can modify the sandbox environment for development or to add new capabilitie
|
|||
```
|
||||
cd backend/sandbox/docker
|
||||
docker compose build
|
||||
docker push kortix/suna:0.1.3.7
|
||||
docker push kortix/suna:0.1.3.9
|
||||
```
|
||||
3. Test your changes locally using docker-compose
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import express from 'express';
|
||||
import { Stagehand, type LogLine, type Page } from '@browserbasehq/stagehand';
|
||||
import { FileChooser } from 'playwright';
|
||||
|
||||
const app = express();
|
||||
app.use(express.json());
|
||||
|
@ -224,9 +225,21 @@ class BrowserAutomation {
|
|||
}
|
||||
|
||||
async act(req: express.Request, res: express.Response): Promise<void> {
|
||||
let fileChooseHandler: ((fileChooser: FileChooser) => Promise<void>) | null = null;
|
||||
try {
|
||||
if (this.page && this.browserInitialized) {
|
||||
const { action, iframes, variables } = req.body;
|
||||
const { action, iframes, variables, filePath } = req.body;
|
||||
|
||||
const fileChooseHandler = async (fileChooser: FileChooser) => {
|
||||
if(filePath){
|
||||
await fileChooser.setFiles(filePath);
|
||||
} else {
|
||||
await fileChooser.setFiles([]);
|
||||
}
|
||||
};
|
||||
|
||||
this.page.on('filechooser', fileChooseHandler);
|
||||
|
||||
const result = await this.page.act({action, iframes: iframes || true, variables});
|
||||
const page_info = await this.get_stagehand_state();
|
||||
const response: BrowserActionResult = {
|
||||
|
@ -255,7 +268,12 @@ class BrowserAutomation {
|
|||
screenshot_base64: page_info.screenshot_base64,
|
||||
error
|
||||
})
|
||||
} finally {
|
||||
if (this.page && fileChooseHandler) {
|
||||
this.page.off('filechooser', fileChooseHandler);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
async extract(req: express.Request, res: express.Response): Promise<void> {
|
||||
|
|
|
@ -6,7 +6,7 @@ services:
|
|||
dockerfile: ${DOCKERFILE:-Dockerfile}
|
||||
args:
|
||||
TARGETPLATFORM: ${TARGETPLATFORM:-linux/amd64}
|
||||
image: kortix/suna:0.1.3.7
|
||||
image: kortix/suna:0.1.3.9
|
||||
ports:
|
||||
- "6080:6080" # noVNC web interface
|
||||
- "5901:5901" # VNC port
|
||||
|
|
|
@ -111,8 +111,8 @@ async def create_sandbox(password: str, project_id: str = None) -> AsyncSandbox:
|
|||
memory=4,
|
||||
disk=5,
|
||||
),
|
||||
auto_stop_interval=120,
|
||||
auto_archive_interval=2 * 60,
|
||||
auto_stop_interval=15,
|
||||
auto_archive_interval=30,
|
||||
)
|
||||
|
||||
# Create the sandbox
|
||||
|
|
|
@ -412,165 +412,271 @@ async def calculate_monthly_usage(client, user_id: str) -> float:
|
|||
|
||||
async def get_usage_logs(client, user_id: str, page: int = 0, items_per_page: int = 1000) -> Dict:
|
||||
"""Get detailed usage logs for a user with pagination, including credit usage info."""
|
||||
# Get start of current month in UTC
|
||||
now = datetime.now(timezone.utc)
|
||||
start_of_month = datetime(now.year, now.month, 1, tzinfo=timezone.utc)
|
||||
logger.info(f"[USAGE_LOGS] Starting get_usage_logs for user_id={user_id}, page={page}, items_per_page={items_per_page}")
|
||||
|
||||
# Use fixed cutoff date: June 26, 2025 midnight UTC
|
||||
# Ignore all token counts before this date
|
||||
cutoff_date = datetime(2025, 6, 30, 9, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
start_of_month = max(start_of_month, cutoff_date)
|
||||
|
||||
# First get all threads for this user in batches
|
||||
batch_size = 1000
|
||||
offset = 0
|
||||
all_threads = []
|
||||
|
||||
while True:
|
||||
threads_batch = await client.table('threads') \
|
||||
.select('thread_id, agent_runs(thread_id)') \
|
||||
.eq('account_id', user_id) \
|
||||
.gte('agent_runs.created_at', start_of_month.isoformat()) \
|
||||
.range(offset, offset + batch_size - 1) \
|
||||
.execute()
|
||||
try:
|
||||
# Get start of current month in UTC
|
||||
now = datetime.now(timezone.utc)
|
||||
start_of_month = datetime(now.year, now.month, 1, tzinfo=timezone.utc)
|
||||
|
||||
if not threads_batch.data:
|
||||
break
|
||||
|
||||
all_threads.extend(threads_batch.data)
|
||||
# Use fixed cutoff date: June 26, 2025 midnight UTC
|
||||
# Ignore all token counts before this date
|
||||
cutoff_date = datetime(2025, 6, 30, 9, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
start_of_month = max(start_of_month, cutoff_date)
|
||||
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Using start_of_month: {start_of_month.isoformat()}")
|
||||
|
||||
# First get all threads for this user in batches
|
||||
batch_size = 1000
|
||||
offset = 0
|
||||
all_threads = []
|
||||
|
||||
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Fetching threads in batches")
|
||||
while True:
|
||||
try:
|
||||
threads_batch = await client.table('threads') \
|
||||
.select('thread_id, agent_runs(thread_id)') \
|
||||
.eq('account_id', user_id) \
|
||||
.gte('agent_runs.created_at', start_of_month.isoformat()) \
|
||||
.range(offset, offset + batch_size - 1) \
|
||||
.execute()
|
||||
|
||||
if not threads_batch.data:
|
||||
break
|
||||
|
||||
all_threads.extend(threads_batch.data)
|
||||
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Fetched {len(threads_batch.data)} threads in batch (offset={offset})")
|
||||
|
||||
# If we got less than batch_size, we've reached the end
|
||||
if len(threads_batch.data) < batch_size:
|
||||
break
|
||||
|
||||
offset += batch_size
|
||||
except Exception as thread_error:
|
||||
logger.error(f"[USAGE_LOGS] user_id={user_id} - Error fetching threads batch at offset {offset}: {str(thread_error)}")
|
||||
raise
|
||||
|
||||
logger.info(f"[USAGE_LOGS] user_id={user_id} - Found {len(all_threads)} total threads")
|
||||
|
||||
if not all_threads:
|
||||
logger.info(f"[USAGE_LOGS] user_id={user_id} - No threads found, returning empty result")
|
||||
return {"logs": [], "has_more": False}
|
||||
|
||||
thread_ids = [t['thread_id'] for t in all_threads]
|
||||
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Thread IDs: {thread_ids[:5]}..." if len(thread_ids) > 5 else f"[USAGE_LOGS] user_id={user_id} - Thread IDs: {thread_ids}")
|
||||
|
||||
# Fetch usage messages with pagination, including thread project info
|
||||
start_time = time.time()
|
||||
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Starting messages query")
|
||||
|
||||
# If we got less than batch_size, we've reached the end
|
||||
if len(threads_batch.data) < batch_size:
|
||||
break
|
||||
|
||||
offset += batch_size
|
||||
|
||||
if not all_threads:
|
||||
return {"logs": [], "has_more": False}
|
||||
|
||||
thread_ids = [t['thread_id'] for t in all_threads]
|
||||
|
||||
# Fetch usage messages with pagination, including thread project info
|
||||
start_time = time.time()
|
||||
messages_result = await client.table('messages') \
|
||||
.select(
|
||||
'message_id, thread_id, created_at, content, threads!inner(project_id)'
|
||||
) \
|
||||
.in_('thread_id', thread_ids) \
|
||||
.eq('type', 'assistant_response_end') \
|
||||
.gte('created_at', start_of_month.isoformat()) \
|
||||
.order('created_at', desc=True) \
|
||||
.range(page * items_per_page, (page + 1) * items_per_page - 1) \
|
||||
.execute()
|
||||
|
||||
end_time = time.time()
|
||||
execution_time = end_time - start_time
|
||||
logger.debug(f"Database query for usage logs took {execution_time:.3f} seconds")
|
||||
|
||||
if not messages_result.data:
|
||||
return {"logs": [], "has_more": False}
|
||||
|
||||
# Get the user's subscription tier info for credit checking
|
||||
subscription = await get_user_subscription(user_id)
|
||||
price_id = config.STRIPE_FREE_TIER_ID # Default to free
|
||||
if subscription and subscription.get('items'):
|
||||
items = subscription['items'].get('data', [])
|
||||
if items:
|
||||
price_id = items[0]['price']['id']
|
||||
|
||||
tier_info = SUBSCRIPTION_TIERS.get(price_id, SUBSCRIPTION_TIERS[config.STRIPE_FREE_TIER_ID])
|
||||
subscription_limit = tier_info['cost']
|
||||
|
||||
# Get credit usage records for this month to match with messages
|
||||
credit_usage_result = await client.table('credit_usage') \
|
||||
.select('message_id, amount_dollars, created_at') \
|
||||
.eq('user_id', user_id) \
|
||||
.gte('created_at', start_of_month.isoformat()) \
|
||||
.execute()
|
||||
|
||||
# Create a map of message_id to credit usage
|
||||
credit_usage_map = {}
|
||||
if credit_usage_result.data:
|
||||
for usage in credit_usage_result.data:
|
||||
if usage.get('message_id'):
|
||||
credit_usage_map[usage['message_id']] = {
|
||||
'amount': float(usage['amount_dollars']),
|
||||
'created_at': usage['created_at']
|
||||
}
|
||||
|
||||
# Track cumulative usage to determine when credits started being used
|
||||
cumulative_cost = 0.0
|
||||
|
||||
# Process messages into usage log entries
|
||||
processed_logs = []
|
||||
|
||||
for message in messages_result.data:
|
||||
try:
|
||||
# Safely extract usage data with defaults
|
||||
content = message.get('content', {})
|
||||
usage = content.get('usage', {})
|
||||
messages_result = await client.table('messages') \
|
||||
.select(
|
||||
'message_id, thread_id, created_at, content, threads!inner(project_id)'
|
||||
) \
|
||||
.in_('thread_id', thread_ids) \
|
||||
.eq('type', 'assistant_response_end') \
|
||||
.gte('created_at', start_of_month.isoformat()) \
|
||||
.order('created_at', desc=True) \
|
||||
.range(page * items_per_page, (page + 1) * items_per_page - 1) \
|
||||
.execute()
|
||||
except Exception as query_error:
|
||||
logger.error(f"[USAGE_LOGS] user_id={user_id} - Database query failed: {str(query_error)}")
|
||||
logger.error(f"[USAGE_LOGS] user_id={user_id} - Query details: page={page}, items_per_page={items_per_page}, thread_count={len(thread_ids)}")
|
||||
raise
|
||||
|
||||
end_time = time.time()
|
||||
execution_time = end_time - start_time
|
||||
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Database query for usage logs took {execution_time:.3f} seconds")
|
||||
|
||||
if not messages_result.data:
|
||||
logger.info(f"[USAGE_LOGS] user_id={user_id} - No messages found, returning empty result")
|
||||
return {"logs": [], "has_more": False}
|
||||
|
||||
logger.info(f"[USAGE_LOGS] user_id={user_id} - Found {len(messages_result.data)} messages to process")
|
||||
|
||||
# Get the user's subscription tier info for credit checking
|
||||
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Getting subscription info")
|
||||
try:
|
||||
subscription = await get_user_subscription(user_id)
|
||||
price_id = config.STRIPE_FREE_TIER_ID # Default to free
|
||||
if subscription and subscription.get('items'):
|
||||
items = subscription['items'].get('data', [])
|
||||
if items:
|
||||
price_id = items[0]['price']['id']
|
||||
|
||||
# Ensure usage has required fields with safe defaults
|
||||
prompt_tokens = usage.get('prompt_tokens', 0)
|
||||
completion_tokens = usage.get('completion_tokens', 0)
|
||||
model = content.get('model', 'unknown')
|
||||
tier_info = SUBSCRIPTION_TIERS.get(price_id, SUBSCRIPTION_TIERS[config.STRIPE_FREE_TIER_ID])
|
||||
subscription_limit = tier_info['cost']
|
||||
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Subscription limit: {subscription_limit}, price_id: {price_id}")
|
||||
except Exception as sub_error:
|
||||
logger.error(f"[USAGE_LOGS] user_id={user_id} - Error getting subscription info: {str(sub_error)}")
|
||||
# Use free tier as fallback
|
||||
tier_info = SUBSCRIPTION_TIERS[config.STRIPE_FREE_TIER_ID]
|
||||
subscription_limit = tier_info['cost']
|
||||
|
||||
# Get credit usage records for this month to match with messages
|
||||
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Fetching credit usage records")
|
||||
try:
|
||||
credit_usage_result = await client.table('credit_usage') \
|
||||
.select('message_id, amount_dollars, created_at') \
|
||||
.eq('user_id', user_id) \
|
||||
.gte('created_at', start_of_month.isoformat()) \
|
||||
.execute()
|
||||
|
||||
# Safely calculate total tokens
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
|
||||
# Calculate estimated cost using the same logic as calculate_monthly_usage
|
||||
estimated_cost = calculate_token_cost(
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
model
|
||||
)
|
||||
|
||||
cumulative_cost += estimated_cost
|
||||
|
||||
# Safely extract project_id from threads relationship
|
||||
project_id = 'unknown'
|
||||
if message.get('threads') and isinstance(message['threads'], list) and len(message['threads']) > 0:
|
||||
project_id = message['threads'][0].get('project_id', 'unknown')
|
||||
|
||||
# Check if credits were used for this message
|
||||
message_id = message.get('message_id')
|
||||
credit_used = credit_usage_map.get(message_id, {})
|
||||
|
||||
log_entry = {
|
||||
'message_id': message_id or 'unknown',
|
||||
'thread_id': message.get('thread_id', 'unknown'),
|
||||
'created_at': message.get('created_at', None),
|
||||
'content': {
|
||||
'usage': {
|
||||
'prompt_tokens': prompt_tokens,
|
||||
'completion_tokens': completion_tokens
|
||||
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Found {len(credit_usage_result.data) if credit_usage_result.data else 0} credit usage records")
|
||||
except Exception as credit_error:
|
||||
logger.error(f"[USAGE_LOGS] user_id={user_id} - Error fetching credit usage: {str(credit_error)}")
|
||||
credit_usage_result = None
|
||||
|
||||
# Create a map of message_id to credit usage
|
||||
credit_usage_map = {}
|
||||
if credit_usage_result and credit_usage_result.data:
|
||||
for usage in credit_usage_result.data:
|
||||
if usage.get('message_id'):
|
||||
try:
|
||||
credit_usage_map[usage['message_id']] = {
|
||||
'amount': float(usage['amount_dollars']),
|
||||
'created_at': usage['created_at']
|
||||
}
|
||||
except Exception as parse_error:
|
||||
logger.warning(f"[USAGE_LOGS] user_id={user_id} - Error parsing credit usage record: {str(parse_error)}")
|
||||
continue
|
||||
|
||||
# Track cumulative usage to determine when credits started being used
|
||||
cumulative_cost = 0.0
|
||||
|
||||
# Process messages into usage log entries
|
||||
processed_logs = []
|
||||
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Starting to process {len(messages_result.data)} messages")
|
||||
|
||||
for i, message in enumerate(messages_result.data):
|
||||
try:
|
||||
message_id = message.get('message_id', f'unknown_{i}')
|
||||
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Processing message {i+1}/{len(messages_result.data)}: {message_id}")
|
||||
|
||||
# Safely extract usage data with defaults
|
||||
content = message.get('content', {})
|
||||
usage = content.get('usage', {})
|
||||
|
||||
# Ensure usage has required fields with safe defaults
|
||||
prompt_tokens = usage.get('prompt_tokens', 0)
|
||||
completion_tokens = usage.get('completion_tokens', 0)
|
||||
model = content.get('model', 'unknown')
|
||||
|
||||
# Validate token values
|
||||
if not isinstance(prompt_tokens, (int, float)) or prompt_tokens is None:
|
||||
logger.warning(f"[USAGE_LOGS] user_id={user_id} - Invalid prompt_tokens for message {message_id}: {prompt_tokens}")
|
||||
prompt_tokens = 0
|
||||
if not isinstance(completion_tokens, (int, float)) or completion_tokens is None:
|
||||
logger.warning(f"[USAGE_LOGS] user_id={user_id} - Invalid completion_tokens for message {message_id}: {completion_tokens}")
|
||||
completion_tokens = 0
|
||||
|
||||
# Safely calculate total tokens
|
||||
total_tokens = int(prompt_tokens or 0) + int(completion_tokens or 0)
|
||||
|
||||
# Calculate estimated cost using the same logic as calculate_monthly_usage
|
||||
try:
|
||||
estimated_cost = calculate_token_cost(
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
model
|
||||
)
|
||||
except Exception as cost_error:
|
||||
logger.warning(f"[USAGE_LOGS] user_id={user_id} - Error calculating cost for message {message_id}: {str(cost_error)}")
|
||||
estimated_cost = 0.0
|
||||
|
||||
cumulative_cost += estimated_cost
|
||||
|
||||
# Safely extract project_id from threads relationship
|
||||
project_id = 'unknown'
|
||||
try:
|
||||
if message.get('threads') and isinstance(message['threads'], list) and len(message['threads']) > 0:
|
||||
project_id = message['threads'][0].get('project_id', 'unknown')
|
||||
except Exception as project_error:
|
||||
logger.warning(f"[USAGE_LOGS] user_id={user_id} - Error extracting project_id for message {message_id}: {str(project_error)}")
|
||||
|
||||
# Check if credits were used for this message
|
||||
credit_used = credit_usage_map.get(message_id, {})
|
||||
|
||||
# Safely handle datetime serialization for created_at
|
||||
created_at = message.get('created_at')
|
||||
if created_at and isinstance(created_at, datetime):
|
||||
created_at = created_at.isoformat()
|
||||
elif created_at and not isinstance(created_at, str):
|
||||
try:
|
||||
created_at = str(created_at)
|
||||
except Exception:
|
||||
logger.warning(f"[USAGE_LOGS] user_id={user_id} - Could not convert created_at to string for message {message_id}")
|
||||
created_at = None
|
||||
|
||||
log_entry = {
|
||||
'message_id': str(message_id) if message_id else 'unknown',
|
||||
'thread_id': str(message.get('thread_id', 'unknown')),
|
||||
'created_at': created_at,
|
||||
'content': {
|
||||
'usage': {
|
||||
'prompt_tokens': int(prompt_tokens),
|
||||
'completion_tokens': int(completion_tokens)
|
||||
},
|
||||
'model': str(model)
|
||||
},
|
||||
'model': model
|
||||
},
|
||||
'total_tokens': total_tokens,
|
||||
'estimated_cost': estimated_cost,
|
||||
'project_id': project_id,
|
||||
# Add credit usage info
|
||||
'credit_used': credit_used.get('amount', 0) if credit_used else 0,
|
||||
'payment_method': 'credits' if credit_used else 'subscription',
|
||||
'was_over_limit': cumulative_cost > subscription_limit if not credit_used else True
|
||||
'total_tokens': int(total_tokens),
|
||||
'estimated_cost': float(estimated_cost),
|
||||
'project_id': str(project_id),
|
||||
# Add credit usage info
|
||||
'credit_used': float(credit_used.get('amount', 0)) if credit_used else 0.0,
|
||||
'payment_method': 'credits' if credit_used else 'subscription',
|
||||
'was_over_limit': bool(cumulative_cost > subscription_limit if not credit_used else True)
|
||||
}
|
||||
|
||||
# Test JSON serialization of this entry before adding it
|
||||
try:
|
||||
json.dumps(log_entry, default=str)
|
||||
except Exception as json_error:
|
||||
logger.error(f"[USAGE_LOGS] user_id={user_id} - JSON serialization failed for message {message_id}: {str(json_error)}")
|
||||
logger.error(f"[USAGE_LOGS] user_id={user_id} - Problematic log_entry: {log_entry}")
|
||||
continue
|
||||
|
||||
processed_logs.append(log_entry)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[USAGE_LOGS] user_id={user_id} - Error processing usage log entry for message {message.get('message_id', 'unknown')}: {str(e)}")
|
||||
continue
|
||||
|
||||
logger.info(f"[USAGE_LOGS] user_id={user_id} - Successfully processed {len(processed_logs)} messages")
|
||||
|
||||
# Check if there are more results
|
||||
has_more = len(processed_logs) == items_per_page
|
||||
|
||||
result = {
|
||||
"logs": processed_logs,
|
||||
"has_more": bool(has_more),
|
||||
"subscription_limit": float(subscription_limit),
|
||||
"cumulative_cost": float(cumulative_cost)
|
||||
}
|
||||
|
||||
# Validate final JSON serialization
|
||||
try:
|
||||
json.dumps(result, default=str)
|
||||
logger.debug(f"[USAGE_LOGS] user_id={user_id} - Final result JSON validation passed")
|
||||
except Exception as final_json_error:
|
||||
logger.error(f"[USAGE_LOGS] user_id={user_id} - Final result JSON serialization failed: {str(final_json_error)}")
|
||||
logger.error(f"[USAGE_LOGS] user_id={user_id} - Problematic result keys: {list(result.keys())}")
|
||||
# Return safe fallback
|
||||
return {
|
||||
"logs": [],
|
||||
"has_more": False,
|
||||
"subscription_limit": float(subscription_limit),
|
||||
"cumulative_cost": 0.0,
|
||||
"error": "Failed to serialize usage data"
|
||||
}
|
||||
|
||||
processed_logs.append(log_entry)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing usage log entry for message {message.get('message_id', 'unknown')}: {str(e)}")
|
||||
continue
|
||||
|
||||
# Check if there are more results
|
||||
has_more = len(processed_logs) == items_per_page
|
||||
|
||||
return {
|
||||
"logs": processed_logs,
|
||||
"has_more": has_more,
|
||||
"subscription_limit": subscription_limit,
|
||||
"cumulative_cost": cumulative_cost
|
||||
}
|
||||
|
||||
logger.info(f"[USAGE_LOGS] user_id={user_id} - Returning {len(processed_logs)} logs, has_more={has_more}")
|
||||
return result
|
||||
|
||||
except Exception as outer_error:
|
||||
logger.error(f"[USAGE_LOGS] user_id={user_id} - Outer exception in get_usage_logs: {str(outer_error)}")
|
||||
raise
|
||||
|
||||
|
||||
def calculate_token_cost(prompt_tokens: int, completion_tokens: int, model: str) -> float:
|
||||
|
@ -2036,6 +2142,8 @@ async def get_usage_logs_endpoint(
|
|||
current_user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
):
|
||||
"""Get detailed usage logs for a user with pagination."""
|
||||
logger.info(f"[USAGE_LOGS_ENDPOINT] Starting get_usage_logs_endpoint for user_id={current_user_id}, page={page}, items_per_page={items_per_page}")
|
||||
|
||||
try:
|
||||
# Get Supabase client
|
||||
db = DBConnection()
|
||||
|
@ -2043,7 +2151,7 @@ async def get_usage_logs_endpoint(
|
|||
|
||||
# Check if we're in local development mode
|
||||
if config.ENV_MODE == EnvMode.LOCAL:
|
||||
logger.debug("Running in local development mode - usage logs are not available")
|
||||
logger.debug(f"[USAGE_LOGS_ENDPOINT] user_id={current_user_id} - Running in local development mode - usage logs are not available")
|
||||
return {
|
||||
"logs": [],
|
||||
"has_more": False,
|
||||
|
@ -2052,20 +2160,35 @@ async def get_usage_logs_endpoint(
|
|||
|
||||
# Validate pagination parameters
|
||||
if page < 0:
|
||||
logger.error(f"[USAGE_LOGS_ENDPOINT] user_id={current_user_id} - Invalid page parameter: {page}")
|
||||
raise HTTPException(status_code=400, detail="Page must be non-negative")
|
||||
if items_per_page < 1 or items_per_page > 1000:
|
||||
logger.error(f"[USAGE_LOGS_ENDPOINT] user_id={current_user_id} - Invalid items_per_page parameter: {items_per_page}")
|
||||
raise HTTPException(status_code=400, detail="Items per page must be between 1 and 1000")
|
||||
|
||||
# Get usage logs
|
||||
logger.debug(f"[USAGE_LOGS_ENDPOINT] user_id={current_user_id} - Calling get_usage_logs")
|
||||
result = await get_usage_logs(client, current_user_id, page, items_per_page)
|
||||
|
||||
# Check if result contains an error
|
||||
if isinstance(result, dict) and result.get('error'):
|
||||
logger.error(f"[USAGE_LOGS_ENDPOINT] user_id={current_user_id} - Usage logs returned error: {result['error']}")
|
||||
raise HTTPException(status_code=400, detail=f"Failed to retrieve usage logs: {result['error']}")
|
||||
|
||||
logger.info(f"[USAGE_LOGS_ENDPOINT] user_id={current_user_id} - Successfully returned {len(result.get('logs', []))} usage logs")
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting usage logs: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Error getting usage logs: {str(e)}")
|
||||
logger.exception(f"[USAGE_LOGS_ENDPOINT] user_id={current_user_id} - Error getting usage logs: {str(e)}")
|
||||
|
||||
# Check if this is a JSON serialization error
|
||||
if "JSON could not be generated" in str(e) or "JSON" in str(e):
|
||||
logger.error(f"[USAGE_LOGS_ENDPOINT] user_id={current_user_id} - Detected JSON serialization error")
|
||||
raise HTTPException(status_code=400, detail=f"Data serialization error: {str(e)}")
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail=f"Error getting usage logs: {str(e)}")
|
||||
|
||||
@router.get("/subscription-commitment/{subscription_id}")
|
||||
async def get_subscription_commitment(
|
||||
|
|
|
@ -280,8 +280,8 @@ class Configuration:
|
|||
STRIPE_PRODUCT_ID_STAGING: str = 'prod_SCgIj3G7yPOAWY'
|
||||
|
||||
# Sandbox configuration
|
||||
SANDBOX_IMAGE_NAME = "kortix/suna:0.1.3.7"
|
||||
SANDBOX_SNAPSHOT_NAME = "kortix/suna:0.1.3.7"
|
||||
SANDBOX_IMAGE_NAME = "kortix/suna:0.1.3.9"
|
||||
SANDBOX_SNAPSHOT_NAME = "kortix/suna:0.1.3.9"
|
||||
SANDBOX_ENTRYPOINT = "/usr/bin/supervisord -n -c /etc/supervisor/conf.d/supervisord.conf"
|
||||
|
||||
# LangFuse configuration
|
||||
|
|
|
@ -0,0 +1,257 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Script to stop all Daytona sandboxes in "STARTED" state.
|
||||
|
||||
This script connects to Daytona API, lists all sandboxes, filters for those
|
||||
in "STARTED" state, and stops them. Useful for cleanup operations or
|
||||
resource management.
|
||||
|
||||
Usage:
|
||||
python stop_started_sandboxes.py [--dry-run] [--save-json] [--json-file filename]
|
||||
|
||||
Examples:
|
||||
# Dry run to see what would be stopped
|
||||
python stop_started_sandboxes.py --dry-run
|
||||
|
||||
# Actually stop all started sandboxes
|
||||
python stop_started_sandboxes.py
|
||||
|
||||
# Save list of sandboxes to JSON file before stopping
|
||||
python stop_started_sandboxes.py --save-json --json-file started_sandboxes.json
|
||||
"""
|
||||
|
||||
PROD_DAYTONA_API_KEY = "" # Your production Daytona API key
|
||||
|
||||
import dotenv
|
||||
import os
|
||||
dotenv.load_dotenv(".env")
|
||||
|
||||
# Override with production credentials if provided
|
||||
if PROD_DAYTONA_API_KEY:
|
||||
os.environ['DAYTONA_API_KEY'] = PROD_DAYTONA_API_KEY
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Optional
|
||||
from utils.config import config
|
||||
from utils.logger import logger
|
||||
|
||||
try:
|
||||
from daytona import Daytona
|
||||
except ImportError:
|
||||
print("Error: Daytona Python SDK not found. Please install it with: pip install daytona")
|
||||
sys.exit(1)
|
||||
|
||||
def save_sandboxes_as_json(sandboxes_list: List, filename: Optional[str] = None) -> Optional[str]:
|
||||
"""Save sandboxes list as JSON file for debugging/auditing."""
|
||||
if filename is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"started_sandboxes_{timestamp}.json"
|
||||
|
||||
logger.info(f"Saving sandboxes list to {filename}")
|
||||
|
||||
try:
|
||||
# Convert sandbox objects to serializable format
|
||||
serializable_data = []
|
||||
for sandbox in sandboxes_list:
|
||||
sandbox_data = {
|
||||
'id': getattr(sandbox, 'id', 'unknown'),
|
||||
'name': getattr(sandbox, 'name', 'unknown'),
|
||||
'state': getattr(sandbox, 'state', 'unknown'),
|
||||
'created_at': str(getattr(sandbox, 'created_at', 'unknown')),
|
||||
'updated_at': str(getattr(sandbox, 'updated_at', 'unknown')),
|
||||
}
|
||||
serializable_data.append(sandbox_data)
|
||||
|
||||
with open(filename, 'w') as f:
|
||||
json.dump(serializable_data, f, indent=2, default=str)
|
||||
|
||||
logger.info(f"✓ Successfully saved sandboxes list to {filename}")
|
||||
return filename
|
||||
except Exception as e:
|
||||
logger.error(f"✗ Failed to save JSON file: {e}")
|
||||
return None
|
||||
|
||||
def stop_started_sandboxes(dry_run: bool = False, save_json: bool = False, json_filename: Optional[str] = None) -> Dict[str, int]:
|
||||
"""
|
||||
Stop all sandboxes in STARTED state.
|
||||
|
||||
Args:
|
||||
dry_run: If True, only simulate the action without actually stopping
|
||||
save_json: If True, save the list of sandboxes to JSON file
|
||||
json_filename: Custom filename for JSON output
|
||||
|
||||
Returns:
|
||||
Dictionary with statistics about the operation
|
||||
"""
|
||||
# Initialize Daytona client
|
||||
try:
|
||||
daytona = Daytona()
|
||||
logger.info("✓ Connected to Daytona")
|
||||
except Exception as e:
|
||||
logger.error(f"✗ Failed to connect to Daytona: {e}")
|
||||
return {"error": 1}
|
||||
|
||||
# Get all sandboxes
|
||||
try:
|
||||
all_sandboxes = daytona.list()
|
||||
logger.info(f"✓ Found {len(all_sandboxes)} total sandboxes")
|
||||
|
||||
# Print sample sandbox data for debugging
|
||||
if all_sandboxes:
|
||||
logger.info("Sample sandbox data structure:")
|
||||
sample_sandbox = all_sandboxes[0]
|
||||
logger.info(f" - ID: {getattr(sample_sandbox, 'id', 'N/A')}")
|
||||
logger.info(f" - State: {getattr(sample_sandbox, 'state', 'N/A')}")
|
||||
logger.info(f" - Name: {getattr(sample_sandbox, 'name', 'N/A')}")
|
||||
|
||||
# Show a few more samples to see different states
|
||||
logger.info("Additional samples:")
|
||||
for i, sb in enumerate(all_sandboxes[1:6]): # Show next 5 samples
|
||||
logger.info(f" Sample {i+2}: State='{getattr(sb, 'state', 'N/A')}', ID={getattr(sb, 'id', 'N/A')[:20]}...")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"✗ Failed to list sandboxes: {e}")
|
||||
return {"error": 1}
|
||||
|
||||
# Filter for STARTED sandboxes
|
||||
started_sandboxes = [sb for sb in all_sandboxes if getattr(sb, 'state', None) == 'started']
|
||||
logger.info(f"✓ Found {len(started_sandboxes)} sandboxes in STARTED state")
|
||||
|
||||
# Save to JSON if requested
|
||||
if save_json and started_sandboxes:
|
||||
save_sandboxes_as_json(started_sandboxes, json_filename)
|
||||
|
||||
if not started_sandboxes:
|
||||
logger.info("No sandboxes to stop")
|
||||
return {
|
||||
"total_sandboxes": len(all_sandboxes),
|
||||
"started_sandboxes": 0,
|
||||
"stopped": 0,
|
||||
"errors": 0
|
||||
}
|
||||
|
||||
# Track statistics
|
||||
stats = {
|
||||
"total_sandboxes": len(all_sandboxes),
|
||||
"started_sandboxes": len(started_sandboxes),
|
||||
"stopped": 0,
|
||||
"errors": 0
|
||||
}
|
||||
|
||||
# Log some sample IDs for verification
|
||||
sample_ids = [getattr(sb, 'id', 'unknown') for sb in started_sandboxes[:5]]
|
||||
logger.info(f"Sample STARTED sandbox IDs: {sample_ids}...")
|
||||
|
||||
# Stop each started sandbox
|
||||
for i, sandbox in enumerate(started_sandboxes):
|
||||
sandbox_id = getattr(sandbox, 'id', 'unknown')
|
||||
sandbox_name = getattr(sandbox, 'name', 'unknown')
|
||||
|
||||
logger.info(f"[{i+1}/{len(started_sandboxes)}] Processing sandbox: {sandbox_id} ({sandbox_name})")
|
||||
|
||||
try:
|
||||
if dry_run:
|
||||
logger.info(f" [DRY RUN] Would stop sandbox: {sandbox_id}")
|
||||
stats["stopped"] += 1
|
||||
else:
|
||||
logger.info(f" Stopping sandbox: {sandbox_id}")
|
||||
|
||||
# Stop the sandbox
|
||||
sandbox.stop()
|
||||
|
||||
# Wait for sandbox to stop (with timeout)
|
||||
try:
|
||||
sandbox.wait_for_sandbox_stop()
|
||||
logger.info(f" ✓ Successfully stopped sandbox: {sandbox_id}")
|
||||
stats["stopped"] += 1
|
||||
except Exception as wait_error:
|
||||
logger.warning(f" ⚠ Sandbox {sandbox_id} stop command sent, but wait failed: {wait_error}")
|
||||
# Still count as success since stop command was sent
|
||||
stats["stopped"] += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" ✗ Failed to stop sandbox {sandbox_id}: {e}")
|
||||
stats["errors"] += 1
|
||||
|
||||
return stats
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Stop all Daytona sandboxes in STARTED state",
|
||||
epilog="""
|
||||
Examples:
|
||||
# Dry run to see what would be stopped
|
||||
python stop_started_sandboxes.py --dry-run
|
||||
|
||||
# Actually stop all started sandboxes
|
||||
python stop_started_sandboxes.py
|
||||
|
||||
# Save list to JSON and stop sandboxes
|
||||
python stop_started_sandboxes.py --save-json --json-file started_sandboxes.json
|
||||
""",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
parser.add_argument('--dry-run', action='store_true', help='Show what would be stopped without actually stopping')
|
||||
parser.add_argument('--save-json', action='store_true', help='Save list of started sandboxes as JSON file')
|
||||
parser.add_argument('--json-file', type=str, help='Custom filename for JSON output (default: started_sandboxes_TIMESTAMP.json)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Verify configuration
|
||||
logger.info("Configuration check:")
|
||||
logger.info(f" Daytona API Key: {'✓ Configured' if config.DAYTONA_API_KEY else '✗ Missing'}")
|
||||
logger.info(f" Daytona API URL: {config.DAYTONA_SERVER_URL}")
|
||||
logger.info("")
|
||||
|
||||
if args.dry_run:
|
||||
logger.info("=== DRY RUN MODE ===")
|
||||
logger.info("No sandboxes will actually be stopped")
|
||||
logger.info("")
|
||||
|
||||
# Run the stop operation
|
||||
try:
|
||||
stats = stop_started_sandboxes(
|
||||
dry_run=args.dry_run,
|
||||
save_json=args.save_json,
|
||||
json_filename=args.json_file
|
||||
)
|
||||
|
||||
# Print summary
|
||||
logger.info("")
|
||||
logger.info("=== SUMMARY ===")
|
||||
logger.info(f"Total sandboxes: {stats.get('total_sandboxes', 0)}")
|
||||
logger.info(f"Started sandboxes: {stats.get('started_sandboxes', 0)}")
|
||||
logger.info(f"Stopped: {stats.get('stopped', 0)}")
|
||||
logger.info(f"Errors: {stats.get('errors', 0)}")
|
||||
|
||||
if args.dry_run and stats.get('started_sandboxes', 0) > 0:
|
||||
logger.info("")
|
||||
logger.info("To actually stop these sandboxes, run the script without --dry-run")
|
||||
|
||||
success = stats.get('errors', 0) == 0 and 'error' not in stats
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.warning("\n⚠️ Operation cancelled by user")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Script failed: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
# Usage examples:
|
||||
#
|
||||
# 1. Dry run to see what would be stopped:
|
||||
# uv run python -m utils.scripts.stop_started_sandboxes --dry-run
|
||||
#
|
||||
# 2. Actually stop all started sandboxes:
|
||||
# uv run python -m utils.scripts.stop_started_sandboxes
|
||||
#
|
||||
# 3. Save list to JSON and stop sandboxes:
|
||||
# uv run python -m utils.scripts.stop_started_sandboxes --save-json --json-file started_sandboxes.json
|
|
@ -127,8 +127,8 @@ As part of the setup, you'll need to:
|
|||
1. Create a Daytona account
|
||||
2. Generate an API key
|
||||
3. Create a Snapshot:
|
||||
- Name: `kortix/suna:0.1.3.7`
|
||||
- Image name: `kortix/suna:0.1.3.7`
|
||||
- Name: `kortix/suna:0.1.3.9`
|
||||
- Image name: `kortix/suna:0.1.3.9`
|
||||
- Entrypoint: `/usr/bin/supervisord -n -c /etc/supervisor/conf.d/supervisord.conf`
|
||||
|
||||
## Manual Configuration
|
||||
|
|
|
@ -241,7 +241,7 @@ export default function APIKeysPage() {
|
|||
<h1 className="text-2xl font-bold">API Keys</h1>
|
||||
</div>
|
||||
<p className="text-muted-foreground">
|
||||
Manage your API keys for programmatic access to Suna
|
||||
Manage your API keys for programmatic access to your agents
|
||||
</p>
|
||||
</div>
|
||||
|
||||
|
@ -285,6 +285,59 @@ export default function APIKeysPage() {
|
|||
</CardContent>
|
||||
</Card>
|
||||
|
||||
{/* Claude Code Integration Notice */}
|
||||
<Card className="border-purple-200/60 bg-gradient-to-br from-purple-50/80 to-violet-50/40 dark:from-purple-950/20 dark:to-violet-950/10 dark:border-purple-800/30">
|
||||
<CardContent className="p-6">
|
||||
<div className="flex items-start gap-4">
|
||||
<div className="relative">
|
||||
<div className="flex h-12 w-12 items-center justify-center rounded-2xl bg-gradient-to-br from-purple-500/20 to-violet-600/10 border border-purple-500/20">
|
||||
<Shield className="w-6 h-6 text-purple-600 dark:text-purple-400" />
|
||||
</div>
|
||||
<div className="absolute -top-1 -right-1">
|
||||
<Badge variant="secondary" className="h-5 px-1.5 text-xs bg-purple-100 text-purple-800 border-purple-200 dark:bg-purple-900/30 dark:text-purple-300 dark:border-purple-700">
|
||||
New
|
||||
</Badge>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex-1 space-y-3">
|
||||
<div>
|
||||
<h3 className="text-base font-semibold text-purple-900 dark:text-purple-100 mb-1">
|
||||
Claude Code Integration
|
||||
</h3>
|
||||
<p className="text-sm text-purple-700 dark:text-purple-300 leading-relaxed mb-3">
|
||||
Connect your agents to Claude Code for seamless AI-powered collaboration.
|
||||
Use your API key to add an MCP server in Claude Code.
|
||||
</p>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<p className="text-xs font-medium text-purple-800 dark:text-purple-200 mb-1">
|
||||
Connection Command:
|
||||
</p>
|
||||
<div className="bg-purple-900/10 dark:bg-purple-900/30 border border-purple-200/50 dark:border-purple-700/50 rounded-lg p-3">
|
||||
<code className="text-xs font-mono text-purple-800 dark:text-purple-200 break-all">
|
||||
claude mcp add AgentPress https://YOUR_DOMAIN/api/mcp --header "Authorization=Bearer YOUR_API_KEY"
|
||||
</code>
|
||||
</div>
|
||||
<p className="text-xs text-purple-600 dark:text-purple-400">
|
||||
Replace <code className="bg-purple-100 dark:bg-purple-900/50 px-1 rounded">YOUR_DOMAIN</code> and <code className="bg-purple-100 dark:bg-purple-900/50 px-1 rounded">YOUR_API_KEY</code> with your actual API key from below.
|
||||
</p>
|
||||
</div>
|
||||
<div className="flex items-center gap-3">
|
||||
<a
|
||||
href="https://docs.anthropic.com/en/docs/claude-code/mcp"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="inline-flex items-center gap-2 text-sm font-medium text-purple-600 hover:text-purple-800 dark:text-purple-400 dark:hover:text-purple-300 transition-colors"
|
||||
>
|
||||
<span>Learn about Claude Code MCP</span>
|
||||
<ExternalLink className="w-4 h-4" />
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
{/* Header Actions */}
|
||||
<div className="flex justify-between items-center">
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
|
|
4
setup.py
4
setup.py
|
@ -669,8 +669,8 @@ class SetupWizard:
|
|||
f"Visit {Colors.GREEN}https://app.daytona.io/dashboard/snapshots{Colors.ENDC}{Colors.CYAN} to create a snapshot."
|
||||
)
|
||||
print_info("Create a snapshot with these exact settings:")
|
||||
print_info(f" - Name:\t\t{Colors.GREEN}kortix/suna:0.1.3.7{Colors.ENDC}")
|
||||
print_info(f" - Snapshot name:\t{Colors.GREEN}kortix/suna:0.1.3.7{Colors.ENDC}")
|
||||
print_info(f" - Name:\t\t{Colors.GREEN}kortix/suna:0.1.3.9{Colors.ENDC}")
|
||||
print_info(f" - Snapshot name:\t{Colors.GREEN}kortix/suna:0.1.3.9{Colors.ENDC}")
|
||||
print_info(
|
||||
f" - Entrypoint:\t{Colors.GREEN}/usr/bin/supervisord -n -c /etc/supervisor/conf.d/supervisord.conf{Colors.ENDC}"
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue