Newer
Older
navi-1 / navi / core / pg_session_store.py
"""PostgreSQL-backed session store using asyncpg connection pool."""

import asyncio
import json
from datetime import datetime, timezone

import asyncpg

from navi.llm.base import Message

from .session import Session, SessionStore

_DDL = """
CREATE TABLE IF NOT EXISTS sessions (
    id                  TEXT PRIMARY KEY,
    profile_id          TEXT NOT NULL,
    user_id             TEXT REFERENCES navi_users(id) ON DELETE SET NULL,
    messages            TEXT NOT NULL DEFAULT '[]',
    context             TEXT NOT NULL DEFAULT '',
    pinned              BOOLEAN NOT NULL DEFAULT FALSE,
    created_at          TIMESTAMPTZ NOT NULL,
    last_active         TIMESTAMPTZ NOT NULL,
    context_token_count INTEGER NOT NULL DEFAULT 0,
    name                TEXT
)
"""

_MIGRATE = """
ALTER TABLE sessions ADD COLUMN IF NOT EXISTS name TEXT;
ALTER TABLE sessions ADD COLUMN IF NOT EXISTS planning_logs TEXT NOT NULL DEFAULT '[]';
ALTER TABLE sessions ADD COLUMN IF NOT EXISTS user_id TEXT REFERENCES navi_users(id) ON DELETE SET NULL
"""


def _serialize(messages: list[Message]) -> str:
    return json.dumps(
        [m.model_dump(mode="json", exclude_none=True) for m in messages],
        ensure_ascii=False,
    )


def _deserialize(raw: str) -> list[Message]:
    if not raw:
        return []
    return [Message.model_validate(m) for m in json.loads(raw)]


class PgSessionStore(SessionStore):
    def __init__(self, pool: asyncpg.Pool) -> None:
        self._pool = pool
        self._initialized = False
        self._lock = asyncio.Lock()

    async def _get_pool(self) -> asyncpg.Pool:
        if not self._initialized:
            async with self._lock:
                if not self._initialized:
                    async with self._pool.acquire() as conn:
                        await conn.execute(_DDL)
                        await conn.execute(_MIGRATE)
                    self._initialized = True
        return self._pool

    async def create(self, profile_id: str, user_id: str | None = None) -> Session:
        session = Session(profile_id=profile_id, user_id=user_id)
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                "INSERT INTO sessions "
                "(id, profile_id, user_id, messages, context, pinned, created_at, last_active, context_token_count) "
                "VALUES ($1, $2, $3, '[]', '', FALSE, $4, $5, 0)",
                session.id, session.profile_id, session.user_id, session.created_at, session.last_active,
            )
        return session

    async def get(self, session_id: str) -> Session | None:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            row = await conn.fetchrow(
                "SELECT id, profile_id, user_id, messages, context, pinned, created_at, last_active, context_token_count, name, planning_logs "
                "FROM sessions WHERE id = $1",
                session_id,
            )
        return self._row_to_session(row) if row else None

    async def save(self, session: Session) -> None:
        session.last_active = datetime.now(timezone.utc)
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                "UPDATE sessions SET profile_id = $1, user_id = $2, messages = $3, context = $4, "
                "last_active = $5, context_token_count = $6, planning_logs = $7 WHERE id = $8",
                session.profile_id, session.user_id, _serialize(session.messages), _serialize(session.context),
                session.last_active, session.context_token_count,
                json.dumps(session.planning_logs, ensure_ascii=False), session.id,
            )

    async def set_pinned(self, session_id: str, pinned: bool) -> bool:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            result = await conn.execute(
                "UPDATE sessions SET pinned = $1 WHERE id = $2",
                pinned, session_id,
            )
        return result == "UPDATE 1"

    async def set_name(self, session_id: str, name: str) -> bool:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            result = await conn.execute(
                "UPDATE sessions SET name = $1 WHERE id = $2",
                name, session_id,
            )
        return result == "UPDATE 1"

    async def list_all(self, user_id: str | None = None, is_admin: bool = False) -> list[Session]:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            if not is_admin and user_id is not None:
                rows = await conn.fetch(
                    "SELECT id, profile_id, user_id, messages, context, pinned, created_at, last_active, context_token_count, name, planning_logs "
                    "FROM sessions WHERE user_id = $1 ORDER BY pinned DESC, last_active DESC",
                    user_id,
                )
            else:
                rows = await conn.fetch(
                    "SELECT id, profile_id, user_id, messages, context, pinned, created_at, last_active, context_token_count, name, planning_logs "
                    "FROM sessions ORDER BY pinned DESC, last_active DESC"
                )
        return [self._row_to_session(r) for r in rows]

    async def list_page(
        self,
        *,
        limit: int,
        offset: int,
        profile_id: str | None = None,
        user_id: str | None = None,
        is_admin: bool = False,
    ) -> list[Session]:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            conditions = []
            params: list = []
            param_idx = 0

            def add_param(value):
                nonlocal param_idx
                param_idx += 1
                params.append(value)
                return f"${param_idx}"

            if not is_admin and user_id is not None:
                conditions.append(f"user_id = {add_param(user_id)}")
            if profile_id:
                conditions.append(f"profile_id = {add_param(profile_id)}")

            where = "WHERE " + " AND ".join(conditions) if conditions else ""
            order_limit = f"ORDER BY pinned DESC, last_active DESC LIMIT {add_param(limit)} OFFSET {add_param(offset)}"

            rows = await conn.fetch(
                "SELECT id, profile_id, user_id, messages, context, pinned, created_at, last_active, context_token_count, name, planning_logs "
                f"FROM sessions {where} {order_limit}",
                *params,
            )
        return [self._row_to_session(r) for r in rows]

    async def delete(self, session_id: str) -> bool:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            result = await conn.execute("DELETE FROM sessions WHERE id = $1", session_id)
        return result == "DELETE 1"

    async def count_all(
        self,
        *,
        user_id: str | None = None,
        is_admin: bool = False,
        search: str | None = None,
    ) -> int:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            conditions = []
            params: list = []
            param_idx = 0

            def add_param(value):
                nonlocal param_idx
                param_idx += 1
                params.append(value)
                return f"${param_idx}"

            if not is_admin and user_id is not None:
                conditions.append(f"user_id = {add_param(user_id)}")
            if search:
                like = f"%{search}%"
                conditions.append(
                    f"(id ILIKE {add_param(like)} OR name ILIKE {add_param(like)} OR user_id ILIKE {add_param(like)} OR profile_id ILIKE {add_param(like)} OR messages ILIKE {add_param(like)})"
                )

            where = "WHERE " + " AND ".join(conditions) if conditions else ""
            row = await conn.fetchrow(f"SELECT COUNT(*) FROM sessions {where}", *params)
        return row["count"] if row else 0

    async def search_list(
        self,
        *,
        limit: int,
        offset: int,
        user_id: str | None = None,
        is_admin: bool = False,
        search: str | None = None,
        sort_by: str = "last_active",
        sort_order: str = "desc",
    ) -> list[Session]:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            conditions = []
            params: list = []
            param_idx = 0

            def add_param(value):
                nonlocal param_idx
                param_idx += 1
                params.append(value)
                return f"${param_idx}"

            if not is_admin and user_id is not None:
                conditions.append(f"user_id = {add_param(user_id)}")
            if search:
                like = f"%{search}%"
                conditions.append(
                    f"(id ILIKE {add_param(like)} OR name ILIKE {add_param(like)} OR user_id ILIKE {add_param(like)} OR profile_id ILIKE {add_param(like)} OR messages ILIKE {add_param(like)})"
                )

            where = "WHERE " + " AND ".join(conditions) if conditions else ""

            allowed_cols = {"created_at", "last_active", "name", "profile_id", "user_id", "pinned"}
            col = sort_by if sort_by in allowed_cols else "last_active"
            order = "DESC" if sort_order == "desc" else "ASC"
            # secondary sort by pinned DESC for stability
            order_clause = f"ORDER BY pinned DESC, {col} {order} LIMIT {add_param(limit)} OFFSET {add_param(offset)}"

            rows = await conn.fetch(
                "SELECT id, profile_id, user_id, messages, context, pinned, created_at, last_active, context_token_count, name, planning_logs "
                f"FROM sessions {where} {order_clause}",
                *params,
            )
        return [self._row_to_session(r) for r in rows]

    def _row_to_session(self, row: asyncpg.Record) -> Session:
        messages = _deserialize(row["messages"])
        context_json = row["context"]
        context = _deserialize(context_json) if context_json else list(messages)
        planning_logs_raw = row["planning_logs"]
        planning_logs = json.loads(planning_logs_raw) if planning_logs_raw else []
        return Session(
            id=row["id"],
            profile_id=row["profile_id"],
            user_id=row["user_id"],
            messages=messages,
            context=context,
            pinned=bool(row["pinned"]),
            name=row["name"],
            created_at=row["created_at"],
            last_active=row["last_active"],
            context_token_count=row["context_token_count"] or 0,
            planning_logs=planning_logs,
        )