Newer
Older
navi-1 / navi / core / sqlite_session_store.py
"""SQLite-backed session store — sessions survive server restarts."""

import json
import sqlite3
from datetime import datetime

import aiosqlite

from navi.llm.base import Message

from .session import Session, SessionStore

_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS sessions (
    id          TEXT PRIMARY KEY,
    profile_id  TEXT NOT NULL,
    messages    TEXT NOT NULL DEFAULT '[]',
    pinned      INTEGER NOT NULL DEFAULT 0,
    created_at  TEXT NOT NULL,
    last_active TEXT NOT NULL
)
"""


class SqliteSessionStore(SessionStore):
    def __init__(self, db_path: str = "navi.db") -> None:
        self._db_path = db_path
        with sqlite3.connect(db_path) as conn:
            conn.execute(_CREATE_TABLE)
            # Migrate: add pinned column to existing tables that don't have it
            try:
                conn.execute("ALTER TABLE sessions ADD COLUMN pinned INTEGER NOT NULL DEFAULT 0")
            except sqlite3.OperationalError:
                pass  # column already exists
            conn.commit()

    async def create(self, profile_id: str) -> Session:
        session = Session(profile_id=profile_id)
        async with aiosqlite.connect(self._db_path) as db:
            await db.execute(
                "INSERT INTO sessions (id, profile_id, messages, pinned, created_at, last_active) "
                "VALUES (?, ?, '[]', 0, ?, ?)",
                (session.id, session.profile_id,
                 session.created_at.isoformat(), session.last_active.isoformat()),
            )
            await db.commit()
        return session

    async def get(self, session_id: str) -> Session | None:
        async with aiosqlite.connect(self._db_path) as db:
            async with db.execute(
                "SELECT id, profile_id, messages, pinned, created_at, last_active "
                "FROM sessions WHERE id = ?",
                (session_id,),
            ) as cur:
                row = await cur.fetchone()
        return self._row_to_session(row) if row else None

    async def save(self, session: Session) -> None:
        session.last_active = datetime.utcnow()
        messages_json = json.dumps(
            [m.model_dump(mode='json', exclude_none=True) for m in session.messages],
            ensure_ascii=False,
        )
        async with aiosqlite.connect(self._db_path) as db:
            await db.execute(
                "UPDATE sessions SET messages = ?, last_active = ? WHERE id = ?",
                (messages_json, session.last_active.isoformat(), session.id),
            )
            await db.commit()

    async def set_pinned(self, session_id: str, pinned: bool) -> bool:
        async with aiosqlite.connect(self._db_path) as db:
            cur = await db.execute(
                "UPDATE sessions SET pinned = ? WHERE id = ?",
                (1 if pinned else 0, session_id),
            )
            await db.commit()
            return cur.rowcount > 0

    async def list_all(self) -> list[Session]:
        async with aiosqlite.connect(self._db_path) as db:
            async with db.execute(
                "SELECT id, profile_id, messages, pinned, created_at, last_active "
                "FROM sessions ORDER BY pinned DESC, last_active DESC"
            ) as cur:
                rows = await cur.fetchall()
        return [self._row_to_session(r) for r in rows]

    async def delete(self, session_id: str) -> bool:
        async with aiosqlite.connect(self._db_path) as db:
            cur = await db.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
            await db.commit()
            return cur.rowcount > 0

    def _row_to_session(self, row: tuple) -> Session:
        id_, profile_id, messages_json, pinned, created_at, last_active = row
        messages = [Message.model_validate(m) for m in json.loads(messages_json)]
        return Session(
            id=id_,
            profile_id=profile_id,
            messages=messages,
            pinned=bool(pinned),
            created_at=datetime.fromisoformat(created_at),
            last_active=datetime.fromisoformat(last_active),
        )