"""PostgreSQL-backed session store using asyncpg connection pool.

Normalized storage:
- session_messages         — hot messages (recent N per session)
- session_messages_archive — old messages moved here to keep the hot table small
- session_images           — base64 images referenced by session_messages (future use)

 sessions.next_sequence    — monotonic global seq for the session (next free number)
 sessions.archive_threshold — all seq < threshold live in the archive table

All session message I/O goes through the normalized tables.  The legacy JSON
columns (sessions.messages, sessions.context) remain in the schema for backward
compatibility but are no longer read or written.
"""

import asyncio
import json
from datetime import datetime, timezone

import asyncpg

from navi.llm.base import Message, ToolCallRequest

from .session import Session, SessionStore

_DDL = """
CREATE TABLE IF NOT EXISTS sessions (
    id                  TEXT PRIMARY KEY,
    profile_id          TEXT NOT NULL,
    user_id             TEXT REFERENCES navi_users(id) ON DELETE SET NULL,
    messages            TEXT NOT NULL DEFAULT '[]',
    context             TEXT NOT NULL DEFAULT '',
    pinned              BOOLEAN NOT NULL DEFAULT FALSE,
    created_at          TIMESTAMPTZ NOT NULL,
    last_active         TIMESTAMPTZ NOT NULL,
    context_token_count INTEGER NOT NULL DEFAULT 0,
    name                TEXT,
    session_metadata    TEXT NOT NULL DEFAULT '{}'
)
"""

_MIGRATE = """
ALTER TABLE sessions ADD COLUMN IF NOT EXISTS name TEXT;
ALTER TABLE sessions ADD COLUMN IF NOT EXISTS planning_logs TEXT NOT NULL DEFAULT '[]';
ALTER TABLE sessions ADD COLUMN IF NOT EXISTS user_id TEXT REFERENCES navi_users(id) ON DELETE SET NULL;
ALTER TABLE sessions ADD COLUMN IF NOT EXISTS next_sequence INTEGER NOT NULL DEFAULT 0;
ALTER TABLE sessions ADD COLUMN IF NOT EXISTS archive_threshold INTEGER NOT NULL DEFAULT 0;
ALTER TABLE sessions ADD COLUMN IF NOT EXISTS session_metadata TEXT NOT NULL DEFAULT '{}'
"""

_SESSION_MESSAGES_DDL = """
CREATE TABLE IF NOT EXISTS session_messages (
    id              SERIAL PRIMARY KEY,
    session_id      TEXT NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
    sequence_number INT NOT NULL,
    role            TEXT NOT NULL,
    content         TEXT,
    images          TEXT,                       -- JSON array of base64 strings
    tool_calls      TEXT,                       -- JSON
    tool_call_id    TEXT,
    name            TEXT,
    created_at      TIMESTAMPTZ,
    is_summary      BOOLEAN NOT NULL DEFAULT FALSE,
    thinking        TEXT,
    is_plan         BOOLEAN NOT NULL DEFAULT FALSE,
    is_compression  BOOLEAN NOT NULL DEFAULT FALSE,
    is_context      BOOLEAN NOT NULL DEFAULT TRUE,
    is_display      BOOLEAN NOT NULL DEFAULT TRUE,
    elapsed_seconds FLOAT,
    tool_call_count INT,
    token_count     INT,
    files           TEXT,                       -- JSON
    metadata        TEXT,                       -- JSON
    is_recall       BOOLEAN NOT NULL DEFAULT FALSE,
    UNIQUE(session_id, sequence_number)
);

CREATE INDEX IF NOT EXISTS idx_session_messages_session_seq ON session_messages(session_id, sequence_number);
CREATE INDEX IF NOT EXISTS idx_session_messages_context ON session_messages(session_id, is_context, sequence_number);
"""

_SESSION_IMAGES_DDL = """
CREATE TABLE IF NOT EXISTS session_images (
    id          SERIAL PRIMARY KEY,
    session_id  TEXT NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
    message_id  INT REFERENCES session_messages(id) ON DELETE CASCADE,
    base64      TEXT NOT NULL,
    filename    TEXT
);

CREATE INDEX IF NOT EXISTS idx_session_images_session ON session_images(session_id);
CREATE INDEX IF NOT EXISTS idx_session_images_message ON session_images(message_id);
"""

_SESSION_ARCHIVE_DDL = """
CREATE TABLE IF NOT EXISTS session_messages_archive (
    id              SERIAL PRIMARY KEY,
    session_id      TEXT NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
    sequence_number INT NOT NULL,
    role            TEXT NOT NULL,
    content         TEXT,
    images          TEXT,
    tool_calls      TEXT,
    tool_call_id    TEXT,
    name            TEXT,
    created_at      TIMESTAMPTZ,
    is_summary      BOOLEAN NOT NULL DEFAULT FALSE,
    thinking        TEXT,
    is_plan         BOOLEAN NOT NULL DEFAULT FALSE,
    is_compression  BOOLEAN NOT NULL DEFAULT FALSE,
    is_context      BOOLEAN NOT NULL DEFAULT TRUE,
    is_display      BOOLEAN NOT NULL DEFAULT TRUE,
    elapsed_seconds FLOAT,
    tool_call_count INT,
    token_count     INT,
    files           TEXT,
    metadata        TEXT,
    is_recall       BOOLEAN NOT NULL DEFAULT FALSE,
    UNIQUE(session_id, sequence_number)
);
CREATE INDEX IF NOT EXISTS idx_session_messages_archive_session_seq ON session_messages_archive(session_id, sequence_number);
"""


def _serialize(messages: list[Message]) -> str:
    return json.dumps(
        [m.model_dump(mode="json", exclude_none=True) for m in messages],
        ensure_ascii=False,
    )


def _deserialize(raw: str) -> list[Message]:
    if not raw:
        return []
    return [Message.model_validate(m) for m in json.loads(raw)]


def _message_key(m: Message) -> tuple:
    """Stable key for matching a message between messages[] and context[].

    Used only by the one-shot boot migration for legacy JSON rows.
    """
    return (
        m.role,
        m.content,
        m.tool_call_id,
        m.name,
        m.is_summary,
        m.is_plan,
        m.is_compression,
        m.is_recall,
        m.thinking,
        m.created_at.isoformat() if m.created_at else None,
        json.dumps([tc.model_dump(mode="json") for tc in m.tool_calls], ensure_ascii=False) if m.tool_calls else None,
        json.dumps(m.files, ensure_ascii=False) if m.files else None,
        json.dumps(m.metadata, ensure_ascii=False) if m.metadata else None,
    )


def _row_to_message(row: asyncpg.Record) -> Message:
    images = json.loads(row["images"]) if row.get("images") else None
    tool_calls = None
    raw_tool_calls = row.get("tool_calls")
    if raw_tool_calls:
        tool_calls = [ToolCallRequest.model_validate(tc) for tc in json.loads(raw_tool_calls)]
    files = json.loads(row["files"]) if row.get("files") else None
    metadata = json.loads(row["metadata"]) if row.get("metadata") else {}

    return Message(
        role=row["role"],
        content=row["content"],
        images=images,
        tool_calls=tool_calls,
        tool_call_id=row.get("tool_call_id"),
        name=row.get("name"),
        created_at=row.get("created_at"),
        is_summary=bool(row.get("is_summary", False)),
        thinking=row.get("thinking"),
        is_plan=bool(row.get("is_plan", False)),
        is_compression=bool(row.get("is_compression", False)),
        is_context=bool(row.get("is_context", True)),
        is_display=bool(row.get("is_display", True)),
        elapsed_seconds=row.get("elapsed_seconds"),
        tool_call_count=row.get("tool_call_count"),
        token_count=row.get("token_count"),
        files=files,
        metadata=metadata,
        is_recall=bool(row.get("is_recall", False)),
        sequence_number=row.get("sequence_number", 0),
    )


async def _ensure_normalized_tables(conn: asyncpg.Connection) -> None:
    await conn.execute(_SESSION_MESSAGES_DDL)
    await conn.execute(_SESSION_IMAGES_DDL)
    await conn.execute(_SESSION_ARCHIVE_DDL)


async def _migrate_to_normalized(conn: asyncpg.Connection) -> None:
    """One-shot migration: copy existing JSON messages/context into session_messages."""
    migrated = await conn.fetchval("SELECT COUNT(*) FROM session_messages")
    if migrated:
        return

    rows = await conn.fetch("SELECT id, messages, context FROM sessions")
    for row in rows:
        session_id = row["id"]
        messages = _deserialize(row["messages"] or "[]")
        context = _deserialize(row["context"] or "")
        if not context and messages:
            context = list(messages)

        context_set = {_message_key(m) for m in context}

        for seq, m in enumerate(messages):
            is_ctx = _message_key(m) in context_set
            await conn.execute(
                """
                INSERT INTO session_messages
                (session_id, sequence_number, role, content, images, tool_calls, tool_call_id, name,
                 created_at, is_summary, thinking, is_plan, is_compression, is_context, is_display,
                 elapsed_seconds, tool_call_count, token_count, files, metadata, is_recall)
                VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21)
                ON CONFLICT (session_id, sequence_number) DO NOTHING
                """,
                session_id,
                seq,
                m.role,
                m.content,
                json.dumps(m.images, ensure_ascii=False) if m.images else None,
                json.dumps([tc.model_dump(mode="json") for tc in m.tool_calls], ensure_ascii=False) if m.tool_calls else None,
                m.tool_call_id,
                m.name,
                m.created_at,
                m.is_summary,
                m.thinking,
                m.is_plan,
                m.is_compression,
                is_ctx,
                True,
                m.elapsed_seconds,
                m.tool_call_count,
                m.token_count,
                json.dumps(m.files, ensure_ascii=False) if m.files else None,
                json.dumps(m.metadata, ensure_ascii=False) if m.metadata else None,
                m.is_recall,
            )


async def _load_messages_map(conn: asyncpg.Connection, session_ids: list[str]) -> dict[str, list[Message]]:
    """Batch-load all messages for the given session IDs."""
    if not session_ids:
        return {}
    rows = await conn.fetch(
        "SELECT * FROM session_messages WHERE session_id = ANY($1) ORDER BY sequence_number",
        session_ids,
    )
    result: dict[str, list[Message]] = {sid: [] for sid in session_ids}
    for row in rows:
        sid = row["session_id"]
        result[sid].append(_row_to_message(row))
    return result


async def _build_sessions(
    conn: asyncpg.Connection,
    rows: list[asyncpg.Record],
) -> list[Session]:
    """Hydrate session rows with hot (non-archived) messages."""
    session_ids = [r["id"] for r in rows]
    messages_map = await _load_messages_map(conn, session_ids)
    sessions: list[Session] = []
    for row in rows:
        all_msgs = messages_map.get(row["id"], [])
        archive_threshold = row.get("archive_threshold", 0) or 0
        hot_msgs = [m for m in all_msgs if m.sequence_number >= archive_threshold]
        messages = [m for m in hot_msgs if m.is_display]
        context = [m for m in hot_msgs if m.is_context]
        planning_logs_raw = row["planning_logs"]
        planning_logs = json.loads(planning_logs_raw) if planning_logs_raw else []

        next_seq = row.get("next_sequence", 0) or 0
        max_seq = max((m.sequence_number for m in hot_msgs), default=-1)
        if next_seq == 0:
            next_seq = max_seq + 1

        s = Session(
            id=row["id"],
            profile_id=row["profile_id"],
            user_id=row["user_id"],
            messages=messages,
            context=context,
            pinned=bool(row["pinned"]),
            name=row["name"],
            created_at=row["created_at"],
            last_active=row["last_active"],
            context_token_count=row["context_token_count"] or 0,
            planning_logs=planning_logs,
        )
        session_metadata = json.loads(row.get("session_metadata") or "{}") if row.get("session_metadata") else {}
        s.db_message_count = len(hot_msgs)
        s.db_next_sequence = next_seq
        s.archive_threshold = archive_threshold
        s.session_metadata = session_metadata
        sessions.append(s)
    return sessions


class PgSessionStore(SessionStore):
    def __init__(self, pool: asyncpg.Pool) -> None:
        self._pool = pool
        self._initialized = False
        self._lock = asyncio.Lock()

    async def _get_pool(self) -> asyncpg.Pool:
        if not self._initialized:
            async with self._lock:
                if not self._initialized:
                    async with self._pool.acquire() as conn:
                        await conn.execute(_DDL)
                        await conn.execute(_MIGRATE)
                        await _ensure_normalized_tables(conn)
                        await _migrate_to_normalized(conn)
                    self._initialized = True
        return self._pool

    async def create(self, profile_id: str, user_id: str | None = None) -> Session:
        session = Session(profile_id=profile_id, user_id=user_id)
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                "INSERT INTO sessions "
                "(id, profile_id, user_id, pinned, created_at, last_active, context_token_count) "
                "VALUES ($1, $2, $3, FALSE, $4, $5, 0)",
                session.id, session.profile_id, session.user_id, session.created_at, session.last_active,
            )
        return session

    async def get(self, session_id: str) -> Session | None:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            row = await conn.fetchrow(
                "SELECT id, profile_id, user_id, pinned, created_at, last_active, context_token_count, name, planning_logs, next_sequence, archive_threshold, session_metadata "
                "FROM sessions WHERE id = $1",
                session_id,
            )
            if not row:
                return None

            archive_threshold = row["archive_threshold"] or 0
            # Load hot (non-archived) messages so messages[] and context[] share
            # the same Python objects (id() matching in the agent works).
            all_rows = await conn.fetch(
                "SELECT * FROM session_messages WHERE session_id = $1 AND sequence_number >= $2 ORDER BY sequence_number",
                session_id,
                archive_threshold,
            )
            all_messages = [_row_to_message(r) for r in all_rows]
            messages = [m for m in all_messages if m.is_display]
            context = [m for m in all_messages if m.is_context]

            planning_logs_raw = row["planning_logs"]
            planning_logs = json.loads(planning_logs_raw) if planning_logs_raw else []

            next_seq = row["next_sequence"] or 0
            max_seq = max((m.sequence_number for m in all_messages), default=-1)
            if next_seq == 0:
                next_seq = max_seq + 1

            session_metadata = json.loads(row.get("session_metadata") or "{}") if row.get("session_metadata") else {}
            s = Session(
                id=row["id"],
                profile_id=row["profile_id"],
                user_id=row["user_id"],
                messages=messages,
                context=context,
                pinned=bool(row["pinned"]),
                name=row["name"],
                created_at=row["created_at"],
                last_active=row["last_active"],
                context_token_count=row["context_token_count"] or 0,
                planning_logs=planning_logs,
            )
            s.db_message_count = len(all_messages)
            s.db_next_sequence = next_seq
            s.archive_threshold = archive_threshold
            s.session_metadata = session_metadata
            return s

    async def save(self, session: Session) -> None:
        session.last_active = datetime.now(timezone.utc)
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            async with conn.transaction():
                await conn.execute(
                    "UPDATE sessions SET profile_id = $1, user_id = $2, "
                    "last_active = $3, context_token_count = $4, planning_logs = $5, session_metadata = $6 WHERE id = $7",
                    session.profile_id, session.user_id,
                    session.last_active, session.context_token_count,
                    json.dumps(session.planning_logs, ensure_ascii=False),
                    json.dumps(session.session_metadata, ensure_ascii=False), session.id,
                )

                messages = session.messages

                # 1. Update mutable flags for already-persisted rows (sequence_number >= 0)
                existing = [m for m in messages if m.sequence_number >= 0]
                if existing:
                    update_rows = []
                    for m in existing:
                        update_rows.append(
                            (
                                session.id,
                                m.sequence_number,
                                m.is_context,
                                m.is_display,
                                m.is_summary,
                                m.is_plan,
                                m.is_compression,
                                m.is_recall,
                                m.thinking,
                                m.elapsed_seconds,
                                m.tool_call_count,
                                m.token_count,
                            )
                        )
                    await conn.executemany(
                        """
                        UPDATE session_messages
                        SET is_context = $3,
                            is_display = $4,
                            is_summary = $5,
                            is_plan = $6,
                            is_compression = $7,
                            is_recall = $8,
                            thinking = $9,
                            elapsed_seconds = $10,
                            tool_call_count = $11,
                            token_count = $12
                        WHERE session_id = $1 AND sequence_number = $2
                        """,
                        update_rows,
                    )

                # 2. Insert new messages (sequence_number < 0 means "not yet persisted")
                new_msgs = [m for m in messages if m.sequence_number < 0]
                if new_msgs:
                    # Lock the session row and read the current next_sequence so concurrent
                    # saves (e.g. UI component render during an agent turn) never collide.
                    db_next = await conn.fetchval(
                        "SELECT next_sequence FROM sessions WHERE id = $1 FOR UPDATE",
                        session.id,
                    )
                    insert_rows = []
                    for i, m in enumerate(new_msgs):
                        seq = db_next + i
                        m.sequence_number = seq
                        insert_rows.append(
                            (
                                session.id,
                                seq,
                                m.role,
                                m.content,
                                json.dumps(m.images, ensure_ascii=False) if m.images else None,
                                json.dumps([tc.model_dump(mode="json") for tc in m.tool_calls], ensure_ascii=False) if m.tool_calls else None,
                                m.tool_call_id,
                                m.name,
                                m.created_at,
                                m.is_summary,
                                m.thinking,
                                m.is_plan,
                                m.is_compression,
                                m.is_context,
                                m.is_display,
                                m.elapsed_seconds,
                                m.tool_call_count,
                                m.token_count,
                                json.dumps(m.files, ensure_ascii=False) if m.files else None,
                                json.dumps(m.metadata, ensure_ascii=False) if m.metadata else None,
                                m.is_recall,
                            )
                        )
                    await conn.executemany(
                        """
                        INSERT INTO session_messages
                        (session_id, sequence_number, role, content, images, tool_calls, tool_call_id, name,
                         created_at, is_summary, thinking, is_plan, is_compression, is_context, is_display,
                         elapsed_seconds, tool_call_count, token_count, files, metadata, is_recall)
                        VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21)
                        """,
                        insert_rows,
                    )
                    new_next = db_next + len(new_msgs)
                    await conn.execute(
                        "UPDATE sessions SET next_sequence = $1 WHERE id = $2",
                        new_next, session.id,
                    )
                    session.db_next_sequence = new_next

                session.db_message_count = len(messages)

    async def archive_old_messages(self, session_id: str, keep_seq_threshold: int) -> int:
        """Move messages older than keep_seq_threshold from hot to archive table.

        Returns number of rows archived.
        """
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            # Copy old rows to archive
            copied = await conn.execute(
                """
                INSERT INTO session_messages_archive
                SELECT * FROM session_messages
                WHERE session_id = $1 AND sequence_number < $2
                ON CONFLICT (session_id, sequence_number) DO NOTHING
                """,
                session_id, keep_seq_threshold,
            )
            # Delete from hot table
            await conn.execute(
                "DELETE FROM session_messages WHERE session_id = $1 AND sequence_number < $2",
                session_id, keep_seq_threshold,
            )
            # Update threshold on session
            await conn.execute(
                "UPDATE sessions SET archive_threshold = $1 WHERE id = $2",
                keep_seq_threshold, session_id,
            )
        # asyncpg execute returns 'INSERT 0 N' — extract N
        parts = copied.split()
        return int(parts[-1]) if len(parts) >= 2 else 0

    async def get_archived_messages(
        self, session_id: str, before_seq: int | None = None, limit: int = 50
    ) -> list[Message]:
        """Return archived messages for a session, newest first."""
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            if before_seq is not None:
                rows = await conn.fetch(
                    "SELECT * FROM session_messages_archive "
                    "WHERE session_id = $1 AND sequence_number < $2 "
                    "ORDER BY sequence_number DESC LIMIT $3",
                    session_id, before_seq, limit,
                )
            else:
                rows = await conn.fetch(
                    "SELECT * FROM session_messages_archive "
                    "WHERE session_id = $1 "
                    "ORDER BY sequence_number DESC LIMIT $2",
                    session_id, limit,
                )
        # Return oldest-first so the caller can prepend them in order
        return [_row_to_message(r) for r in reversed(rows)]

    async def set_pinned(self, session_id: str, pinned: bool) -> bool:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            result = await conn.execute(
                "UPDATE sessions SET pinned = $1 WHERE id = $2",
                pinned, session_id,
            )
        return result == "UPDATE 1"

    async def set_name(self, session_id: str, name: str) -> bool:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            result = await conn.execute(
                "UPDATE sessions SET name = $1 WHERE id = $2",
                name, session_id,
            )
        return result == "UPDATE 1"

    async def list_all(self, user_id: str | None = None, is_admin: bool = False) -> list[Session]:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            if not is_admin and user_id is not None:
                rows = await conn.fetch(
                    "SELECT id, profile_id, user_id, pinned, created_at, last_active, context_token_count, name, planning_logs, next_sequence, archive_threshold, session_metadata "
                    "FROM sessions WHERE user_id = $1 ORDER BY pinned DESC, last_active DESC",
                    user_id,
                )
            else:
                rows = await conn.fetch(
                    "SELECT id, profile_id, user_id, pinned, created_at, last_active, context_token_count, name, planning_logs, next_sequence, archive_threshold, session_metadata "
                    "FROM sessions ORDER BY pinned DESC, last_active DESC"
                )
            return await _build_sessions(conn, rows)

    async def list_page(
        self,
        *,
        limit: int,
        offset: int,
        profile_id: str | None = None,
        user_id: str | None = None,
        is_admin: bool = False,
    ) -> list[Session]:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            conditions = []
            params: list = []
            param_idx = 0

            def add_param(value):
                nonlocal param_idx
                param_idx += 1
                params.append(value)
                return f"${param_idx}"

            if not is_admin and user_id is not None:
                conditions.append(f"user_id = {add_param(user_id)}")
            if profile_id:
                conditions.append(f"profile_id = {add_param(profile_id)}")

            where = "WHERE " + " AND ".join(conditions) if conditions else ""
            order_limit = f"ORDER BY pinned DESC, last_active DESC LIMIT {add_param(limit)} OFFSET {add_param(offset)}"

            rows = await conn.fetch(
                "SELECT id, profile_id, user_id, pinned, created_at, last_active, context_token_count, name, planning_logs, next_sequence, archive_threshold, session_metadata "
                f"FROM sessions {where} {order_limit}",
                *params,
            )
            return await _build_sessions(conn, rows)

    async def delete(self, session_id: str) -> bool:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            result = await conn.execute("DELETE FROM sessions WHERE id = $1", session_id)
        return result == "DELETE 1"

    async def count_all(
        self,
        *,
        user_id: str | None = None,
        is_admin: bool = False,
        search: str | None = None,
    ) -> int:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            conditions = []
            params: list = []
            param_idx = 0

            def add_param(value):
                nonlocal param_idx
                param_idx += 1
                params.append(value)
                return f"${param_idx}"

            if not is_admin and user_id is not None:
                conditions.append(f"user_id = {add_param(user_id)}")
            if search:
                like = f"%{search}%"
                conditions.append(
                    f"(id ILIKE {add_param(like)} OR name ILIKE {add_param(like)} OR user_id ILIKE {add_param(like)} OR profile_id ILIKE {add_param(like)} OR EXISTS (SELECT 1 FROM session_messages m WHERE m.session_id = sessions.id AND m.is_display = true AND m.content ILIKE {add_param(like)}) OR EXISTS (SELECT 1 FROM session_messages_archive a WHERE a.session_id = sessions.id AND a.is_display = true AND a.content ILIKE {add_param(like)}))"
                )

            where = "WHERE " + " AND ".join(conditions) if conditions else ""
            row = await conn.fetchrow(f"SELECT COUNT(*) FROM sessions {where}", *params)
        return row["count"] if row else 0

    async def search_list(
        self,
        *,
        limit: int,
        offset: int,
        user_id: str | None = None,
        is_admin: bool = False,
        search: str | None = None,
        sort_by: str = "last_active",
        sort_order: str = "desc",
    ) -> list[Session]:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            conditions = []
            params: list = []
            param_idx = 0

            def add_param(value):
                nonlocal param_idx
                param_idx += 1
                params.append(value)
                return f"${param_idx}"

            if not is_admin and user_id is not None:
                conditions.append(f"user_id = {add_param(user_id)}")
            if search:
                like = f"%{search}%"
                conditions.append(
                    f"(id ILIKE {add_param(like)} OR name ILIKE {add_param(like)} OR user_id ILIKE {add_param(like)} OR profile_id ILIKE {add_param(like)} OR EXISTS (SELECT 1 FROM session_messages m WHERE m.session_id = sessions.id AND m.is_display = true AND m.content ILIKE {add_param(like)}) OR EXISTS (SELECT 1 FROM session_messages_archive a WHERE a.session_id = sessions.id AND a.is_display = true AND a.content ILIKE {add_param(like)}))"
                )

            where = "WHERE " + " AND ".join(conditions) if conditions else ""

            allowed_cols = {"created_at", "last_active", "name", "profile_id", "user_id", "pinned"}
            col = sort_by if sort_by in allowed_cols else "last_active"
            order = "DESC" if sort_order == "desc" else "ASC"
            # secondary sort by pinned DESC for stability
            order_clause = f"ORDER BY pinned DESC, {col} {order} LIMIT {add_param(limit)} OFFSET {add_param(offset)}"

            rows = await conn.fetch(
                "SELECT id, profile_id, user_id, pinned, created_at, last_active, context_token_count, name, planning_logs, next_sequence, archive_threshold, session_metadata "
                f"FROM sessions {where} {order_clause}",
                *params,
            )
            return await _build_sessions(conn, rows)
