"""Gmail tool — send, list, read, and reply to emails via Gmail API."""

import asyncio
import base64
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from pathlib import Path
from typing import Any

import html2text as _html2text
from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build

name = "gmail"
description = (
    "Interact with Gmail: send emails, list inbox (paginated), list all unread messages, "
    "read a specific email by ID (marks it as read), or reply to an email thread."
)

parameters = {
    "type": "object",
    "properties": {
        "action": {
            "type": "string",
            "enum": ["send", "list", "list_unread", "read", "reply"],
            "description": (
                "Action to perform. "
                "'send': send a new email. "
                "'list': list inbox messages (paginated). "
                "'list_unread': list all unread messages. "
                "'read': fetch and display a specific message by ID (marks it read). "
                "'reply': reply to a message thread by ID."
            ),
        },
        "to": {
            "type": "string",
            "description": "Recipient email address. Required for 'send'.",
        },
        "subject": {
            "type": "string",
            "description": "Email subject line. Required for 'send'.",
        },
        "body": {
            "type": "string",
            "description": "Email body (HTML). Required for 'send' and 'reply'.",
        },
        "message_id": {
            "type": "string",
            "description": "Gmail message ID (from list output). Required for 'read' and 'reply'.",
        },
        "max_results": {
            "type": "integer",
            "description": "Max messages to return for 'list' (default 10, max 50).",
        },
        "page_token": {
            "type": "string",
            "description": "Page token returned by a previous 'list' call for pagination.",
        },
        "query": {
            "type": "string",
            "description": "Gmail search query to filter 'list' results (e.g. 'from:example.com subject:invoice').",
        },
    },
    "required": ["action"],
}

_SCOPES = ["https://www.googleapis.com/auth/gmail.modify"]
_TOKEN_PATH = Path(__file__).parent / "gmail_token.json"
_MAX_CHARS = 5000


# ─── Auth ──────────────────────────────────────────────────────────────────

def _get_service():
    if not _TOKEN_PATH.exists():
        raise RuntimeError(
            "Gmail not authorized. Steps:\n"
            "1. Place gmail_credentials.json in the tools/ directory\n"
            "2. Run: python tools/gmail_auth.py\n"
            "3. Complete the browser login flow"
        )
    creds = Credentials.from_authorized_user_file(str(_TOKEN_PATH), _SCOPES)
    if creds.expired and creds.refresh_token:
        creds.refresh(Request())
        _TOKEN_PATH.write_text(creds.to_json())
    return build("gmail", "v1", credentials=creds, cache_discovery=False)


# ─── MIME helpers ──────────────────────────────────────────────────────────

def _header(msg: dict, name: str) -> str:
    for h in msg.get("payload", {}).get("headers", []):
        if h["name"].lower() == name.lower():
            return h["value"]
    return ""


def _find_part(payload: dict, mime_type: str) -> str | None:
    """Recursively find the first part with the given MIME type; return decoded text."""
    if payload.get("mimeType") == mime_type:
        data = payload.get("body", {}).get("data", "")
        if data:
            return base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
    for part in payload.get("parts", []):
        result = _find_part(part, mime_type)
        if result is not None:
            return result
    return None


def _extract_text(payload: dict) -> str:
    """Extract readable text. Prefers text/plain; falls back to HTML→text conversion."""
    plain = _find_part(payload, "text/plain")
    if plain:
        return plain.strip()
    html = _find_part(payload, "text/html")
    if html:
        h = _html2text.HTML2Text()
        h.ignore_links = False
        h.ignore_images = True
        h.body_width = 0
        return h.handle(html).strip()
    return ""


def _list_attachments(payload: dict) -> list[tuple[str, int]]:
    result: list[tuple[str, int]] = []
    def _scan(part: dict) -> None:
        filename = part.get("filename", "")
        if filename:
            size = part.get("body", {}).get("size", 0)
            result.append((filename, size))
        for p in part.get("parts", []):
            _scan(p)
    _scan(payload)
    return result


def _fmt_size(n: int) -> str:
    if n >= 1024 * 1024:
        return f"{n / 1024 / 1024:.1f}MB"
    if n >= 1024:
        return f"{n / 1024:.1f}KB"
    return f"{n}B"


def _build_mime(
    to: str,
    subject: str,
    body_html: str,
    in_reply_to: str = "",
    references: str = "",
) -> MIMEMultipart:
    h = _html2text.HTML2Text()
    h.ignore_links = False
    h.ignore_images = True
    h.body_width = 0
    body_plain = h.handle(body_html)

    msg = MIMEMultipart("alternative")
    msg["To"] = to
    msg["Subject"] = subject
    if in_reply_to:
        msg["In-Reply-To"] = in_reply_to
    if references:
        msg["References"] = references
    msg.attach(MIMEText(body_plain, "plain", "utf-8"))
    msg.attach(MIMEText(body_html, "html", "utf-8"))
    return msg


# ─── Actions ───────────────────────────────────────────────────────────────

def _do_send(service: Any, to: str, subject: str, body_html: str) -> str:
    if not to:
        return "Error: 'to' is required."
    if not subject:
        return "Error: 'subject' is required."
    if not body_html:
        return "Error: 'body' is required."

    msg = _build_mime(to, subject, body_html)
    raw = base64.urlsafe_b64encode(msg.as_bytes()).decode()
    service.users().messages().send(userId="me", body={"raw": raw}).execute()
    return f"Sent to {to} | Subject: {subject}"


def _do_list(
    service: Any,
    max_results: int,
    page_token: str | None,
    query: str | None,
) -> str:
    max_results = min(int(max_results), 50)
    kwargs: dict[str, Any] = {"userId": "me", "maxResults": max_results}
    if page_token:
        kwargs["pageToken"] = page_token
    if query:
        kwargs["q"] = query

    result = service.users().messages().list(**kwargs).execute()
    messages = result.get("messages", [])
    next_token = result.get("nextPageToken")

    if not messages:
        return "No messages found."

    lines: list[str] = []
    for m in messages:
        full = service.users().messages().get(
            userId="me",
            id=m["id"],
            format="metadata",
            metadataHeaders=["From", "Subject", "Date"],
        ).execute()
        from_hdr = _header(full, "From")
        subject = _header(full, "Subject") or "(no subject)"
        date = _header(full, "Date")
        unread = "●" if "UNREAD" in full.get("labelIds", []) else "○"
        lines.append(f"{unread} {m['id']} | {from_hdr} | {subject} | {date}")

    out = "\n".join(lines)
    if next_token:
        out += f"\n\n[next_page_token: {next_token}]"
    return out


def _do_list_unread(service: Any) -> str:
    result = service.users().messages().list(
        userId="me", q="is:unread", maxResults=100
    ).execute()
    messages = result.get("messages", [])

    if not messages:
        return "No unread messages."

    lines: list[str] = []
    for m in messages:
        full = service.users().messages().get(
            userId="me",
            id=m["id"],
            format="metadata",
            metadataHeaders=["From", "Subject", "Date"],
        ).execute()
        from_hdr = _header(full, "From")
        subject = _header(full, "Subject") or "(no subject)"
        date = _header(full, "Date")
        lines.append(f"● {m['id']} | {from_hdr} | {subject} | {date}")

    header_line = f"Unread: {len(messages)}" + (
        " (showing first 100)" if len(messages) == 100 else ""
    )
    return header_line + "\n" + "\n".join(lines)


def _do_read(service: Any, message_id: str) -> str:
    if not message_id:
        return "Error: 'message_id' is required."

    msg = service.users().messages().get(
        userId="me", id=message_id, format="full"
    ).execute()

    from_hdr = _header(msg, "From")
    to_hdr = _header(msg, "To")
    subject = _header(msg, "Subject") or "(no subject)"
    date = _header(msg, "Date")

    text = _extract_text(msg.get("payload", {}))
    truncated = ""
    if len(text) > _MAX_CHARS:
        truncated = f"\n\n[truncated — {len(text) - _MAX_CHARS} chars omitted]"
        text = text[:_MAX_CHARS]

    attachments = _list_attachments(msg.get("payload", {}))

    lines = [
        f"From:    {from_hdr}",
        f"To:      {to_hdr}",
        f"Subject: {subject}",
        f"Date:    {date}",
        "─" * 40,
        text or "(no text content)",
    ]
    if truncated:
        lines.append(truncated)
    if attachments:
        lines.append("─" * 40)
        lines.append("Attachments:")
        for fname, fsize in attachments:
            lines.append(f"  • {fname} ({_fmt_size(fsize)})")

    # Mark as read
    service.users().messages().modify(
        userId="me", id=message_id, body={"removeLabelIds": ["UNREAD"]}
    ).execute()

    return "\n".join(lines)


def _do_reply(service: Any, message_id: str, body_html: str) -> str:
    if not message_id:
        return "Error: 'message_id' is required."
    if not body_html:
        return "Error: 'body' is required."

    orig = service.users().messages().get(
        userId="me",
        id=message_id,
        format="metadata",
        metadataHeaders=["From", "Subject", "Message-ID", "References"],
    ).execute()

    thread_id = orig.get("threadId", "")
    from_hdr = _header(orig, "From")
    subject = _header(orig, "Subject") or ""
    orig_msg_id = _header(orig, "Message-ID")
    references = _header(orig, "References")

    reply_subject = subject if subject.lower().startswith("re:") else f"Re: {subject}"
    ref_chain = f"{references} {orig_msg_id}".strip() if references else orig_msg_id

    msg = _build_mime(from_hdr, reply_subject, body_html, orig_msg_id, ref_chain)
    raw = base64.urlsafe_b64encode(msg.as_bytes()).decode()
    service.users().messages().send(
        userId="me", body={"raw": raw, "threadId": thread_id}
    ).execute()

    return f"Reply sent to {from_hdr} | Subject: {reply_subject}"


# ─── Entry point ───────────────────────────────────────────────────────────

async def execute(params: dict) -> str:
    action = params.get("action", "")

    def _run() -> str:
        try:
            service = _get_service()
        except RuntimeError as e:
            return f"Error: {e}"

        if action == "send":
            return _do_send(
                service,
                to=params.get("to", ""),
                subject=params.get("subject", ""),
                body_html=params.get("body", ""),
            )
        if action == "list":
            return _do_list(
                service,
                max_results=params.get("max_results", 10),
                page_token=params.get("page_token"),
                query=params.get("query"),
            )
        if action == "list_unread":
            return _do_list_unread(service)
        if action == "read":
            return _do_read(service, message_id=params.get("message_id", ""))
        if action == "reply":
            return _do_reply(
                service,
                message_id=params.get("message_id", ""),
                body_html=params.get("body", ""),
            )
        return f"Error: unknown action '{action}'. Valid: send, list, list_unread, read, reply."

    return await asyncio.to_thread(_run)
