feat(deepai): add deepai specific endpoints

This commit is contained in:
mykonos-ibiza 2025-07-26 18:47:22 +05:30
parent f4dc33ab13
commit 37f8e63cf1
4 changed files with 173 additions and 5 deletions

View File

@ -2545,3 +2545,105 @@ async def admin_install_suna_for_user(
detail=f"Failed to install Suna agent for user {account_id}"
)
@router.get("/threads")
async def get_user_threads(
user_id: str = Depends(get_current_user_id_from_jwt),
page: Optional[int] = Query(1, ge=1, description="Page number (1-based)"),
limit: Optional[int] = Query(1000, ge=1, le=1000, description="Number of items per page (max 1000)")
):
"""Get all threads for the current user (account_id = user_id) with pagination."""
logger.info(f"Fetching threads for user: {user_id} (page={page}, limit={limit})")
client = await db.client
try:
offset = (page - 1) * limit
threads_query = client.table('threads').select('*').eq('account_id', user_id).order('created_at', desc=True)
total_result = await threads_query.execute()
total_count = len(total_result.data) if total_result.data else 0
threads_result = await threads_query.range(offset, offset + limit - 1).execute()
threads = threads_result.data or []
total_pages = (total_count + limit - 1) // limit if total_count else 0
return {
"threads": threads,
"pagination": {
"page": page,
"limit": limit,
"total": total_count,
"pages": total_pages
}
}
except Exception as e:
logger.error(f"Error fetching threads for user {user_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to fetch threads: {str(e)}")
@router.get("/threads/{thread_id}/messages")
async def get_thread_messages(
thread_id: str,
user_id: str = Depends(get_current_user_id_from_jwt),
order: str = Query("desc", description="Order by created_at: 'asc' or 'desc'")
):
"""Get all messages for a thread, fetching in batches of 1000 from the DB to avoid large queries."""
logger.info(f"Fetching all messages for thread: {thread_id}, order={order}")
client = await db.client
await verify_thread_access(client, thread_id, user_id)
try:
batch_size = 1000
offset = 0
all_messages = []
while True:
query = client.table('messages').select('*').eq('thread_id', thread_id)
query = query.order('created_at', desc=(order == "desc"))
query = query.range(offset, offset + batch_size - 1)
messages_result = await query.execute()
batch = messages_result.data or []
all_messages.extend(batch)
logger.debug(f"Fetched batch of {len(batch)} messages (offset {offset})")
if len(batch) < batch_size:
break
offset += batch_size
return {"messages": all_messages}
except Exception as e:
logger.error(f"Error fetching messages for thread {thread_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to fetch messages: {str(e)}")
@router.get("/agent-runs/{agent_run_id}")
async def get_agent_run(
agent_run_id: str,
user_id: str = Depends(get_current_user_id_from_jwt),
):
"""Get an agent run by ID"""
logger.info(f"Fetching agent run: {agent_run_id}")
client = await db.client
try:
agent_run_result = await client.table('agent_runs').select('*').eq('agent_run_id', agent_run_id).eq('account_id', user_id).execute()
if not agent_run_result.data:
raise HTTPException(status_code=404, detail="Agent run not found")
return agent_run_result.data[0]
except Exception as e:
logger.error(f"Error fetching agent run {agent_run_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to fetch agent run: {str(e)}")
@router.post("/threads/{thread_id}/messages/add")
async def add_message_to_thread(
thread_id: str,
message: str,
user_id: str = Depends(get_current_user_id_from_jwt),
):
"""Add a message to a thread"""
logger.info(f"Adding message to thread: {thread_id}")
client = await db.client
await verify_thread_access(client, thread_id, user_id)
try:
message_result = await client.table('messages').insert({
'thread_id': thread_id,
'type': 'user',
'is_llm_message': True,
'content': {
"role": "user",
"content": message
}
}).execute()
return message_result.data[0]
except Exception as e:
logger.error(f"Error adding message to thread {thread_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to add message: {str(e)}")

View File

@ -155,7 +155,7 @@ app.add_middleware(
allow_origin_regex=allow_origin_regex,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["Content-Type", "Authorization", "X-Project-Id", "X-MCP-URL", "X-MCP-Type", "X-MCP-Headers", "X-Refresh-Token"],
allow_headers=["Content-Type", "Authorization", "X-Project-Id", "X-MCP-URL", "X-MCP-Type", "X-MCP-Headers", "X-Refresh-Token", "X-API-Key"],
)
# Create a main API router

View File

@ -11,7 +11,7 @@ from datetime import datetime, timezone
from utils.logger import logger
from utils.config import config, EnvMode
from services.supabase import DBConnection
from utils.auth_utils import get_current_user_id_from_jwt
from utils.auth_utils import get_current_user_id_from_jwt, is_deepai_user
from pydantic import BaseModel
from utils.constants import MODEL_ACCESS_TIERS, MODEL_NAME_ALIASES, HARDCODED_MODEL_PRICES
from litellm.cost_calculator import cost_per_token
@ -419,7 +419,9 @@ async def get_allowed_models_for_user(client, user_id: str):
Returns:
List of model names allowed for the user's subscription tier.
"""
if is_deepai_user(user_id):
# DeepAI users get all models
return list(set(MODEL_NAME_ALIASES.values()))
subscription = await get_user_subscription(user_id)
tier_name = 'free'
@ -447,7 +449,9 @@ async def can_use_model(client, user_id: str, model_name: str):
"plan_name": "Local Development",
"minutes_limit": "no limit"
}
if is_deepai_user(user_id):
# DeepAI users can use any model
return True, "DeepAI user: all models allowed", list(set(MODEL_NAME_ALIASES.values()))
allowed_models = await get_allowed_models_for_user(client, user_id)
resolved_model = MODEL_NAME_ALIASES.get(model_name, model_name)
if resolved_model in allowed_models:
@ -469,7 +473,13 @@ async def check_billing_status(client, user_id: str) -> Tuple[bool, str, Optiona
"plan_name": "Local Development",
"minutes_limit": "no limit"
}
if is_deepai_user(user_id):
# DeepAI users have infinite usage
return True, "DeepAI user: unlimited usage", {
"price_id": "deepai",
"plan_name": "DeepAI",
"minutes_limit": "no limit"
}
# Get current subscription
subscription = await get_user_subscription(user_id)
# print("Current subscription:", subscription)
@ -1114,6 +1124,23 @@ async def get_available_models(
"total_models": len(model_info)
}
if is_deepai_user(current_user_id):
# DeepAI users get all models, all available
model_info = []
for short_name, full_name in MODEL_NAME_ALIASES.items():
model_info.append({
"id": full_name,
"display_name": short_name,
"short_name": short_name,
"requires_subscription": False,
"is_available": True
})
return {
"models": model_info,
"subscription_tier": "DeepAI",
"total_models": len(model_info)
}
# For non-local mode, get list of allowed models for this user
allowed_models = await get_allowed_models_for_user(client, current_user_id)
free_tier_models = MODEL_ACCESS_TIERS.get('free', [])

View File

@ -5,6 +5,36 @@ import jwt
from jwt.exceptions import PyJWTError
from utils.logger import structlog
from utils.config import config
import os
# DeepAI user UUID
DEEPAI_USER_ID = os.getenv('DEEPAI_USER_ID', '00000000-0000-0000-0000-000000000000')
DEEPAI_API_KEY = os.getenv('DEEPAI_API_KEY', '00000000-0000-0000-0000-000000000000')
def str_safe_compare(str1: str, str2: str) -> bool:
"""
Compare two strings by first SHA256 hashing them and then comparing the hashes.
This is a safe way to compare strings that are not known to be equal.
"""
import hashlib
return hashlib.sha256(str1.encode()).hexdigest() == hashlib.sha256(str2.encode()).hexdigest()
def is_deepai_user(user_id: str) -> bool:
"""
Check if a user ID belongs to an deepai account.
This function is maintained for backward compatibility. The deepai user
now works like any other user with a proper UUID, so this check is mainly
used for legacy code and documentation purposes.
Args:
user_id: The user ID to check
Returns:
bool: True if the user is an deepai user, False otherwise
"""
# Check for the specific deepai user UUID or the old string format for backward compatibility
return str_safe_compare(user_id, DEEPAI_USER_ID)
# This function extracts the user ID from Supabase JWT
async def get_current_user_id_from_jwt(request: Request) -> str:
@ -23,6 +53,12 @@ async def get_current_user_id_from_jwt(request: Request) -> str:
Raises:
HTTPException: If no valid token is found or if the token is invalid
"""
x_api_key = request.headers.get('x-api-key')
if str_safe_compare(x_api_key, DEEPAI_API_KEY):
return DEEPAI_USER_ID
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
@ -126,6 +162,9 @@ async def get_user_id_from_stream_auth(
try:
# Try to get user_id from token in query param (for EventSource which can't set headers)
if token:
if str_safe_compare(token, DEEPAI_API_KEY):
return DEEPAI_USER_ID
try:
# For Supabase JWT, we just need to decode and extract the user ID
payload = jwt.decode(token, options={"verify_signature": False})