mirror of https://github.com/kortix-ai/suna.git
Merge pull request #1475 from ffrankan/fix/redis-connection-optimization
Fix Redis connection optimization in SSE streaming
This commit is contained in:
commit
fc842cb601
|
@ -739,8 +739,7 @@ async def stream_agent_run(
|
|||
async def stream_generator(agent_run_data):
|
||||
logger.debug(f"Streaming responses for {agent_run_id} using Redis list {response_list_key} and channel {response_channel}")
|
||||
last_processed_index = -1
|
||||
pubsub_response = None
|
||||
pubsub_control = None
|
||||
# Single pubsub used for response + control
|
||||
listener_task = None
|
||||
terminate_stream = False
|
||||
initial_yield_complete = False
|
||||
|
@ -769,49 +768,38 @@ async def stream_agent_run(
|
|||
thread_id=agent_run_data.get('thread_id'),
|
||||
)
|
||||
|
||||
# 3. Set up Pub/Sub listeners for new responses and control signals concurrently
|
||||
pubsub_response_task = asyncio.create_task(redis.create_pubsub())
|
||||
pubsub_control_task = asyncio.create_task(redis.create_pubsub())
|
||||
|
||||
pubsub_response, pubsub_control = await asyncio.gather(pubsub_response_task, pubsub_control_task)
|
||||
|
||||
# Subscribe to channels concurrently
|
||||
response_subscribe_task = asyncio.create_task(pubsub_response.subscribe(response_channel))
|
||||
control_subscribe_task = asyncio.create_task(pubsub_control.subscribe(control_channel))
|
||||
|
||||
await asyncio.gather(response_subscribe_task, control_subscribe_task)
|
||||
|
||||
logger.debug(f"Subscribed to response channel: {response_channel}")
|
||||
logger.debug(f"Subscribed to control channel: {control_channel}")
|
||||
# 3. Use a single Pub/Sub connection subscribed to both channels
|
||||
pubsub = await redis.create_pubsub()
|
||||
await pubsub.subscribe(response_channel, control_channel)
|
||||
logger.debug(f"Subscribed to channels: {response_channel}, {control_channel}")
|
||||
|
||||
# Queue to communicate between listeners and the main generator loop
|
||||
message_queue = asyncio.Queue()
|
||||
|
||||
async def listen_messages():
|
||||
response_reader = pubsub_response.listen()
|
||||
control_reader = pubsub_control.listen()
|
||||
tasks = [asyncio.create_task(response_reader.__anext__()), asyncio.create_task(control_reader.__anext__())]
|
||||
listener = pubsub.listen()
|
||||
task = asyncio.create_task(listener.__anext__())
|
||||
|
||||
while not terminate_stream:
|
||||
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
for task in done:
|
||||
done, _ = await asyncio.wait([task], return_when=asyncio.FIRST_COMPLETED)
|
||||
for finished in done:
|
||||
try:
|
||||
message = task.result()
|
||||
message = finished.result()
|
||||
if message and isinstance(message, dict) and message.get("type") == "message":
|
||||
channel = message.get("channel")
|
||||
data = message.get("data")
|
||||
if isinstance(data, bytes): data = data.decode('utf-8')
|
||||
if isinstance(data, bytes):
|
||||
data = data.decode('utf-8')
|
||||
|
||||
if channel == response_channel and data == "new":
|
||||
await message_queue.put({"type": "new_response"})
|
||||
elif channel == control_channel and data in ["STOP", "END_STREAM", "ERROR"]:
|
||||
logger.debug(f"Received control signal '{data}' for {agent_run_id}")
|
||||
await message_queue.put({"type": "control", "data": data})
|
||||
return # Stop listening on control signal
|
||||
return # Stop listening on control signal
|
||||
|
||||
except StopAsyncIteration:
|
||||
logger.warning(f"Listener {task} stopped.")
|
||||
# Decide how to handle listener stopping, maybe terminate?
|
||||
logger.warning(f"Listener stopped for {agent_run_id}.")
|
||||
await message_queue.put({"type": "error", "data": "Listener stopped unexpectedly"})
|
||||
return
|
||||
except Exception as e:
|
||||
|
@ -819,17 +807,9 @@ async def stream_agent_run(
|
|||
await message_queue.put({"type": "error", "data": "Listener failed"})
|
||||
return
|
||||
finally:
|
||||
# Reschedule the completed listener task
|
||||
if task in tasks:
|
||||
tasks.remove(task)
|
||||
if message and isinstance(message, dict) and message.get("channel") == response_channel:
|
||||
tasks.append(asyncio.create_task(response_reader.__anext__()))
|
||||
elif message and isinstance(message, dict) and message.get("channel") == control_channel:
|
||||
tasks.append(asyncio.create_task(control_reader.__anext__()))
|
||||
|
||||
# Cancel pending listener tasks on exit
|
||||
for p_task in pending: p_task.cancel()
|
||||
for task in tasks: task.cancel()
|
||||
# Resubscribe to the next message if continuing
|
||||
if not terminate_stream:
|
||||
task = asyncio.create_task(listener.__anext__())
|
||||
|
||||
|
||||
listener_task = asyncio.create_task(listen_messages())
|
||||
|
@ -888,10 +868,12 @@ async def stream_agent_run(
|
|||
finally:
|
||||
terminate_stream = True
|
||||
# Graceful shutdown order: unsubscribe → close → cancel
|
||||
if pubsub_response: await pubsub_response.unsubscribe(response_channel)
|
||||
if pubsub_control: await pubsub_control.unsubscribe(control_channel)
|
||||
if pubsub_response: await pubsub_response.close()
|
||||
if pubsub_control: await pubsub_control.close()
|
||||
try:
|
||||
if 'pubsub' in locals() and pubsub:
|
||||
await pubsub.unsubscribe(response_channel, control_channel)
|
||||
await pubsub.close()
|
||||
except Exception as e:
|
||||
logger.debug(f"Error during pubsub cleanup for {agent_run_id}: {e}")
|
||||
|
||||
if listener_task:
|
||||
listener_task.cancel()
|
||||
|
|
Loading…
Reference in New Issue