diff --git a/.env.example b/.env.example index 3852b53..454713d 100644 --- a/.env.example +++ b/.env.example @@ -1,5 +1,5 @@ OLLAMA_HOST=http://localhost:11434 -OLLAMA_DEFAULT_MODEL=gemma4:e4b-it-q8 +OLLAMA_DEFAULT_MODEL=gemma4:e4b-it-q_8 OPENAI_API_KEY= ANTHROPIC_API_KEY= diff --git a/navi/core/agent.py b/navi/core/agent.py index 335fc40..4f2d7be 100644 --- a/navi/core/agent.py +++ b/navi/core/agent.py @@ -99,6 +99,10 @@ mem = await self._memory_msg() + # Expose session_id to tools (e.g. SSH connection pool) via ContextVar + from navi.tools.base import current_session_id as _sid_var + _sid_var.set(session_id) + user_msg = Message(role="user", content=user_message, images=images or None, created_at=datetime.now(timezone.utc)) session.messages.append(user_msg) @@ -170,6 +174,10 @@ mem = await self._memory_msg() + # Expose session_id to tools (e.g. SSH connection pool) via ContextVar + from navi.tools.base import current_session_id as _sid_var + _sid_token = _sid_var.set(session_id) + user_msg = Message(role="user", content=user_message, images=images or None, created_at=datetime.now(timezone.utc)) session.messages.append(user_msg) diff --git a/navi/tools/base.py b/navi/tools/base.py index 6692ce8..af9ce84 100644 --- a/navi/tools/base.py +++ b/navi/tools/base.py @@ -3,13 +3,20 @@ Each tool is self-describing: name, description, and parameters (JSON Schema). The schema() method builds the LLM-facing function spec automatically. + +current_session_id — ContextVar set by Agent before every tool call. +Tools that need per-session state (e.g. SSH connection pool) read it here. """ from abc import ABC, abstractmethod +from contextvars import ContextVar from dataclasses import dataclass, field from navi.llm.base import ToolSchema +# Set by Agent before every tool call. Tools that need per-session state read this. +current_session_id: ContextVar[str | None] = ContextVar("current_session_id", default=None) + @dataclass class ToolResult: diff --git a/navi/tools/ssh_exec.py b/navi/tools/ssh_exec.py index c6b8a65..d856636 100644 --- a/navi/tools/ssh_exec.py +++ b/navi/tools/ssh_exec.py @@ -1,5 +1,9 @@ """SSH tool — execute commands on remote hosts via SSH. +Connections are cached per-session with a 20-minute TTL. Each session gets +its own independent connection pool, so parallel sessions working the same +server don't interfere with each other. + Connections are defined in ssh_hosts.json (see .env.example for path config). The tool also accepts inline host/user/key parameters for one-off connections. @@ -26,15 +30,79 @@ import asyncio import json import os +import time +from dataclasses import dataclass from pathlib import Path import asyncssh from navi.config import settings -from .base import Tool, ToolResult +from .base import Tool, ToolResult, current_session_id _TIMEOUT = 60 +_TTL = 20 * 60 # 20 minutes in seconds + + +@dataclass +class _PoolEntry: + conn: asyncssh.SSHClientConnection + created_at: float # time.monotonic() + connect_kwargs: dict + + +# Global pool: key → _PoolEntry +_pool: dict[str, _PoolEntry] = {} + +# Per-key locks to prevent concurrent connection creation for the same key. +# Protected by a meta-lock when creating new per-key locks. +_lock_meta = asyncio.Lock() +_locks: dict[str, asyncio.Lock] = {} + + +async def _get_lock(key: str) -> asyncio.Lock: + async with _lock_meta: + if key not in _locks: + _locks[key] = asyncio.Lock() + return _locks[key] + + +async def _get_connection(key: str, connect_kwargs: dict) -> asyncssh.SSHClientConnection: + """Return a live connection from pool, creating or reconnecting as needed.""" + lock = await _get_lock(key) + async with lock: + entry = _pool.get(key) + if entry is not None: + age = time.monotonic() - entry.created_at + if age > _TTL or entry.conn.is_closing(): + # Expired or already closed — drop it + try: + entry.conn.close() + except Exception: + pass + del _pool[key] + entry = None + + if entry is None: + conn = await asyncssh.connect(**connect_kwargs) + _pool[key] = _PoolEntry(conn=conn, created_at=time.monotonic(), connect_kwargs=connect_kwargs) + + return _pool[key].conn + + +def _evict(key: str) -> None: + """Remove a connection from the pool (called after an error).""" + entry = _pool.pop(key, None) + if entry is not None: + try: + entry.conn.close() + except Exception: + pass + + +def _pool_key(session_id: str | None, host: str, port: int, username: str) -> str: + sid = session_id or "anonymous" + return f"{sid}:{host}:{port}:{username}" def _load_hosts() -> dict: @@ -53,7 +121,8 @@ "Run a command on a remote server over SSH. " "Required: command + host. Pass username and password (or key_path) directly — " "no config file needed. Use for any task on a remote VPS, server, or device. " - "Host key verification is skipped by default for ad-hoc connections." + "Host key verification is skipped by default for ad-hoc connections. " + "Connections are reused within a session (20-minute TTL) — no reconnect overhead." ) parameters = { "type": "object", @@ -109,34 +178,70 @@ error="no_target", ) - try: - async with asyncssh.connect(**connect_kwargs) as conn: + session_id = current_session_id.get() + host = connect_kwargs["host"] + port = int(connect_kwargs.get("port", 22)) + username = connect_kwargs.get("username", "") + key = _pool_key(session_id, host, port, username) + + return await self._run_with_retry(key, connect_kwargs, command, timeout) + + async def _run_with_retry( + self, + key: str, + connect_kwargs: dict, + command: str, + timeout: int, + ) -> ToolResult: + for attempt in range(2): + try: + conn = await _get_connection(key, connect_kwargs) result = await asyncio.wait_for( conn.run(command, check=False), timeout=timeout, ) - output_parts = [] - if result.stdout: - output_parts.append(result.stdout) - if result.stderr: - output_parts.append(f"[stderr]\n{result.stderr}") + output_parts = [] + if result.stdout: + output_parts.append(result.stdout) + if result.stderr: + output_parts.append(f"[stderr]\n{result.stderr}") - success = result.exit_status == 0 - return ToolResult( - success=success, - output="\n".join(output_parts) or "(no output)", - metadata={"exit_status": result.exit_status, "host": connect_kwargs.get("host")}, - error=None if success else f"Exit status {result.exit_status}", - ) - except asyncssh.DisconnectError as e: - return ToolResult(success=False, output=f"SSH disconnected: {e}", error=str(e)) - except asyncssh.PermissionDenied: - return ToolResult(success=False, output="SSH permission denied. Check username and password/key.", error="permission_denied") - except (TimeoutError, asyncio.TimeoutError): - return ToolResult(success=False, output=f"SSH command timed out after {timeout}s", error="timeout") - except Exception as e: - return ToolResult(success=False, output=f"SSH error: {e}", error=str(e)) + success = result.exit_status == 0 + return ToolResult( + success=success, + output="\n".join(output_parts) or "(no output)", + metadata={"exit_status": result.exit_status, "host": connect_kwargs.get("host")}, + error=None if success else f"Exit status {result.exit_status}", + ) + + except (asyncssh.DisconnectError, asyncssh.ConnectionLost, EOFError) as e: + _evict(key) + if attempt == 0: + continue # retry with fresh connection + return ToolResult(success=False, output=f"SSH disconnected: {e}", error=str(e)) + + except asyncssh.PermissionDenied: + _evict(key) + return ToolResult( + success=False, + output="SSH permission denied. Check username and password/key.", + error="permission_denied", + ) + + except (TimeoutError, asyncio.TimeoutError): + return ToolResult( + success=False, + output=f"SSH command timed out after {timeout}s", + error="timeout", + ) + + except Exception as e: + _evict(key) + return ToolResult(success=False, output=f"SSH error: {e}", error=str(e)) + + # Should not be reached + return ToolResult(success=False, output="SSH: unexpected retry exhaustion", error="retry_failed") def _resolve(self, params: dict) -> dict | None: # Named connection from ssh_hosts.json takes precedence @@ -197,5 +302,3 @@ kwargs["known_hosts"] = str(Path(known_hosts).expanduser()) return kwargs - -