mirror of https://github.com/kortix-ai/suna.git
Merge pull request #704 from tnfssc/fix/retry-redis-pubsub
This commit is contained in:
commit
9f0534dcb4
|
@ -15,6 +15,7 @@ from services import redis
|
||||||
from dramatiq.brokers.rabbitmq import RabbitmqBroker
|
from dramatiq.brokers.rabbitmq import RabbitmqBroker
|
||||||
import os
|
import os
|
||||||
from services.langfuse import langfuse
|
from services.langfuse import langfuse
|
||||||
|
from utils.retry import retry
|
||||||
|
|
||||||
rabbitmq_host = os.getenv('RABBITMQ_HOST', 'rabbitmq')
|
rabbitmq_host = os.getenv('RABBITMQ_HOST', 'rabbitmq')
|
||||||
rabbitmq_port = int(os.getenv('RABBITMQ_PORT', 5672))
|
rabbitmq_port = int(os.getenv('RABBITMQ_PORT', 5672))
|
||||||
|
@ -28,18 +29,12 @@ instance_id = "single"
|
||||||
async def initialize():
|
async def initialize():
|
||||||
"""Initialize the agent API with resources from the main API."""
|
"""Initialize the agent API with resources from the main API."""
|
||||||
global db, instance_id, _initialized
|
global db, instance_id, _initialized
|
||||||
if _initialized:
|
|
||||||
try: await redis.client.ping()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Redis connection failed, re-initializing: {e}")
|
|
||||||
await redis.initialize_async(force=True)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Use provided instance_id or generate a new one
|
# Use provided instance_id or generate a new one
|
||||||
if not instance_id:
|
if not instance_id:
|
||||||
# Generate instance ID
|
# Generate instance ID
|
||||||
instance_id = str(uuid.uuid4())[:8]
|
instance_id = str(uuid.uuid4())[:8]
|
||||||
await redis.initialize_async()
|
await retry(lambda: redis.initialize_async())
|
||||||
await db.initialize()
|
await db.initialize()
|
||||||
|
|
||||||
_initialized = True
|
_initialized = True
|
||||||
|
@ -136,7 +131,12 @@ async def run_agent_background(
|
||||||
try:
|
try:
|
||||||
# Setup Pub/Sub listener for control signals
|
# Setup Pub/Sub listener for control signals
|
||||||
pubsub = await redis.create_pubsub()
|
pubsub = await redis.create_pubsub()
|
||||||
await pubsub.subscribe(instance_control_channel, global_control_channel)
|
try:
|
||||||
|
await retry(lambda: pubsub.subscribe(instance_control_channel, global_control_channel))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Redis failed to subscribe to control channels: {e}", exc_info=True)
|
||||||
|
raise e
|
||||||
|
|
||||||
logger.debug(f"Subscribed to control channels: {instance_control_channel}, {global_control_channel}")
|
logger.debug(f"Subscribed to control channels: {instance_control_channel}, {global_control_channel}")
|
||||||
stop_checker = asyncio.create_task(check_for_stop_signal())
|
stop_checker = asyncio.create_task(check_for_stop_signal())
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ from dotenv import load_dotenv
|
||||||
import asyncio
|
import asyncio
|
||||||
from utils.logger import logger
|
from utils.logger import logger
|
||||||
from typing import List, Any
|
from typing import List, Any
|
||||||
|
from utils.retry import retry
|
||||||
|
|
||||||
# Redis client
|
# Redis client
|
||||||
client: redis.Redis | None = None
|
client: redis.Redis | None = None
|
||||||
|
@ -22,12 +23,12 @@ def initialize():
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
# Get Redis configuration
|
# Get Redis configuration
|
||||||
redis_host = os.getenv('REDIS_HOST', 'redis')
|
redis_host = os.getenv("REDIS_HOST", "redis")
|
||||||
redis_port = int(os.getenv('REDIS_PORT', 6379))
|
redis_port = int(os.getenv("REDIS_PORT", 6379))
|
||||||
redis_password = os.getenv('REDIS_PASSWORD', '')
|
redis_password = os.getenv("REDIS_PASSWORD", "")
|
||||||
# Convert string 'True'/'False' to boolean
|
# Convert string 'True'/'False' to boolean
|
||||||
redis_ssl_str = os.getenv('REDIS_SSL', 'False')
|
redis_ssl_str = os.getenv("REDIS_SSL", "False")
|
||||||
redis_ssl = redis_ssl_str.lower() == 'true'
|
redis_ssl = redis_ssl_str.lower() == "true"
|
||||||
|
|
||||||
logger.info(f"Initializing Redis connection to {redis_host}:{redis_port}")
|
logger.info(f"Initializing Redis connection to {redis_host}:{redis_port}")
|
||||||
|
|
||||||
|
@ -41,37 +42,30 @@ def initialize():
|
||||||
socket_timeout=5.0,
|
socket_timeout=5.0,
|
||||||
socket_connect_timeout=5.0,
|
socket_connect_timeout=5.0,
|
||||||
retry_on_timeout=True,
|
retry_on_timeout=True,
|
||||||
health_check_interval=30
|
health_check_interval=30,
|
||||||
)
|
)
|
||||||
|
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
async def initialize_async(force: bool = False):
|
async def initialize_async():
|
||||||
"""Initialize Redis connection asynchronously."""
|
"""Initialize Redis connection asynchronously."""
|
||||||
global client, _initialized
|
global client, _initialized
|
||||||
|
|
||||||
async with _init_lock:
|
async with _init_lock:
|
||||||
if _initialized and force:
|
|
||||||
logger.info("Redis connection already initialized, closing and re-initializing")
|
|
||||||
_initialized = False
|
|
||||||
try:
|
|
||||||
await close()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to close Redis connection, proceeding with re-initialization anyway: {e}")
|
|
||||||
|
|
||||||
if not _initialized:
|
if not _initialized:
|
||||||
logger.info("Initializing Redis connection")
|
logger.info("Initializing Redis connection")
|
||||||
initialize()
|
initialize()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await client.ping()
|
await client.ping()
|
||||||
logger.info("Successfully connected to Redis")
|
logger.info("Successfully connected to Redis")
|
||||||
_initialized = True
|
_initialized = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to connect to Redis: {e}")
|
logger.error(f"Failed to connect to Redis: {e}")
|
||||||
client = None
|
client = None
|
||||||
raise
|
_initialized = False
|
||||||
|
raise
|
||||||
|
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
@ -91,7 +85,7 @@ async def get_client():
|
||||||
"""Get the Redis client, initializing if necessary."""
|
"""Get the Redis client, initializing if necessary."""
|
||||||
global client, _initialized
|
global client, _initialized
|
||||||
if client is None or not _initialized:
|
if client is None or not _initialized:
|
||||||
await initialize_async()
|
await retry(lambda: initialize_async())
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
import asyncio
|
||||||
|
from typing import TypeVar, Callable, Awaitable, Optional
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
async def retry(
|
||||||
|
fn: Callable[[], Awaitable[T]],
|
||||||
|
max_attempts: int = 3,
|
||||||
|
delay_seconds: int = 1,
|
||||||
|
) -> T:
|
||||||
|
"""
|
||||||
|
Retry an async function with exponential backoff.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: The async function to retry
|
||||||
|
max_attempts: Maximum number of attempts
|
||||||
|
delay_seconds: Delay between attempts in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result of the function call
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
The last exception if all attempts fail
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
async def fetch_data():
|
||||||
|
# Some operation that might fail
|
||||||
|
return await api_call()
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await retry(fetch_data, max_attempts=3, delay_seconds=2)
|
||||||
|
print(f"Success: {result}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed after all retries: {e}")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if max_attempts <= 0:
|
||||||
|
raise ValueError("max_attempts must be greater than zero")
|
||||||
|
|
||||||
|
last_error: Optional[Exception] = None
|
||||||
|
|
||||||
|
for attempt in range(1, max_attempts + 1):
|
||||||
|
try:
|
||||||
|
return await fn()
|
||||||
|
except Exception as error:
|
||||||
|
last_error = error
|
||||||
|
|
||||||
|
if attempt == max_attempts:
|
||||||
|
break
|
||||||
|
|
||||||
|
await asyncio.sleep(delay_seconds)
|
||||||
|
|
||||||
|
if last_error:
|
||||||
|
raise last_error
|
||||||
|
|
||||||
|
raise RuntimeError("Unexpected: last_error is None")
|
Loading…
Reference in New Issue