mirror of https://github.com/kortix-ai/suna.git
feat(deepai): add deepai specific endpoints
This commit is contained in:
parent
f4dc33ab13
commit
37f8e63cf1
|
@ -2545,3 +2545,105 @@ async def admin_install_suna_for_user(
|
||||||
detail=f"Failed to install Suna agent for user {account_id}"
|
detail=f"Failed to install Suna agent for user {account_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@router.get("/threads")
|
||||||
|
async def get_user_threads(
|
||||||
|
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||||
|
page: Optional[int] = Query(1, ge=1, description="Page number (1-based)"),
|
||||||
|
limit: Optional[int] = Query(1000, ge=1, le=1000, description="Number of items per page (max 1000)")
|
||||||
|
):
|
||||||
|
"""Get all threads for the current user (account_id = user_id) with pagination."""
|
||||||
|
logger.info(f"Fetching threads for user: {user_id} (page={page}, limit={limit})")
|
||||||
|
client = await db.client
|
||||||
|
try:
|
||||||
|
offset = (page - 1) * limit
|
||||||
|
threads_query = client.table('threads').select('*').eq('account_id', user_id).order('created_at', desc=True)
|
||||||
|
total_result = await threads_query.execute()
|
||||||
|
total_count = len(total_result.data) if total_result.data else 0
|
||||||
|
threads_result = await threads_query.range(offset, offset + limit - 1).execute()
|
||||||
|
threads = threads_result.data or []
|
||||||
|
total_pages = (total_count + limit - 1) // limit if total_count else 0
|
||||||
|
return {
|
||||||
|
"threads": threads,
|
||||||
|
"pagination": {
|
||||||
|
"page": page,
|
||||||
|
"limit": limit,
|
||||||
|
"total": total_count,
|
||||||
|
"pages": total_pages
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching threads for user {user_id}: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Failed to fetch threads: {str(e)}")
|
||||||
|
|
||||||
|
@router.get("/threads/{thread_id}/messages")
|
||||||
|
async def get_thread_messages(
|
||||||
|
thread_id: str,
|
||||||
|
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||||
|
order: str = Query("desc", description="Order by created_at: 'asc' or 'desc'")
|
||||||
|
):
|
||||||
|
"""Get all messages for a thread, fetching in batches of 1000 from the DB to avoid large queries."""
|
||||||
|
logger.info(f"Fetching all messages for thread: {thread_id}, order={order}")
|
||||||
|
client = await db.client
|
||||||
|
await verify_thread_access(client, thread_id, user_id)
|
||||||
|
try:
|
||||||
|
batch_size = 1000
|
||||||
|
offset = 0
|
||||||
|
all_messages = []
|
||||||
|
while True:
|
||||||
|
query = client.table('messages').select('*').eq('thread_id', thread_id)
|
||||||
|
query = query.order('created_at', desc=(order == "desc"))
|
||||||
|
query = query.range(offset, offset + batch_size - 1)
|
||||||
|
messages_result = await query.execute()
|
||||||
|
batch = messages_result.data or []
|
||||||
|
all_messages.extend(batch)
|
||||||
|
logger.debug(f"Fetched batch of {len(batch)} messages (offset {offset})")
|
||||||
|
if len(batch) < batch_size:
|
||||||
|
break
|
||||||
|
offset += batch_size
|
||||||
|
return {"messages": all_messages}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching messages for thread {thread_id}: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Failed to fetch messages: {str(e)}")
|
||||||
|
|
||||||
|
@router.get("/agent-runs/{agent_run_id}")
|
||||||
|
async def get_agent_run(
|
||||||
|
agent_run_id: str,
|
||||||
|
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||||
|
):
|
||||||
|
"""Get an agent run by ID"""
|
||||||
|
logger.info(f"Fetching agent run: {agent_run_id}")
|
||||||
|
client = await db.client
|
||||||
|
try:
|
||||||
|
agent_run_result = await client.table('agent_runs').select('*').eq('agent_run_id', agent_run_id).eq('account_id', user_id).execute()
|
||||||
|
if not agent_run_result.data:
|
||||||
|
raise HTTPException(status_code=404, detail="Agent run not found")
|
||||||
|
return agent_run_result.data[0]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching agent run {agent_run_id}: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Failed to fetch agent run: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/threads/{thread_id}/messages/add")
|
||||||
|
async def add_message_to_thread(
|
||||||
|
thread_id: str,
|
||||||
|
message: str,
|
||||||
|
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||||
|
):
|
||||||
|
"""Add a message to a thread"""
|
||||||
|
logger.info(f"Adding message to thread: {thread_id}")
|
||||||
|
client = await db.client
|
||||||
|
await verify_thread_access(client, thread_id, user_id)
|
||||||
|
try:
|
||||||
|
message_result = await client.table('messages').insert({
|
||||||
|
'thread_id': thread_id,
|
||||||
|
'type': 'user',
|
||||||
|
'is_llm_message': True,
|
||||||
|
'content': {
|
||||||
|
"role": "user",
|
||||||
|
"content": message
|
||||||
|
}
|
||||||
|
}).execute()
|
||||||
|
return message_result.data[0]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error adding message to thread {thread_id}: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Failed to add message: {str(e)}")
|
||||||
|
|
|
@ -155,7 +155,7 @@ app.add_middleware(
|
||||||
allow_origin_regex=allow_origin_regex,
|
allow_origin_regex=allow_origin_regex,
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||||
allow_headers=["Content-Type", "Authorization", "X-Project-Id", "X-MCP-URL", "X-MCP-Type", "X-MCP-Headers", "X-Refresh-Token"],
|
allow_headers=["Content-Type", "Authorization", "X-Project-Id", "X-MCP-URL", "X-MCP-Type", "X-MCP-Headers", "X-Refresh-Token", "X-API-Key"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a main API router
|
# Create a main API router
|
||||||
|
|
|
@ -11,7 +11,7 @@ from datetime import datetime, timezone
|
||||||
from utils.logger import logger
|
from utils.logger import logger
|
||||||
from utils.config import config, EnvMode
|
from utils.config import config, EnvMode
|
||||||
from services.supabase import DBConnection
|
from services.supabase import DBConnection
|
||||||
from utils.auth_utils import get_current_user_id_from_jwt
|
from utils.auth_utils import get_current_user_id_from_jwt, is_deepai_user
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from utils.constants import MODEL_ACCESS_TIERS, MODEL_NAME_ALIASES, HARDCODED_MODEL_PRICES
|
from utils.constants import MODEL_ACCESS_TIERS, MODEL_NAME_ALIASES, HARDCODED_MODEL_PRICES
|
||||||
from litellm.cost_calculator import cost_per_token
|
from litellm.cost_calculator import cost_per_token
|
||||||
|
@ -419,7 +419,9 @@ async def get_allowed_models_for_user(client, user_id: str):
|
||||||
Returns:
|
Returns:
|
||||||
List of model names allowed for the user's subscription tier.
|
List of model names allowed for the user's subscription tier.
|
||||||
"""
|
"""
|
||||||
|
if is_deepai_user(user_id):
|
||||||
|
# DeepAI users get all models
|
||||||
|
return list(set(MODEL_NAME_ALIASES.values()))
|
||||||
subscription = await get_user_subscription(user_id)
|
subscription = await get_user_subscription(user_id)
|
||||||
tier_name = 'free'
|
tier_name = 'free'
|
||||||
|
|
||||||
|
@ -447,7 +449,9 @@ async def can_use_model(client, user_id: str, model_name: str):
|
||||||
"plan_name": "Local Development",
|
"plan_name": "Local Development",
|
||||||
"minutes_limit": "no limit"
|
"minutes_limit": "no limit"
|
||||||
}
|
}
|
||||||
|
if is_deepai_user(user_id):
|
||||||
|
# DeepAI users can use any model
|
||||||
|
return True, "DeepAI user: all models allowed", list(set(MODEL_NAME_ALIASES.values()))
|
||||||
allowed_models = await get_allowed_models_for_user(client, user_id)
|
allowed_models = await get_allowed_models_for_user(client, user_id)
|
||||||
resolved_model = MODEL_NAME_ALIASES.get(model_name, model_name)
|
resolved_model = MODEL_NAME_ALIASES.get(model_name, model_name)
|
||||||
if resolved_model in allowed_models:
|
if resolved_model in allowed_models:
|
||||||
|
@ -469,7 +473,13 @@ async def check_billing_status(client, user_id: str) -> Tuple[bool, str, Optiona
|
||||||
"plan_name": "Local Development",
|
"plan_name": "Local Development",
|
||||||
"minutes_limit": "no limit"
|
"minutes_limit": "no limit"
|
||||||
}
|
}
|
||||||
|
if is_deepai_user(user_id):
|
||||||
|
# DeepAI users have infinite usage
|
||||||
|
return True, "DeepAI user: unlimited usage", {
|
||||||
|
"price_id": "deepai",
|
||||||
|
"plan_name": "DeepAI",
|
||||||
|
"minutes_limit": "no limit"
|
||||||
|
}
|
||||||
# Get current subscription
|
# Get current subscription
|
||||||
subscription = await get_user_subscription(user_id)
|
subscription = await get_user_subscription(user_id)
|
||||||
# print("Current subscription:", subscription)
|
# print("Current subscription:", subscription)
|
||||||
|
@ -1114,6 +1124,23 @@ async def get_available_models(
|
||||||
"total_models": len(model_info)
|
"total_models": len(model_info)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if is_deepai_user(current_user_id):
|
||||||
|
# DeepAI users get all models, all available
|
||||||
|
model_info = []
|
||||||
|
for short_name, full_name in MODEL_NAME_ALIASES.items():
|
||||||
|
model_info.append({
|
||||||
|
"id": full_name,
|
||||||
|
"display_name": short_name,
|
||||||
|
"short_name": short_name,
|
||||||
|
"requires_subscription": False,
|
||||||
|
"is_available": True
|
||||||
|
})
|
||||||
|
return {
|
||||||
|
"models": model_info,
|
||||||
|
"subscription_tier": "DeepAI",
|
||||||
|
"total_models": len(model_info)
|
||||||
|
}
|
||||||
|
|
||||||
# For non-local mode, get list of allowed models for this user
|
# For non-local mode, get list of allowed models for this user
|
||||||
allowed_models = await get_allowed_models_for_user(client, current_user_id)
|
allowed_models = await get_allowed_models_for_user(client, current_user_id)
|
||||||
free_tier_models = MODEL_ACCESS_TIERS.get('free', [])
|
free_tier_models = MODEL_ACCESS_TIERS.get('free', [])
|
||||||
|
|
|
@ -5,6 +5,36 @@ import jwt
|
||||||
from jwt.exceptions import PyJWTError
|
from jwt.exceptions import PyJWTError
|
||||||
from utils.logger import structlog
|
from utils.logger import structlog
|
||||||
from utils.config import config
|
from utils.config import config
|
||||||
|
import os
|
||||||
|
|
||||||
|
# DeepAI user UUID
|
||||||
|
DEEPAI_USER_ID = os.getenv('DEEPAI_USER_ID', '00000000-0000-0000-0000-000000000000')
|
||||||
|
DEEPAI_API_KEY = os.getenv('DEEPAI_API_KEY', '00000000-0000-0000-0000-000000000000')
|
||||||
|
|
||||||
|
def str_safe_compare(str1: str, str2: str) -> bool:
|
||||||
|
"""
|
||||||
|
Compare two strings by first SHA256 hashing them and then comparing the hashes.
|
||||||
|
This is a safe way to compare strings that are not known to be equal.
|
||||||
|
"""
|
||||||
|
import hashlib
|
||||||
|
return hashlib.sha256(str1.encode()).hexdigest() == hashlib.sha256(str2.encode()).hexdigest()
|
||||||
|
|
||||||
|
def is_deepai_user(user_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a user ID belongs to an deepai account.
|
||||||
|
|
||||||
|
This function is maintained for backward compatibility. The deepai user
|
||||||
|
now works like any other user with a proper UUID, so this check is mainly
|
||||||
|
used for legacy code and documentation purposes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the user is an deepai user, False otherwise
|
||||||
|
"""
|
||||||
|
# Check for the specific deepai user UUID or the old string format for backward compatibility
|
||||||
|
return str_safe_compare(user_id, DEEPAI_USER_ID)
|
||||||
|
|
||||||
# This function extracts the user ID from Supabase JWT
|
# This function extracts the user ID from Supabase JWT
|
||||||
async def get_current_user_id_from_jwt(request: Request) -> str:
|
async def get_current_user_id_from_jwt(request: Request) -> str:
|
||||||
|
@ -23,6 +53,12 @@ async def get_current_user_id_from_jwt(request: Request) -> str:
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException: If no valid token is found or if the token is invalid
|
HTTPException: If no valid token is found or if the token is invalid
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
x_api_key = request.headers.get('x-api-key')
|
||||||
|
|
||||||
|
if str_safe_compare(x_api_key, DEEPAI_API_KEY):
|
||||||
|
return DEEPAI_USER_ID
|
||||||
|
|
||||||
auth_header = request.headers.get('Authorization')
|
auth_header = request.headers.get('Authorization')
|
||||||
|
|
||||||
if not auth_header or not auth_header.startswith('Bearer '):
|
if not auth_header or not auth_header.startswith('Bearer '):
|
||||||
|
@ -126,6 +162,9 @@ async def get_user_id_from_stream_auth(
|
||||||
try:
|
try:
|
||||||
# Try to get user_id from token in query param (for EventSource which can't set headers)
|
# Try to get user_id from token in query param (for EventSource which can't set headers)
|
||||||
if token:
|
if token:
|
||||||
|
if str_safe_compare(token, DEEPAI_API_KEY):
|
||||||
|
return DEEPAI_USER_ID
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# For Supabase JWT, we just need to decode and extract the user ID
|
# For Supabase JWT, we just need to decode and extract the user ID
|
||||||
payload = jwt.decode(token, options={"verify_signature": False})
|
payload = jwt.decode(token, options={"verify_signature": False})
|
||||||
|
|
Loading…
Reference in New Issue