Merge pull request #1475 from ffrankan/fix/redis-connection-optimization

Fix Redis connection optimization in SSE streaming
This commit is contained in:
Marko Kraemer 2025-08-28 16:11:40 -07:00 committed by GitHub
commit fc842cb601
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 23 additions and 41 deletions

View File

@ -739,8 +739,7 @@ async def stream_agent_run(
async def stream_generator(agent_run_data):
logger.debug(f"Streaming responses for {agent_run_id} using Redis list {response_list_key} and channel {response_channel}")
last_processed_index = -1
pubsub_response = None
pubsub_control = None
# Single pubsub used for response + control
listener_task = None
terminate_stream = False
initial_yield_complete = False
@ -769,38 +768,28 @@ async def stream_agent_run(
thread_id=agent_run_data.get('thread_id'),
)
# 3. Set up Pub/Sub listeners for new responses and control signals concurrently
pubsub_response_task = asyncio.create_task(redis.create_pubsub())
pubsub_control_task = asyncio.create_task(redis.create_pubsub())
pubsub_response, pubsub_control = await asyncio.gather(pubsub_response_task, pubsub_control_task)
# Subscribe to channels concurrently
response_subscribe_task = asyncio.create_task(pubsub_response.subscribe(response_channel))
control_subscribe_task = asyncio.create_task(pubsub_control.subscribe(control_channel))
await asyncio.gather(response_subscribe_task, control_subscribe_task)
logger.debug(f"Subscribed to response channel: {response_channel}")
logger.debug(f"Subscribed to control channel: {control_channel}")
# 3. Use a single Pub/Sub connection subscribed to both channels
pubsub = await redis.create_pubsub()
await pubsub.subscribe(response_channel, control_channel)
logger.debug(f"Subscribed to channels: {response_channel}, {control_channel}")
# Queue to communicate between listeners and the main generator loop
message_queue = asyncio.Queue()
async def listen_messages():
response_reader = pubsub_response.listen()
control_reader = pubsub_control.listen()
tasks = [asyncio.create_task(response_reader.__anext__()), asyncio.create_task(control_reader.__anext__())]
listener = pubsub.listen()
task = asyncio.create_task(listener.__anext__())
while not terminate_stream:
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
for task in done:
done, _ = await asyncio.wait([task], return_when=asyncio.FIRST_COMPLETED)
for finished in done:
try:
message = task.result()
message = finished.result()
if message and isinstance(message, dict) and message.get("type") == "message":
channel = message.get("channel")
data = message.get("data")
if isinstance(data, bytes): data = data.decode('utf-8')
if isinstance(data, bytes):
data = data.decode('utf-8')
if channel == response_channel and data == "new":
await message_queue.put({"type": "new_response"})
@ -810,8 +799,7 @@ async def stream_agent_run(
return # Stop listening on control signal
except StopAsyncIteration:
logger.warning(f"Listener {task} stopped.")
# Decide how to handle listener stopping, maybe terminate?
logger.warning(f"Listener stopped for {agent_run_id}.")
await message_queue.put({"type": "error", "data": "Listener stopped unexpectedly"})
return
except Exception as e:
@ -819,17 +807,9 @@ async def stream_agent_run(
await message_queue.put({"type": "error", "data": "Listener failed"})
return
finally:
# Reschedule the completed listener task
if task in tasks:
tasks.remove(task)
if message and isinstance(message, dict) and message.get("channel") == response_channel:
tasks.append(asyncio.create_task(response_reader.__anext__()))
elif message and isinstance(message, dict) and message.get("channel") == control_channel:
tasks.append(asyncio.create_task(control_reader.__anext__()))
# Cancel pending listener tasks on exit
for p_task in pending: p_task.cancel()
for task in tasks: task.cancel()
# Resubscribe to the next message if continuing
if not terminate_stream:
task = asyncio.create_task(listener.__anext__())
listener_task = asyncio.create_task(listen_messages())
@ -888,10 +868,12 @@ async def stream_agent_run(
finally:
terminate_stream = True
# Graceful shutdown order: unsubscribe → close → cancel
if pubsub_response: await pubsub_response.unsubscribe(response_channel)
if pubsub_control: await pubsub_control.unsubscribe(control_channel)
if pubsub_response: await pubsub_response.close()
if pubsub_control: await pubsub_control.close()
try:
if 'pubsub' in locals() and pubsub:
await pubsub.unsubscribe(response_channel, control_channel)
await pubsub.close()
except Exception as e:
logger.debug(f"Error during pubsub cleanup for {agent_run_id}: {e}")
if listener_task:
listener_task.cancel()