From 37f8e63cf12eab241535109e22d87e417588d25c Mon Sep 17 00:00:00 2001 From: mykonos-ibiza <222371740+mykonos-ibiza@users.noreply.github.com> Date: Sat, 26 Jul 2025 18:47:22 +0530 Subject: [PATCH] feat(deepai): add deepai specific endpoints --- backend/agent/api.py | 102 ++++++++++++++++++++++++++++++++++++ backend/api.py | 2 +- backend/services/billing.py | 35 +++++++++++-- backend/utils/auth_utils.py | 39 ++++++++++++++ 4 files changed, 173 insertions(+), 5 deletions(-) diff --git a/backend/agent/api.py b/backend/agent/api.py index 2c7c4601..77d7e394 100644 --- a/backend/agent/api.py +++ b/backend/agent/api.py @@ -2545,3 +2545,105 @@ async def admin_install_suna_for_user( detail=f"Failed to install Suna agent for user {account_id}" ) +@router.get("/threads") +async def get_user_threads( + user_id: str = Depends(get_current_user_id_from_jwt), + page: Optional[int] = Query(1, ge=1, description="Page number (1-based)"), + limit: Optional[int] = Query(1000, ge=1, le=1000, description="Number of items per page (max 1000)") +): + """Get all threads for the current user (account_id = user_id) with pagination.""" + logger.info(f"Fetching threads for user: {user_id} (page={page}, limit={limit})") + client = await db.client + try: + offset = (page - 1) * limit + threads_query = client.table('threads').select('*').eq('account_id', user_id).order('created_at', desc=True) + total_result = await threads_query.execute() + total_count = len(total_result.data) if total_result.data else 0 + threads_result = await threads_query.range(offset, offset + limit - 1).execute() + threads = threads_result.data or [] + total_pages = (total_count + limit - 1) // limit if total_count else 0 + return { + "threads": threads, + "pagination": { + "page": page, + "limit": limit, + "total": total_count, + "pages": total_pages + } + } + except Exception as e: + logger.error(f"Error fetching threads for user {user_id}: {str(e)}") + raise HTTPException(status_code=500, detail=f"Failed to fetch threads: {str(e)}") + +@router.get("/threads/{thread_id}/messages") +async def get_thread_messages( + thread_id: str, + user_id: str = Depends(get_current_user_id_from_jwt), + order: str = Query("desc", description="Order by created_at: 'asc' or 'desc'") +): + """Get all messages for a thread, fetching in batches of 1000 from the DB to avoid large queries.""" + logger.info(f"Fetching all messages for thread: {thread_id}, order={order}") + client = await db.client + await verify_thread_access(client, thread_id, user_id) + try: + batch_size = 1000 + offset = 0 + all_messages = [] + while True: + query = client.table('messages').select('*').eq('thread_id', thread_id) + query = query.order('created_at', desc=(order == "desc")) + query = query.range(offset, offset + batch_size - 1) + messages_result = await query.execute() + batch = messages_result.data or [] + all_messages.extend(batch) + logger.debug(f"Fetched batch of {len(batch)} messages (offset {offset})") + if len(batch) < batch_size: + break + offset += batch_size + return {"messages": all_messages} + except Exception as e: + logger.error(f"Error fetching messages for thread {thread_id}: {str(e)}") + raise HTTPException(status_code=500, detail=f"Failed to fetch messages: {str(e)}") + +@router.get("/agent-runs/{agent_run_id}") +async def get_agent_run( + agent_run_id: str, + user_id: str = Depends(get_current_user_id_from_jwt), +): + """Get an agent run by ID""" + logger.info(f"Fetching agent run: {agent_run_id}") + client = await db.client + try: + agent_run_result = await client.table('agent_runs').select('*').eq('agent_run_id', agent_run_id).eq('account_id', user_id).execute() + if not agent_run_result.data: + raise HTTPException(status_code=404, detail="Agent run not found") + return agent_run_result.data[0] + except Exception as e: + logger.error(f"Error fetching agent run {agent_run_id}: {str(e)}") + raise HTTPException(status_code=500, detail=f"Failed to fetch agent run: {str(e)}") + + +@router.post("/threads/{thread_id}/messages/add") +async def add_message_to_thread( + thread_id: str, + message: str, + user_id: str = Depends(get_current_user_id_from_jwt), +): + """Add a message to a thread""" + logger.info(f"Adding message to thread: {thread_id}") + client = await db.client + await verify_thread_access(client, thread_id, user_id) + try: + message_result = await client.table('messages').insert({ + 'thread_id': thread_id, + 'type': 'user', + 'is_llm_message': True, + 'content': { + "role": "user", + "content": message + } + }).execute() + return message_result.data[0] + except Exception as e: + logger.error(f"Error adding message to thread {thread_id}: {str(e)}") + raise HTTPException(status_code=500, detail=f"Failed to add message: {str(e)}") diff --git a/backend/api.py b/backend/api.py index 0d274c34..52acd76d 100644 --- a/backend/api.py +++ b/backend/api.py @@ -155,7 +155,7 @@ app.add_middleware( allow_origin_regex=allow_origin_regex, allow_credentials=True, allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], - allow_headers=["Content-Type", "Authorization", "X-Project-Id", "X-MCP-URL", "X-MCP-Type", "X-MCP-Headers", "X-Refresh-Token"], + allow_headers=["Content-Type", "Authorization", "X-Project-Id", "X-MCP-URL", "X-MCP-Type", "X-MCP-Headers", "X-Refresh-Token", "X-API-Key"], ) # Create a main API router diff --git a/backend/services/billing.py b/backend/services/billing.py index 117c79a2..584f9ba3 100644 --- a/backend/services/billing.py +++ b/backend/services/billing.py @@ -11,7 +11,7 @@ from datetime import datetime, timezone from utils.logger import logger from utils.config import config, EnvMode from services.supabase import DBConnection -from utils.auth_utils import get_current_user_id_from_jwt +from utils.auth_utils import get_current_user_id_from_jwt, is_deepai_user from pydantic import BaseModel from utils.constants import MODEL_ACCESS_TIERS, MODEL_NAME_ALIASES, HARDCODED_MODEL_PRICES from litellm.cost_calculator import cost_per_token @@ -419,7 +419,9 @@ async def get_allowed_models_for_user(client, user_id: str): Returns: List of model names allowed for the user's subscription tier. """ - + if is_deepai_user(user_id): + # DeepAI users get all models + return list(set(MODEL_NAME_ALIASES.values())) subscription = await get_user_subscription(user_id) tier_name = 'free' @@ -447,7 +449,9 @@ async def can_use_model(client, user_id: str, model_name: str): "plan_name": "Local Development", "minutes_limit": "no limit" } - + if is_deepai_user(user_id): + # DeepAI users can use any model + return True, "DeepAI user: all models allowed", list(set(MODEL_NAME_ALIASES.values())) allowed_models = await get_allowed_models_for_user(client, user_id) resolved_model = MODEL_NAME_ALIASES.get(model_name, model_name) if resolved_model in allowed_models: @@ -469,7 +473,13 @@ async def check_billing_status(client, user_id: str) -> Tuple[bool, str, Optiona "plan_name": "Local Development", "minutes_limit": "no limit" } - + if is_deepai_user(user_id): + # DeepAI users have infinite usage + return True, "DeepAI user: unlimited usage", { + "price_id": "deepai", + "plan_name": "DeepAI", + "minutes_limit": "no limit" + } # Get current subscription subscription = await get_user_subscription(user_id) # print("Current subscription:", subscription) @@ -1114,6 +1124,23 @@ async def get_available_models( "total_models": len(model_info) } + if is_deepai_user(current_user_id): + # DeepAI users get all models, all available + model_info = [] + for short_name, full_name in MODEL_NAME_ALIASES.items(): + model_info.append({ + "id": full_name, + "display_name": short_name, + "short_name": short_name, + "requires_subscription": False, + "is_available": True + }) + return { + "models": model_info, + "subscription_tier": "DeepAI", + "total_models": len(model_info) + } + # For non-local mode, get list of allowed models for this user allowed_models = await get_allowed_models_for_user(client, current_user_id) free_tier_models = MODEL_ACCESS_TIERS.get('free', []) diff --git a/backend/utils/auth_utils.py b/backend/utils/auth_utils.py index b9a1f21d..3759ec61 100644 --- a/backend/utils/auth_utils.py +++ b/backend/utils/auth_utils.py @@ -5,6 +5,36 @@ import jwt from jwt.exceptions import PyJWTError from utils.logger import structlog from utils.config import config +import os + +# DeepAI user UUID +DEEPAI_USER_ID = os.getenv('DEEPAI_USER_ID', '00000000-0000-0000-0000-000000000000') +DEEPAI_API_KEY = os.getenv('DEEPAI_API_KEY', '00000000-0000-0000-0000-000000000000') + +def str_safe_compare(str1: str, str2: str) -> bool: + """ + Compare two strings by first SHA256 hashing them and then comparing the hashes. + This is a safe way to compare strings that are not known to be equal. + """ + import hashlib + return hashlib.sha256(str1.encode()).hexdigest() == hashlib.sha256(str2.encode()).hexdigest() + +def is_deepai_user(user_id: str) -> bool: + """ + Check if a user ID belongs to an deepai account. + + This function is maintained for backward compatibility. The deepai user + now works like any other user with a proper UUID, so this check is mainly + used for legacy code and documentation purposes. + + Args: + user_id: The user ID to check + + Returns: + bool: True if the user is an deepai user, False otherwise + """ + # Check for the specific deepai user UUID or the old string format for backward compatibility + return str_safe_compare(user_id, DEEPAI_USER_ID) # This function extracts the user ID from Supabase JWT async def get_current_user_id_from_jwt(request: Request) -> str: @@ -23,6 +53,12 @@ async def get_current_user_id_from_jwt(request: Request) -> str: Raises: HTTPException: If no valid token is found or if the token is invalid """ + + x_api_key = request.headers.get('x-api-key') + + if str_safe_compare(x_api_key, DEEPAI_API_KEY): + return DEEPAI_USER_ID + auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith('Bearer '): @@ -126,6 +162,9 @@ async def get_user_id_from_stream_auth( try: # Try to get user_id from token in query param (for EventSource which can't set headers) if token: + if str_safe_compare(token, DEEPAI_API_KEY): + return DEEPAI_USER_ID + try: # For Supabase JWT, we just need to decode and extract the user ID payload = jwt.decode(token, options={"verify_signature": False})