Newer
Older
navi-1 / navi / tools / ssh_exec.py
"""SSH tool — execute commands on remote hosts via SSH.

Connections are cached per-session with a 20-minute TTL. Each session gets
its own independent connection pool, so parallel sessions working the same
server don't interfere with each other.

Connections are defined in ssh_hosts.json (see .env.example for path config).
The tool also accepts inline host/user/key parameters for one-off connections.

ssh_hosts.json example:
{
  "prod": {
    "host": "1.2.3.4",
    "port": 22,
    "username": "root",
    "client_keys": ["~/.ssh/id_rsa"],
    "known_hosts": null
  },
  "staging": {
    "host": "staging.example.com",
    "username": "ubuntu"
  }
}

known_hosts:
  null  — use system ~/.ssh/known_hosts
  "none" — skip host key verification (useful for fresh VPS, not recommended for prod)
"""

import asyncio
import json
import os
import time
from dataclasses import dataclass
from pathlib import Path

import asyncssh

from navi.config import settings

from ._internal.base import Tool, ToolContext, ToolResult, current_session_id

_TIMEOUT = 60
_TTL = 20 * 60  # 20 minutes in seconds


@dataclass
class _PoolEntry:
    conn: asyncssh.SSHClientConnection
    created_at: float  # time.monotonic()
    connect_kwargs: dict


# Global pool: key → _PoolEntry
_pool: dict[str, _PoolEntry] = {}

# Per-key locks to prevent concurrent connection creation for the same key.
# Protected by a meta-lock when creating new per-key locks.
_lock_meta = asyncio.Lock()
_locks: dict[str, asyncio.Lock] = {}


async def _get_lock(key: str) -> asyncio.Lock:
    async with _lock_meta:
        if key not in _locks:
            _locks[key] = asyncio.Lock()
        return _locks[key]


async def _get_connection(key: str, connect_kwargs: dict) -> asyncssh.SSHClientConnection:
    """Return a live connection from pool, creating or reconnecting as needed."""
    lock = await _get_lock(key)
    async with lock:
        entry = _pool.get(key)
        if entry is not None:
            age = time.monotonic() - entry.created_at
            if age > _TTL or entry.conn.is_closed():
                # Expired or already closed — drop it
                try:
                    entry.conn.close()
                except Exception:
                    pass
                del _pool[key]
                entry = None

        if entry is None:
            conn = await asyncssh.connect(**connect_kwargs)
            _pool[key] = _PoolEntry(conn=conn, created_at=time.monotonic(), connect_kwargs=connect_kwargs)

        return _pool[key].conn


def _evict(key: str) -> None:
    """Remove a connection from the pool (called after an error)."""
    entry = _pool.pop(key, None)
    if entry is not None:
        try:
            entry.conn.close()
        except Exception:
            pass


def close_all_connections() -> None:
    """Close every pooled SSH connection. Called on server shutdown."""
    for entry in list(_pool.values()):
        try:
            entry.conn.close()
        except Exception:
            pass
    _pool.clear()


def _pool_key(session_id: str | None, host: str, port: int, username: str) -> str:
    sid = session_id or "anonymous"
    return f"{sid}:{host}:{port}:{username}"


def _load_hosts() -> dict:
    path = Path(settings.ssh_hosts_file).expanduser()
    if not path.exists():
        return {}
    try:
        return json.loads(path.read_text())
    except Exception:
        return {}


class SshExecTool(Tool):
    name = "ssh_exec"
    description = (
        "Run a command on a remote server over SSH. "
        "Required: command + host. Pass username and password (or key_path) directly — "
        "no config file needed. Use for any task on a remote VPS, server, or device. "
        "Host key verification is skipped by default for ad-hoc connections. "
        "Connections are reused within a session (20-minute TTL) — no reconnect overhead."
    )
    parameters = {
        "type": "object",
        "properties": {
            "action": {
                "type": "string",
                "enum": ["exec", "scp"],
                "description": "Action to perform: exec (run command, default) or scp (transfer file).",
            },
            "command": {
                "type": "string",
                "description": "Shell command to run on the remote host (required for action=exec).",
            },
            "host": {
                "type": "string",
                "description": "Hostname or IP address of the remote server",
            },
            "username": {
                "type": "string",
                "description": "SSH username",
            },
            "password": {
                "type": "string",
                "description": "SSH password (if using password authentication)",
            },
            "port": {
                "type": "integer",
                "description": "SSH port (default 22)",
            },
            "key_path": {
                "type": "string",
                "description": "Path to private key file, e.g. ~/.ssh/id_rsa (if using key authentication)",
            },
            "connection": {
                "type": "string",
                "description": "Named connection from ssh_hosts.json — shortcut that provides host/user/creds automatically",
            },
            "local_path": {
                "type": "string",
                "description": "Local file path (required for action=scp).",
            },
            "remote_path": {
                "type": "string",
                "description": "Remote file path (required for action=scp).",
            },
            "direction": {
                "type": "string",
                "enum": ["upload", "download"],
                "description": "Transfer direction for scp: upload (local→remote) or download (remote→local).",
            },
            "timeout": {
                "type": "integer",
                "description": f"Timeout in seconds (default {_TIMEOUT})",
            },
        },
        "required": [],
    }

    async def execute(self, params: dict, ctx: ToolContext | None = None) -> ToolResult:
        action = params.get("action", "exec")
        timeout = int(params.get("timeout") or _TIMEOUT)

        connect_kwargs = self._resolve(params)
        if connect_kwargs is None:
            msg = (
                "No SSH target specified. Provide 'host' (and optionally 'username', "
                "'password', 'key_path'), or a named 'connection' from ssh_hosts.json."
            )
            return ToolResult(success=False, output=msg, error=msg)

        if action == "scp":
            return await self._run_scp(params, connect_kwargs, timeout)

        command = (params.get("command") or "").strip()
        if not command:
            return ToolResult(success=False, output="'command' is required for exec action.", error="missing_command")

        session_id = ctx.session_id if ctx else current_session_id.get()
        host = connect_kwargs["host"]
        port = int(connect_kwargs.get("port", 22))
        username = connect_kwargs.get("username", "")
        key = _pool_key(session_id, host, port, username)

        return await self._run_with_retry(key, connect_kwargs, command, timeout)

    async def _run_with_retry(
        self,
        key: str,
        connect_kwargs: dict,
        command: str,
        timeout: int,
    ) -> ToolResult:
        for attempt in range(2):
            try:
                conn = await _get_connection(key, connect_kwargs)
                result = await asyncio.wait_for(
                    conn.run(command, check=False),
                    timeout=timeout,
                )

                output_parts = []
                if result.stdout:
                    output_parts.append(result.stdout)
                if result.stderr:
                    output_parts.append(f"[stderr]\n{result.stderr}")

                output_text = "\n".join(output_parts) or "(no output)"
                success = result.exit_status == 0
                return ToolResult(
                    success=success,
                    output=output_text,
                    metadata={"exit_status": result.exit_status, "host": connect_kwargs.get("host")},
                    error=None if success else f"Exit status {result.exit_status}\n{output_text}",
                )

            except (asyncssh.DisconnectError, asyncssh.ConnectionLost, EOFError) as e:
                _evict(key)
                if attempt == 0:
                    continue  # retry with fresh connection
                msg = f"SSH disconnected: {e}"
                return ToolResult(success=False, output=msg, error=msg)

            except asyncssh.PermissionDenied:
                _evict(key)
                msg = "SSH permission denied. Check username and password/key."
                return ToolResult(success=False, output=msg, error=msg)

            except (TimeoutError, asyncio.TimeoutError):
                msg = f"SSH command timed out after {timeout}s"
                return ToolResult(success=False, output=msg, error=msg)

            except Exception as e:
                _evict(key)
                msg = f"SSH error: {e}"
                return ToolResult(success=False, output=msg, error=msg)

        # Should not be reached
        return ToolResult(success=False, output="SSH: unexpected retry exhaustion", error="retry_failed")

    async def _run_scp(self, params: dict, connect_kwargs: dict, timeout: int) -> ToolResult:
        local_path = (params.get("local_path") or "").strip()
        remote_path = (params.get("remote_path") or "").strip()
        direction = params.get("direction", "upload")

        if not local_path:
            return ToolResult(success=False, output="'local_path' is required for scp.", error="missing_local_path")
        if not remote_path:
            return ToolResult(success=False, output="'remote_path' is required for scp.", error="missing_remote_path")

        host = connect_kwargs["host"]
        port = int(connect_kwargs.get("port", 22))
        username = connect_kwargs.get("username", os.environ.get("USER", "root"))
        ssh_target = (host, remote_path) if port == 22 else (host, port, remote_path)

        try:
            if direction == "upload":
                await asyncio.wait_for(
                    asyncssh.scp(local_path, ssh_target, **connect_kwargs),
                    timeout=timeout,
                )
                return ToolResult(success=True, output=f"Uploaded: {local_path} → {username}@{host}:{remote_path}")
            else:
                await asyncio.wait_for(
                    asyncssh.scp(ssh_target, local_path, **connect_kwargs),
                    timeout=timeout,
                )
                return ToolResult(success=True, output=f"Downloaded: {username}@{host}:{remote_path} → {local_path}")
        except (TimeoutError, asyncio.TimeoutError):
            return ToolResult(success=False, output=f"SCP timed out after {timeout}s", error="timeout")
        except Exception as e:
            return ToolResult(success=False, output=f"SCP error: {e}", error=str(e))

    def _resolve(self, params: dict) -> dict | None:
        # Named connection from ssh_hosts.json takes precedence
        connection = params.get("connection", "").strip()
        if connection:
            hosts = _load_hosts()
            if connection in hosts:
                cfg = dict(hosts[connection])
                # Inline params override stored values
                for k in ("host", "username", "password", "port"):
                    if params.get(k):
                        cfg[k] = params[k]
                if params.get("key_path"):
                    cfg["client_keys"] = [params["key_path"]]
                # Skip host key verification by default (same as ad-hoc)
                cfg.setdefault("known_hosts", "none")
                return self._build_kwargs(cfg)

        # Direct params
        host = params.get("host", "").strip()
        if not host:
            return None

        cfg: dict = {"host": host}
        if params.get("username"):
            cfg["username"] = params["username"]
        if params.get("password"):
            cfg["password"] = params["password"]
        if params.get("port"):
            cfg["port"] = params["port"]
        if params.get("key_path"):
            cfg["client_keys"] = [params["key_path"]]
        cfg.setdefault("known_hosts", "none")
        return self._build_kwargs(cfg)

    def _build_kwargs(self, cfg: dict) -> dict:
        kwargs: dict = {
            "host": cfg["host"],
            "port": int(cfg.get("port", 22)),
            "username": cfg.get("username", os.environ.get("USER", "root")),
        }

        client_keys = cfg.get("client_keys")
        password = cfg.get("password")

        if client_keys:
            kwargs["client_keys"] = [str(Path(k).expanduser()) for k in client_keys]
            if password:
                kwargs["password"] = password  # fallback
        elif password:
            kwargs["client_keys"] = []  # disable key lookup, use password only
            kwargs["password"] = password
        # else: no creds — asyncssh tries ~/.ssh/* by default

        known_hosts = cfg.get("known_hosts")
        if known_hosts == "none":
            kwargs["known_hosts"] = None
        elif known_hosts is not None:
            kwargs["known_hosts"] = str(Path(known_hosts).expanduser())

        return kwargs