mirror of https://github.com/kortix-ai/suna.git
125 lines
4.4 KiB
Python
125 lines
4.4 KiB
Python
"""
|
|
Centralized database connection management for AgentPress.
|
|
"""
|
|
|
|
import aiosqlite
|
|
import logging
|
|
from contextlib import asynccontextmanager
|
|
import os
|
|
import asyncio
|
|
|
|
class DBConnection:
|
|
"""Singleton database connection manager."""
|
|
|
|
_instance = None
|
|
_initialized = False
|
|
_db_path = os.path.join(os.getcwd(), "agentpress.db")
|
|
_init_lock = asyncio.Lock()
|
|
_initialization_task = None
|
|
|
|
def __new__(cls):
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
# Start initialization when instance is first created
|
|
cls._initialization_task = asyncio.create_task(cls._instance._initialize())
|
|
return cls._instance
|
|
|
|
def __init__(self):
|
|
"""No initialization needed in __init__ as it's handled in __new__"""
|
|
pass
|
|
|
|
@classmethod
|
|
async def _initialize(cls):
|
|
"""Internal initialization method."""
|
|
if cls._initialized:
|
|
return
|
|
|
|
async with cls._init_lock:
|
|
if cls._initialized: # Double-check after acquiring lock
|
|
return
|
|
|
|
try:
|
|
async with aiosqlite.connect(cls._db_path) as db:
|
|
# Threads table
|
|
await db.execute("""
|
|
CREATE TABLE IF NOT EXISTS threads (
|
|
thread_id TEXT PRIMARY KEY,
|
|
messages TEXT,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
""")
|
|
|
|
# State stores table
|
|
await db.execute("""
|
|
CREATE TABLE IF NOT EXISTS state_stores (
|
|
store_id TEXT PRIMARY KEY,
|
|
store_data TEXT,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
""")
|
|
|
|
await db.commit()
|
|
cls._initialized = True
|
|
logging.info("Database schema initialized")
|
|
except Exception as e:
|
|
logging.error(f"Database initialization error: {e}")
|
|
raise
|
|
|
|
@classmethod
|
|
def set_db_path(cls, db_path: str):
|
|
"""Set custom database path."""
|
|
if cls._initialized:
|
|
raise RuntimeError("Cannot change database path after initialization")
|
|
cls._db_path = db_path
|
|
logging.info(f"Updated database path to: {db_path}")
|
|
|
|
@asynccontextmanager
|
|
async def connection(self):
|
|
"""Get a database connection."""
|
|
# Wait for initialization to complete if it hasn't already
|
|
if self._initialization_task and not self._initialized:
|
|
await self._initialization_task
|
|
|
|
async with aiosqlite.connect(self._db_path) as conn:
|
|
try:
|
|
yield conn
|
|
except Exception as e:
|
|
logging.error(f"Database error: {e}")
|
|
raise
|
|
|
|
@asynccontextmanager
|
|
async def transaction(self):
|
|
"""Execute operations in a transaction."""
|
|
async with self.connection() as db:
|
|
try:
|
|
yield db
|
|
await db.commit()
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logging.error(f"Transaction error: {e}")
|
|
raise
|
|
|
|
async def execute(self, query: str, params: tuple = ()):
|
|
"""Execute a single query."""
|
|
async with self.connection() as db:
|
|
try:
|
|
result = await db.execute(query, params)
|
|
await db.commit()
|
|
return result
|
|
except Exception as e:
|
|
logging.error(f"Query execution error: {e}")
|
|
raise
|
|
|
|
async def fetch_one(self, query: str, params: tuple = ()):
|
|
"""Fetch a single row."""
|
|
async with self.connection() as db:
|
|
async with db.execute(query, params) as cursor:
|
|
return await cursor.fetchone()
|
|
|
|
async def fetch_all(self, query: str, params: tuple = ()):
|
|
"""Fetch all rows."""
|
|
async with self.connection() as db:
|
|
async with db.execute(query, params) as cursor:
|
|
return await cursor.fetchall() |