"""Fact CRUD and search — vector + ILIKE fallback."""

import re
import uuid
from datetime import datetime, timezone

import asyncpg
import structlog

from navi.config import settings

from ._embeddings import _vector_to_str

log = structlog.get_logger()

_SEPARATORS = re.compile(r"[-_/\\.]")
_NOISE = re.compile(r"[^\w\s]", re.UNICODE)
_AUTO_DUMP_THRESHOLD = 60
_VECTOR_CUTOFF_DISTANCE = 0.3


def _normalize_query(query: str) -> list[str]:
    q = _SEPARATORS.sub(" ", query)
    q = _NOISE.sub(" ", q)
    return [t for t in q.lower().split() if len(t) > 1]


class FactMixin:
    """Fact storage operations.

    Expected on the composite class:
        _get_pool() -> asyncpg.Pool
        _generate_embedding(text) -> list[float] | None
        _has_pgvector() -> bool
    """

    async def upsert_fact(
        self,
        category: str,
        key: str,
        value: str,
        user_id: str | None = None,
        source_session_id: str | None = None,
        source: str = "conversation",
        confidence: int = 70,
        expires_at: datetime | None = None,
        source_context: str = "",
    ) -> None:
        now = datetime.now(timezone.utc)
        embedding = await self._generate_embedding(value)
        vec_str = _vector_to_str(embedding) if embedding else None
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            if vec_str:
                await conn.execute(
                    """INSERT INTO memory_facts
                        (id, user_id, category, key, value, created_at, updated_at, source_session_id,
                         embedding, source, confidence, expires_at, source_context)
                       VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9::vector, $10, $11, $12, $13)
                       ON CONFLICT(user_id, category, key) DO UPDATE SET
                           value             = EXCLUDED.value,
                           updated_at        = EXCLUDED.updated_at,
                           source_session_id = EXCLUDED.source_session_id,
                           embedding         = EXCLUDED.embedding,
                           source            = EXCLUDED.source,
                           confidence        = EXCLUDED.confidence,
                           expires_at        = EXCLUDED.expires_at,
                           source_context    = EXCLUDED.source_context""",
                    str(uuid.uuid4()), user_id, category, key, value, now, now,
                    source_session_id, vec_str, source, confidence, expires_at, source_context,
                )
            else:
                await conn.execute(
                    """INSERT INTO memory_facts
                        (id, user_id, category, key, value, created_at, updated_at, source_session_id,
                         source, confidence, expires_at, source_context)
                       VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
                       ON CONFLICT(user_id, category, key) DO UPDATE SET
                           value             = EXCLUDED.value,
                           updated_at        = EXCLUDED.updated_at,
                           source_session_id = EXCLUDED.source_session_id,
                           source            = EXCLUDED.source,
                           confidence        = EXCLUDED.confidence,
                           expires_at        = EXCLUDED.expires_at,
                           source_context    = EXCLUDED.source_context""",
                    str(uuid.uuid4()), user_id, category, key, value, now, now,
                    source_session_id, source, confidence, expires_at, source_context,
                )

    async def search_facts(self, query: str, user_id: str | None = None, limit: int = 15) -> list[dict]:
        # 1. Try vector search if pgvector + embedding backend are available
        if self._embedding_backend and await self._has_pgvector():
            query_embedding = await self._generate_embedding(query)
            vec_str = _vector_to_str(query_embedding) if query_embedding else None
            if vec_str:
                try:
                    pool = await self._get_pool()
                    async with pool.acquire() as conn:
                        if user_id is None:
                            rows = await conn.fetch(
                                """SELECT id, category, key, value, updated_at,
                                          source, confidence, expires_at, source_context,
                                          embedding <=> $1::vector AS distance
                                   FROM memory_facts
                                   WHERE (expires_at IS NULL OR expires_at > now())
                                     AND user_id IS NULL
                                   ORDER BY embedding <=> $1::vector
                                   LIMIT $2""",
                                vec_str, limit,
                            )
                        else:
                            rows = await conn.fetch(
                                """SELECT id, category, key, value, updated_at,
                                          source, confidence, expires_at, source_context,
                                          embedding <=> $1::vector AS distance
                                   FROM memory_facts
                                   WHERE (expires_at IS NULL OR expires_at > now())
                                     AND user_id = $3
                                   ORDER BY embedding <=> $1::vector
                                   LIMIT $2""",
                                vec_str, limit, user_id,
                            )
                        results = [
                            _row_to_dict(r)
                            for r in rows
                            if r["distance"] < _VECTOR_CUTOFF_DISTANCE
                        ]
                        if results:
                            log.debug("memory.vector_search", hits=len(results), query=query[:40])
                            return results
                except Exception:
                    log.warning("memory.vector_search_failed", query=query[:40], exc_info=True)
            else:
                log.warning("memory.vector_fallback_to_ilike", reason="embed_failed", query=query[:40])
        else:
            reason = "no_embedding_backend" if not self._embedding_backend else "no_pgvector"
            log.warning("memory.vector_fallback_to_ilike", reason=reason, query=query[:40])

        # 2. Text fallback (original ILIKE logic)
        terms = _normalize_query(query)
        if not terms:
            return await self.get_all_facts(user_id=user_id, limit=limit)

        if await self.fact_count(user_id=user_id) <= _AUTO_DUMP_THRESHOLD:
            return await self.get_all_facts(user_id=user_id, limit=limit)

        base_params: list = []
        term_conds: list[str] = []
        for term in terms:
            like = f"%{term}%"
            i = len(base_params) + 1
            term_conds.append(
                f"(category ILIKE ${i} OR key ILIKE ${i + 1} OR value ILIKE ${i + 2})"
            )
            base_params.extend([like, like, like])

        user_filter = "user_id IS NULL" if user_id is None else f"user_id = ${len(base_params) + 1}"
        if user_id is not None:
            base_params.append(user_id)

        limit_idx = len(base_params) + 1
        and_where = " AND ".join(term_conds)
        or_where = " OR ".join(term_conds)

        pool = await self._get_pool()
        async with pool.acquire() as conn:
            rows = await conn.fetch(
                f"SELECT id, category, key, value, updated_at, source, confidence, "
                f"expires_at, source_context FROM memory_facts "
                f"WHERE {user_filter} AND {and_where} ORDER BY updated_at DESC LIMIT ${limit_idx}",
                *base_params, limit,
            )
            if rows:
                return [_row_to_dict(r) for r in rows]

            score_expr = " + ".join(
                f"CASE WHEN {c} THEN 1 ELSE 0 END" for c in term_conds
            )
            rows = await conn.fetch(
                f"SELECT id, category, key, value, updated_at, source, confidence, "
                f"expires_at, source_context, ({score_expr}) AS score "
                f"FROM memory_facts WHERE {user_filter} AND ({or_where}) "
                f"ORDER BY score DESC, updated_at DESC LIMIT ${limit_idx}",
                *base_params, limit,
            )
        return [_row_to_dict(r) for r in rows]

    async def delete_fact(self, key: str, category: str | None = None, user_id: str | None = None) -> int:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            if category:
                if user_id is None:
                    result = await conn.execute(
                        "DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1) AND LOWER(category)=LOWER($2) AND user_id IS NULL",
                        key, category,
                    )
                else:
                    result = await conn.execute(
                        "DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1) AND LOWER(category)=LOWER($2) AND user_id = $3",
                        key, category, user_id,
                    )
            else:
                if user_id is None:
                    result = await conn.execute(
                        "DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1) AND user_id IS NULL", key
                    )
                else:
                    result = await conn.execute(
                        "DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1) AND user_id = $2", key, user_id
                    )
        return int(result.split()[1])

    async def get_all_facts(
        self,
        user_id: str | None = None,
        limit: int | None = None,
        offset: int = 0,
        search: str | None = None,
        sort_by: str = "category",
        sort_order: str = "desc",
        all_users: bool = False,
    ) -> list[dict]:
        q = (
            "SELECT id, category, key, value, updated_at, source, confidence, "
            "expires_at, source_context FROM memory_facts "
        )
        params: list = []
        param_idx = 0

        def add_param(value):
            nonlocal param_idx
            param_idx += 1
            params.append(value)
            return f"${param_idx}"

        conditions = []
        if not all_users:
            if user_id is None:
                conditions.append("user_id IS NULL")
            else:
                conditions.append(f"user_id = {add_param(user_id)}")
        if search:
            like = f"%{search}%"
            conditions.append(
                f"(key ILIKE {add_param(like)} OR value ILIKE {add_param(like)} OR category ILIKE {add_param(like)})"
            )

        if conditions:
            q += "WHERE " + " AND ".join(conditions) + " "

        allowed_cols = {"category", "key", "updated_at", "confidence", "source"}
        col = sort_by if sort_by in allowed_cols else "category"
        order = "DESC" if sort_order == "desc" else "ASC"
        if col == "category":
            q += f"ORDER BY {col} {order}, updated_at DESC"
        else:
            q += f"ORDER BY {col} {order}"

        if limit is not None:
            q += f" LIMIT {add_param(limit)}"
        if offset:
            q += f" OFFSET {add_param(offset)}"

        pool = await self._get_pool()
        async with pool.acquire() as conn:
            rows = await conn.fetch(q, *params)
        return [_row_to_dict(r) for r in rows]

    async def fact_count(
        self,
        user_id: str | None = None,
        all_users: bool = False,
        search: str | None = None,
    ) -> int:
        q = "SELECT COUNT(*) FROM memory_facts"
        params: list = []
        conditions = []
        if not all_users:
            if user_id is None:
                conditions.append("user_id IS NULL")
            else:
                conditions.append(f"user_id = ${len(params) + 1}")
                params.append(user_id)
        if search:
            like = f"%{search}%"
            conditions.append(
                f"(key ILIKE ${len(params) + 1} OR value ILIKE ${len(params) + 2} OR category ILIKE ${len(params) + 3})"
            )
            params.extend([like, like, like])
        if conditions:
            q += " WHERE " + " AND ".join(conditions)
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            return await conn.fetchval(q, *params) or 0


    async def get_categories(self, user_id: str | None = None) -> list[str]:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            if user_id is None:
                rows = await conn.fetch(
                    "SELECT DISTINCT category FROM memory_facts WHERE user_id IS NULL ORDER BY category"
                )
            else:
                rows = await conn.fetch(
                    "SELECT DISTINCT category FROM memory_facts WHERE user_id = $1 ORDER BY category",
                    user_id,
                )
        return [r["category"] for r in rows]


def _row_to_dict(row: asyncpg.Record) -> dict:
    val = row["updated_at"]
    updated_at = val.isoformat() if hasattr(val, "isoformat") else str(val)
    expires_at = row.get("expires_at")
    expires_str = expires_at.isoformat() if hasattr(expires_at, "isoformat") else None
    return {
        "id": row["id"],
        "category": row["category"],
        "key": row["key"],
        "value": row["value"],
        "updated_at": updated_at,
        "source": row.get("source", "conversation"),
        "confidence": row.get("confidence", 70),
        "expires_at": expires_str,
        "source_context": row.get("source_context", ""),
    }
