"""WebSocket endpoint for streaming agent responses.

Protocol (client -> server):
  {"type": "message", "content": "..."}
  {"type": "form_submit", "form_id": "...", "values": {...}}

Protocol (server -> client):
  {"type": "stream_start"}
  {"type": "thinking_delta",  "delta": "..."}
  {"type": "thinking_end"}
  {"type": "stream_delta",    "delta": "..."}
  {"type": "tool_started",    "tool": "...", "args": {...}, "is_subagent": bool}
  {"type": "tool_call",       "tool": "...", "args": {...}, "result": "...", "success": bool, "is_subagent": bool}
  {"type": "stream_end",      "content": "..."}
  {"type": "context_compressed"}
  {"type": "error",           "message": "..."}
"""

import asyncio
import json

import structlog
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
from typing import Annotated

from navi.api.deps import get_orchestrator, get_session_store
from navi.auth.deps import get_current_user, get_current_user_ws
from navi.auth import User
from navi.auth.deps import check_session_access
from navi.core import SessionStore
from navi.core.event_bus import get_event_bus
from navi.core.events import AgentEvent, RecallUpdate
from navi.llm.base import Message

router = APIRouter(tags=["websocket"])
log = structlog.get_logger()


# ── Helpers ───────────────────────────────────────────────────────────────────

def _event_to_dict(event) -> dict | None:
    if hasattr(event, "to_wire"):
        return event.to_wire()
    return None


async def _start_agent_run(
    *,
    session_id: str,
    user_content: str,
    display_content: str | None,
    raw_images: list[str] | None,
    uploaded_files: list[dict],
    hidden_msg: Message | None,
    hidden: bool,
    websocket: WebSocket,
    orchestrator: "AgentSessionOrchestrator",
    session_store: SessionStore,
    user: User | None,
) -> bool:
    """Atomically start an agent run, stream its events, and clean up.

    Returns True if the client stayed connected, False otherwise.
    """
    run = None
    queue = None

    # Guard against concurrent runs for the same session (atomically).
    async with orchestrator.session_lock(session_id):
        if orchestrator.is_running(session_id):
            await websocket.send_json({
                "type": "error",
                "message": "Agent is already running for this session.",
            })
            return True  # socket still alive; caller should keep reading

        # Register run and subscribe before starting the task so we never
        # miss events even if the task is very fast.
        run = orchestrator.create_run(session_id)
        queue = run.subscribe()

    # Set user context for tool sandboxing (inherited by the agent task)
    from navi.tools._internal.base import current_user_id as _uid_var, current_user_role as _role_var, current_user_info as _uinfo_var
    if user is not None:
        uid_token = _uid_var.set(user.id)
        role_token = _role_var.set(user.role)
        uinfo_token = _uinfo_var.set(user.model_dump(mode="json"))
    else:
        uid_token = _uid_var.set(None)
        role_token = _role_var.set("user")
        uinfo_token = _uinfo_var.set(None)

    try:
        run.task = asyncio.create_task(
            orchestrator.run_agent(
                session_id,
                user_content,
                raw_images,
                display_content,
                uploaded_files,
                session_store,
                hidden=hidden,
            )
        )
    except Exception:
        # If create_task fails, ensure the run is cleaned up so the session
        # isn't stuck as "running" forever.
        orchestrator.clear_run(session_id)
        if queue is not None:
            run.unsubscribe(queue)
        raise
    finally:
        _uid_var.reset(uid_token)
        _role_var.reset(role_token)
        _uinfo_var.reset(uinfo_token)

    await websocket.send_json({"type": "stream_start"})
    connected = await _stream_to_client(websocket, queue)
    run.unsubscribe(queue)

    return connected


_HEARTBEAT_INTERVAL = 20.0  # seconds — keeps the browser from dropping idle connections


async def _stream_to_client(websocket: WebSocket, queue: asyncio.Queue) -> bool:
    """
    Forward queue items to the WebSocket until ("done", None).
    Returns True if client stayed connected, False if it disconnected mid-stream.
    On disconnect, continues draining silently so the agent task can finish cleanly.
    Sends periodic heartbeat pings while waiting so the browser doesn't time out
    during long-running silent operations (e.g. AIHelper LLM calls).
    """
    client_alive = True
    while True:
        try:
            kind, payload = await asyncio.wait_for(queue.get(), timeout=_HEARTBEAT_INTERVAL)
        except asyncio.TimeoutError:
            # No event arrived — send a heartbeat to keep the WS alive
            if client_alive:
                try:
                    await websocket.send_json({"type": "heartbeat"})
                except Exception:
                    client_alive = False
            continue

        if kind == "done":
            return client_alive
        if not client_alive:
            continue  # keep draining so run_agent can broadcast without blocking
        try:
            if kind == "error":
                await websocket.send_json({"type": "error", "message": payload})
            elif kind == "stopped":
                await websocket.send_json({"type": "stream_stopped"})
            elif kind == "event":
                msg = _event_to_dict(payload)
                if msg:
                    await websocket.send_json(msg)
        except Exception:
            client_alive = False


# ── Endpoints ─────────────────────────────────────────────────────────────────

@router.post("/sessions/{session_id}/stop")
async def stop_session(
    session_id: str,
    store: Annotated[SessionStore, Depends(get_session_store)],
    user: Annotated[User | None, Depends(get_current_user)] = None,
) -> dict:
    """Signal the running agent for this session to stop cooperatively."""
    session = await store.get(session_id)
    if session is None:
        raise HTTPException(status_code=404, detail="Session not found")
    if user is None and session.user_id is not None:
        raise HTTPException(status_code=401, detail="Authentication required")
    if user is not None:
        check_session_access(session, user)
    orchestrator = get_orchestrator()
    ok = orchestrator.stop(session_id)
    return {"ok": ok} if ok else {"ok": False, "reason": "no active run"}


@router.websocket("/ws/sessions/{session_id}")
async def websocket_session(
    session_id: str,
    websocket: WebSocket,
) -> None:
    log.info("ws.handler_enter", session_id=session_id)

    # Resolve user manually — bypass FastAPI Depends() to avoid HTTP 403
    # on the upgrade request when auth resolution fails.
    try:
        user: User | None = await get_current_user_ws(websocket)
    except Exception as exc:
        log.warning("ws.resolve_user_exc", session_id=session_id, exc_type=type(exc).__name__, error=str(exc))
        user = None
    log.info("ws.user_resolved", session_id=session_id, user_id=user.id if user else None)

    session_store = get_session_store()
    orchestrator = get_orchestrator()

    session = await session_store.get(session_id)
    if session is None:
        log.warning("ws.session_not_found", session_id=session_id)
        await websocket.close(code=4004, reason="Session not found")
        return

    log.info("ws.session_found", session_id=session_id, session_user_id=session.user_id)

    # Accept the WebSocket before checking access so that auth failures can be
    # sent as WebSocket close codes rather than HTTP 403 on the upgrade request.
    await websocket.accept()
    orchestrator.add_websocket(session_id, websocket)
    log.info("ws.accepted", session_id=session_id)

    if user is None:
        # Anonymous users may only connect to legacy sessions (no owner).
        if session.user_id is not None:
            log.warning("ws.anonymous_denied", session_id=session_id)
            await websocket.close(code=4003, reason="Authentication required")
            return
    else:
        try:
            check_session_access(session, user)
            log.info("ws.access_granted", session_id=session_id, user_id=user.id)
        except Exception:
            log.warning("ws.access_denied", session_id=session_id, user_id=user.id)
            await websocket.close(code=4003, reason="Access denied")
            return

    queue: asyncio.Queue | None = None
    current_run = None

    try:
        # Re-attach to an in-progress run (e.g. client reloaded the page mid-stream).
        existing = orchestrator.get_run(session_id)
        if existing is not None:
            current_run = existing
            # Subscribe BEFORE noting replay_count — single-threaded async, no race:
            # any broadcast() that happens after subscribe() goes into our queue,
            # and we replay only events[0:replay_count] from the buffer.
            queue = existing.subscribe()
            replay_count = len(existing.events)
            log.info("ws.reattached", session_id=session_id, replay_count=replay_count)
            await websocket.send_json({"type": "stream_start"})
            # Replay all events emitted before we subscribed so the client can
            # reconstruct the full in-progress UI state without any gaps.
            if replay_count > 0:
                await websocket.send_json({"type": "replay_start", "count": replay_count})
                for ev_dict in existing.events[:replay_count]:
                    try:
                        await websocket.send_json(ev_dict)
                    except Exception:
                        existing.unsubscribe(queue)
                        return
                await websocket.send_json({"type": "replay_end"})
            connected = await _stream_to_client(websocket, queue)
            existing.unsubscribe(queue)
            queue = None
            current_run = None
            if not connected:
                return  # client disconnected again — stop here
            # Stream finished — tell the client to sync session history so it sees
            # the full saved response (handles any events missed during disconnect).
            await websocket.send_json({"type": "session_sync"})
        else:
            # No active run — if this is a reconnect after the agent already finished,
            # the client needs to reload session history to see the saved response.
            await websocket.send_json({"type": "session_sync"})

        while True:
            raw = await websocket.receive_text()

            try:
                data = json.loads(raw)
            except json.JSONDecodeError:
                await websocket.send_json({"type": "error", "message": "Invalid JSON"})
                continue

            msg_type = data.get("type")

            if msg_type == "form_submit":
                form_id = data.get("form_id")
                values = data.get("values")
                if not form_id or not isinstance(values, dict):
                    await websocket.send_json({
                        "type": "error",
                        "message": "Expected {type: 'form_submit', form_id: '...', values: {...}}",
                    })
                    continue

                payload_text = json.dumps(
                    {"form_id": form_id, "values": values},
                    ensure_ascii=False,
                )
                hidden_msg = Message(
                    role="user",
                    content=payload_text,
                    created_at=None,
                    is_display=False,
                    is_context=True,
                )
                connected = await _start_agent_run(
                    session_id=session_id,
                    user_content=payload_text,
                    display_content=None,
                    raw_images=None,
                    uploaded_files=[],
                    hidden_msg=hidden_msg,
                    hidden=True,
                    websocket=websocket,
                    orchestrator=orchestrator,
                    session_store=session_store,
                    user=user,
                )
                if not connected:
                    break
                continue

            if msg_type != "message" or not data.get("content"):
                await websocket.send_json({
                    "type": "error",
                    "message": "Expected {type: 'message', content: '...'} or {type: 'form_submit', form_id: '...', values: {...}}",
                })
                continue

            original_content = data["content"]
            user_content = original_content
            raw_images: list[str] | None = data.get("images") or None
            if raw_images:
                # Guard against abuse: limit count and total payload size
                _MAX_IMAGES = 8
                _MAX_IMAGE_BYTES_TOTAL = 50 * 1024 * 1024  # 50 MB total payload
                if len(raw_images) > _MAX_IMAGES:
                    await websocket.send_json({
                        "type": "error",
                        "message": f"Too many images ({len(raw_images)}). Max {_MAX_IMAGES} allowed.",
                    })
                    continue
                cleaned = []
                total_bytes = 0
                for img in raw_images:
                    if "," in img and img.startswith("data:"):
                        img = img.split(",", 1)[1]
                    img_bytes = len(img.encode("utf-8"))
                    if total_bytes + img_bytes > _MAX_IMAGE_BYTES_TOTAL:
                        await websocket.send_json({
                            "type": "error",
                            "message": "Total image payload exceeds 50 MB limit.",
                        })
                        cleaned = None
                        break
                    total_bytes += img_bytes
                    cleaned.append(img)
                if cleaned is None:
                    continue
                raw_images = cleaned

            # Tell the model the inline images are already in its multimodal context,
            # so it doesn't hallucinate a path/URL and call image_view to "load" them.
            if raw_images:
                n = len(raw_images)
                noun = "image" if n == 1 else "images"
                user_content = (
                    user_content
                    + f"\n\n[{n} {noun} attached inline — already in your context, no extra loading needed.]"
                )

            # Append uploaded file paths to user content so Navi knows about them
            uploaded_files: list[dict] = data.get("files") or []
            if uploaded_files:
                file_lines = "\n".join(
                    f"- {f['name']} → {f['path']}" for f in uploaded_files
                )
                user_content = (
                    user_content + f"\n\n[Uploaded files on disk:\n{file_lines}]"
                )

            connected = await _start_agent_run(
                session_id=session_id,
                user_content=user_content,
                display_content=original_content,
                raw_images=raw_images,
                uploaded_files=uploaded_files,
                hidden_msg=None,
                hidden=False,
                websocket=websocket,
                orchestrator=orchestrator,
                session_store=session_store,
                user=user,
            )
            if not connected:
                break  # avoid calling receive_text() on a dead socket

    except (WebSocketDisconnect, RuntimeError):
        log.info("ws.disconnected", session_id=session_id)
    finally:
        # Ensure queue is removed from subscribers on any abrupt exit.
        if queue is not None and current_run is not None:
            current_run.unsubscribe(queue)
        # Remove this socket from the session tracking set.
        await orchestrator.remove_websocket(session_id, websocket)
