suna/backend/agent/api.py

1482 lines
68 KiB
Python

from fastapi import APIRouter, HTTPException, Depends, Request, Body, File, UploadFile, Form
from fastapi.responses import StreamingResponse
import asyncio
import json
import traceback
from datetime import datetime, timezone
import uuid
from typing import Optional, List, Dict, Any
import jwt
from pydantic import BaseModel
import tempfile
import os
from agentpress.thread_manager import ThreadManager
from services.supabase import DBConnection
from services import redis
from utils.auth_utils import get_current_user_id_from_jwt, get_user_id_from_stream_auth, verify_thread_access
from utils.logger import logger
from services.billing import check_billing_status, can_use_model
from utils.config import config
from sandbox.sandbox import create_sandbox, get_or_start_sandbox
from services.llm import make_llm_api_call
from run_agent_background import run_agent_background, _cleanup_redis_response_list, update_agent_run_status
from utils.constants import MODEL_NAME_ALIASES
# Initialize shared resources
router = APIRouter()
db = None
instance_id = None # Global instance ID for this backend instance
# TTL for Redis response lists (24 hours)
REDIS_RESPONSE_LIST_TTL = 3600 * 24
class AgentStartRequest(BaseModel):
model_name: Optional[str] = None # Will be set from config.MODEL_TO_USE in the endpoint
enable_thinking: Optional[bool] = False
reasoning_effort: Optional[str] = 'low'
stream: Optional[bool] = True
enable_context_manager: Optional[bool] = False
agent_id: Optional[str] = None # Custom agent to use
class InitiateAgentResponse(BaseModel):
thread_id: str
agent_run_id: Optional[str] = None
class AgentCreateRequest(BaseModel):
name: str
description: Optional[str] = None
system_prompt: str
configured_mcps: Optional[List[Dict[str, Any]]] = []
agentpress_tools: Optional[Dict[str, Any]] = {}
is_default: Optional[bool] = False
avatar: Optional[str] = None
avatar_color: Optional[str] = None
class AgentUpdateRequest(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
system_prompt: Optional[str] = None
configured_mcps: Optional[List[Dict[str, Any]]] = None
agentpress_tools: Optional[Dict[str, Any]] = None
is_default: Optional[bool] = None
avatar: Optional[str] = None
avatar_color: Optional[str] = None
class AgentResponse(BaseModel):
agent_id: str
account_id: str
name: str
description: Optional[str]
system_prompt: str
configured_mcps: List[Dict[str, Any]]
agentpress_tools: Dict[str, Any]
is_default: bool
is_public: Optional[bool] = False
marketplace_published_at: Optional[str] = None
download_count: Optional[int] = 0
tags: Optional[List[str]] = []
avatar: Optional[str]
avatar_color: Optional[str]
created_at: str
updated_at: str
class ThreadAgentResponse(BaseModel):
agent: Optional[AgentResponse]
source: str # "thread", "default", "none", "missing"
message: str
class AgentBuilderChatRequest(BaseModel):
message: str
conversation_history: List[Dict[str, str]] = []
partial_config: Optional[Dict[str, Any]] = None
class AgentBuilderChatResponse(BaseModel):
response: str
suggested_config: Optional[Dict[str, Any]] = None
next_step: Optional[str] = None
def initialize(
_db: DBConnection,
_instance_id: str = None
):
"""Initialize the agent API with resources from the main API."""
global db, instance_id
db = _db
# Use provided instance_id or generate a new one
if _instance_id:
instance_id = _instance_id
else:
# Generate instance ID
instance_id = str(uuid.uuid4())[:8]
logger.info(f"Initialized agent API with instance ID: {instance_id}")
# Note: Redis will be initialized in the lifespan function in api.py
async def cleanup():
"""Clean up resources and stop running agents on shutdown."""
logger.info("Starting cleanup of agent API resources")
# Use the instance_id to find and clean up this instance's keys
try:
if instance_id: # Ensure instance_id is set
running_keys = await redis.keys(f"active_run:{instance_id}:*")
logger.info(f"Found {len(running_keys)} running agent runs for instance {instance_id} to clean up")
for key in running_keys:
# Key format: active_run:{instance_id}:{agent_run_id}
parts = key.split(":")
if len(parts) == 3:
agent_run_id = parts[2]
await stop_agent_run(agent_run_id, error_message=f"Instance {instance_id} shutting down")
else:
logger.warning(f"Unexpected key format found: {key}")
else:
logger.warning("Instance ID not set, cannot clean up instance-specific agent runs.")
except Exception as e:
logger.error(f"Failed to clean up running agent runs: {str(e)}")
# Close Redis connection
await redis.close()
logger.info("Completed cleanup of agent API resources")
async def stop_agent_run(agent_run_id: str, error_message: Optional[str] = None):
"""Update database and publish stop signal to Redis."""
logger.info(f"Stopping agent run: {agent_run_id}")
client = await db.client
final_status = "failed" if error_message else "stopped"
# Attempt to fetch final responses from Redis
response_list_key = f"agent_run:{agent_run_id}:responses"
all_responses = []
try:
all_responses_json = await redis.lrange(response_list_key, 0, -1)
all_responses = [json.loads(r) for r in all_responses_json]
logger.info(f"Fetched {len(all_responses)} responses from Redis for DB update on stop/fail: {agent_run_id}")
except Exception as e:
logger.error(f"Failed to fetch responses from Redis for {agent_run_id} during stop/fail: {e}")
# Try fetching from DB as a fallback? Or proceed without responses? Proceeding without for now.
# Update the agent run status in the database
update_success = await update_agent_run_status(
client, agent_run_id, final_status, error=error_message, responses=all_responses
)
if not update_success:
logger.error(f"Failed to update database status for stopped/failed run {agent_run_id}")
# Send STOP signal to the global control channel
global_control_channel = f"agent_run:{agent_run_id}:control"
try:
await redis.publish(global_control_channel, "STOP")
logger.debug(f"Published STOP signal to global channel {global_control_channel}")
except Exception as e:
logger.error(f"Failed to publish STOP signal to global channel {global_control_channel}: {str(e)}")
# Find all instances handling this agent run and send STOP to instance-specific channels
try:
instance_keys = await redis.keys(f"active_run:*:{agent_run_id}")
logger.debug(f"Found {len(instance_keys)} active instance keys for agent run {agent_run_id}")
for key in instance_keys:
# Key format: active_run:{instance_id}:{agent_run_id}
parts = key.split(":")
if len(parts) == 3:
instance_id_from_key = parts[1]
instance_control_channel = f"agent_run:{agent_run_id}:control:{instance_id_from_key}"
try:
await redis.publish(instance_control_channel, "STOP")
logger.debug(f"Published STOP signal to instance channel {instance_control_channel}")
except Exception as e:
logger.warning(f"Failed to publish STOP signal to instance channel {instance_control_channel}: {str(e)}")
else:
logger.warning(f"Unexpected key format found: {key}")
# Clean up the response list immediately on stop/fail
await _cleanup_redis_response_list(agent_run_id)
except Exception as e:
logger.error(f"Failed to find or signal active instances for {agent_run_id}: {str(e)}")
logger.info(f"Successfully initiated stop process for agent run: {agent_run_id}")
# async def restore_running_agent_runs():
# """Mark agent runs that were still 'running' in the database as failed and clean up Redis resources."""
# logger.info("Restoring running agent runs after server restart")
# client = await db.client
# running_agent_runs = await client.table('agent_runs').select('id').eq("status", "running").execute()
# for run in running_agent_runs.data:
# agent_run_id = run['id']
# logger.warning(f"Found running agent run {agent_run_id} from before server restart")
# # Clean up Redis resources for this run
# try:
# # Clean up active run key
# active_run_key = f"active_run:{instance_id}:{agent_run_id}"
# await redis.delete(active_run_key)
# # Clean up response list
# response_list_key = f"agent_run:{agent_run_id}:responses"
# await redis.delete(response_list_key)
# # Clean up control channels
# control_channel = f"agent_run:{agent_run_id}:control"
# instance_control_channel = f"agent_run:{agent_run_id}:control:{instance_id}"
# await redis.delete(control_channel)
# await redis.delete(instance_control_channel)
# logger.info(f"Cleaned up Redis resources for agent run {agent_run_id}")
# except Exception as e:
# logger.error(f"Error cleaning up Redis resources for agent run {agent_run_id}: {e}")
# # Call stop_agent_run to handle status update and cleanup
# await stop_agent_run(agent_run_id, error_message="Server restarted while agent was running")
async def check_for_active_project_agent_run(client, project_id: str):
"""
Check if there is an active agent run for any thread in the given project.
If found, returns the ID of the active run, otherwise returns None.
"""
project_threads = await client.table('threads').select('thread_id').eq('project_id', project_id).execute()
project_thread_ids = [t['thread_id'] for t in project_threads.data]
if project_thread_ids:
active_runs = await client.table('agent_runs').select('id').in_('thread_id', project_thread_ids).eq('status', 'running').execute()
if active_runs.data and len(active_runs.data) > 0:
return active_runs.data[0]['id']
return None
async def get_agent_run_with_access_check(client, agent_run_id: str, user_id: str):
"""Get agent run data after verifying user access."""
agent_run = await client.table('agent_runs').select('*').eq('id', agent_run_id).execute()
if not agent_run.data:
raise HTTPException(status_code=404, detail="Agent run not found")
agent_run_data = agent_run.data[0]
thread_id = agent_run_data['thread_id']
await verify_thread_access(client, thread_id, user_id)
return agent_run_data
async def enhance_system_prompt(agent_name: str, description: str, user_system_prompt: str) -> str:
"""
Enhance a basic system prompt using GPT-4o to create a more comprehensive and effective system prompt.
Args:
agent_name: Name of the agent
description: Description of the agent
user_system_prompt: User's basic system prompt/instructions
Returns:
Enhanced system prompt generated by GPT-4o
"""
try:
system_message = """You are an expert at creating comprehensive system prompts for AI agents. Your task is to take basic agent information and transform it into a detailed, effective system prompt that will help the agent perform optimally.
Guidelines for creating system prompts:
1. Be specific about the agent's role, expertise, and capabilities
2. Include clear behavioral guidelines and interaction style
3. Specify the agent's knowledge domains and areas of expertise
4. Include guidance on how to handle different types of requests
5. Set appropriate boundaries and limitations
6. Make the prompt engaging and easy to understand
7. Ensure the prompt is comprehensive but not overly verbose
8. Include relevant context about tools and capabilities the agent might have
The enhanced prompt should be professional, clear, and actionable."""
user_message = f"""Please create an enhanced system prompt for an AI agent with the following details:
Agent Name: {agent_name}
Agent Description: {description}
User's Instructions: {user_system_prompt}
Transform this basic information into a comprehensive, effective system prompt that will help the agent perform at its best. The prompt should be detailed enough to guide the agent's behavior while remaining clear and actionable."""
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message}
]
logger.info(f"Enhancing system prompt for agent: {agent_name}")
response = await make_llm_api_call(
messages=messages,
model_name="openai/gpt-4o",
max_tokens=2000,
temperature=0.7
)
if response and response.get('choices') and response['choices'][0].get('message'):
enhanced_prompt = response['choices'][0]['message'].get('content', '').strip()
if enhanced_prompt:
logger.info(f"Successfully enhanced system prompt for agent: {agent_name}")
return enhanced_prompt
else:
logger.warning(f"GPT-4o returned empty enhanced prompt for agent: {agent_name}")
return user_system_prompt
else:
logger.warning(f"Failed to get valid response from GPT-4o for agent: {agent_name}")
return user_system_prompt
except Exception as e:
logger.error(f"Error enhancing system prompt for agent {agent_name}: {str(e)}")
# Return the original prompt if enhancement fails
return user_system_prompt
@router.post("/thread/{thread_id}/agent/start")
async def start_agent(
thread_id: str,
body: AgentStartRequest = Body(...),
user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Start an agent for a specific thread in the background."""
global instance_id # Ensure instance_id is accessible
if not instance_id:
raise HTTPException(status_code=500, detail="Agent API not initialized with instance ID")
# Use model from config if not specified in the request
model_name = body.model_name
logger.info(f"Original model_name from request: {model_name}")
if model_name is None:
model_name = config.MODEL_TO_USE
logger.info(f"Using model from config: {model_name}")
# Log the model name after alias resolution
resolved_model = MODEL_NAME_ALIASES.get(model_name, model_name)
logger.info(f"Resolved model name: {resolved_model}")
# Update model_name to use the resolved version
model_name = resolved_model
logger.info(f"Starting new agent for thread: {thread_id} with config: model={model_name}, thinking={body.enable_thinking}, effort={body.reasoning_effort}, stream={body.stream}, context_manager={body.enable_context_manager} (Instance: {instance_id})")
client = await db.client
await verify_thread_access(client, thread_id, user_id)
thread_result = await client.table('threads').select('project_id', 'account_id', 'agent_id').eq('thread_id', thread_id).execute()
if not thread_result.data:
raise HTTPException(status_code=404, detail="Thread not found")
thread_data = thread_result.data[0]
project_id = thread_data.get('project_id')
account_id = thread_data.get('account_id')
thread_agent_id = thread_data.get('agent_id')
# Load agent configuration
agent_config = None
effective_agent_id = body.agent_id or thread_agent_id # Use provided agent_id or the one stored in thread
if effective_agent_id:
agent_result = await client.table('agents').select('*').eq('agent_id', effective_agent_id).eq('account_id', account_id).execute()
if not agent_result.data:
if body.agent_id:
raise HTTPException(status_code=404, detail="Agent not found or access denied")
else:
logger.warning(f"Stored agent_id {effective_agent_id} not found, falling back to default")
effective_agent_id = None
else:
agent_config = agent_result.data[0]
source = "request" if body.agent_id else "thread"
logger.info(f"Using agent from {source}: {agent_config['name']} ({effective_agent_id})")
# If no agent found yet, try to get default agent for the account
if not agent_config:
default_agent_result = await client.table('agents').select('*').eq('account_id', account_id).eq('is_default', True).execute()
if default_agent_result.data:
agent_config = default_agent_result.data[0]
logger.info(f"Using default agent: {agent_config['name']} ({agent_config['agent_id']})")
# Update thread's agent_id if a different agent was explicitly requested
if body.agent_id and body.agent_id != thread_agent_id and agent_config:
try:
await client.table('threads').update({"agent_id": agent_config['agent_id']}).eq('thread_id', thread_id).execute()
logger.info(f"Updated thread {thread_id} to use agent {agent_config['agent_id']}")
except Exception as e:
logger.warning(f"Failed to update thread agent_id: {e}")
can_use, model_message, allowed_models = await can_use_model(client, account_id, model_name)
if not can_use:
raise HTTPException(status_code=403, detail={"message": model_message, "allowed_models": allowed_models})
can_run, message, subscription = await check_billing_status(client, account_id)
if not can_run:
raise HTTPException(status_code=402, detail={"message": message, "subscription": subscription})
active_run_id = await check_for_active_project_agent_run(client, project_id)
if active_run_id:
logger.info(f"Stopping existing agent run {active_run_id} for project {project_id}")
await stop_agent_run(active_run_id)
try:
# Get project data to find sandbox ID
project_result = await client.table('projects').select('*').eq('project_id', project_id).execute()
if not project_result.data:
raise HTTPException(status_code=404, detail="Project not found")
project_data = project_result.data[0]
sandbox_info = project_data.get('sandbox', {})
if not sandbox_info.get('id'):
raise HTTPException(status_code=404, detail="No sandbox found for this project")
sandbox_id = sandbox_info['id']
sandbox = await get_or_start_sandbox(sandbox_id)
logger.info(f"Successfully started sandbox {sandbox_id} for project {project_id}")
except Exception as e:
logger.error(f"Failed to start sandbox for project {project_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to initialize sandbox: {str(e)}")
agent_run = await client.table('agent_runs').insert({
"thread_id": thread_id, "status": "running",
"started_at": datetime.now(timezone.utc).isoformat()
}).execute()
agent_run_id = agent_run.data[0]['id']
logger.info(f"Created new agent run: {agent_run_id}")
# Register this run in Redis with TTL using instance ID
instance_key = f"active_run:{instance_id}:{agent_run_id}"
try:
await redis.set(instance_key, "running", ex=redis.REDIS_KEY_TTL)
except Exception as e:
logger.warning(f"Failed to register agent run in Redis ({instance_key}): {str(e)}")
# Run the agent in the background
run_agent_background.send(
agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id,
project_id=project_id,
model_name=model_name, # Already resolved above
enable_thinking=body.enable_thinking, reasoning_effort=body.reasoning_effort,
stream=body.stream, enable_context_manager=body.enable_context_manager,
agent_config=agent_config # Pass agent configuration
)
return {"agent_run_id": agent_run_id, "status": "running"}
@router.post("/agent-run/{agent_run_id}/stop")
async def stop_agent(agent_run_id: str, user_id: str = Depends(get_current_user_id_from_jwt)):
"""Stop a running agent."""
logger.info(f"Received request to stop agent run: {agent_run_id}")
client = await db.client
await get_agent_run_with_access_check(client, agent_run_id, user_id)
await stop_agent_run(agent_run_id)
return {"status": "stopped"}
@router.get("/thread/{thread_id}/agent-runs")
async def get_agent_runs(thread_id: str, user_id: str = Depends(get_current_user_id_from_jwt)):
"""Get all agent runs for a thread."""
logger.info(f"Fetching agent runs for thread: {thread_id}")
client = await db.client
await verify_thread_access(client, thread_id, user_id)
agent_runs = await client.table('agent_runs').select('*').eq("thread_id", thread_id).order('created_at', desc=True).execute()
logger.debug(f"Found {len(agent_runs.data)} agent runs for thread: {thread_id}")
return {"agent_runs": agent_runs.data}
@router.get("/agent-run/{agent_run_id}")
async def get_agent_run(agent_run_id: str, user_id: str = Depends(get_current_user_id_from_jwt)):
"""Get agent run status and responses."""
logger.info(f"Fetching agent run details: {agent_run_id}")
client = await db.client
agent_run_data = await get_agent_run_with_access_check(client, agent_run_id, user_id)
# Note: Responses are not included here by default, they are in the stream or DB
return {
"id": agent_run_data['id'],
"threadId": agent_run_data['thread_id'],
"status": agent_run_data['status'],
"startedAt": agent_run_data['started_at'],
"completedAt": agent_run_data['completed_at'],
"error": agent_run_data['error']
}
@router.get("/thread/{thread_id}/agent", response_model=ThreadAgentResponse)
async def get_thread_agent(thread_id: str, user_id: str = Depends(get_current_user_id_from_jwt)):
"""Get the agent details for a specific thread."""
logger.info(f"Fetching agent details for thread: {thread_id}")
client = await db.client
try:
# Verify thread access and get thread data including agent_id
await verify_thread_access(client, thread_id, user_id)
thread_result = await client.table('threads').select('agent_id', 'account_id').eq('thread_id', thread_id).execute()
if not thread_result.data:
raise HTTPException(status_code=404, detail="Thread not found")
thread_data = thread_result.data[0]
thread_agent_id = thread_data.get('agent_id')
account_id = thread_data.get('account_id')
# If no agent_id is set in the thread, try to get the default agent
effective_agent_id = thread_agent_id
agent_source = "thread"
if not effective_agent_id:
# No agent set in thread, get default agent for the account
default_agent_result = await client.table('agents').select('agent_id').eq('account_id', account_id).eq('is_default', True).execute()
if default_agent_result.data:
effective_agent_id = default_agent_result.data[0]['agent_id']
agent_source = "default"
else:
# No default agent found
return {
"agent": None,
"source": "none",
"message": "No agent configured for this thread"
}
# Fetch the agent details
agent_result = await client.table('agents').select('*').eq('agent_id', effective_agent_id).eq('account_id', account_id).execute()
if not agent_result.data:
# Agent was deleted or doesn't exist
return {
"agent": None,
"source": "missing",
"message": f"Agent {effective_agent_id} not found or was deleted"
}
agent_data = agent_result.data[0]
return {
"agent": AgentResponse(
agent_id=agent_data['agent_id'],
account_id=agent_data['account_id'],
name=agent_data['name'],
description=agent_data.get('description'),
system_prompt=agent_data['system_prompt'],
configured_mcps=agent_data.get('configured_mcps', []),
agentpress_tools=agent_data.get('agentpress_tools', {}),
is_default=agent_data.get('is_default', False),
is_public=agent_data.get('is_public', False),
marketplace_published_at=agent_data.get('marketplace_published_at'),
download_count=agent_data.get('download_count', 0),
tags=agent_data.get('tags', []),
avatar=agent_data.get('avatar'),
avatar_color=agent_data.get('avatar_color'),
created_at=agent_data['created_at'],
updated_at=agent_data['updated_at']
),
"source": agent_source,
"message": f"Using {agent_source} agent: {agent_data['name']}"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error fetching agent for thread {thread_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to fetch thread agent: {str(e)}")
@router.get("/agent-run/{agent_run_id}/stream")
async def stream_agent_run(
agent_run_id: str,
token: Optional[str] = None,
request: Request = None
):
"""Stream the responses of an agent run using Redis Lists and Pub/Sub."""
logger.info(f"Starting stream for agent run: {agent_run_id}")
client = await db.client
user_id = await get_user_id_from_stream_auth(request, token)
agent_run_data = await get_agent_run_with_access_check(client, agent_run_id, user_id)
response_list_key = f"agent_run:{agent_run_id}:responses"
response_channel = f"agent_run:{agent_run_id}:new_response"
control_channel = f"agent_run:{agent_run_id}:control" # Global control channel
async def stream_generator():
logger.debug(f"Streaming responses for {agent_run_id} using Redis list {response_list_key} and channel {response_channel}")
last_processed_index = -1
pubsub_response = None
pubsub_control = None
listener_task = None
terminate_stream = False
initial_yield_complete = False
try:
# 1. Fetch and yield initial responses from Redis list
initial_responses_json = await redis.lrange(response_list_key, 0, -1)
initial_responses = []
if initial_responses_json:
initial_responses = [json.loads(r) for r in initial_responses_json]
logger.debug(f"Sending {len(initial_responses)} initial responses for {agent_run_id}")
for response in initial_responses:
yield f"data: {json.dumps(response)}\n\n"
last_processed_index = len(initial_responses) - 1
initial_yield_complete = True
# 2. Check run status *after* yielding initial data
run_status = await client.table('agent_runs').select('status').eq("id", agent_run_id).maybe_single().execute()
current_status = run_status.data.get('status') if run_status.data else None
if current_status != 'running':
logger.info(f"Agent run {agent_run_id} is not running (status: {current_status}). Ending stream.")
yield f"data: {json.dumps({'type': 'status', 'status': 'completed'})}\n\n"
return
# 3. Set up Pub/Sub listeners for new responses and control signals
pubsub_response = await redis.create_pubsub()
await pubsub_response.subscribe(response_channel)
logger.debug(f"Subscribed to response channel: {response_channel}")
pubsub_control = await redis.create_pubsub()
await pubsub_control.subscribe(control_channel)
logger.debug(f"Subscribed to control channel: {control_channel}")
# Queue to communicate between listeners and the main generator loop
message_queue = asyncio.Queue()
async def listen_messages():
response_reader = pubsub_response.listen()
control_reader = pubsub_control.listen()
tasks = [asyncio.create_task(response_reader.__anext__()), asyncio.create_task(control_reader.__anext__())]
while not terminate_stream:
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
for task in done:
try:
message = task.result()
if message and isinstance(message, dict) and message.get("type") == "message":
channel = message.get("channel")
data = message.get("data")
if isinstance(data, bytes): data = data.decode('utf-8')
if channel == response_channel and data == "new":
await message_queue.put({"type": "new_response"})
elif channel == control_channel and data in ["STOP", "END_STREAM", "ERROR"]:
logger.info(f"Received control signal '{data}' for {agent_run_id}")
await message_queue.put({"type": "control", "data": data})
return # Stop listening on control signal
except StopAsyncIteration:
logger.warning(f"Listener {task} stopped.")
# Decide how to handle listener stopping, maybe terminate?
await message_queue.put({"type": "error", "data": "Listener stopped unexpectedly"})
return
except Exception as e:
logger.error(f"Error in listener for {agent_run_id}: {e}")
await message_queue.put({"type": "error", "data": "Listener failed"})
return
finally:
# Reschedule the completed listener task
if task in tasks:
tasks.remove(task)
if message and isinstance(message, dict) and message.get("channel") == response_channel:
tasks.append(asyncio.create_task(response_reader.__anext__()))
elif message and isinstance(message, dict) and message.get("channel") == control_channel:
tasks.append(asyncio.create_task(control_reader.__anext__()))
# Cancel pending listener tasks on exit
for p_task in pending: p_task.cancel()
for task in tasks: task.cancel()
listener_task = asyncio.create_task(listen_messages())
# 4. Main loop to process messages from the queue
while not terminate_stream:
try:
queue_item = await message_queue.get()
if queue_item["type"] == "new_response":
# Fetch new responses from Redis list starting after the last processed index
new_start_index = last_processed_index + 1
new_responses_json = await redis.lrange(response_list_key, new_start_index, -1)
if new_responses_json:
new_responses = [json.loads(r) for r in new_responses_json]
num_new = len(new_responses)
# logger.debug(f"Received {num_new} new responses for {agent_run_id} (index {new_start_index} onwards)")
for response in new_responses:
yield f"data: {json.dumps(response)}\n\n"
# Check if this response signals completion
if response.get('type') == 'status' and response.get('status') in ['completed', 'failed', 'stopped']:
logger.info(f"Detected run completion via status message in stream: {response.get('status')}")
terminate_stream = True
break # Stop processing further new responses
last_processed_index += num_new
if terminate_stream: break
elif queue_item["type"] == "control":
control_signal = queue_item["data"]
terminate_stream = True # Stop the stream on any control signal
yield f"data: {json.dumps({'type': 'status', 'status': control_signal})}\n\n"
break
elif queue_item["type"] == "error":
logger.error(f"Listener error for {agent_run_id}: {queue_item['data']}")
terminate_stream = True
yield f"data: {json.dumps({'type': 'status', 'status': 'error'})}\n\n"
break
except asyncio.CancelledError:
logger.info(f"Stream generator main loop cancelled for {agent_run_id}")
terminate_stream = True
break
except Exception as loop_err:
logger.error(f"Error in stream generator main loop for {agent_run_id}: {loop_err}", exc_info=True)
terminate_stream = True
yield f"data: {json.dumps({'type': 'status', 'status': 'error', 'message': f'Stream failed: {loop_err}'})}\n\n"
break
except Exception as e:
logger.error(f"Error setting up stream for agent run {agent_run_id}: {e}", exc_info=True)
# Only yield error if initial yield didn't happen
if not initial_yield_complete:
yield f"data: {json.dumps({'type': 'status', 'status': 'error', 'message': f'Failed to start stream: {e}'})}\n\n"
finally:
terminate_stream = True
# Graceful shutdown order: unsubscribe → close → cancel
if pubsub_response: await pubsub_response.unsubscribe(response_channel)
if pubsub_control: await pubsub_control.unsubscribe(control_channel)
if pubsub_response: await pubsub_response.close()
if pubsub_control: await pubsub_control.close()
if listener_task:
listener_task.cancel()
try:
await listener_task # Reap inner tasks & swallow their errors
except asyncio.CancelledError:
pass
except Exception as e:
logger.debug(f"listener_task ended with: {e}")
# Wait briefly for tasks to cancel
await asyncio.sleep(0.1)
logger.debug(f"Streaming cleanup complete for agent run: {agent_run_id}")
return StreamingResponse(stream_generator(), media_type="text/event-stream", headers={
"Cache-Control": "no-cache, no-transform", "Connection": "keep-alive",
"X-Accel-Buffering": "no", "Content-Type": "text/event-stream",
"Access-Control-Allow-Origin": "*"
})
async def generate_and_update_project_name(project_id: str, prompt: str):
"""Generates a project name using an LLM and updates the database."""
logger.info(f"Starting background task to generate name for project: {project_id}")
try:
db_conn = DBConnection()
client = await db_conn.client
model_name = "openai/gpt-4o-mini"
system_prompt = "You are a helpful assistant that generates extremely concise titles (2-4 words maximum) for chat threads based on the user's message. Respond with only the title, no other text or punctuation."
user_message = f"Generate an extremely brief title (2-4 words only) for a chat thread that starts with this message: \"{prompt}\""
messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_message}]
logger.debug(f"Calling LLM ({model_name}) for project {project_id} naming.")
response = await make_llm_api_call(messages=messages, model_name=model_name, max_tokens=20, temperature=0.7)
generated_name = None
if response and response.get('choices') and response['choices'][0].get('message'):
raw_name = response['choices'][0]['message'].get('content', '').strip()
cleaned_name = raw_name.strip('\'" \n\t')
if cleaned_name:
generated_name = cleaned_name
logger.info(f"LLM generated name for project {project_id}: '{generated_name}'")
else:
logger.warning(f"LLM returned an empty name for project {project_id}.")
else:
logger.warning(f"Failed to get valid response from LLM for project {project_id} naming. Response: {response}")
if generated_name:
update_result = await client.table('projects').update({"name": generated_name}).eq("project_id", project_id).execute()
if hasattr(update_result, 'data') and update_result.data:
logger.info(f"Successfully updated project {project_id} name to '{generated_name}'")
else:
logger.error(f"Failed to update project {project_id} name in database. Update result: {update_result}")
else:
logger.warning(f"No generated name, skipping database update for project {project_id}.")
except Exception as e:
logger.error(f"Error in background naming task for project {project_id}: {str(e)}\n{traceback.format_exc()}")
finally:
# No need to disconnect DBConnection singleton instance here
logger.info(f"Finished background naming task for project: {project_id}")
@router.post("/agent/initiate", response_model=InitiateAgentResponse)
async def initiate_agent_with_files(
prompt: str = Form(...),
model_name: Optional[str] = Form(None), # Default to None to use config.MODEL_TO_USE
enable_thinking: Optional[bool] = Form(False),
reasoning_effort: Optional[str] = Form("low"),
stream: Optional[bool] = Form(True),
enable_context_manager: Optional[bool] = Form(False),
agent_id: Optional[str] = Form(None), # Add agent_id parameter
files: List[UploadFile] = File(default=[]),
is_agent_builder: Optional[bool] = Form(False),
user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Initiate a new agent session with optional file attachments."""
global instance_id # Ensure instance_id is accessible
if not instance_id:
raise HTTPException(status_code=500, detail="Agent API not initialized with instance ID")
# Use model from config if not specified in the request
logger.info(f"Original model_name from request: {model_name}")
if model_name is None:
model_name = config.MODEL_TO_USE
logger.info(f"Using model from config: {model_name}")
# Log the model name after alias resolution
resolved_model = MODEL_NAME_ALIASES.get(model_name, model_name)
logger.info(f"Resolved model name: {resolved_model}")
# Update model_name to use the resolved version
model_name = resolved_model
logger.info(f"[\033[91mDEBUG\033[0m] Initiating new agent with prompt and {len(files)} files (Instance: {instance_id}), model: {model_name}, enable_thinking: {enable_thinking}")
client = await db.client
account_id = user_id # In Basejump, personal account_id is the same as user_id
# Load agent configuration if agent_id is provided
agent_config = None
if agent_id:
agent_result = await client.table('agents').select('*').eq('agent_id', agent_id).eq('account_id', account_id).execute()
if not agent_result.data:
raise HTTPException(status_code=404, detail="Agent not found or access denied")
agent_config = agent_result.data[0]
logger.info(f"Using custom agent: {agent_config['name']} ({agent_id})")
else:
# Try to get default agent for the account
default_agent_result = await client.table('agents').select('*').eq('account_id', account_id).eq('is_default', True).execute()
if default_agent_result.data:
agent_config = default_agent_result.data[0]
logger.info(f"Using default agent: {agent_config['name']} ({agent_config['agent_id']})")
can_use, model_message, allowed_models = await can_use_model(client, account_id, model_name)
if not can_use:
raise HTTPException(status_code=403, detail={"message": model_message, "allowed_models": allowed_models})
can_run, message, subscription = await check_billing_status(client, account_id)
if not can_run:
raise HTTPException(status_code=402, detail={"message": message, "subscription": subscription})
try:
# 1. Create Project
placeholder_name = f"{prompt[:30]}..." if len(prompt) > 30 else prompt
project = await client.table('projects').insert({
"project_id": str(uuid.uuid4()), "account_id": account_id, "name": placeholder_name,
"created_at": datetime.now(timezone.utc).isoformat()
}).execute()
project_id = project.data[0]['project_id']
logger.info(f"Created new project: {project_id}")
# 2. Create Thread
thread_data = {
"thread_id": str(uuid.uuid4()),
"project_id": project_id,
"account_id": account_id,
"created_at": datetime.now(timezone.utc).isoformat()
}
# Store the agent_id in the thread if we have one
if agent_config:
thread_data["agent_id"] = agent_config['agent_id']
logger.info(f"Storing agent_id {agent_config['agent_id']} in thread")
thread = await client.table('threads').insert(thread_data).execute()
thread_id = thread.data[0]['thread_id']
logger.info(f"Created new thread: {thread_id}")
# Trigger Background Naming Task
asyncio.create_task(generate_and_update_project_name(project_id=project_id, prompt=prompt))
# 3. Create Sandbox
sandbox_pass = str(uuid.uuid4())
sandbox = create_sandbox(sandbox_pass, project_id)
sandbox_id = sandbox.id
logger.info(f"Created new sandbox {sandbox_id} for project {project_id}")
# Get preview links
vnc_link = sandbox.get_preview_link(6080)
website_link = sandbox.get_preview_link(8080)
vnc_url = vnc_link.url if hasattr(vnc_link, 'url') else str(vnc_link).split("url='")[1].split("'")[0]
website_url = website_link.url if hasattr(website_link, 'url') else str(website_link).split("url='")[1].split("'")[0]
token = None
if hasattr(vnc_link, 'token'):
token = vnc_link.token
elif "token='" in str(vnc_link):
token = str(vnc_link).split("token='")[1].split("'")[0]
# Update project with sandbox info
update_result = await client.table('projects').update({
'sandbox': {
'id': sandbox_id, 'pass': sandbox_pass, 'vnc_preview': vnc_url,
'sandbox_url': website_url, 'token': token
}
}).eq('project_id', project_id).execute()
if not update_result.data:
logger.error(f"Failed to update project {project_id} with new sandbox {sandbox_id}")
raise Exception("Database update failed")
# 4. Upload Files to Sandbox (if any)
message_content = prompt
if files:
successful_uploads = []
failed_uploads = []
for file in files:
if file.filename:
try:
safe_filename = file.filename.replace('/', '_').replace('\\', '_')
target_path = f"/workspace/{safe_filename}"
logger.info(f"Attempting to upload {safe_filename} to {target_path} in sandbox {sandbox_id}")
content = await file.read()
upload_successful = False
try:
if hasattr(sandbox, 'fs') and hasattr(sandbox.fs, 'upload_file'):
import inspect
if inspect.iscoroutinefunction(sandbox.fs.upload_file):
await sandbox.fs.upload_file(target_path, content)
else:
sandbox.fs.upload_file(target_path, content)
logger.debug(f"Called sandbox.fs.upload_file for {target_path}")
upload_successful = True
else:
raise NotImplementedError("Suitable upload method not found on sandbox object.")
except Exception as upload_error:
logger.error(f"Error during sandbox upload call for {safe_filename}: {str(upload_error)}", exc_info=True)
if upload_successful:
try:
await asyncio.sleep(0.2)
parent_dir = os.path.dirname(target_path)
files_in_dir = sandbox.fs.list_files(parent_dir)
file_names_in_dir = [f.name for f in files_in_dir]
if safe_filename in file_names_in_dir:
successful_uploads.append(target_path)
logger.info(f"Successfully uploaded and verified file {safe_filename} to sandbox path {target_path}")
else:
logger.error(f"Verification failed for {safe_filename}: File not found in {parent_dir} after upload attempt.")
failed_uploads.append(safe_filename)
except Exception as verify_error:
logger.error(f"Error verifying file {safe_filename} after upload: {str(verify_error)}", exc_info=True)
failed_uploads.append(safe_filename)
else:
failed_uploads.append(safe_filename)
except Exception as file_error:
logger.error(f"Error processing file {file.filename}: {str(file_error)}", exc_info=True)
failed_uploads.append(file.filename)
finally:
await file.close()
if successful_uploads:
message_content += "\n\n" if message_content else ""
for file_path in successful_uploads: message_content += f"[Uploaded File: {file_path}]\n"
if failed_uploads:
message_content += "\n\nThe following files failed to upload:\n"
for failed_file in failed_uploads: message_content += f"- {failed_file}\n"
# 5. Add initial user message to thread
message_id = str(uuid.uuid4())
message_payload = {"role": "user", "content": message_content}
await client.table('messages').insert({
"message_id": message_id, "thread_id": thread_id, "type": "user",
"is_llm_message": True, "content": json.dumps(message_payload),
"created_at": datetime.now(timezone.utc).isoformat()
}).execute()
# 6. Start Agent Run
agent_run = await client.table('agent_runs').insert({
"thread_id": thread_id, "status": "running",
"started_at": datetime.now(timezone.utc).isoformat()
}).execute()
agent_run_id = agent_run.data[0]['id']
logger.info(f"Created new agent run: {agent_run_id}")
# Register run in Redis
instance_key = f"active_run:{instance_id}:{agent_run_id}"
try:
await redis.set(instance_key, "running", ex=redis.REDIS_KEY_TTL)
except Exception as e:
logger.warning(f"Failed to register agent run in Redis ({instance_key}): {str(e)}")
# Run agent in background
run_agent_background.send(
agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id,
project_id=project_id,
model_name=model_name, # Already resolved above
enable_thinking=enable_thinking, reasoning_effort=reasoning_effort,
stream=stream, enable_context_manager=enable_context_manager,
agent_config=agent_config, # Pass agent configuration
is_agent_builder=is_agent_builder
)
return {"thread_id": thread_id, "agent_run_id": agent_run_id}
except Exception as e:
logger.error(f"Error in agent initiation: {str(e)}\n{traceback.format_exc()}")
# TODO: Clean up created project/thread if initiation fails mid-way
raise HTTPException(status_code=500, detail=f"Failed to initiate agent session: {str(e)}")
@router.get("/agents", response_model=List[AgentResponse])
async def get_agents(user_id: str = Depends(get_current_user_id_from_jwt)):
"""Get all agents for the current user."""
logger.info(f"Fetching agents for user: {user_id}")
client = await db.client
try:
# Get agents for the current user's account
agents = await client.table('agents').select('*').eq("account_id", user_id).order('created_at', desc=True).execute()
if not agents.data:
logger.info(f"No agents found for user: {user_id}")
return []
# Format the response
agent_list = []
for agent in agents.data:
agent_list.append(AgentResponse(
agent_id=agent['agent_id'],
account_id=agent['account_id'],
name=agent['name'],
description=agent.get('description'),
system_prompt=agent['system_prompt'],
configured_mcps=agent.get('configured_mcps', []),
agentpress_tools=agent.get('agentpress_tools', {}),
is_default=agent.get('is_default', False),
is_public=agent.get('is_public', False),
marketplace_published_at=agent.get('marketplace_published_at'),
download_count=agent.get('download_count', 0),
tags=agent.get('tags', []),
avatar=agent.get('avatar'),
avatar_color=agent.get('avatar_color'),
created_at=agent['created_at'],
updated_at=agent['updated_at']
))
logger.info(f"Found {len(agent_list)} agents for user: {user_id}")
return agent_list
except Exception as e:
logger.error(f"Error fetching agents for user {user_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to fetch agents: {str(e)}")
@router.get("/agents/{agent_id}", response_model=AgentResponse)
async def get_agent(agent_id: str, user_id: str = Depends(get_current_user_id_from_jwt)):
"""Get a specific agent by ID. Only the owner can access non-public agents."""
logger.info(f"Fetching agent {agent_id} for user: {user_id}")
client = await db.client
try:
# Get agent with access check - only owner or public agents
agent = await client.table('agents').select('*').eq("agent_id", agent_id).execute()
if not agent.data:
raise HTTPException(status_code=404, detail="Agent not found")
agent_data = agent.data[0]
# Check ownership - only owner can access non-public agents
if agent_data['account_id'] != user_id and not agent_data.get('is_public', False):
raise HTTPException(status_code=403, detail="Access denied")
return AgentResponse(
agent_id=agent_data['agent_id'],
account_id=agent_data['account_id'],
name=agent_data['name'],
description=agent_data.get('description'),
system_prompt=agent_data['system_prompt'],
configured_mcps=agent_data.get('configured_mcps', []),
agentpress_tools=agent_data.get('agentpress_tools', {}),
is_default=agent_data.get('is_default', False),
is_public=agent_data.get('is_public', False),
marketplace_published_at=agent_data.get('marketplace_published_at'),
download_count=agent_data.get('download_count', 0),
tags=agent_data.get('tags', []),
avatar=agent_data.get('avatar'),
avatar_color=agent_data.get('avatar_color'),
created_at=agent_data['created_at'],
updated_at=agent_data['updated_at']
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error fetching agent {agent_id} for user {user_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to fetch agent: {str(e)}")
@router.post("/agents", response_model=AgentResponse)
async def create_agent(
agent_data: AgentCreateRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Create a new agent."""
logger.info(f"Creating new agent for user: {user_id}")
client = await db.client
try:
# If this is set as default, we need to unset other defaults first
if agent_data.is_default:
await client.table('agents').update({"is_default": False}).eq("account_id", user_id).eq("is_default", True).execute()
# enhanced_system_prompt = await enhance_system_prompt(
# agent_name=agent_data.name,
# description=agent_data.description or "",
# user_system_prompt=agent_data.system_prompt
# )
insert_data = {
"account_id": user_id,
"name": agent_data.name,
"description": agent_data.description,
"system_prompt": agent_data.system_prompt,
"configured_mcps": agent_data.configured_mcps or [],
"agentpress_tools": agent_data.agentpress_tools or {},
"is_default": agent_data.is_default or False,
"avatar": agent_data.avatar,
"avatar_color": agent_data.avatar_color
}
new_agent = await client.table('agents').insert(insert_data).execute()
if not new_agent.data:
raise HTTPException(status_code=500, detail="Failed to create agent")
agent = new_agent.data[0]
logger.info(f"Created agent {agent['agent_id']} for user: {user_id}")
return AgentResponse(
agent_id=agent['agent_id'],
account_id=agent['account_id'],
name=agent['name'],
description=agent.get('description'),
system_prompt=agent['system_prompt'],
configured_mcps=agent.get('configured_mcps', []),
agentpress_tools=agent.get('agentpress_tools', {}),
is_default=agent.get('is_default', False),
is_public=agent.get('is_public', False),
marketplace_published_at=agent.get('marketplace_published_at'),
download_count=agent.get('download_count', 0),
tags=agent.get('tags', []),
avatar=agent.get('avatar'),
avatar_color=agent.get('avatar_color'),
created_at=agent['created_at'],
updated_at=agent['updated_at']
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error creating agent for user {user_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to create agent: {str(e)}")
@router.put("/agents/{agent_id}", response_model=AgentResponse)
async def update_agent(
agent_id: str,
agent_data: AgentUpdateRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Update an existing agent."""
logger.info(f"Updating agent {agent_id} for user: {user_id}")
client = await db.client
try:
# First verify the agent exists and belongs to the user
existing_agent = await client.table('agents').select('*').eq("agent_id", agent_id).eq("account_id", user_id).maybe_single().execute()
if not existing_agent.data:
raise HTTPException(status_code=404, detail="Agent not found")
existing_data = existing_agent.data
# Prepare update data (only include fields that are not None)
update_data = {}
if agent_data.name is not None:
update_data["name"] = agent_data.name
if agent_data.description is not None:
update_data["description"] = agent_data.description
if agent_data.system_prompt is not None:
# Enhance the system prompt using GPT-4o if it's being updated
# enhanced_system_prompt = await enhance_system_prompt(
# agent_name=agent_data.name or existing_data['name'],
# description=agent_data.description or existing_data.get('description', ''),
# user_system_prompt=agent_data.system_prompt
# )
update_data["system_prompt"] = agent_data.system_prompt
if agent_data.configured_mcps is not None:
update_data["configured_mcps"] = agent_data.configured_mcps
if agent_data.agentpress_tools is not None:
update_data["agentpress_tools"] = agent_data.agentpress_tools
if agent_data.is_default is not None:
update_data["is_default"] = agent_data.is_default
# If setting as default, unset other defaults first
if agent_data.is_default:
await client.table('agents').update({"is_default": False}).eq("account_id", user_id).eq("is_default", True).neq("agent_id", agent_id).execute()
if agent_data.avatar is not None:
update_data["avatar"] = agent_data.avatar
if agent_data.avatar_color is not None:
update_data["avatar_color"] = agent_data.avatar_color
if not update_data:
# No fields to update, return existing agent
agent = existing_agent.data
else:
# Update the agent
update_result = await client.table('agents').update(update_data).eq("agent_id", agent_id).eq("account_id", user_id).execute()
if not update_result.data:
raise HTTPException(status_code=500, detail="Failed to update agent")
# Fetch the updated agent data
updated_agent = await client.table('agents').select('*').eq("agent_id", agent_id).eq("account_id", user_id).maybe_single().execute()
if not updated_agent.data:
raise HTTPException(status_code=500, detail="Failed to fetch updated agent")
agent = updated_agent.data
logger.info(f"Updated agent {agent_id} for user: {user_id}")
return AgentResponse(
agent_id=agent['agent_id'],
account_id=agent['account_id'],
name=agent['name'],
description=agent.get('description'),
system_prompt=agent['system_prompt'],
configured_mcps=agent.get('configured_mcps', []),
agentpress_tools=agent.get('agentpress_tools', {}),
is_default=agent.get('is_default', False),
is_public=agent.get('is_public', False),
marketplace_published_at=agent.get('marketplace_published_at'),
download_count=agent.get('download_count', 0),
tags=agent.get('tags', []),
avatar=agent.get('avatar'),
avatar_color=agent.get('avatar_color'),
created_at=agent['created_at'],
updated_at=agent['updated_at']
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error updating agent {agent_id} for user {user_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to update agent: {str(e)}")
@router.delete("/agents/{agent_id}")
async def delete_agent(agent_id: str, user_id: str = Depends(get_current_user_id_from_jwt)):
"""Delete an agent."""
logger.info(f"Deleting agent: {agent_id}")
client = await db.client
try:
# Verify agent ownership
agent_result = await client.table('agents').select('*').eq('agent_id', agent_id).execute()
if not agent_result.data:
raise HTTPException(status_code=404, detail="Agent not found")
agent = agent_result.data[0]
if agent['account_id'] != user_id:
raise HTTPException(status_code=403, detail="Access denied")
# Check if agent is default
if agent['is_default']:
raise HTTPException(status_code=400, detail="Cannot delete default agent")
# Delete the agent
await client.table('agents').delete().eq('agent_id', agent_id).execute()
logger.info(f"Successfully deleted agent: {agent_id}")
return {"message": "Agent deleted successfully"}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error deleting agent {agent_id}: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
# Marketplace Models
class MarketplaceAgent(BaseModel):
agent_id: str
name: str
description: Optional[str]
system_prompt: str
configured_mcps: List[Dict[str, Any]]
agentpress_tools: Dict[str, Any]
tags: Optional[List[str]]
download_count: int
marketplace_published_at: str
created_at: str
creator_name: str
avatar: Optional[str]
avatar_color: Optional[str]
class MarketplaceAgentsResponse(BaseModel):
agents: List[MarketplaceAgent]
class PublishAgentRequest(BaseModel):
tags: Optional[List[str]] = []
@router.get("/marketplace/agents", response_model=MarketplaceAgentsResponse)
async def get_marketplace_agents(
search: Optional[str] = None,
tags: Optional[str] = None, # Comma-separated string
limit: Optional[int] = 50,
offset: Optional[int] = 0
):
"""Get public agents from the marketplace."""
logger.info(f"Fetching marketplace agents with search='{search}', tags='{tags}', limit={limit}, offset={offset}")
client = await db.client
try:
# Convert tags string to array if provided
tags_array = None
if tags:
tags_array = [tag.strip() for tag in tags.split(',') if tag.strip()]
# Call the database function
result = await client.rpc('get_marketplace_agents', {
'p_search': search,
'p_tags': tags_array,
'p_limit': limit,
'p_offset': offset
}).execute()
if result.data is None:
result.data = []
logger.info(f"Found {len(result.data)} marketplace agents")
return {"agents": result.data}
except Exception as e:
logger.error(f"Error fetching marketplace agents: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/agents/{agent_id}/publish")
async def publish_agent_to_marketplace(
agent_id: str,
publish_data: PublishAgentRequest,
user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Publish an agent to the marketplace."""
logger.info(f"Publishing agent {agent_id} to marketplace")
client = await db.client
try:
# Verify agent ownership
agent_result = await client.table('agents').select('*').eq('agent_id', agent_id).execute()
if not agent_result.data:
raise HTTPException(status_code=404, detail="Agent not found")
agent = agent_result.data[0]
if agent['account_id'] != user_id:
raise HTTPException(status_code=403, detail="Access denied")
# Update agent with marketplace data
update_data = {
'is_public': True,
'marketplace_published_at': datetime.now(timezone.utc).isoformat()
}
if publish_data.tags:
update_data['tags'] = publish_data.tags
await client.table('agents').update(update_data).eq('agent_id', agent_id).execute()
logger.info(f"Successfully published agent {agent_id} to marketplace")
return {"message": "Agent published to marketplace successfully"}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error publishing agent {agent_id}: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/agents/{agent_id}/unpublish")
async def unpublish_agent_from_marketplace(
agent_id: str,
user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Unpublish an agent from the marketplace."""
logger.info(f"Unpublishing agent {agent_id} from marketplace")
client = await db.client
try:
# Verify agent ownership
agent_result = await client.table('agents').select('*').eq('agent_id', agent_id).execute()
if not agent_result.data:
raise HTTPException(status_code=404, detail="Agent not found")
agent = agent_result.data[0]
if agent['account_id'] != user_id:
raise HTTPException(status_code=403, detail="Access denied")
# Update agent to remove from marketplace
await client.table('agents').update({
'is_public': False,
'marketplace_published_at': None
}).eq('agent_id', agent_id).execute()
logger.info(f"Successfully unpublished agent {agent_id} from marketplace")
return {"message": "Agent removed from marketplace successfully"}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error unpublishing agent {agent_id}: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/marketplace/agents/{agent_id}/add-to-library")
async def add_agent_to_library(
agent_id: str,
user_id: str = Depends(get_current_user_id_from_jwt)
):
"""Add an agent from the marketplace to user's library."""
logger.info(f"Adding marketplace agent {agent_id} to user {user_id} library")
client = await db.client
try:
# Call the database function with user_id
result = await client.rpc('add_agent_to_library', {
'p_original_agent_id': agent_id,
'p_user_account_id': user_id
}).execute()
if result.data:
new_agent_id = result.data
logger.info(f"Successfully added agent {agent_id} to library as {new_agent_id}")
return {"message": "Agent added to library successfully", "new_agent_id": new_agent_id}
else:
raise HTTPException(status_code=400, detail="Failed to add agent to library")
except Exception as e:
error_msg = str(e)
logger.error(f"Error adding agent {agent_id} to library: {error_msg}")
if "Agent not found or not public" in error_msg:
raise HTTPException(status_code=404, detail="Agent not found or not public")
elif "Agent already in your library" in error_msg:
raise HTTPException(status_code=409, detail="Agent already in your library")
else:
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/user/agent-library")
async def get_user_agent_library(user_id: str = Depends(get_current_user_id_from_jwt)):
"""Get user's agent library (agents added from marketplace)."""
logger.info(f"Fetching agent library for user {user_id}")
client = await db.client
try:
result = await client.table('user_agent_library').select("""
*,
original_agent:agents!user_agent_library_original_agent_id_fkey(
agent_id,
name,
description,
download_count
),
agent:agents!user_agent_library_agent_id_fkey(
agent_id,
name,
description,
system_prompt
)
""").eq('user_account_id', user_id).order('added_at', desc=True).execute()
logger.info(f"Found {len(result.data or [])} agents in user library")
return {"library": result.data or []}
except Exception as e:
logger.error(f"Error fetching user agent library: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")