"""SSH tool — execute commands on remote hosts via SSH.

Connections are defined in ssh_hosts.json (see .env.example for path config).
The tool also accepts inline host/user/key parameters for one-off connections.

ssh_hosts.json example:
{
  "prod": {
    "host": "1.2.3.4",
    "port": 22,
    "username": "root",
    "client_keys": ["~/.ssh/id_rsa"],
    "known_hosts": null
  },
  "staging": {
    "host": "staging.example.com",
    "username": "ubuntu"
  }
}

known_hosts:
  null  — use system ~/.ssh/known_hosts
  "none" — skip host key verification (useful for fresh VPS, not recommended for prod)
"""

import asyncio
import json
import os
from pathlib import Path

import asyncssh

from navi.config import settings

from .base import Tool, ToolResult

_TIMEOUT = 60


def _load_hosts() -> dict:
    path = Path(settings.ssh_hosts_file).expanduser()
    if not path.exists():
        return {}
    try:
        return json.loads(path.read_text())
    except Exception:
        return {}


class SshExecTool(Tool):
    name = "ssh_exec"
    description = (
        "Execute a command on a remote server via SSH. "
        "Use a named connection from ssh_hosts.json or specify host/username directly."
    )
    parameters = {
        "type": "object",
        "properties": {
            "connection": {
                "type": "string",
                "description": (
                    "Named connection from ssh_hosts.json (e.g. 'prod'), "
                    "or 'user@host' for a direct connection using default SSH keys."
                ),
            },
            "command": {
                "type": "string",
                "description": "Shell command to run on the remote host",
            },
            "timeout": {
                "type": "integer",
                "description": f"Timeout in seconds (default {_TIMEOUT})",
            },
        },
        "required": ["connection", "command"],
    }

    async def execute(self, params: dict) -> ToolResult:
        connection = params["connection"].strip()
        command = params["command"].strip()
        timeout = int(params.get("timeout") or _TIMEOUT)

        connect_kwargs = self._resolve_connection(connection)
        if connect_kwargs is None:
            hosts = list(_load_hosts().keys())
            hint = f"Available named connections: {hosts}" if hosts else "No ssh_hosts.json found."
            return ToolResult(
                success=False,
                output=f"Unknown connection '{connection}'. {hint}",
                error="unknown_connection",
            )

        try:
            async with asyncssh.connect(**connect_kwargs) as conn:
                result = await asyncio.wait_for(
                    conn.run(command, check=False),
                    timeout=timeout,
                )

            output_parts = []
            if result.stdout:
                output_parts.append(result.stdout)
            if result.stderr:
                output_parts.append(f"[stderr]\n{result.stderr}")

            success = result.exit_status == 0
            return ToolResult(
                success=success,
                output="\n".join(output_parts) or "(no output)",
                metadata={"exit_status": result.exit_status, "host": connect_kwargs.get("host")},
                error=None if success else f"Exit status {result.exit_status}",
            )
        except asyncssh.DisconnectError as e:
            return ToolResult(success=False, output=f"SSH disconnected: {e}", error=str(e))
        except asyncssh.PermissionDenied:
            return ToolResult(success=False, output="SSH permission denied. Check credentials.", error="permission_denied")
        except (TimeoutError, asyncio.TimeoutError):
            return ToolResult(success=False, output=f"SSH command timed out after {timeout}s", error="timeout")
        except Exception as e:
            return ToolResult(success=False, output=f"SSH error: {e}", error=str(e))

    def _resolve_connection(self, connection: str) -> dict | None:
        hosts = _load_hosts()

        # Named connection
        if connection in hosts:
            cfg = hosts[connection]
            return self._build_kwargs(cfg)

        # Inline user@host
        if "@" in connection:
            parts = connection.split("@", 1)
            username, host = parts[0], parts[1]
            return self._build_kwargs({"host": host, "username": username})

        return None

    def _build_kwargs(self, cfg: dict) -> dict:
        kwargs: dict = {
            "host": cfg["host"],
            "port": int(cfg.get("port", 22)),
            "username": cfg.get("username", os.environ.get("USER", "root")),
        }

        client_keys = cfg.get("client_keys")
        if client_keys:
            kwargs["client_keys"] = [str(Path(k).expanduser()) for k in client_keys]

        password = cfg.get("password")
        if password:
            kwargs["password"] = password

        known_hosts = cfg.get("known_hosts", None)
        if known_hosts == "none":
            kwargs["known_hosts"] = None
        elif known_hosts is not None:
            kwargs["known_hosts"] = str(Path(known_hosts).expanduser())
        # else: omit → asyncssh uses system known_hosts

        return kwargs


