"""SQLite-backed session store — sessions survive server restarts."""

import json
import sqlite3
from datetime import datetime, timezone

import aiosqlite

from navi.llm.base import Message

from .session import Session, SessionStore

_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS sessions (
    id                  TEXT PRIMARY KEY,
    profile_id          TEXT NOT NULL,
    messages            TEXT NOT NULL DEFAULT '[]',
    context             TEXT NOT NULL DEFAULT '',
    pinned              INTEGER NOT NULL DEFAULT 0,
    created_at          TEXT NOT NULL,
    last_active         TEXT NOT NULL,
    context_token_count INTEGER NOT NULL DEFAULT 0
)
"""


def _serialize(messages: list[Message]) -> str:
    return json.dumps(
        [m.model_dump(mode='json', exclude_none=True) for m in messages],
        ensure_ascii=False,
    )


def _parse_dt(value: str) -> datetime:
    """Parse ISO datetime string, always returning a timezone-aware datetime."""
    dt = datetime.fromisoformat(value)
    if dt.tzinfo is None:
        dt = dt.replace(tzinfo=timezone.utc)
    return dt


def _deserialize(raw: str) -> list[Message]:
    if not raw:
        return []
    return [Message.model_validate(m) for m in json.loads(raw)]


class SqliteSessionStore(SessionStore):
    def __init__(self, db_path: str = "navi.db") -> None:
        self._db_path = db_path
        with sqlite3.connect(db_path) as conn:
            conn.execute(_CREATE_TABLE)
            for migration in [
                "ALTER TABLE sessions ADD COLUMN pinned INTEGER NOT NULL DEFAULT 0",
                "ALTER TABLE sessions ADD COLUMN context TEXT NOT NULL DEFAULT ''",
                "ALTER TABLE sessions ADD COLUMN context_token_count INTEGER NOT NULL DEFAULT 0",
                "ALTER TABLE sessions ADD COLUMN name TEXT",
                "ALTER TABLE sessions ADD COLUMN planning_logs TEXT NOT NULL DEFAULT '[]'",
            ]:
                try:
                    conn.execute(migration)
                except sqlite3.OperationalError:
                    pass  # column already exists
            conn.commit()

    async def create(self, profile_id: str) -> Session:
        session = Session(profile_id=profile_id)
        async with aiosqlite.connect(self._db_path) as db:
            await db.execute(
                "INSERT INTO sessions "
                "(id, profile_id, messages, context, pinned, created_at, last_active, context_token_count) "
                "VALUES (?, ?, '[]', '', 0, ?, ?, 0)",
                (session.id, session.profile_id,
                 session.created_at.isoformat(), session.last_active.isoformat()),
            )
            await db.commit()
        return session

    async def get(self, session_id: str) -> Session | None:
        async with aiosqlite.connect(self._db_path) as db:
            async with db.execute(
                "SELECT id, profile_id, messages, context, pinned, created_at, last_active, context_token_count, name, planning_logs "
                "FROM sessions WHERE id = ?",
                (session_id,),
            ) as cur:
                row = await cur.fetchone()
        return self._row_to_session(row) if row else None

    async def save(self, session: Session) -> None:
        session.last_active = datetime.now(timezone.utc)
        async with aiosqlite.connect(self._db_path) as db:
            await db.execute(
                "UPDATE sessions SET profile_id = ?, messages = ?, context = ?, "
                "last_active = ?, context_token_count = ?, planning_logs = ? WHERE id = ?",
                (session.profile_id, _serialize(session.messages), _serialize(session.context),
                 session.last_active.isoformat(), session.context_token_count,
                 json.dumps(session.planning_logs, ensure_ascii=False), session.id),
            )
            await db.commit()

    async def set_pinned(self, session_id: str, pinned: bool) -> bool:
        async with aiosqlite.connect(self._db_path) as db:
            cur = await db.execute(
                "UPDATE sessions SET pinned = ? WHERE id = ?",
                (1 if pinned else 0, session_id),
            )
            await db.commit()
            return cur.rowcount > 0

    async def set_name(self, session_id: str, name: str) -> bool:
        async with aiosqlite.connect(self._db_path) as db:
            cur = await db.execute(
                "UPDATE sessions SET name = ? WHERE id = ?",
                (name, session_id),
            )
            await db.commit()
            return cur.rowcount > 0

    async def list_all(self) -> list[Session]:
        async with aiosqlite.connect(self._db_path) as db:
            async with db.execute(
                "SELECT id, profile_id, messages, context, pinned, created_at, last_active, context_token_count, name, planning_logs "
                "FROM sessions ORDER BY pinned DESC, last_active DESC"
            ) as cur:
                rows = await cur.fetchall()
        return [self._row_to_session(r) for r in rows]

    async def delete(self, session_id: str) -> bool:
        async with aiosqlite.connect(self._db_path) as db:
            cur = await db.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
            await db.commit()
            return cur.rowcount > 0

    def _row_to_session(self, row: tuple) -> Session:
        id_, profile_id, messages_json, context_json, pinned, created_at, last_active, context_token_count, name, planning_logs_json = row
        messages = _deserialize(messages_json)
        context = _deserialize(context_json) if context_json else list(messages)
        planning_logs = json.loads(planning_logs_json) if planning_logs_json else []
        return Session(
            id=id_,
            profile_id=profile_id,
            messages=messages,
            context=context,
            pinned=bool(pinned),
            name=name,
            created_at=_parse_dt(created_at),
            last_active=_parse_dt(last_active),
            context_token_count=context_token_count or 0,
            planning_logs=planning_logs,
        )
