diff --git a/backend/run_agent_background.py b/backend/run_agent_background.py index cbae733c..547eafcb 100644 --- a/backend/run_agent_background.py +++ b/backend/run_agent_background.py @@ -15,6 +15,7 @@ from services import redis from dramatiq.brokers.rabbitmq import RabbitmqBroker import os from services.langfuse import langfuse +from utils.retry import retry rabbitmq_host = os.getenv('RABBITMQ_HOST', 'rabbitmq') rabbitmq_port = int(os.getenv('RABBITMQ_PORT', 5672)) @@ -28,18 +29,12 @@ instance_id = "single" async def initialize(): """Initialize the agent API with resources from the main API.""" global db, instance_id, _initialized - if _initialized: - try: await redis.client.ping() - except Exception as e: - logger.warning(f"Redis connection failed, re-initializing: {e}") - await redis.initialize_async(force=True) - return # Use provided instance_id or generate a new one if not instance_id: # Generate instance ID instance_id = str(uuid.uuid4())[:8] - await redis.initialize_async() + await retry(lambda: redis.initialize_async()) await db.initialize() _initialized = True @@ -136,7 +131,12 @@ async def run_agent_background( try: # Setup Pub/Sub listener for control signals pubsub = await redis.create_pubsub() - await pubsub.subscribe(instance_control_channel, global_control_channel) + try: + await retry(lambda: pubsub.subscribe(instance_control_channel, global_control_channel)) + except Exception as e: + logger.error(f"Redis failed to subscribe to control channels: {e}", exc_info=True) + raise e + logger.debug(f"Subscribed to control channels: {instance_control_channel}, {global_control_channel}") stop_checker = asyncio.create_task(check_for_stop_signal()) diff --git a/backend/services/redis.py b/backend/services/redis.py index 67a1df0a..12598790 100644 --- a/backend/services/redis.py +++ b/backend/services/redis.py @@ -4,6 +4,7 @@ from dotenv import load_dotenv import asyncio from utils.logger import logger from typing import List, Any +from utils.retry import retry # Redis client client: redis.Redis | None = None @@ -22,12 +23,12 @@ def initialize(): load_dotenv() # Get Redis configuration - redis_host = os.getenv('REDIS_HOST', 'redis') - redis_port = int(os.getenv('REDIS_PORT', 6379)) - redis_password = os.getenv('REDIS_PASSWORD', '') + redis_host = os.getenv("REDIS_HOST", "redis") + redis_port = int(os.getenv("REDIS_PORT", 6379)) + redis_password = os.getenv("REDIS_PASSWORD", "") # Convert string 'True'/'False' to boolean - redis_ssl_str = os.getenv('REDIS_SSL', 'False') - redis_ssl = redis_ssl_str.lower() == 'true' + redis_ssl_str = os.getenv("REDIS_SSL", "False") + redis_ssl = redis_ssl_str.lower() == "true" logger.info(f"Initializing Redis connection to {redis_host}:{redis_port}") @@ -41,37 +42,30 @@ def initialize(): socket_timeout=5.0, socket_connect_timeout=5.0, retry_on_timeout=True, - health_check_interval=30 + health_check_interval=30, ) return client -async def initialize_async(force: bool = False): +async def initialize_async(): """Initialize Redis connection asynchronously.""" global client, _initialized async with _init_lock: - if _initialized and force: - logger.info("Redis connection already initialized, closing and re-initializing") - _initialized = False - try: - await close() - except Exception as e: - logger.warning(f"Failed to close Redis connection, proceeding with re-initialization anyway: {e}") - if not _initialized: logger.info("Initializing Redis connection") initialize() - try: - await client.ping() - logger.info("Successfully connected to Redis") - _initialized = True - except Exception as e: - logger.error(f"Failed to connect to Redis: {e}") - client = None - raise + try: + await client.ping() + logger.info("Successfully connected to Redis") + _initialized = True + except Exception as e: + logger.error(f"Failed to connect to Redis: {e}") + client = None + _initialized = False + raise return client @@ -91,7 +85,7 @@ async def get_client(): """Get the Redis client, initializing if necessary.""" global client, _initialized if client is None or not _initialized: - await initialize_async() + await retry(lambda: initialize_async()) return client @@ -156,4 +150,4 @@ async def expire(key: str, time: int): async def keys(pattern: str) -> List[str]: """Get keys matching a pattern.""" redis_client = await get_client() - return await redis_client.keys(pattern) \ No newline at end of file + return await redis_client.keys(pattern) diff --git a/backend/utils/retry.py b/backend/utils/retry.py new file mode 100644 index 00000000..992c0454 --- /dev/null +++ b/backend/utils/retry.py @@ -0,0 +1,58 @@ +import asyncio +from typing import TypeVar, Callable, Awaitable, Optional + +T = TypeVar("T") + + +async def retry( + fn: Callable[[], Awaitable[T]], + max_attempts: int = 3, + delay_seconds: int = 1, +) -> T: + """ + Retry an async function with exponential backoff. + + Args: + fn: The async function to retry + max_attempts: Maximum number of attempts + delay_seconds: Delay between attempts in seconds + + Returns: + The result of the function call + + Raises: + The last exception if all attempts fail + + Example: + ```python + async def fetch_data(): + # Some operation that might fail + return await api_call() + + try: + result = await retry(fetch_data, max_attempts=3, delay_seconds=2) + print(f"Success: {result}") + except Exception as e: + print(f"Failed after all retries: {e}") + ``` + """ + if max_attempts <= 0: + raise ValueError("max_attempts must be greater than zero") + + last_error: Optional[Exception] = None + + for attempt in range(1, max_attempts + 1): + try: + return await fn() + except Exception as error: + last_error = error + + if attempt == max_attempts: + break + + await asyncio.sleep(delay_seconds) + + if last_error: + raise last_error + + raise RuntimeError("Unexpected: last_error is None")