Merge pull request #704 from tnfssc/fix/retry-redis-pubsub

This commit is contained in:
Sharath 2025-06-10 17:59:05 +05:30 committed by GitHub
commit 9f0534dcb4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 85 additions and 33 deletions

View File

@ -15,6 +15,7 @@ from services import redis
from dramatiq.brokers.rabbitmq import RabbitmqBroker from dramatiq.brokers.rabbitmq import RabbitmqBroker
import os import os
from services.langfuse import langfuse from services.langfuse import langfuse
from utils.retry import retry
rabbitmq_host = os.getenv('RABBITMQ_HOST', 'rabbitmq') rabbitmq_host = os.getenv('RABBITMQ_HOST', 'rabbitmq')
rabbitmq_port = int(os.getenv('RABBITMQ_PORT', 5672)) rabbitmq_port = int(os.getenv('RABBITMQ_PORT', 5672))
@ -28,18 +29,12 @@ instance_id = "single"
async def initialize(): async def initialize():
"""Initialize the agent API with resources from the main API.""" """Initialize the agent API with resources from the main API."""
global db, instance_id, _initialized global db, instance_id, _initialized
if _initialized:
try: await redis.client.ping()
except Exception as e:
logger.warning(f"Redis connection failed, re-initializing: {e}")
await redis.initialize_async(force=True)
return
# Use provided instance_id or generate a new one # Use provided instance_id or generate a new one
if not instance_id: if not instance_id:
# Generate instance ID # Generate instance ID
instance_id = str(uuid.uuid4())[:8] instance_id = str(uuid.uuid4())[:8]
await redis.initialize_async() await retry(lambda: redis.initialize_async())
await db.initialize() await db.initialize()
_initialized = True _initialized = True
@ -136,7 +131,12 @@ async def run_agent_background(
try: try:
# Setup Pub/Sub listener for control signals # Setup Pub/Sub listener for control signals
pubsub = await redis.create_pubsub() pubsub = await redis.create_pubsub()
await pubsub.subscribe(instance_control_channel, global_control_channel) try:
await retry(lambda: pubsub.subscribe(instance_control_channel, global_control_channel))
except Exception as e:
logger.error(f"Redis failed to subscribe to control channels: {e}", exc_info=True)
raise e
logger.debug(f"Subscribed to control channels: {instance_control_channel}, {global_control_channel}") logger.debug(f"Subscribed to control channels: {instance_control_channel}, {global_control_channel}")
stop_checker = asyncio.create_task(check_for_stop_signal()) stop_checker = asyncio.create_task(check_for_stop_signal())

View File

@ -4,6 +4,7 @@ from dotenv import load_dotenv
import asyncio import asyncio
from utils.logger import logger from utils.logger import logger
from typing import List, Any from typing import List, Any
from utils.retry import retry
# Redis client # Redis client
client: redis.Redis | None = None client: redis.Redis | None = None
@ -22,12 +23,12 @@ def initialize():
load_dotenv() load_dotenv()
# Get Redis configuration # Get Redis configuration
redis_host = os.getenv('REDIS_HOST', 'redis') redis_host = os.getenv("REDIS_HOST", "redis")
redis_port = int(os.getenv('REDIS_PORT', 6379)) redis_port = int(os.getenv("REDIS_PORT", 6379))
redis_password = os.getenv('REDIS_PASSWORD', '') redis_password = os.getenv("REDIS_PASSWORD", "")
# Convert string 'True'/'False' to boolean # Convert string 'True'/'False' to boolean
redis_ssl_str = os.getenv('REDIS_SSL', 'False') redis_ssl_str = os.getenv("REDIS_SSL", "False")
redis_ssl = redis_ssl_str.lower() == 'true' redis_ssl = redis_ssl_str.lower() == "true"
logger.info(f"Initializing Redis connection to {redis_host}:{redis_port}") logger.info(f"Initializing Redis connection to {redis_host}:{redis_port}")
@ -41,37 +42,30 @@ def initialize():
socket_timeout=5.0, socket_timeout=5.0,
socket_connect_timeout=5.0, socket_connect_timeout=5.0,
retry_on_timeout=True, retry_on_timeout=True,
health_check_interval=30 health_check_interval=30,
) )
return client return client
async def initialize_async(force: bool = False): async def initialize_async():
"""Initialize Redis connection asynchronously.""" """Initialize Redis connection asynchronously."""
global client, _initialized global client, _initialized
async with _init_lock: async with _init_lock:
if _initialized and force:
logger.info("Redis connection already initialized, closing and re-initializing")
_initialized = False
try:
await close()
except Exception as e:
logger.warning(f"Failed to close Redis connection, proceeding with re-initialization anyway: {e}")
if not _initialized: if not _initialized:
logger.info("Initializing Redis connection") logger.info("Initializing Redis connection")
initialize() initialize()
try: try:
await client.ping() await client.ping()
logger.info("Successfully connected to Redis") logger.info("Successfully connected to Redis")
_initialized = True _initialized = True
except Exception as e: except Exception as e:
logger.error(f"Failed to connect to Redis: {e}") logger.error(f"Failed to connect to Redis: {e}")
client = None client = None
raise _initialized = False
raise
return client return client
@ -91,7 +85,7 @@ async def get_client():
"""Get the Redis client, initializing if necessary.""" """Get the Redis client, initializing if necessary."""
global client, _initialized global client, _initialized
if client is None or not _initialized: if client is None or not _initialized:
await initialize_async() await retry(lambda: initialize_async())
return client return client

58
backend/utils/retry.py Normal file
View File

@ -0,0 +1,58 @@
import asyncio
from typing import TypeVar, Callable, Awaitable, Optional
T = TypeVar("T")
async def retry(
fn: Callable[[], Awaitable[T]],
max_attempts: int = 3,
delay_seconds: int = 1,
) -> T:
"""
Retry an async function with exponential backoff.
Args:
fn: The async function to retry
max_attempts: Maximum number of attempts
delay_seconds: Delay between attempts in seconds
Returns:
The result of the function call
Raises:
The last exception if all attempts fail
Example:
```python
async def fetch_data():
# Some operation that might fail
return await api_call()
try:
result = await retry(fetch_data, max_attempts=3, delay_seconds=2)
print(f"Success: {result}")
except Exception as e:
print(f"Failed after all retries: {e}")
```
"""
if max_attempts <= 0:
raise ValueError("max_attempts must be greater than zero")
last_error: Optional[Exception] = None
for attempt in range(1, max_attempts + 1):
try:
return await fn()
except Exception as error:
last_error = error
if attempt == max_attempts:
break
await asyncio.sleep(delay_seconds)
if last_error:
raise last_error
raise RuntimeError("Unexpected: last_error is None")