"""Memory tool — search, save, and forget facts about the user."""

from datetime import datetime, timedelta, timezone

from navi.memory.store import MemoryStore
from navi.tools._internal.base import current_session_id, current_user_id

from ._internal.base import Tool, ToolResult

_VALID_CATEGORIES = {"profile", "preferences", "technical", "projects", "other"}
_VALID_SOURCES = {"conversation", "tool_call", "auto_discovery", "user_explicit"}


class MemoryTool(Tool):
    name = "memory"
    description = (
        "Manage long-term memory about the user — facts that survive across sessions. "
        "Actions: save (upsert a fact), search (find facts by query), forget (delete by key), list (all facts). "
        "When saving system facts (IPs, hosts, services), set source='tool_call', confidence=95, and expires_days=7."
    )
    parameters = {
        "type": "object",
        "properties": {
            "action": {
                "type": "string",
                "enum": ["save", "search", "forget", "list"],
                "description": (
                    "save — upsert a fact (overwrites existing key). "
                    "search — find facts by keyword query. "
                    "forget — delete a fact by key. "
                    "list — return all stored facts."
                ),
            },
            "query": {
                "type": "string",
                "description": "search only: keywords describing what to look for.",
            },
            "category": {
                "type": "string",
                "enum": ["profile", "preferences", "technical", "projects", "other"],
                "description": (
                    "save/forget: fact category. "
                    "profile=who they are, preferences=likes/dislikes, "
                    "technical=OS/tools/servers, projects=ongoing work, other=anything else."
                ),
            },
            "key": {
                "type": "string",
                "description": (
                    "save/forget: snake_case identifier unique within the category. "
                    "Examples: name, primary_os, home_server_ip, response_language."
                ),
            },
            "value": {
                "type": "string",
                "description": "save only: the fact as a concise plain-text statement.",
            },
            "source": {
                "type": "string",
                "enum": ["conversation", "tool_call", "auto_discovery", "user_explicit"],
                "description": (
                    "save only: how the fact was obtained. "
                    "conversation=extracted from chat, tool_call=found via tool execution, "
                    "auto_discovery=system scan, user_explicit=user told me directly."
                ),
            },
            "confidence": {
                "type": "integer",
                "description": "save only: 0-100. Tool output=95, user statement=80, web=50, guess=30.",
                "minimum": 0,
                "maximum": 100,
            },
            "expires_days": {
                "type": "integer",
                "description": "save only: how many days this fact stays valid. Null = never expires.",
            },
            "source_context": {
                "type": "string",
                "description": "save only: provenance — 'found via ip addr on localhost', 'user said in session X'.",
            },
        },
        "required": ["action"],
    }

    def __init__(self, memory_store: MemoryStore) -> None:
        self._store = memory_store

    async def execute(self, params: dict) -> ToolResult:
        action = params.get("action", "")

        if action == "save":
            return await self._save(params)
        if action == "search":
            return await self._search(params)
        if action == "forget":
            return await self._forget(params)
        if action == "list":
            return await self._list()
        return ToolResult(success=False, output=f"Unknown action '{action}'.", error="invalid action")

    async def _save(self, params: dict) -> ToolResult:
        category = (params.get("category") or "").strip().lower()
        key = (params.get("key") or "").strip()
        value = (params.get("value") or "").strip()
        source = (params.get("source") or "conversation").strip().lower()
        confidence = params.get("confidence", 70)
        expires_days = params.get("expires_days")
        source_context = (params.get("source_context") or "").strip()

        if not category:
            return ToolResult(success=False, output="category is required for save.", error="missing category")
        if category not in _VALID_CATEGORIES:
            return ToolResult(
                success=False,
                output=f"Invalid category '{category}'. Must be one of: {', '.join(sorted(_VALID_CATEGORIES))}",
                error="invalid category",
            )
        if not key:
            return ToolResult(success=False, output="key is required for save.", error="missing key")
        if not value:
            return ToolResult(success=False, output="value is required for save.", error="missing value")
        if source not in _VALID_SOURCES:
            source = "conversation"

        try:
            confidence = int(confidence)
        except Exception:
            confidence = 70
        confidence = max(0, min(100, confidence))

        expires_at = None
        if expires_days is not None:
            try:
                expires_at = datetime.now(timezone.utc) + timedelta(days=int(expires_days))
            except Exception:
                pass

        session_id = current_session_id.get(None)
        user_id = current_user_id.get(None)
        await self._store.upsert_fact(
            category=category,
            key=key,
            value=value,
            user_id=user_id,
            source_session_id=session_id,
            source=source,
            confidence=confidence,
            expires_at=expires_at,
            source_context=source_context,
        )
        return ToolResult(success=True, output=f"Saved [{category}] {key}: {value}")

    async def _search(self, params: dict) -> ToolResult:
        query = (params.get("query") or "").strip()
        if not query:
            return ToolResult(success=False, output="query is required for search.", error="missing query")

        user_id = current_user_id.get(None)
        facts = await self._store.search_facts(query, user_id=user_id, limit=15)
        if not facts:
            return ToolResult(success=True, output="No matching facts found in memory.")

        lines = []
        for f in facts:
            prov = ""
            if f.get("source"):
                prov += f" (src: {f['source']}"
                if f.get("confidence"):
                    prov += f", conf: {f['confidence']}"
                if f.get("source_context"):
                    prov += f", ctx: {f['source_context']}"
                prov += ")"
            lines.append(f"[{f['category']}] {f['key']}: {f['value']}{prov}")
        return ToolResult(success=True, output=f"Found {len(facts)} fact(s):\n" + "\n".join(lines))

    async def _forget(self, params: dict) -> ToolResult:
        key = (params.get("key") or "").strip()
        category = (params.get("category") or "").strip() or None

        if not key:
            return ToolResult(success=False, output="key is required for forget.", error="missing key")

        user_id = current_user_id.get(None)
        deleted = await self._store.delete_fact(key, category, user_id=user_id)
        if deleted == 0:
            return ToolResult(success=False, output=f"No fact found with key '{key}'.", error="not found")

        noun = "fact" if deleted == 1 else "facts"
        return ToolResult(success=True, output=f"Deleted {deleted} {noun} with key '{key}'.")

    async def _list(self) -> ToolResult:
        user_id = current_user_id.get(None)
        categories = await self._store.get_categories(user_id=user_id)
        if not categories:
            return ToolResult(success=True, output="Memory is empty.")
        return ToolResult(
            success=True,
            output="Categories in memory:\n" + "\n".join(f"  • {c}" for c in categories),
        )
