diff --git a/navi/core/container.py b/navi/core/container.py index 7656595..50e3ad6 100644 --- a/navi/core/container.py +++ b/navi/core/container.py @@ -13,6 +13,7 @@ from navi.core import Agent, BackendRegistry, ProfileRegistry, SessionStore, ToolRegistry from navi.core.orchestrator import AgentSessionOrchestrator from navi.core.scheduler import RecallScheduler + from navi.db import Database from navi.memory import MemoryStore from navi.mcp import McpManager from navi.store import KvStore @@ -23,15 +24,16 @@ class AppContainer: """Holds all application-level singletons created at startup.""" - memory_store: "MemoryStore" - session_store: "SessionStore" - kv_store: "KvStore" - scheduler: "RecallScheduler" - tool_registry: "ToolRegistry" - profile_registry: "ProfileRegistry" - backend_registry: "BackendRegistry" - cp_registry: "ContextProviderRegistry" - workers: list["Worker"] + database: "Database | None" = None + memory_store: "MemoryStore" = None # type: ignore[assignment] + session_store: "SessionStore" = None # type: ignore[assignment] + kv_store: "KvStore" = None # type: ignore[assignment] + scheduler: "RecallScheduler" = None # type: ignore[assignment] + tool_registry: "ToolRegistry" = None # type: ignore[assignment] + profile_registry: "ProfileRegistry" = None # type: ignore[assignment] + backend_registry: "BackendRegistry" = None # type: ignore[assignment] + cp_registry: "ContextProviderRegistry" = None # type: ignore[assignment] + workers: list["Worker"] = field(default_factory=list) mcp_manager: "McpManager | None" = None orchestrator: "AgentSessionOrchestrator | None" = None @@ -55,34 +57,14 @@ async def shutdown(self) -> None: """Close all resources that need explicit cleanup.""" - # MCP if self.mcp_manager is not None: try: await self.mcp_manager.disconnect_all() except Exception: pass - # Session store pool - if hasattr(self.session_store, "_pool") and self.session_store._pool is not None: + if self.database is not None: try: - await self.session_store._pool.close() - except Exception: - pass - # Memory store pool - if hasattr(self.memory_store, "_pool") and self.memory_store._pool is not None: - try: - await self.memory_store._pool.close() - except Exception: - pass - # KV store pool - if hasattr(self.kv_store, "_pool") and self.kv_store._pool is not None: - try: - await self.kv_store._pool.close() - except Exception: - pass - # Scheduler pool - if hasattr(self.scheduler, "_pool") and self.scheduler._pool is not None: - try: - await self.scheduler._pool.close() + await self.database.close() except Exception: pass @@ -92,6 +74,7 @@ from navi.core.registry import build_default_registries from navi.core.scheduler import RecallScheduler from navi.core.pg_session_store import PgSessionStore + from navi.db import Database from navi.memory import MemoryStore from navi.mcp import McpManager, load_mcp_servers from navi.mcp.tools import McpTool @@ -101,10 +84,13 @@ if not settings.database_url: raise RuntimeError("DATABASE_URL is required. SQLite support has been removed.") - session_store = PgSessionStore(settings.database_url) - memory_store = MemoryStore(settings.database_url) - kv_store = KvStore(settings.database_url) - scheduler = RecallScheduler(settings.database_url) + database = Database(settings.database_url) + pool = await database.pool() + + session_store = PgSessionStore(pool) + memory_store = MemoryStore(pool) + kv_store = KvStore(pool) + scheduler = RecallScheduler(pool) tool_registry, profile_registry, backend_registry, cp_registry = build_default_registries( memory_store=memory_store, @@ -154,6 +140,7 @@ pass container = AppContainer( + database=database, memory_store=memory_store, session_store=session_store, kv_store=kv_store, diff --git a/navi/core/pg_session_store.py b/navi/core/pg_session_store.py index d7f3b01..e17170d 100644 --- a/navi/core/pg_session_store.py +++ b/navi/core/pg_session_store.py @@ -46,21 +46,19 @@ class PgSessionStore(SessionStore): - def __init__(self, dsn: str) -> None: - self._dsn = dsn - self._pool: asyncpg.Pool | None = None + def __init__(self, pool: asyncpg.Pool) -> None: + self._pool = pool + self._initialized = False self._lock = asyncio.Lock() async def _get_pool(self) -> asyncpg.Pool: - if self._pool is not None: - return self._pool - async with self._lock: - if self._pool is None: - pool = await asyncpg.create_pool(self._dsn) - async with pool.acquire() as conn: - await conn.execute(_DDL) - await conn.execute(_MIGRATE) - self._pool = pool + if not self._initialized: + async with self._lock: + if not self._initialized: + async with self._pool.acquire() as conn: + await conn.execute(_DDL) + await conn.execute(_MIGRATE) + self._initialized = True return self._pool async def create(self, profile_id: str, user_id: str | None = None) -> Session: diff --git a/navi/core/scheduler.py b/navi/core/scheduler.py index be3ad89..96a1923 100644 --- a/navi/core/scheduler.py +++ b/navi/core/scheduler.py @@ -82,23 +82,20 @@ class RecallScheduler: """PostgreSQL-backed scheduler for session recalls.""" - def __init__(self, dsn: str) -> None: - self._dsn = dsn - self._pool: Any | None = None + def __init__(self, pool: Any) -> None: + self._pool = pool + self._initialized = False self._lock = asyncio.Lock() async def _get_pool(self) -> Any: - if self._pool is not None: - return self._pool - async with self._lock: - if self._pool is not None: - return self._pool - import asyncpg + if not self._initialized: + async with self._lock: + if not self._initialized: + import asyncpg - pool = await asyncpg.create_pool(self._dsn) - async with pool.acquire() as conn: - await conn.execute(_DDL) - self._pool = pool + async with self._pool.acquire() as conn: + await conn.execute(_DDL) + self._initialized = True return self._pool async def ensure_tables(self) -> None: diff --git a/navi/db.py b/navi/db.py new file mode 100644 index 0000000..7959111 --- /dev/null +++ b/navi/db.py @@ -0,0 +1,34 @@ +"""Database service — manages a single asyncpg connection pool for the application.""" + +import asyncio + +import asyncpg +import structlog + +log = structlog.get_logger() + + +class Database: + """Manages a single asyncpg pool shared by all stores.""" + + def __init__(self, dsn: str) -> None: + self._dsn = dsn + self._pool: asyncpg.Pool | None = None + self._lock = asyncio.Lock() + + async def pool(self) -> asyncpg.Pool: + """Return the connection pool, creating it lazily if needed.""" + if self._pool is not None: + return self._pool + async with self._lock: + if self._pool is not None: + return self._pool + self._pool = await asyncpg.create_pool(self._dsn) + log.info("db.pool_created") + return self._pool + + async def close(self) -> None: + """Close the pool if it exists.""" + if self._pool is not None: + await self._pool.close() + self._pool = None diff --git a/navi/main.py b/navi/main.py index c620079..3fea4eb 100644 --- a/navi/main.py +++ b/navi/main.py @@ -59,7 +59,7 @@ # Apply persisted profile overrides try: - pool = await container.session_store._get_pool() + pool = await container.database.pool() await ensure_table(pool) overrides = await load_overrides(pool) if overrides: diff --git a/navi/memory/store.py b/navi/memory/store.py index 90e6172..0053662 100644 --- a/navi/memory/store.py +++ b/navi/memory/store.py @@ -13,8 +13,6 @@ import asyncpg import structlog -from navi.config import settings - from ._ddl import _build_ddl from ._embeddings import EmbeddingMixin from ._facts import FactMixin @@ -28,9 +26,9 @@ class MemoryStore(EmbeddingMixin, FactMixin, SummaryMixin, SessionStateMixin): - def __init__(self, dsn: str, embedding_backend: "LLMBackend | None" = None) -> None: - self._dsn = dsn - self._pool: asyncpg.Pool | None = None + def __init__(self, pool: asyncpg.Pool, embedding_backend: "LLMBackend | None" = None) -> None: + self._pool = pool + self._initialized = False self._lock = asyncio.Lock() self._embedding_backend = embedding_backend self._pgvector_checked = False @@ -40,24 +38,22 @@ self._embedding_backend = backend async def _get_pool(self) -> asyncpg.Pool: - if self._pool is not None: - return self._pool - async with self._lock: - if self._pool is None: - pool = await asyncpg.create_pool(self._dsn) - async with pool.acquire() as conn: - pgvector_available = False - try: - await conn.execute("CREATE EXTENSION IF NOT EXISTS vector") - row = await conn.fetchval("SELECT 1 FROM pg_extension WHERE extname = 'vector'") - pgvector_available = bool(row) - except Exception: - log.warning("memory.pgvector_not_available", exc_info=True) - - for stmt in _build_ddl(pgvector_available): + if not self._initialized: + async with self._lock: + if not self._initialized: + async with self._pool.acquire() as conn: + pgvector_available = False try: - await conn.execute(stmt) + await conn.execute("CREATE EXTENSION IF NOT EXISTS vector") + row = await conn.fetchval("SELECT 1 FROM pg_extension WHERE extname = 'vector'") + pgvector_available = bool(row) except Exception: - log.warning("memory.ddl_failed", stmt=stmt[:80], exc_info=True) - self._pool = pool + log.warning("memory.pgvector_not_available", exc_info=True) + + for stmt in _build_ddl(pgvector_available): + try: + await conn.execute(stmt) + except Exception: + log.warning("memory.ddl_failed", stmt=stmt[:80], exc_info=True) + self._initialized = True return self._pool diff --git a/navi/store/__init__.py b/navi/store/__init__.py index 4bca5a7..32f932c 100644 --- a/navi/store/__init__.py +++ b/navi/store/__init__.py @@ -33,20 +33,18 @@ class KvStore: """Simple key-value persistence with user + session + scope scoping.""" - def __init__(self, dsn: str) -> None: - self._dsn = dsn - self._pool: asyncpg.Pool | None = None + def __init__(self, pool: asyncpg.Pool) -> None: + self._pool = pool + self._initialized = False self._lock = asyncio.Lock() async def _get_pool(self) -> asyncpg.Pool: - if self._pool is not None: - return self._pool - async with self._lock: - if self._pool is None: - pool = await asyncpg.create_pool(self._dsn) - async with pool.acquire() as conn: - await conn.execute(_DDL) - self._pool = pool + if not self._initialized: + async with self._lock: + if not self._initialized: + async with self._pool.acquire() as conn: + await conn.execute(_DDL) + self._initialized = True return self._pool async def get(self, user_id: str | None, session_id: str, scope: str, key: str) -> str | None: diff --git a/tests/conftest_factory.py b/tests/conftest_factory.py index 37b83c3..03a0da4 100644 --- a/tests/conftest_factory.py +++ b/tests/conftest_factory.py @@ -264,8 +264,9 @@ """Build a MemoryStore wired to a FakePool.""" from navi.memory.store import MemoryStore - store = MemoryStore(dsn="fake://test") - store._pool = FakePool(conn) + pool = FakePool(conn) + store = MemoryStore(pool=pool) + store._initialized = True return store @@ -273,6 +274,7 @@ """Build a RecallScheduler wired to a FakePool.""" from navi.core.scheduler import RecallScheduler - scheduler = RecallScheduler(dsn="fake://test") - scheduler._pool = FakePool(conn) + pool = FakePool(conn) + scheduler = RecallScheduler(pool=pool) + scheduler._initialized = True return scheduler diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 7a5f08e..29db170 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -67,6 +67,7 @@ # Build container directly — no more module-level singletons from navi.core.container import AppContainer container = AppContainer( + database=None, memory_store=None, session_store=store, kv_store=None, diff --git a/tests/unit/store/test_kv_store.py b/tests/unit/store/test_kv_store.py index 5de2bb3..d66e38f 100644 --- a/tests/unit/store/test_kv_store.py +++ b/tests/unit/store/test_kv_store.py @@ -9,8 +9,8 @@ class TestKvStore: @pytest.fixture def store(self): - s = KvStore(dsn="postgresql://fake") - s._pool = FakePool() + s = KvStore(pool=FakePool()) + s._initialized = True return s @pytest.mark.asyncio