suna/backend/services/redis.py

165 lines
4.5 KiB
Python
Raw Normal View History

import redis.asyncio as redis
import os
from dotenv import load_dotenv
import asyncio
2025-04-02 02:49:35 +08:00
from utils.logger import logger
2025-04-25 05:15:40 +08:00
from typing import List, Any
2025-06-10 12:52:09 +08:00
from utils.retry import retry
# Redis client and connection pool
client: redis.Redis | None = None
pool: redis.ConnectionPool | None = None
_initialized = False
_init_lock = asyncio.Lock()
2025-04-25 05:15:40 +08:00
# Constants
REDIS_KEY_TTL = 3600 * 24 # 24 hour TTL as safety mechanism
2025-04-26 14:11:25 +08:00
def initialize():
"""Initialize Redis connection pool and client using environment variables."""
global client, pool
2025-04-26 14:11:25 +08:00
# Load environment variables if not already loaded
load_dotenv()
2025-04-26 14:11:25 +08:00
2025-04-25 05:15:40 +08:00
# Get Redis configuration
2025-06-10 12:52:09 +08:00
redis_host = os.getenv("REDIS_HOST", "redis")
redis_port = int(os.getenv("REDIS_PORT", 6379))
redis_password = os.getenv("REDIS_PASSWORD", "")
# Connection pool configuration
max_connections = int(os.getenv("REDIS_MAX_CONNECTIONS", 1024))
retry_on_timeout = not (os.getenv("REDIS_RETRY_ON_TIMEOUT", "True").lower() != "true")
2025-04-26 14:11:25 +08:00
logger.info(f"Initializing Redis connection pool to {redis_host}:{redis_port} with max {max_connections} connections")
2025-04-26 14:11:25 +08:00
# Create connection pool
pool = redis.ConnectionPool(
2025-04-25 05:15:40 +08:00
host=redis_host,
port=redis_port,
password=redis_password,
decode_responses=True,
2025-04-25 05:15:40 +08:00
socket_timeout=5.0,
socket_connect_timeout=5.0,
retry_on_timeout=retry_on_timeout,
2025-06-10 12:52:09 +08:00
health_check_interval=30,
max_connections=max_connections,
)
2025-04-26 14:11:25 +08:00
# Create Redis client from connection pool
client = redis.Redis(connection_pool=pool)
return client
2025-04-26 14:11:25 +08:00
2025-06-10 12:52:09 +08:00
async def initialize_async():
"""Initialize Redis connection asynchronously."""
global client, _initialized
2025-04-26 14:11:25 +08:00
async with _init_lock:
if not _initialized:
2025-04-01 10:36:26 +08:00
logger.info("Initializing Redis connection")
initialize()
2025-04-26 14:11:25 +08:00
2025-06-10 12:52:09 +08:00
try:
await client.ping()
logger.info("Successfully connected to Redis")
_initialized = True
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
client = None
_initialized = False
raise
2025-04-26 14:11:25 +08:00
return client
2025-04-26 14:11:25 +08:00
async def close():
"""Close Redis connection and connection pool."""
global client, pool, _initialized
if client:
2025-04-01 10:36:26 +08:00
logger.info("Closing Redis connection")
await client.aclose()
client = None
if pool:
logger.info("Closing Redis connection pool")
await pool.aclose()
pool = None
_initialized = False
logger.info("Redis connection and pool closed")
2025-04-26 14:11:25 +08:00
async def get_client():
"""Get the Redis client, initializing if necessary."""
global client, _initialized
if client is None or not _initialized:
2025-06-10 12:52:09 +08:00
await retry(lambda: initialize_async())
return client
2025-04-26 14:11:25 +08:00
2025-04-25 05:15:40 +08:00
# Basic Redis operations
async def set(key: str, value: str, ex: int = None, nx: bool = False):
2025-04-25 05:15:40 +08:00
"""Set a Redis key."""
redis_client = await get_client()
return await redis_client.set(key, value, ex=ex, nx=nx)
2025-04-26 14:11:25 +08:00
2025-04-25 05:15:40 +08:00
async def get(key: str, default: str = None):
"""Get a Redis key."""
redis_client = await get_client()
2025-04-25 05:15:40 +08:00
result = await redis_client.get(key)
return result if result is not None else default
2025-04-26 14:11:25 +08:00
2025-04-25 05:15:40 +08:00
async def delete(key: str):
"""Delete a Redis key."""
redis_client = await get_client()
2025-04-25 05:15:40 +08:00
return await redis_client.delete(key)
2025-04-26 14:11:25 +08:00
2025-04-25 05:15:40 +08:00
async def publish(channel: str, message: str):
"""Publish a message to a Redis channel."""
redis_client = await get_client()
2025-04-25 05:15:40 +08:00
return await redis_client.publish(channel, message)
2025-04-26 14:11:25 +08:00
2025-04-25 05:15:40 +08:00
async def create_pubsub():
"""Create a Redis pubsub object."""
redis_client = await get_client()
2025-04-25 05:15:40 +08:00
return redis_client.pubsub()
2025-04-26 14:11:25 +08:00
2025-04-25 05:15:40 +08:00
# List operations
async def rpush(key: str, *values: Any):
"""Append one or more values to a list."""
2025-04-24 08:37:14 +08:00
redis_client = await get_client()
2025-04-25 05:15:40 +08:00
return await redis_client.rpush(key, *values)
2025-04-24 08:37:14 +08:00
2025-04-26 14:11:25 +08:00
2025-04-25 05:15:40 +08:00
async def lrange(key: str, start: int, end: int) -> List[str]:
"""Get a range of elements from a list."""
2025-04-24 08:37:14 +08:00
redis_client = await get_client()
2025-04-25 05:15:40 +08:00
return await redis_client.lrange(key, start, end)
2025-04-24 08:37:14 +08:00
2025-04-26 14:11:25 +08:00
2025-04-25 05:15:40 +08:00
async def llen(key: str) -> int:
"""Get the length of a list."""
2025-04-24 08:37:14 +08:00
redis_client = await get_client()
2025-04-25 05:15:40 +08:00
return await redis_client.llen(key)
2025-04-24 08:37:14 +08:00
2025-04-26 14:11:25 +08:00
2025-04-25 05:15:40 +08:00
# Key management
async def expire(key: str, time: int):
"""Set a key's time to live in seconds."""
2025-04-24 08:37:14 +08:00
redis_client = await get_client()
2025-04-25 05:15:40 +08:00
return await redis_client.expire(key, time)
2025-04-24 08:37:14 +08:00
2025-04-26 14:11:25 +08:00
2025-04-25 05:15:40 +08:00
async def keys(pattern: str) -> List[str]:
"""Get keys matching a pattern."""
redis_client = await get_client()
2025-06-10 12:52:09 +08:00
return await redis_client.keys(pattern)