mirror of https://github.com/kortix-ai/suna.git
Optimize agent initiate, start and stream
This commit is contained in:
parent
d070f3ece5
commit
24deeb2c5e
|
@ -279,12 +279,15 @@ async def stop_agent_run(agent_run_id: str, error_message: Optional[str] = None)
|
|||
logger.info(f"Successfully initiated stop process for agent run: {agent_run_id}")
|
||||
|
||||
async def get_agent_run_with_access_check(client, agent_run_id: str, user_id: str):
|
||||
agent_run = await client.table('agent_runs').select('*').eq('id', agent_run_id).execute()
|
||||
agent_run = await client.table('agent_runs').select('*, threads(account_id)').eq('id', agent_run_id).execute()
|
||||
if not agent_run.data:
|
||||
raise HTTPException(status_code=404, detail="Agent run not found")
|
||||
|
||||
agent_run_data = agent_run.data[0]
|
||||
thread_id = agent_run_data['thread_id']
|
||||
account_id = agent_run_data['threads']['account_id']
|
||||
if account_id == user_id:
|
||||
return agent_run_data
|
||||
await verify_thread_access(client, thread_id, user_id)
|
||||
return agent_run_data
|
||||
|
||||
|
@ -321,7 +324,6 @@ async def start_agent(
|
|||
logger.info(f"Starting new agent for thread: {thread_id} with config: model={model_name}, thinking={body.enable_thinking}, effort={body.reasoning_effort}, stream={body.stream}, context_manager={body.enable_context_manager} (Instance: {instance_id})")
|
||||
client = await db.client
|
||||
|
||||
await verify_thread_access(client, thread_id, user_id)
|
||||
|
||||
thread_result = await client.table('threads').select('project_id', 'account_id', 'metadata').eq('thread_id', thread_id).execute()
|
||||
|
||||
|
@ -332,6 +334,9 @@ async def start_agent(
|
|||
account_id = thread_data.get('account_id')
|
||||
thread_metadata = thread_data.get('metadata', {})
|
||||
|
||||
if account_id != user_id:
|
||||
await verify_thread_access(client, thread_id, user_id)
|
||||
|
||||
structlog.contextvars.bind_contextvars(
|
||||
project_id=project_id,
|
||||
account_id=account_id,
|
||||
|
@ -433,18 +438,23 @@ async def start_agent(
|
|||
logger.info(f"[AGENT LOAD] Agent config keys: {list(agent_config.keys())}")
|
||||
logger.info(f"Using agent {agent_config['agent_id']} for this agent run (thread remains agent-agnostic)")
|
||||
|
||||
can_use, model_message, allowed_models = await can_use_model(client, account_id, model_name)
|
||||
# Run all checks concurrently
|
||||
model_check_task = asyncio.create_task(can_use_model(client, account_id, model_name))
|
||||
billing_check_task = asyncio.create_task(check_billing_status(client, account_id))
|
||||
limit_check_task = asyncio.create_task(check_agent_run_limit(client, account_id))
|
||||
|
||||
# Wait for all checks to complete
|
||||
(can_use, model_message, allowed_models), (can_run, message, subscription), limit_check = await asyncio.gather(
|
||||
model_check_task, billing_check_task, limit_check_task
|
||||
)
|
||||
|
||||
# Check results and raise appropriate errors
|
||||
if not can_use:
|
||||
raise HTTPException(status_code=403, detail={"message": model_message, "allowed_models": allowed_models})
|
||||
|
||||
can_run, message, subscription = await check_billing_status(client, account_id)
|
||||
|
||||
if not can_run:
|
||||
raise HTTPException(status_code=402, detail={"message": message, "subscription": subscription})
|
||||
|
||||
limit_check = await check_agent_run_limit(client, account_id)
|
||||
|
||||
if not limit_check['can_start']:
|
||||
error_detail = {
|
||||
"message": f"Maximum of {config.MAX_PARALLEL_AGENT_RUNS} parallel agent runs allowed within 24 hours. You currently have {limit_check['running_count']} running.",
|
||||
|
@ -707,8 +717,8 @@ async def stream_agent_run(
|
|||
logger.info(f"Starting stream for agent run: {agent_run_id}")
|
||||
client = await db.client
|
||||
|
||||
user_id = await get_user_id_from_stream_auth(request, token)
|
||||
agent_run_data = await get_agent_run_with_access_check(client, agent_run_id, user_id)
|
||||
user_id = await get_user_id_from_stream_auth(request, token) # practically instant
|
||||
agent_run_data = await get_agent_run_with_access_check(client, agent_run_id, user_id) # 1 db query
|
||||
|
||||
structlog.contextvars.bind_contextvars(
|
||||
agent_run_id=agent_run_id,
|
||||
|
@ -719,7 +729,7 @@ async def stream_agent_run(
|
|||
response_channel = f"agent_run:{agent_run_id}:new_response"
|
||||
control_channel = f"agent_run:{agent_run_id}:control" # Global control channel
|
||||
|
||||
async def stream_generator():
|
||||
async def stream_generator(agent_run_data):
|
||||
logger.debug(f"Streaming responses for {agent_run_id} using Redis list {response_list_key} and channel {response_channel}")
|
||||
last_processed_index = -1
|
||||
pubsub_response = None
|
||||
|
@ -740,9 +750,8 @@ async def stream_agent_run(
|
|||
last_processed_index = len(initial_responses) - 1
|
||||
initial_yield_complete = True
|
||||
|
||||
# 2. Check run status *after* yielding initial data
|
||||
run_status = await client.table('agent_runs').select('status', 'thread_id').eq("id", agent_run_id).maybe_single().execute()
|
||||
current_status = run_status.data.get('status') if run_status.data else None
|
||||
# 2. Check run status
|
||||
current_status = agent_run_data.get('status') if agent_run_data else None
|
||||
|
||||
if current_status != 'running':
|
||||
logger.info(f"Agent run {agent_run_id} is not running (status: {current_status}). Ending stream.")
|
||||
|
@ -750,16 +759,22 @@ async def stream_agent_run(
|
|||
return
|
||||
|
||||
structlog.contextvars.bind_contextvars(
|
||||
thread_id=run_status.data.get('thread_id'),
|
||||
thread_id=agent_run_data.get('thread_id'),
|
||||
)
|
||||
|
||||
# 3. Set up Pub/Sub listeners for new responses and control signals
|
||||
pubsub_response = await redis.create_pubsub()
|
||||
await pubsub_response.subscribe(response_channel)
|
||||
logger.debug(f"Subscribed to response channel: {response_channel}")
|
||||
# 3. Set up Pub/Sub listeners for new responses and control signals concurrently
|
||||
pubsub_response_task = asyncio.create_task(redis.create_pubsub())
|
||||
pubsub_control_task = asyncio.create_task(redis.create_pubsub())
|
||||
|
||||
pubsub_control = await redis.create_pubsub()
|
||||
await pubsub_control.subscribe(control_channel)
|
||||
pubsub_response, pubsub_control = await asyncio.gather(pubsub_response_task, pubsub_control_task)
|
||||
|
||||
# Subscribe to channels concurrently
|
||||
response_subscribe_task = asyncio.create_task(pubsub_response.subscribe(response_channel))
|
||||
control_subscribe_task = asyncio.create_task(pubsub_control.subscribe(control_channel))
|
||||
|
||||
await asyncio.gather(response_subscribe_task, control_subscribe_task)
|
||||
|
||||
logger.debug(f"Subscribed to response channel: {response_channel}")
|
||||
logger.debug(f"Subscribed to control channel: {control_channel}")
|
||||
|
||||
# Queue to communicate between listeners and the main generator loop
|
||||
|
@ -883,7 +898,7 @@ async def stream_agent_run(
|
|||
await asyncio.sleep(0.1)
|
||||
logger.debug(f"Streaming cleanup complete for agent run: {agent_run_id}")
|
||||
|
||||
return StreamingResponse(stream_generator(), media_type="text/event-stream", headers={
|
||||
return StreamingResponse(stream_generator(agent_run_data), media_type="text/event-stream", headers={
|
||||
"Cache-Control": "no-cache, no-transform", "Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", "Content-Type": "text/event-stream",
|
||||
"Access-Control-Allow-Origin": "*"
|
||||
|
@ -1054,7 +1069,17 @@ async def initiate_agent_with_files(
|
|||
if agent_config:
|
||||
logger.info(f"[AGENT INITIATE] Agent config keys: {list(agent_config.keys())}")
|
||||
|
||||
can_use, model_message, allowed_models = await can_use_model(client, account_id, model_name)
|
||||
# Run all checks concurrently
|
||||
model_check_task = asyncio.create_task(can_use_model(client, account_id, model_name))
|
||||
billing_check_task = asyncio.create_task(check_billing_status(client, account_id))
|
||||
limit_check_task = asyncio.create_task(check_agent_run_limit(client, account_id))
|
||||
|
||||
# Wait for all checks to complete
|
||||
(can_use, model_message, allowed_models), (can_run, message, subscription), limit_check = await asyncio.gather(
|
||||
model_check_task, billing_check_task, limit_check_task
|
||||
)
|
||||
|
||||
# Check results and raise appropriate errors
|
||||
if not can_use:
|
||||
raise HTTPException(status_code=403, detail={"message": model_message, "allowed_models": allowed_models})
|
||||
|
||||
|
@ -1095,6 +1120,7 @@ async def initiate_agent_with_files(
|
|||
token = None
|
||||
|
||||
if files:
|
||||
# 3. Create Sandbox (lazy): only create now if files were uploaded and need the
|
||||
try:
|
||||
sandbox_pass = str(uuid.uuid4())
|
||||
sandbox = await create_sandbox(sandbox_pass, project_id)
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
"@hookform/resolvers": "^5.0.1",
|
||||
"@next/third-parties": "^15.3.1",
|
||||
"@number-flow/react": "^0.5.7",
|
||||
"@pipedream/sdk": "^1.7.0",
|
||||
"@radix-ui/react-accordion": "^1.2.11",
|
||||
"@radix-ui/react-alert-dialog": "^1.1.11",
|
||||
"@radix-ui/react-avatar": "^1.1.4",
|
||||
|
@ -2952,30 +2951,6 @@
|
|||
"node": ">=0.10"
|
||||
}
|
||||
},
|
||||
"node_modules/@pipedream/sdk": {
|
||||
"version": "1.7.0",
|
||||
"resolved": "https://registry.npmjs.org/@pipedream/sdk/-/sdk-1.7.0.tgz",
|
||||
"integrity": "sha512-Vxz1ehT9EfFGN1txLQlh2KspRdjwqU1NCczibJdV7NyNh0PQcptbohlOTDiH7kdYwhL90vaeQXyVaO8RfnnOJQ==",
|
||||
"license": "SEE LICENSE IN LICENSE",
|
||||
"dependencies": {
|
||||
"@rails/actioncable": "^8.0.0",
|
||||
"commander": "^12.1.0",
|
||||
"oauth4webapi": "^3.1.4",
|
||||
"ws": "^8.18.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@pipedream/sdk/node_modules/commander": {
|
||||
"version": "12.1.0",
|
||||
"resolved": "https://registry.npmjs.org/commander/-/commander-12.1.0.tgz",
|
||||
"integrity": "sha512-Vw8qHK3bZM9y/P10u3Vib8o/DdkvA2OtPtZvD871QKjy74Wj1WSKFILMPRPSdUSx5RFK1arlJzEtA4PkFgnbuA==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@preact/signals": {
|
||||
"version": "1.3.2",
|
||||
"resolved": "https://registry.npmjs.org/@preact/signals/-/signals-1.3.2.tgz",
|
||||
|
@ -4163,12 +4138,6 @@
|
|||
"integrity": "sha512-HPwpGIzkl28mWyZqG52jiqDJ12waP11Pa1lGoiyUkIEuMLBP0oeK/C89esbXrxsky5we7dfd8U58nm0SgAWpVw==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@rails/actioncable": {
|
||||
"version": "8.0.200",
|
||||
"resolved": "https://registry.npmjs.org/@rails/actioncable/-/actioncable-8.0.200.tgz",
|
||||
"integrity": "sha512-EDqWyxck22BHmv1e+mD8Kl6GmtNkhEPdRfGFT7kvsv1yoXd9iYrqHDVAaR8bKmU/syC5eEZ2I5aWWxtB73ukMw==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@react-pdf/fns": {
|
||||
"version": "3.1.2",
|
||||
"resolved": "https://registry.npmjs.org/@react-pdf/fns/-/fns-3.1.2.tgz",
|
||||
|
@ -12003,15 +11972,6 @@
|
|||
"esm-env": "^1.1.4"
|
||||
}
|
||||
},
|
||||
"node_modules/oauth4webapi": {
|
||||
"version": "3.5.5",
|
||||
"resolved": "https://registry.npmjs.org/oauth4webapi/-/oauth4webapi-3.5.5.tgz",
|
||||
"integrity": "sha512-1K88D2GiAydGblHo39NBro5TebGXa+7tYoyIbxvqv3+haDDry7CBE1eSYuNbOSsYCCU6y0gdynVZAkm4YPw4hg==",
|
||||
"license": "MIT",
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/panva"
|
||||
}
|
||||
},
|
||||
"node_modules/object-assign": {
|
||||
"version": "4.1.1",
|
||||
"resolved": "https://registry.npmjs.org/object-assign/-/object-assign-4.1.1.tgz",
|
||||
|
|
Loading…
Reference in New Issue