"""PostgreSQL-backed session store using asyncpg connection pool."""
import asyncio
import json
from datetime import datetime, timezone
import asyncpg
from navi.llm.base import Message
from .session import Session, SessionStore
_DDL = """
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
profile_id TEXT NOT NULL,
messages TEXT NOT NULL DEFAULT '[]',
context TEXT NOT NULL DEFAULT '',
pinned BOOLEAN NOT NULL DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL,
last_active TIMESTAMPTZ NOT NULL,
context_token_count INTEGER NOT NULL DEFAULT 0
)
"""
def _serialize(messages: list[Message]) -> str:
return json.dumps(
[m.model_dump(mode="json", exclude_none=True) for m in messages],
ensure_ascii=False,
)
def _deserialize(raw: str) -> list[Message]:
if not raw:
return []
return [Message.model_validate(m) for m in json.loads(raw)]
class PgSessionStore(SessionStore):
def __init__(self, dsn: str) -> None:
self._dsn = dsn
self._pool: asyncpg.Pool | None = None
self._lock = asyncio.Lock()
async def _get_pool(self) -> asyncpg.Pool:
if self._pool is not None:
return self._pool
async with self._lock:
if self._pool is None:
pool = await asyncpg.create_pool(self._dsn)
async with pool.acquire() as conn:
await conn.execute(_DDL)
self._pool = pool
return self._pool
async def create(self, profile_id: str) -> Session:
session = Session(profile_id=profile_id)
pool = await self._get_pool()
async with pool.acquire() as conn:
await conn.execute(
"INSERT INTO sessions "
"(id, profile_id, messages, context, pinned, created_at, last_active, context_token_count) "
"VALUES ($1, $2, '[]', '', FALSE, $3, $4, 0)",
session.id, session.profile_id, session.created_at, session.last_active,
)
return session
async def get(self, session_id: str) -> Session | None:
pool = await self._get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT id, profile_id, messages, context, pinned, created_at, last_active, context_token_count "
"FROM sessions WHERE id = $1",
session_id,
)
return self._row_to_session(row) if row else None
async def save(self, session: Session) -> None:
session.last_active = datetime.now(timezone.utc)
pool = await self._get_pool()
async with pool.acquire() as conn:
await conn.execute(
"UPDATE sessions SET profile_id = $1, messages = $2, context = $3, "
"last_active = $4, context_token_count = $5 WHERE id = $6",
session.profile_id, _serialize(session.messages), _serialize(session.context),
session.last_active, session.context_token_count, session.id,
)
async def set_pinned(self, session_id: str, pinned: bool) -> bool:
pool = await self._get_pool()
async with pool.acquire() as conn:
result = await conn.execute(
"UPDATE sessions SET pinned = $1 WHERE id = $2",
pinned, session_id,
)
return result == "UPDATE 1"
async def list_all(self) -> list[Session]:
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(
"SELECT id, profile_id, messages, context, pinned, created_at, last_active, context_token_count "
"FROM sessions ORDER BY pinned DESC, last_active DESC"
)
return [self._row_to_session(r) for r in rows]
async def delete(self, session_id: str) -> bool:
pool = await self._get_pool()
async with pool.acquire() as conn:
result = await conn.execute("DELETE FROM sessions WHERE id = $1", session_id)
return result == "DELETE 1"
def _row_to_session(self, row: asyncpg.Record) -> Session:
messages = _deserialize(row["messages"])
context_json = row["context"]
context = _deserialize(context_json) if context_json else list(messages)
return Session(
id=row["id"],
profile_id=row["profile_id"],
messages=messages,
context=context,
pinned=bool(row["pinned"]),
created_at=row["created_at"],
last_active=row["last_active"],
context_token_count=row["context_token_count"] or 0,
)