"""
Agent: the tool-calling loop.

Flow:
1. Receive user message, load session + profile
2. Build tool schemas from profile's enabled_tools
3. Loop (up to max_iterations):
   a. Call LLM with current messages + tool schemas
   b. If finish_reason == "stop"  -> done, return content
   c. If finish_reason == "tool_calls" -> execute tools concurrently, append results, continue
4. Final streaming path: use llm.stream() to yield text deltas to WebSocket clients

For multi-agent extension: instantiate multiple Agent objects with different profiles.
An Orchestrator (core/orchestrator.py) dispatches tasks to worker agents via asyncio Queues.
"""

import asyncio
import json
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import AsyncGenerator

import structlog

from navi.config import settings
from navi.exceptions import MaxIterationsReached, SessionNotFound
from navi.llm.base import LLMBackend, Message, ToolCallRequest
from navi.tools.base import Tool

from .compressor import compress_session, should_compress
from .registry import BackendRegistry, ProfileRegistry, ToolRegistry
from .session import SessionStore

_USER_ENABLED_FILE = Path(settings.tools_dir) / "enabled.json"


def _load_user_enabled_tools() -> list[str]:
    try:
        return json.loads(_USER_ENABLED_FILE.read_text())
    except Exception:
        return []

log = structlog.get_logger()


@dataclass
class ToolEvent:
    """Emitted during streaming to inform the client about tool activity."""

    tool_name: str
    arguments: dict
    result: str
    success: bool


@dataclass
class TextDelta:
    """A chunk of text from the streaming LLM response."""

    delta: str


@dataclass
class ThinkingDelta:
    """A chunk of thinking/reasoning text from the streaming LLM response."""

    delta: str


@dataclass
class ThinkingEnd:
    """Marks the end of the thinking phase."""


@dataclass
class StreamEnd:
    """Marks the end of the streaming response."""

    full_content: str
    context_tokens: int | None = None   # total tokens used in this turn
    max_context_tokens: int = 0         # ollama_num_ctx from config


@dataclass
class ContextCompressed:
    """Emitted after compression runs successfully."""

    messages_before: int
    messages_after: int


AgentEvent = ToolEvent | TextDelta | ThinkingDelta | ThinkingEnd | StreamEnd | ContextCompressed


class Agent:
    def __init__(
        self,
        session_store: SessionStore,
        profile_registry: ProfileRegistry,
        tool_registry: ToolRegistry,
        backend_registry: BackendRegistry,
    ) -> None:
        self._sessions = session_store
        self._profiles = profile_registry
        self._tools = tool_registry
        self._backends = backend_registry

    # ------------------------------------------------------------------
    # Public interface
    # ------------------------------------------------------------------

    async def run(self, session_id: str, user_message: str, images: list[str] | None = None) -> str:
        """Non-streaming: run the full tool-calling loop and return the final text."""
        session = await self._sessions.get(session_id)
        if session is None:
            raise SessionNotFound(session_id)

        profile = self._profiles.get(session.profile_id)
        tools = self._tool_list(profile.enabled_tools)
        tool_schemas = [t.schema() for t in tools]
        llm = self._get_backend(profile.llm_backend)

        # Inject system prompt on first message
        if not session.messages:
            session.messages.append(Message(role="system", content=self._build_system_prompt(profile.system_prompt)))

        session.messages.append(Message(role="user", content=user_message, images=images or None, created_at=datetime.now(timezone.utc)))

        for iteration in range(profile.max_iterations):
            log.debug("agent.iteration", session_id=session_id, iteration=iteration)
            response = await llm.complete(
                session.messages,
                tools=tool_schemas if tools else None,
                temperature=profile.temperature,
                model=profile.model,
            )

            if response.finish_reason == "stop" or not response.tool_calls:
                content = response.content or ""
                session.messages.append(Message(role="assistant", content=content))
                await self._sessions.save(session)
                return content

            # Tool calls turn
            assistant_msg = Message(
                role="assistant",
                content=response.content,
                tool_calls=response.tool_calls,
            )
            session.messages.append(assistant_msg)

            tool_results, image_injections = await self._execute_tool_calls(response.tool_calls, tools)
            session.messages.extend(tool_results)
            session.messages.extend(image_injections)

        await self._sessions.save(session)
        raise MaxIterationsReached(profile.max_iterations)

    async def run_stream(
        self, session_id: str, user_message: str, images: list[str] | None = None
    ) -> AsyncGenerator[AgentEvent, None]:
        """
        Streaming variant. Yields AgentEvent objects:
        - ToolEvent: when a tool is called and its result arrives
        - TextDelta: each text chunk from the final LLM response
        - StreamEnd: final event with the full assembled content
        """
        session = await self._sessions.get(session_id)
        if session is None:
            raise SessionNotFound(session_id)

        profile = self._profiles.get(session.profile_id)
        tools = self._tool_list(profile.enabled_tools)
        tool_schemas = [t.schema() for t in tools]
        llm = self._get_backend(profile.llm_backend)

        if not session.messages:
            session.messages.append(Message(role="system", content=self._build_system_prompt(profile.system_prompt)))

        session.messages.append(Message(role="user", content=user_message, images=images or None, created_at=datetime.now(timezone.utc)))

        # Tool-calling loop (non-streaming)
        for iteration in range(profile.max_iterations):
            response = await llm.complete(
                session.messages,
                tools=tool_schemas if tools else None,
                temperature=profile.temperature,
                model=profile.model,
            )

            if response.finish_reason == "stop" or not response.tool_calls:
                # Switch to streaming for the final text response
                final_messages = session.messages.copy()
                accumulated = ""
                thinking_active = False
                context_tokens: int | None = None

                async for chunk in llm.stream(final_messages, temperature=profile.temperature, model=profile.model):
                    if chunk.prompt_tokens is not None or chunk.completion_tokens is not None:
                        context_tokens = (chunk.prompt_tokens or 0) + (chunk.completion_tokens or 0)
                    if chunk.thinking:
                        if not thinking_active:
                            thinking_active = True
                        yield ThinkingDelta(delta=chunk.thinking)
                    elif chunk.delta:
                        if thinking_active:
                            thinking_active = False
                            yield ThinkingEnd()
                        accumulated += chunk.delta
                        yield TextDelta(delta=chunk.delta)

                if thinking_active:
                    yield ThinkingEnd()

                session.messages.append(Message(role="assistant", content=accumulated, created_at=datetime.now(timezone.utc)))
                await self._sessions.save(session)

                yield StreamEnd(
                    full_content=accumulated,
                    context_tokens=context_tokens,
                    max_context_tokens=settings.ollama_num_ctx,
                )

                # Post-response compression — runs after client receives StreamEnd
                if (
                    settings.context_compression_enabled
                    and context_tokens is not None
                    and should_compress(context_tokens, settings.ollama_num_ctx, settings.context_compression_threshold)
                ):
                    count_before = len(session.messages)
                    try:
                        new_messages = await compress_session(
                            messages=session.messages,
                            llm=llm,
                            model=profile.model,
                            temperature=settings.context_summary_temperature,
                            keep_recent=settings.context_keep_recent,
                        )
                        if new_messages is not None:
                            session.messages = new_messages
                            await self._sessions.save(session)
                            log.info(
                                "agent.compressed",
                                session_id=session_id,
                                before=count_before,
                                after=len(session.messages),
                            )
                            yield ContextCompressed(
                                messages_before=count_before,
                                messages_after=len(session.messages),
                            )
                    except Exception:
                        log.warning("agent.compress_failed", session_id=session_id, exc_info=True)

                return

            # Tool calls: emit events, execute, continue loop
            assistant_msg = Message(
                role="assistant",
                content=response.content,
                tool_calls=response.tool_calls,
            )
            session.messages.append(assistant_msg)

            tool_results_msgs, image_injections = await self._execute_tool_calls_streaming(
                response.tool_calls, tools
            )
            for event, msg in tool_results_msgs:
                yield event
                session.messages.append(msg)
            session.messages.extend(image_injections)

        await self._sessions.save(session)
        raise MaxIterationsReached(profile.max_iterations)

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------

    def _build_system_prompt(self, profile_prompt: str) -> str:
        persona = settings.navi_persona.strip()
        if persona:
            return f"{persona}\n\n---\n\n{profile_prompt}"
        return profile_prompt

    def _tool_list(self, enabled: list[str]) -> list[Tool]:
        names = list(enabled)
        # Merge in user-created tools from tools/enabled.json
        extra = _load_user_enabled_tools()
        for name in extra:
            if name not in names:
                names.append(name)
        # Silently skip any names not registered (e.g. tool was deleted)
        result = []
        for name in names:
            try:
                result.append(self._tools.get(name))
            except Exception:
                pass
        return result

    def _get_backend(self, backend_key: str) -> LLMBackend:
        return self._backends.get(backend_key)

    async def _execute_tool_calls(
        self, tool_calls: list[ToolCallRequest], tools: list[Tool]
    ) -> tuple[list[Message], list[Message]]:
        tool_map = {t.name: t for t in tools}

        async def _run_one(tc: ToolCallRequest) -> tuple[Message, Message | None]:
            tool = tool_map.get(tc.name)
            image_msg = None
            if tool is None:
                content = f"Error: tool '{tc.name}' not found."
            else:
                log.info("tool.execute", tool=tc.name, args=tc.arguments)
                result = await tool.execute(tc.arguments)
                content = result.to_message_content()
                if result.success and result.metadata and result.metadata.get("is_image"):
                    b64 = result.metadata.get("base64")
                    if b64:
                        image_msg = Message(
                            role="user",
                            content=f"[Image loaded via {tc.name} — analyse it]",
                            images=[b64],
                        )
            tool_msg = Message(role="tool", content=content, tool_call_id=tc.id, name=tc.name)
            return tool_msg, image_msg

        pairs = await asyncio.gather(*[_run_one(tc) for tc in tool_calls])
        tool_msgs = [p[0] for p in pairs]
        image_msgs = [p[1] for p in pairs if p[1] is not None]
        return tool_msgs, image_msgs

    async def _execute_tool_calls_streaming(
        self, tool_calls: list[ToolCallRequest], tools: list[Tool]
    ) -> tuple[list[tuple[ToolEvent, Message]], list[Message]]:
        tool_map = {t.name: t for t in tools}

        async def _run_one(tc: ToolCallRequest) -> tuple[ToolEvent, Message, Message | None]:
            tool = tool_map.get(tc.name)
            image_msg = None
            if tool is None:
                content = f"Error: tool '{tc.name}' not found."
                event = ToolEvent(
                    tool_name=tc.name, arguments=tc.arguments, result=content, success=False
                )
            else:
                log.info("tool.execute", tool=tc.name, args=tc.arguments)
                result = await tool.execute(tc.arguments)
                content = result.to_message_content()
                event = ToolEvent(
                    tool_name=tc.name,
                    arguments=tc.arguments,
                    result=content,
                    success=result.success,
                )
                if result.success and result.metadata and result.metadata.get("is_image"):
                    b64 = result.metadata.get("base64")
                    if b64:
                        image_msg = Message(
                            role="user",
                            content=f"[Image loaded via {tc.name} — analyse it]",
                            images=[b64],
                        )
            msg = Message(role="tool", content=content, tool_call_id=tc.id, name=tc.name)
            return event, msg, image_msg

        triples = await asyncio.gather(*[_run_one(tc) for tc in tool_calls])
        pairs = [(t[0], t[1]) for t in triples]
        image_msgs = [t[2] for t in triples if t[2] is not None]
        return pairs, image_msgs
