mirror of https://github.com/kortix-ai/suna.git
feat(deepai): add deepai specific endpoints
This commit is contained in:
parent
f4dc33ab13
commit
37f8e63cf1
|
@ -2545,3 +2545,105 @@ async def admin_install_suna_for_user(
|
|||
detail=f"Failed to install Suna agent for user {account_id}"
|
||||
)
|
||||
|
||||
@router.get("/threads")
|
||||
async def get_user_threads(
|
||||
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||
page: Optional[int] = Query(1, ge=1, description="Page number (1-based)"),
|
||||
limit: Optional[int] = Query(1000, ge=1, le=1000, description="Number of items per page (max 1000)")
|
||||
):
|
||||
"""Get all threads for the current user (account_id = user_id) with pagination."""
|
||||
logger.info(f"Fetching threads for user: {user_id} (page={page}, limit={limit})")
|
||||
client = await db.client
|
||||
try:
|
||||
offset = (page - 1) * limit
|
||||
threads_query = client.table('threads').select('*').eq('account_id', user_id).order('created_at', desc=True)
|
||||
total_result = await threads_query.execute()
|
||||
total_count = len(total_result.data) if total_result.data else 0
|
||||
threads_result = await threads_query.range(offset, offset + limit - 1).execute()
|
||||
threads = threads_result.data or []
|
||||
total_pages = (total_count + limit - 1) // limit if total_count else 0
|
||||
return {
|
||||
"threads": threads,
|
||||
"pagination": {
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"total": total_count,
|
||||
"pages": total_pages
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching threads for user {user_id}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to fetch threads: {str(e)}")
|
||||
|
||||
@router.get("/threads/{thread_id}/messages")
|
||||
async def get_thread_messages(
|
||||
thread_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||
order: str = Query("desc", description="Order by created_at: 'asc' or 'desc'")
|
||||
):
|
||||
"""Get all messages for a thread, fetching in batches of 1000 from the DB to avoid large queries."""
|
||||
logger.info(f"Fetching all messages for thread: {thread_id}, order={order}")
|
||||
client = await db.client
|
||||
await verify_thread_access(client, thread_id, user_id)
|
||||
try:
|
||||
batch_size = 1000
|
||||
offset = 0
|
||||
all_messages = []
|
||||
while True:
|
||||
query = client.table('messages').select('*').eq('thread_id', thread_id)
|
||||
query = query.order('created_at', desc=(order == "desc"))
|
||||
query = query.range(offset, offset + batch_size - 1)
|
||||
messages_result = await query.execute()
|
||||
batch = messages_result.data or []
|
||||
all_messages.extend(batch)
|
||||
logger.debug(f"Fetched batch of {len(batch)} messages (offset {offset})")
|
||||
if len(batch) < batch_size:
|
||||
break
|
||||
offset += batch_size
|
||||
return {"messages": all_messages}
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching messages for thread {thread_id}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to fetch messages: {str(e)}")
|
||||
|
||||
@router.get("/agent-runs/{agent_run_id}")
|
||||
async def get_agent_run(
|
||||
agent_run_id: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||
):
|
||||
"""Get an agent run by ID"""
|
||||
logger.info(f"Fetching agent run: {agent_run_id}")
|
||||
client = await db.client
|
||||
try:
|
||||
agent_run_result = await client.table('agent_runs').select('*').eq('agent_run_id', agent_run_id).eq('account_id', user_id).execute()
|
||||
if not agent_run_result.data:
|
||||
raise HTTPException(status_code=404, detail="Agent run not found")
|
||||
return agent_run_result.data[0]
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching agent run {agent_run_id}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to fetch agent run: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/threads/{thread_id}/messages/add")
|
||||
async def add_message_to_thread(
|
||||
thread_id: str,
|
||||
message: str,
|
||||
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||
):
|
||||
"""Add a message to a thread"""
|
||||
logger.info(f"Adding message to thread: {thread_id}")
|
||||
client = await db.client
|
||||
await verify_thread_access(client, thread_id, user_id)
|
||||
try:
|
||||
message_result = await client.table('messages').insert({
|
||||
'thread_id': thread_id,
|
||||
'type': 'user',
|
||||
'is_llm_message': True,
|
||||
'content': {
|
||||
"role": "user",
|
||||
"content": message
|
||||
}
|
||||
}).execute()
|
||||
return message_result.data[0]
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding message to thread {thread_id}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to add message: {str(e)}")
|
||||
|
|
|
@ -155,7 +155,7 @@ app.add_middleware(
|
|||
allow_origin_regex=allow_origin_regex,
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
allow_headers=["Content-Type", "Authorization", "X-Project-Id", "X-MCP-URL", "X-MCP-Type", "X-MCP-Headers", "X-Refresh-Token"],
|
||||
allow_headers=["Content-Type", "Authorization", "X-Project-Id", "X-MCP-URL", "X-MCP-Type", "X-MCP-Headers", "X-Refresh-Token", "X-API-Key"],
|
||||
)
|
||||
|
||||
# Create a main API router
|
||||
|
|
|
@ -11,7 +11,7 @@ from datetime import datetime, timezone
|
|||
from utils.logger import logger
|
||||
from utils.config import config, EnvMode
|
||||
from services.supabase import DBConnection
|
||||
from utils.auth_utils import get_current_user_id_from_jwt
|
||||
from utils.auth_utils import get_current_user_id_from_jwt, is_deepai_user
|
||||
from pydantic import BaseModel
|
||||
from utils.constants import MODEL_ACCESS_TIERS, MODEL_NAME_ALIASES, HARDCODED_MODEL_PRICES
|
||||
from litellm.cost_calculator import cost_per_token
|
||||
|
@ -419,7 +419,9 @@ async def get_allowed_models_for_user(client, user_id: str):
|
|||
Returns:
|
||||
List of model names allowed for the user's subscription tier.
|
||||
"""
|
||||
|
||||
if is_deepai_user(user_id):
|
||||
# DeepAI users get all models
|
||||
return list(set(MODEL_NAME_ALIASES.values()))
|
||||
subscription = await get_user_subscription(user_id)
|
||||
tier_name = 'free'
|
||||
|
||||
|
@ -447,7 +449,9 @@ async def can_use_model(client, user_id: str, model_name: str):
|
|||
"plan_name": "Local Development",
|
||||
"minutes_limit": "no limit"
|
||||
}
|
||||
|
||||
if is_deepai_user(user_id):
|
||||
# DeepAI users can use any model
|
||||
return True, "DeepAI user: all models allowed", list(set(MODEL_NAME_ALIASES.values()))
|
||||
allowed_models = await get_allowed_models_for_user(client, user_id)
|
||||
resolved_model = MODEL_NAME_ALIASES.get(model_name, model_name)
|
||||
if resolved_model in allowed_models:
|
||||
|
@ -469,7 +473,13 @@ async def check_billing_status(client, user_id: str) -> Tuple[bool, str, Optiona
|
|||
"plan_name": "Local Development",
|
||||
"minutes_limit": "no limit"
|
||||
}
|
||||
|
||||
if is_deepai_user(user_id):
|
||||
# DeepAI users have infinite usage
|
||||
return True, "DeepAI user: unlimited usage", {
|
||||
"price_id": "deepai",
|
||||
"plan_name": "DeepAI",
|
||||
"minutes_limit": "no limit"
|
||||
}
|
||||
# Get current subscription
|
||||
subscription = await get_user_subscription(user_id)
|
||||
# print("Current subscription:", subscription)
|
||||
|
@ -1114,6 +1124,23 @@ async def get_available_models(
|
|||
"total_models": len(model_info)
|
||||
}
|
||||
|
||||
if is_deepai_user(current_user_id):
|
||||
# DeepAI users get all models, all available
|
||||
model_info = []
|
||||
for short_name, full_name in MODEL_NAME_ALIASES.items():
|
||||
model_info.append({
|
||||
"id": full_name,
|
||||
"display_name": short_name,
|
||||
"short_name": short_name,
|
||||
"requires_subscription": False,
|
||||
"is_available": True
|
||||
})
|
||||
return {
|
||||
"models": model_info,
|
||||
"subscription_tier": "DeepAI",
|
||||
"total_models": len(model_info)
|
||||
}
|
||||
|
||||
# For non-local mode, get list of allowed models for this user
|
||||
allowed_models = await get_allowed_models_for_user(client, current_user_id)
|
||||
free_tier_models = MODEL_ACCESS_TIERS.get('free', [])
|
||||
|
|
|
@ -5,6 +5,36 @@ import jwt
|
|||
from jwt.exceptions import PyJWTError
|
||||
from utils.logger import structlog
|
||||
from utils.config import config
|
||||
import os
|
||||
|
||||
# DeepAI user UUID
|
||||
DEEPAI_USER_ID = os.getenv('DEEPAI_USER_ID', '00000000-0000-0000-0000-000000000000')
|
||||
DEEPAI_API_KEY = os.getenv('DEEPAI_API_KEY', '00000000-0000-0000-0000-000000000000')
|
||||
|
||||
def str_safe_compare(str1: str, str2: str) -> bool:
|
||||
"""
|
||||
Compare two strings by first SHA256 hashing them and then comparing the hashes.
|
||||
This is a safe way to compare strings that are not known to be equal.
|
||||
"""
|
||||
import hashlib
|
||||
return hashlib.sha256(str1.encode()).hexdigest() == hashlib.sha256(str2.encode()).hexdigest()
|
||||
|
||||
def is_deepai_user(user_id: str) -> bool:
|
||||
"""
|
||||
Check if a user ID belongs to an deepai account.
|
||||
|
||||
This function is maintained for backward compatibility. The deepai user
|
||||
now works like any other user with a proper UUID, so this check is mainly
|
||||
used for legacy code and documentation purposes.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to check
|
||||
|
||||
Returns:
|
||||
bool: True if the user is an deepai user, False otherwise
|
||||
"""
|
||||
# Check for the specific deepai user UUID or the old string format for backward compatibility
|
||||
return str_safe_compare(user_id, DEEPAI_USER_ID)
|
||||
|
||||
# This function extracts the user ID from Supabase JWT
|
||||
async def get_current_user_id_from_jwt(request: Request) -> str:
|
||||
|
@ -23,6 +53,12 @@ async def get_current_user_id_from_jwt(request: Request) -> str:
|
|||
Raises:
|
||||
HTTPException: If no valid token is found or if the token is invalid
|
||||
"""
|
||||
|
||||
x_api_key = request.headers.get('x-api-key')
|
||||
|
||||
if str_safe_compare(x_api_key, DEEPAI_API_KEY):
|
||||
return DEEPAI_USER_ID
|
||||
|
||||
auth_header = request.headers.get('Authorization')
|
||||
|
||||
if not auth_header or not auth_header.startswith('Bearer '):
|
||||
|
@ -126,6 +162,9 @@ async def get_user_id_from_stream_auth(
|
|||
try:
|
||||
# Try to get user_id from token in query param (for EventSource which can't set headers)
|
||||
if token:
|
||||
if str_safe_compare(token, DEEPAI_API_KEY):
|
||||
return DEEPAI_USER_ID
|
||||
|
||||
try:
|
||||
# For Supabase JWT, we just need to decode and extract the user ID
|
||||
payload = jwt.decode(token, options={"verify_signature": False})
|
||||
|
|
Loading…
Reference in New Issue