"""SQLite-backed session store — sessions survive server restarts."""

import json
import sqlite3
from datetime import datetime, timezone

import aiosqlite

from navi.llm.base import Message

from .session import Session, SessionStore

_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS sessions (
    id          TEXT PRIMARY KEY,
    profile_id  TEXT NOT NULL,
    messages    TEXT NOT NULL DEFAULT '[]',
    context     TEXT NOT NULL DEFAULT '',
    pinned      INTEGER NOT NULL DEFAULT 0,
    created_at  TEXT NOT NULL,
    last_active TEXT NOT NULL
)
"""


def _serialize(messages: list[Message]) -> str:
    return json.dumps(
        [m.model_dump(mode='json', exclude_none=True) for m in messages],
        ensure_ascii=False,
    )


def _deserialize(raw: str) -> list[Message]:
    if not raw:
        return []
    return [Message.model_validate(m) for m in json.loads(raw)]


class SqliteSessionStore(SessionStore):
    def __init__(self, db_path: str = "navi.db") -> None:
        self._db_path = db_path
        with sqlite3.connect(db_path) as conn:
            conn.execute(_CREATE_TABLE)
            for migration in [
                "ALTER TABLE sessions ADD COLUMN pinned INTEGER NOT NULL DEFAULT 0",
                "ALTER TABLE sessions ADD COLUMN context TEXT NOT NULL DEFAULT ''",
            ]:
                try:
                    conn.execute(migration)
                except sqlite3.OperationalError:
                    pass  # column already exists
            conn.commit()

    async def create(self, profile_id: str) -> Session:
        session = Session(profile_id=profile_id)
        async with aiosqlite.connect(self._db_path) as db:
            await db.execute(
                "INSERT INTO sessions (id, profile_id, messages, context, pinned, created_at, last_active) "
                "VALUES (?, ?, '[]', '', 0, ?, ?)",
                (session.id, session.profile_id,
                 session.created_at.isoformat(), session.last_active.isoformat()),
            )
            await db.commit()
        return session

    async def get(self, session_id: str) -> Session | None:
        async with aiosqlite.connect(self._db_path) as db:
            async with db.execute(
                "SELECT id, profile_id, messages, context, pinned, created_at, last_active "
                "FROM sessions WHERE id = ?",
                (session_id,),
            ) as cur:
                row = await cur.fetchone()
        return self._row_to_session(row) if row else None

    async def save(self, session: Session) -> None:
        session.last_active = datetime.now(timezone.utc)
        async with aiosqlite.connect(self._db_path) as db:
            await db.execute(
                "UPDATE sessions SET messages = ?, context = ?, last_active = ? WHERE id = ?",
                (_serialize(session.messages), _serialize(session.context),
                 session.last_active.isoformat(), session.id),
            )
            await db.commit()

    async def set_pinned(self, session_id: str, pinned: bool) -> bool:
        async with aiosqlite.connect(self._db_path) as db:
            cur = await db.execute(
                "UPDATE sessions SET pinned = ? WHERE id = ?",
                (1 if pinned else 0, session_id),
            )
            await db.commit()
            return cur.rowcount > 0

    async def list_all(self) -> list[Session]:
        async with aiosqlite.connect(self._db_path) as db:
            async with db.execute(
                "SELECT id, profile_id, messages, context, pinned, created_at, last_active "
                "FROM sessions ORDER BY pinned DESC, last_active DESC"
            ) as cur:
                rows = await cur.fetchall()
        return [self._row_to_session(r) for r in rows]

    async def delete(self, session_id: str) -> bool:
        async with aiosqlite.connect(self._db_path) as db:
            cur = await db.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
            await db.commit()
            return cur.rowcount > 0

    def _row_to_session(self, row: tuple) -> Session:
        id_, profile_id, messages_json, context_json, pinned, created_at, last_active = row
        messages = _deserialize(messages_json)
        # Backward compat: existing sessions have empty context column —
        # initialize context from messages so they work without re-compression.
        context = _deserialize(context_json) if context_json else list(messages)
        return Session(
            id=id_,
            profile_id=profile_id,
            messages=messages,
            context=context,
            pinned=bool(pinned),
            created_at=datetime.fromisoformat(created_at),
            last_active=datetime.fromisoformat(last_active),
        )
