feat(deepai): add deepai specific endpoints

This commit is contained in:
mykonos-ibiza 2025-07-26 18:47:22 +05:30
parent f4dc33ab13
commit 37f8e63cf1
4 changed files with 173 additions and 5 deletions

View File

@ -2545,3 +2545,105 @@ async def admin_install_suna_for_user(
detail=f"Failed to install Suna agent for user {account_id}" detail=f"Failed to install Suna agent for user {account_id}"
) )
@router.get("/threads")
async def get_user_threads(
user_id: str = Depends(get_current_user_id_from_jwt),
page: Optional[int] = Query(1, ge=1, description="Page number (1-based)"),
limit: Optional[int] = Query(1000, ge=1, le=1000, description="Number of items per page (max 1000)")
):
"""Get all threads for the current user (account_id = user_id) with pagination."""
logger.info(f"Fetching threads for user: {user_id} (page={page}, limit={limit})")
client = await db.client
try:
offset = (page - 1) * limit
threads_query = client.table('threads').select('*').eq('account_id', user_id).order('created_at', desc=True)
total_result = await threads_query.execute()
total_count = len(total_result.data) if total_result.data else 0
threads_result = await threads_query.range(offset, offset + limit - 1).execute()
threads = threads_result.data or []
total_pages = (total_count + limit - 1) // limit if total_count else 0
return {
"threads": threads,
"pagination": {
"page": page,
"limit": limit,
"total": total_count,
"pages": total_pages
}
}
except Exception as e:
logger.error(f"Error fetching threads for user {user_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to fetch threads: {str(e)}")
@router.get("/threads/{thread_id}/messages")
async def get_thread_messages(
thread_id: str,
user_id: str = Depends(get_current_user_id_from_jwt),
order: str = Query("desc", description="Order by created_at: 'asc' or 'desc'")
):
"""Get all messages for a thread, fetching in batches of 1000 from the DB to avoid large queries."""
logger.info(f"Fetching all messages for thread: {thread_id}, order={order}")
client = await db.client
await verify_thread_access(client, thread_id, user_id)
try:
batch_size = 1000
offset = 0
all_messages = []
while True:
query = client.table('messages').select('*').eq('thread_id', thread_id)
query = query.order('created_at', desc=(order == "desc"))
query = query.range(offset, offset + batch_size - 1)
messages_result = await query.execute()
batch = messages_result.data or []
all_messages.extend(batch)
logger.debug(f"Fetched batch of {len(batch)} messages (offset {offset})")
if len(batch) < batch_size:
break
offset += batch_size
return {"messages": all_messages}
except Exception as e:
logger.error(f"Error fetching messages for thread {thread_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to fetch messages: {str(e)}")
@router.get("/agent-runs/{agent_run_id}")
async def get_agent_run(
agent_run_id: str,
user_id: str = Depends(get_current_user_id_from_jwt),
):
"""Get an agent run by ID"""
logger.info(f"Fetching agent run: {agent_run_id}")
client = await db.client
try:
agent_run_result = await client.table('agent_runs').select('*').eq('agent_run_id', agent_run_id).eq('account_id', user_id).execute()
if not agent_run_result.data:
raise HTTPException(status_code=404, detail="Agent run not found")
return agent_run_result.data[0]
except Exception as e:
logger.error(f"Error fetching agent run {agent_run_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to fetch agent run: {str(e)}")
@router.post("/threads/{thread_id}/messages/add")
async def add_message_to_thread(
thread_id: str,
message: str,
user_id: str = Depends(get_current_user_id_from_jwt),
):
"""Add a message to a thread"""
logger.info(f"Adding message to thread: {thread_id}")
client = await db.client
await verify_thread_access(client, thread_id, user_id)
try:
message_result = await client.table('messages').insert({
'thread_id': thread_id,
'type': 'user',
'is_llm_message': True,
'content': {
"role": "user",
"content": message
}
}).execute()
return message_result.data[0]
except Exception as e:
logger.error(f"Error adding message to thread {thread_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to add message: {str(e)}")

View File

@ -155,7 +155,7 @@ app.add_middleware(
allow_origin_regex=allow_origin_regex, allow_origin_regex=allow_origin_regex,
allow_credentials=True, allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["Content-Type", "Authorization", "X-Project-Id", "X-MCP-URL", "X-MCP-Type", "X-MCP-Headers", "X-Refresh-Token"], allow_headers=["Content-Type", "Authorization", "X-Project-Id", "X-MCP-URL", "X-MCP-Type", "X-MCP-Headers", "X-Refresh-Token", "X-API-Key"],
) )
# Create a main API router # Create a main API router

View File

@ -11,7 +11,7 @@ from datetime import datetime, timezone
from utils.logger import logger from utils.logger import logger
from utils.config import config, EnvMode from utils.config import config, EnvMode
from services.supabase import DBConnection from services.supabase import DBConnection
from utils.auth_utils import get_current_user_id_from_jwt from utils.auth_utils import get_current_user_id_from_jwt, is_deepai_user
from pydantic import BaseModel from pydantic import BaseModel
from utils.constants import MODEL_ACCESS_TIERS, MODEL_NAME_ALIASES, HARDCODED_MODEL_PRICES from utils.constants import MODEL_ACCESS_TIERS, MODEL_NAME_ALIASES, HARDCODED_MODEL_PRICES
from litellm.cost_calculator import cost_per_token from litellm.cost_calculator import cost_per_token
@ -419,7 +419,9 @@ async def get_allowed_models_for_user(client, user_id: str):
Returns: Returns:
List of model names allowed for the user's subscription tier. List of model names allowed for the user's subscription tier.
""" """
if is_deepai_user(user_id):
# DeepAI users get all models
return list(set(MODEL_NAME_ALIASES.values()))
subscription = await get_user_subscription(user_id) subscription = await get_user_subscription(user_id)
tier_name = 'free' tier_name = 'free'
@ -447,7 +449,9 @@ async def can_use_model(client, user_id: str, model_name: str):
"plan_name": "Local Development", "plan_name": "Local Development",
"minutes_limit": "no limit" "minutes_limit": "no limit"
} }
if is_deepai_user(user_id):
# DeepAI users can use any model
return True, "DeepAI user: all models allowed", list(set(MODEL_NAME_ALIASES.values()))
allowed_models = await get_allowed_models_for_user(client, user_id) allowed_models = await get_allowed_models_for_user(client, user_id)
resolved_model = MODEL_NAME_ALIASES.get(model_name, model_name) resolved_model = MODEL_NAME_ALIASES.get(model_name, model_name)
if resolved_model in allowed_models: if resolved_model in allowed_models:
@ -469,7 +473,13 @@ async def check_billing_status(client, user_id: str) -> Tuple[bool, str, Optiona
"plan_name": "Local Development", "plan_name": "Local Development",
"minutes_limit": "no limit" "minutes_limit": "no limit"
} }
if is_deepai_user(user_id):
# DeepAI users have infinite usage
return True, "DeepAI user: unlimited usage", {
"price_id": "deepai",
"plan_name": "DeepAI",
"minutes_limit": "no limit"
}
# Get current subscription # Get current subscription
subscription = await get_user_subscription(user_id) subscription = await get_user_subscription(user_id)
# print("Current subscription:", subscription) # print("Current subscription:", subscription)
@ -1114,6 +1124,23 @@ async def get_available_models(
"total_models": len(model_info) "total_models": len(model_info)
} }
if is_deepai_user(current_user_id):
# DeepAI users get all models, all available
model_info = []
for short_name, full_name in MODEL_NAME_ALIASES.items():
model_info.append({
"id": full_name,
"display_name": short_name,
"short_name": short_name,
"requires_subscription": False,
"is_available": True
})
return {
"models": model_info,
"subscription_tier": "DeepAI",
"total_models": len(model_info)
}
# For non-local mode, get list of allowed models for this user # For non-local mode, get list of allowed models for this user
allowed_models = await get_allowed_models_for_user(client, current_user_id) allowed_models = await get_allowed_models_for_user(client, current_user_id)
free_tier_models = MODEL_ACCESS_TIERS.get('free', []) free_tier_models = MODEL_ACCESS_TIERS.get('free', [])

View File

@ -5,6 +5,36 @@ import jwt
from jwt.exceptions import PyJWTError from jwt.exceptions import PyJWTError
from utils.logger import structlog from utils.logger import structlog
from utils.config import config from utils.config import config
import os
# DeepAI user UUID
DEEPAI_USER_ID = os.getenv('DEEPAI_USER_ID', '00000000-0000-0000-0000-000000000000')
DEEPAI_API_KEY = os.getenv('DEEPAI_API_KEY', '00000000-0000-0000-0000-000000000000')
def str_safe_compare(str1: str, str2: str) -> bool:
"""
Compare two strings by first SHA256 hashing them and then comparing the hashes.
This is a safe way to compare strings that are not known to be equal.
"""
import hashlib
return hashlib.sha256(str1.encode()).hexdigest() == hashlib.sha256(str2.encode()).hexdigest()
def is_deepai_user(user_id: str) -> bool:
"""
Check if a user ID belongs to an deepai account.
This function is maintained for backward compatibility. The deepai user
now works like any other user with a proper UUID, so this check is mainly
used for legacy code and documentation purposes.
Args:
user_id: The user ID to check
Returns:
bool: True if the user is an deepai user, False otherwise
"""
# Check for the specific deepai user UUID or the old string format for backward compatibility
return str_safe_compare(user_id, DEEPAI_USER_ID)
# This function extracts the user ID from Supabase JWT # This function extracts the user ID from Supabase JWT
async def get_current_user_id_from_jwt(request: Request) -> str: async def get_current_user_id_from_jwt(request: Request) -> str:
@ -23,6 +53,12 @@ async def get_current_user_id_from_jwt(request: Request) -> str:
Raises: Raises:
HTTPException: If no valid token is found or if the token is invalid HTTPException: If no valid token is found or if the token is invalid
""" """
x_api_key = request.headers.get('x-api-key')
if str_safe_compare(x_api_key, DEEPAI_API_KEY):
return DEEPAI_USER_ID
auth_header = request.headers.get('Authorization') auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '): if not auth_header or not auth_header.startswith('Bearer '):
@ -126,6 +162,9 @@ async def get_user_id_from_stream_auth(
try: try:
# Try to get user_id from token in query param (for EventSource which can't set headers) # Try to get user_id from token in query param (for EventSource which can't set headers)
if token: if token:
if str_safe_compare(token, DEEPAI_API_KEY):
return DEEPAI_USER_ID
try: try:
# For Supabase JWT, we just need to decode and extract the user ID # For Supabase JWT, we just need to decode and extract the user ID
payload = jwt.decode(token, options={"verify_signature": False}) payload = jwt.decode(token, options={"verify_signature": False})