diff --git a/README.md b/README.md index 7e694aa..ce7a8cb 100644 --- a/README.md +++ b/README.md @@ -90,6 +90,26 @@ - Профиль по умолчанию: `navi_code`. - Без авторизации: `NAVI_AUTH_ENABLED=false`. - CLI сохраняет `session_id` в `~/.navi_code/state.json`. +- **Navi работает с проектом в текущей директории:** клиент автоматически передаёт серверу рабочую папку, из которой запущен `navi-code`. Относительные пути в `filesystem`, `terminal` и `code_exec` разрешаются относительно неё. + +### Установка для запуска из любой директории + +После `pip install -e .` консольная команда `navi-code` появляется в PATH внутри активированного venv. Если хочешь вызывать `navi-code` из любой папки без активации venv, есть два варианта: + +**Вариант 1 — symlink на wrapper (рекомендуется для постоянного использования):** +```bash +ln -s /path/to/navi-1/bin/navi-code ~/.local/bin/navi-code +# или любая другая директория из $PATH +``` + +**Вариант 2 — editable install даёт entry point:** +```bash +.venv/bin/pip install -e . +# затем можно symlink и сам entry point: +ln -s /path/to/navi-1/.venv/bin/navi-code ~/.local/bin/navi-code +``` + +Wrapper в `bin/navi-code` сам находит `.venv` рядом с репозиторием и запускает клиент, передавая текущую директорию серверу. Подробнее: [`docs/navi_code.md`](docs/navi_code.md) и [`docs/navi_code_cli.md`](docs/navi_code_cli.md). diff --git a/bin/navi-code b/bin/navi-code new file mode 100755 index 0000000..f4dc66d --- /dev/null +++ b/bin/navi-code @@ -0,0 +1,30 @@ +#!/usr/bin/env bash +""" +Wrapper script for launching the Navi Code terminal client from anywhere. + +Usage: + 1. Clone the navi repository. + 2. Create a venv and install the package: + python -m venv .venv + .venv/bin/pip install -e . + 3. Symlink this script onto your PATH: + ln -s /path/to/navi/bin/navi-code ~/.local/bin/navi-code + 4. cd into any project and run: + navi-code + +The wrapper locates the project root from the script's own location, activates +the local venv if present, and forwards all arguments to the `navi-code` entry +point. The client captures the shell's current working directory and sends it to +the backend so Navi resolves relative paths against your project directory. +""" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" +VENV_DIR="${NAVI_CODE_VENV:-${PROJECT_ROOT}/.venv}" + +if [[ -d "${VENV_DIR}" && -f "${VENV_DIR}/bin/navi-code" ]]; then + exec "${VENV_DIR}/bin/navi-code" "$@" +fi + +# Fallback: run the entry point directly if the package is installed elsewhere. +exec python -m clients.terminal.cli "$@" diff --git a/clients/terminal/cli.py b/clients/terminal/cli.py index 151a95b..8fe33ec 100644 --- a/clients/terminal/cli.py +++ b/clients/terminal/cli.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +from pathlib import Path import click @@ -50,21 +51,23 @@ settings.show_thinking = True settings.show_events = show_events + cwd = Path.cwd().resolve() + if raw or prompt: - _run_raw(prompt, new_session, profile_id) + _run_raw(prompt, new_session, profile_id, cwd) return - _run_tui(profile_id, new_session, theme, mouse) + _run_tui(profile_id, new_session, theme, mouse, cwd) -def _run_raw(prompt: str | None, new_session: bool, profile_id: str | None) -> None: +def _run_raw(prompt: str | None, new_session: bool, profile_id: str | None, cwd: Path) -> None: state = StateManager() session_id = _resolve_session_id(state, new_session, profile_id) if not session_id: raise click.ClickException("Failed to create or resume a session.") renderer = Renderer(show_thinking=settings.show_thinking, show_events=settings.show_events) - client = NaviWebSocketClient(session_id, renderer=renderer) + client = NaviWebSocketClient(session_id, renderer=renderer, cwd=cwd) if prompt: asyncio.run(_run_one_shot(client, prompt)) @@ -74,11 +77,15 @@ def _run_tui( - profile_id: str | None, new_session: bool, theme: str | None, mouse: bool | None + profile_id: str | None, + new_session: bool, + theme: str | None, + mouse: bool | None, + cwd: Path, ) -> None: from clients.terminal.tui.tui_app import NaviCodeTui - app = NaviCodeTui(profile_id=profile_id, new_session=new_session, theme_name=theme) + app = NaviCodeTui(profile_id=profile_id, new_session=new_session, theme_name=theme, cwd=cwd) if mouse is not None: app._mouse_enabled = mouse app.run(mouse=app._mouse_enabled) diff --git a/clients/terminal/tui/chat_model.py b/clients/terminal/tui/chat_model.py index d269e49..307535f 100644 --- a/clients/terminal/tui/chat_model.py +++ b/clients/terminal/tui/chat_model.py @@ -85,6 +85,13 @@ self.items.append(item) return item + if msg_type == "stream_stopped": + self._current_assistant = None + self._current_thinking = None + item = ChatItem(kind="status", content="Generation stopped by user") + self.items.append(item) + return item + if msg_type in ("stream_end", "context_compressed", "heartbeat", "session_sync"): return None diff --git a/clients/terminal/tui/context.py b/clients/terminal/tui/context.py index 418c597..b2cd02d 100644 --- a/clients/terminal/tui/context.py +++ b/clients/terminal/tui/context.py @@ -2,7 +2,8 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field +from pathlib import Path from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -28,6 +29,7 @@ status_panel: "StatusPanel | None" = None sessions_panel: "SessionsPanel | None" = None chat_model: "ChatModel | None" = None + cwd: Path = field(default_factory=lambda: Path.cwd().resolve()) def app(self): """Return the running TuiApp instance.""" diff --git a/clients/terminal/tui/tui_app.py b/clients/terminal/tui/tui_app.py index b928c4e..9488786 100644 --- a/clients/terminal/tui/tui_app.py +++ b/clients/terminal/tui/tui_app.py @@ -2,6 +2,9 @@ from __future__ import annotations +import asyncio +from pathlib import Path + from textual.app import App, ComposeResult from textual.containers import Horizontal, Vertical from textual.widgets import Footer, Header @@ -41,6 +44,7 @@ ("ctrl+x l", "list_sessions", "Sessions"), ("ctrl+x c", "compact", "Compact"), ("ctrl+x t", "toggle_thinking", "Thinking"), + ("escape", "stop_stream", "Stop"), ] def __init__( @@ -49,6 +53,7 @@ profile_id: str | None = None, new_session: bool = False, theme_name: str | None = None, + cwd: Path | None = None, ) -> None: self._tui_settings = get_tui_settings() self._theme_name = theme_name or self._tui_settings.theme @@ -68,6 +73,7 @@ status_panel=self._status_panel, sessions_panel=self._sessions_panel, settings=self._tui_settings, + cwd=cwd or Path.cwd().resolve(), ) self._bridge: WsBridge | None = None self._permission_engine = PermissionEngine() @@ -75,6 +81,7 @@ self._requested_session_id = session_id self._requested_profile_id = profile_id self._force_new_session = new_session + self._streaming = False def compose(self) -> ComposeResult: yield Header(show_clock=False) @@ -90,6 +97,11 @@ self.apply_theme() self.run_worker(self._startup) self._input_box.focus_input() + self._update_footer_bindings() + + def _update_footer_bindings(self) -> None: + """Refresh footer so the Stop binding visibility tracks streaming state.""" + self.refresh_bindings() def _register_textual_themes(self) -> None: """Register every Navi theme as a Textual theme so $tui-* variables resolve.""" @@ -163,10 +175,12 @@ ) self._status_panel.set_backend(settings.base_url) self._status_panel.set_theme(self._theme_name) + if cwd := self._ctx.cwd: + self._chat_panel.handle_ws_event({"type": "status", "content": f"cwd: {cwd}"}) if self._bridge: await self._bridge.stop() - self._bridge = WsBridge(self, session_id) + self._bridge = WsBridge(self, session_id, cwd=self._ctx.cwd) await self._bridge.start() self._ctx.ws_client = self._bridge.client self._chat_panel.handle_ws_event( @@ -307,7 +321,18 @@ def on_ws_event(self, event: WsEvent) -> None: payload = event.payload - if payload.get("type") == "tool_started": + msg_type = payload.get("type") + if msg_type == "stream_start": + self._streaming = True + self._input_box.set_placeholder("Navi is thinking... (Esc to stop)") + self._update_footer_bindings() + elif msg_type in ("stream_end", "stream_stopped", "error"): + self._streaming = False + self._input_box.set_placeholder("Ask anything...") + self._input_box.focus_input() + self._update_footer_bindings() + + if msg_type == "tool_started": tool = payload.get("tool", "") args = payload.get("args") or {} if self._permission_engine.is_always_deny(tool, args): @@ -409,6 +434,24 @@ def action_toggle_thinking(self) -> None: self._run_command("/thinking") + def action_stop_stream(self) -> None: + if not self._streaming: + return + session_id = self._ctx.session_id + if not session_id: + return + self.run_worker(self._stop_stream_worker(session_id)) + + async def _stop_stream_worker(self, session_id: str) -> None: + try: + result = api.stop_session(session_id) + if asyncio.iscoroutine(result): + await result + except Exception as exc: + self._chat_panel.handle_ws_event( + {"type": "error", "message": f"Failed to stop generation: {exc}"} + ) + async def action_quit(self) -> None: if self._bridge: await self._bridge.stop() diff --git a/clients/terminal/tui/widgets/input_box.py b/clients/terminal/tui/widgets/input_box.py index d9964f2..5b42cdf 100644 --- a/clients/terminal/tui/widgets/input_box.py +++ b/clients/terminal/tui/widgets/input_box.py @@ -4,7 +4,7 @@ from textual.app import ComposeResult from textual.containers import Vertical -from textual.widgets import Input, Static +from textual.widgets import Input from clients.terminal.tui.events import UserSubmitted @@ -69,5 +69,8 @@ def focus_input(self) -> None: self._input.focus() + def set_placeholder(self, text: str) -> None: + self._input.placeholder = text + def set_prompt_char(self, char: str) -> None: self._prompt.update(char) diff --git a/clients/terminal/tui/ws_bridge.py b/clients/terminal/tui/ws_bridge.py index d20bd18..455bf6b 100644 --- a/clients/terminal/tui/ws_bridge.py +++ b/clients/terminal/tui/ws_bridge.py @@ -4,6 +4,7 @@ import asyncio import json +from pathlib import Path from textual.app import App @@ -14,10 +15,10 @@ class WsBridge: """Wraps NaviWebSocketClient and forwards events to a Textual App.""" - def __init__(self, app: App, session_id: str) -> None: + def __init__(self, app: App, session_id: str, cwd: Path | None = None) -> None: self.app = app self.session_id = session_id - self._client = NaviWebSocketClient(session_id) + self._client = NaviWebSocketClient(session_id, cwd=cwd) self._receive_task: asyncio.Task | None = None self._input_task: asyncio.Task | None = None self._connected = False diff --git a/clients/terminal/ws_client.py b/clients/terminal/ws_client.py index 6630e3c..1e388ae 100644 --- a/clients/terminal/ws_client.py +++ b/clients/terminal/ws_client.py @@ -4,6 +4,7 @@ import asyncio import json +from pathlib import Path import websockets @@ -18,6 +19,7 @@ self, session_id: str, renderer: Renderer | None = None, + cwd: Path | None = None, ) -> None: self.session_id = session_id self.renderer = renderer or Renderer( @@ -25,6 +27,7 @@ show_events=settings.show_events, ) self.url = settings.websocket_url(session_id) + self._cwd = cwd or Path.cwd().resolve() self._ws: websockets.ClientConnection | None = None self._stop_event = asyncio.Event() self._input_queue: asyncio.Queue[str | None] = asyncio.Queue() @@ -40,7 +43,12 @@ async def send(self, content: str) -> None: if not self._ws: raise RuntimeError("WebSocket is not connected") - await self._ws.send(json.dumps({"type": "message", "content": content})) + payload = { + "type": "message", + "content": content, + "cwd": str(self._cwd), + } + await self._ws.send(json.dumps(payload)) async def receive_loop(self) -> None: if not self._ws: @@ -52,7 +60,7 @@ except json.JSONDecodeError: continue self.renderer.render(msg) - if msg.get("type") in ("stream_end", "error"): + if msg.get("type") in ("stream_end", "stream_stopped", "error"): self._stop_event.set() except websockets.exceptions.ConnectionClosed: pass diff --git a/docs/websocket.md b/docs/websocket.md index 9d09f3b..0141d10 100644 --- a/docs/websocket.md +++ b/docs/websocket.md @@ -31,6 +31,7 @@ { "type": "message", "content": "user text", + "cwd": "/home/user/projects/my-app", "images": ["base64string", ...], "files": [{"name": "file.pdf", "path": "/abs/path"}] } @@ -38,6 +39,7 @@ - `type` must be `"message"`. Other types return an error frame. - `content` is required and must be non-empty. +- `cwd`: optional absolute path of the client's current working directory. Recommended for terminal clients; the server treats it as the project root and resolves relative paths in `filesystem`, `terminal`, and `code_exec` against it. Stored per-session in `session_metadata["cwd"]` and injected into the LLM system context. - `images`: optional list of base64-encoded images (data URIs accepted; the `data:...;base64,` prefix is stripped server-side). **Limits:** max 10 images per message, 5 MB each (decoded). Excess images are rejected with a WebSocket error. - `files`: optional list of uploaded file references (appended to content as `[Uploaded files on disk: ...]`). diff --git a/navi/api/websocket.py b/navi/api/websocket.py index 13d2145..01a8fa3 100644 --- a/navi/api/websocket.py +++ b/navi/api/websocket.py @@ -1,7 +1,11 @@ """WebSocket endpoint for streaming agent responses. Protocol (client -> server): - {"type": "message", "content": "..."} + {"type": "message", "content": "...", "cwd": "..."} + + `cwd` is optional but recommended for terminal clients; when present the + server treats it as the user's working directory and resolves relative paths + against it. Protocol (server -> client): {"type": "stream_start"} @@ -27,8 +31,6 @@ from navi.auth import User from navi.auth.deps import check_session_access from navi.core import SessionStore -from navi.core.event_bus import get_event_bus -from navi.core.events import AgentEvent, RecallUpdate router = APIRouter(tags=["websocket"]) log = structlog.get_logger() @@ -36,6 +38,7 @@ # ── Helpers ─────────────────────────────────────────────────────────────────── + def _event_to_dict(event) -> dict | None: if hasattr(event, "to_wire"): return event.to_wire() @@ -85,6 +88,7 @@ # ── Endpoints ───────────────────────────────────────────────────────────────── + @router.post("/sessions/{session_id}/stop") async def stop_session( session_id: str, @@ -92,6 +96,8 @@ user: Annotated[User | None, Depends(get_current_user)] = None, ) -> dict: """Signal the running agent for this session to stop cooperatively.""" + from fastapi import HTTPException + session = await store.get(session_id) if session is None: raise HTTPException(status_code=404, detail="Session not found") @@ -116,7 +122,12 @@ try: user: User | None = await get_current_user_ws(websocket) except Exception as exc: - log.warning("ws.resolve_user_exc", session_id=session_id, exc_type=type(exc).__name__, error=str(exc)) + log.warning( + "ws.resolve_user_exc", + session_id=session_id, + exc_type=type(exc).__name__, + error=str(exc), + ) user = None log.info("ws.user_resolved", session_id=session_id, user_id=user.id if user else None) @@ -202,10 +213,12 @@ continue if data.get("type") != "message" or not data.get("content"): - await websocket.send_json({ - "type": "error", - "message": "Expected {type: 'message', content: '...'}", - }) + await websocket.send_json( + { + "type": "error", + "message": "Expected {type: 'message', content: '...', cwd?: '...'}", + } + ) continue original_content = data["content"] @@ -216,10 +229,12 @@ _MAX_IMAGES = 8 _MAX_IMAGE_BYTES_TOTAL = 50 * 1024 * 1024 # 50 MB total payload if len(raw_images) > _MAX_IMAGES: - await websocket.send_json({ - "type": "error", - "message": f"Too many images ({len(raw_images)}). Max {_MAX_IMAGES} allowed.", - }) + await websocket.send_json( + { + "type": "error", + "message": f"Too many images ({len(raw_images)}). Max {_MAX_IMAGES} allowed.", + } + ) continue cleaned = [] total_bytes = 0 @@ -228,10 +243,12 @@ img = img.split(",", 1)[1] img_bytes = len(img.encode("utf-8")) if total_bytes + img_bytes > _MAX_IMAGE_BYTES_TOTAL: - await websocket.send_json({ - "type": "error", - "message": "Total image payload exceeds 50 MB limit.", - }) + await websocket.send_json( + { + "type": "error", + "message": "Total image payload exceeds 50 MB limit.", + } + ) cleaned = None break total_bytes += img_bytes @@ -253,20 +270,20 @@ # Append uploaded file paths to user content so Navi knows about them uploaded_files: list[dict] = data.get("files") or [] if uploaded_files: - file_lines = "\n".join( - f"- {f['name']} → {f['path']}" for f in uploaded_files - ) - user_content = ( - user_content + f"\n\n[Uploaded files on disk:\n{file_lines}]" - ) + file_lines = "\n".join(f"- {f['name']} → {f['path']}" for f in uploaded_files) + user_content = user_content + f"\n\n[Uploaded files on disk:\n{file_lines}]" + + client_cwd = data.get("cwd") # Guard against concurrent runs for the same session (atomically). async with orchestrator.session_lock(session_id): if orchestrator.is_running(session_id): - await websocket.send_json({ - "type": "error", - "message": "Agent is already running for this session.", - }) + await websocket.send_json( + { + "type": "error", + "message": "Agent is already running for this session.", + } + ) continue # Register run and subscribe before starting the task so we never @@ -276,7 +293,12 @@ current_run = run # Set user context for tool sandboxing (inherited by the agent task) - from navi.tools._internal.base import current_user_id as _uid_var, current_user_role as _role_var, current_user_info as _uinfo_var + from navi.tools._internal.base import ( + current_user_id as _uid_var, + current_user_role as _role_var, + current_user_info as _uinfo_var, + ) + if user is not None: uid_token = _uid_var.set(user.id) role_token = _role_var.set(user.role) @@ -289,7 +311,13 @@ try: run.task = asyncio.create_task( orchestrator.run_agent( - session_id, user_content, raw_images, original_content, uploaded_files, session_store + session_id, + user_content, + raw_images, + original_content, + uploaded_files, + session_store, + cwd=client_cwd, ) ) except Exception: diff --git a/navi/core/agent.py b/navi/core/agent.py index 978e67d..c69a575 100644 --- a/navi/core/agent.py +++ b/navi/core/agent.py @@ -28,9 +28,22 @@ from PIL import Image from navi.config import settings -from navi.exceptions import ContextTooLargeError, LLMBackendError, LLMConnectionError, MaxIterationsReached, SessionNotFound -from navi.llm.base import LLMBackend, Message, ToolCallRequest -from navi.tools._internal.base import Tool, ToolContext, current_event_sink, current_stop_event, current_user_role, current_user_info +from navi.exceptions import ( + ContextTooLargeError, + LLMBackendError, + LLMConnectionError, + MaxIterationsReached, + SessionNotFound, +) +from navi.llm.base import LLMBackend, Message +from navi.tools._internal.base import ( + Tool, + ToolContext, + current_event_sink, + current_stop_event, + current_user_role, + current_user_info, +) from .agent_run_context import AgentTurnContext, StreamState from .anti_stall import AntiStallMonitor @@ -39,15 +52,12 @@ from .planning import PlanningEngine from .stream_guard import _iter_stream_guarded from .subagent_runner import SubAgentRunner -from .tool_utils import build_tool_list, load_user_enabled_tools +from .tool_utils import build_tool_list from .events import ( AgentEvent, AIHelperTokensUsed, CompressionStarted, - ContextCompressed, PlanningDebugData, - PlanningStatus, - PlanReady, StreamEnd, StreamStopped, SubagentComplete, @@ -56,7 +66,6 @@ ThinkingEnd, ToolEvent, ToolStarted, - TurnThinking, ) from .registry import BackendRegistry, ProfileRegistry, ToolRegistry from .session import SessionStore @@ -65,8 +74,7 @@ if TYPE_CHECKING: from navi.context_providers._loader import ContextProviderRegistry from navi.memory.store import MemoryStore - from navi.workers.base import Worker, WorkerContext - + from navi.workers.base import Worker log = structlog.get_logger() @@ -75,28 +83,117 @@ _TOOL_DONE = object() -_CASUAL_WORDS = frozenset({ - # Russian greetings/social - "привет", "здравствуй", "здравствуйте", "хай", "хелло", "хеллоу", - "как", "дела", "делишки", "ты", "вы", "поживаешь", "поживаете", - "жизнь", "сам", "сама", - "спасибо", "спс", "пока", "bye", "goodbye", - "доброе", "утро", "добрый", "день", "вечер", "спокойной", "ночи", - "ок", "окей", "ладно", "давай", - # English greetings/social - "hi", "hello", "hey", "hola", "bonjour", - "how", "are", "you", "it", "going", "things", "what", "up", "s", - "thanks", "thank", "thx", - "good", "morning", "afternoon", "evening", "night", "see", "cya", - "ok", "okay", - # Common fillers that keep a social phrase social - "a", "an", "the", "and", "too", "very", "much", "today", "now", - "there", "here", "again", "well", "oh", "ah", "um", - "в", "и", "а", "но", "же", "тоже", "очень", "сегодня", "сейчас", - "ну", "вот", "тут", "ещё", "раз", "ка", - "is", "am", "are", "do", "does", "did", "be", "been", "being", - "man", "dude", "bro", "mate", "friend", "dear", -}) +_CASUAL_WORDS = frozenset( + { + # Russian greetings/social + "привет", + "здравствуй", + "здравствуйте", + "хай", + "хелло", + "хеллоу", + "как", + "дела", + "делишки", + "ты", + "вы", + "поживаешь", + "поживаете", + "жизнь", + "сам", + "сама", + "спасибо", + "спс", + "пока", + "bye", + "goodbye", + "доброе", + "утро", + "добрый", + "день", + "вечер", + "спокойной", + "ночи", + "ок", + "окей", + "ладно", + "давай", + # English greetings/social + "hi", + "hello", + "hey", + "hola", + "bonjour", + "how", + "are", + "you", + "it", + "going", + "things", + "what", + "up", + "s", + "thanks", + "thank", + "thx", + "good", + "morning", + "afternoon", + "evening", + "night", + "see", + "cya", + "ok", + "okay", + # Common fillers that keep a social phrase social + "a", + "an", + "the", + "and", + "too", + "very", + "much", + "today", + "now", + "there", + "here", + "again", + "well", + "oh", + "ah", + "um", + "в", + "и", + "а", + "но", + "же", + "тоже", + "очень", + "сегодня", + "сейчас", + "ну", + "вот", + "тут", + "ещё", + "раз", + "ка", + "is", + "am", + "are", + "do", + "does", + "did", + "be", + "been", + "being", + "man", + "dude", + "bro", + "mate", + "friend", + "dear", + } +) def _is_casual_message(text: str) -> bool: @@ -112,7 +209,22 @@ if any(marker in text for marker in ("@", "!", "http://", "https://", "file://")): return False # Path-like or command-like fragments. - if any(fragment in text for fragment in ("/home", "/tmp", "/etc", "./", "../", "~/", "\\", ".py ", ".txt", ".md", ".json")): + if any( + fragment in text + for fragment in ( + "/home", + "/tmp", + "/etc", + "./", + "../", + "~/", + "\\", + ".py ", + ".txt", + ".md", + ".json", + ) + ): return False # A bare leading slash is a command/path indicator. if text.strip().startswith("/"): @@ -131,9 +243,12 @@ return casual_count / len(words) >= 0.5 -async def _todo_progress_message(session_id: str, *, first_iteration: bool = False) -> "Message | None": +async def _todo_progress_message( + session_id: str, *, first_iteration: bool = False +) -> "Message | None": """Build a compact system reminder with current todo state and update discipline.""" from navi.tools.todo import get_progress_message + return await get_progress_message(session_id, first_iteration=first_iteration) @@ -187,10 +302,19 @@ # Public interface # ------------------------------------------------------------------ - async def run(self, session_id: str, user_message: str, images: list[str] | None = None, files: list[dict] | None = None, is_recall: bool = False) -> str: + async def run( + self, + session_id: str, + user_message: str, + images: list[str] | None = None, + files: list[dict] | None = None, + is_recall: bool = False, + ) -> str: """Non-streaming: run the full tool-calling loop and return the final text.""" full_content = "" - async for event in self.run_stream(session_id, user_message, images=images, files=files, is_recall=is_recall): + async for event in self.run_stream( + session_id, user_message, images=images, files=files, is_recall=is_recall + ): if isinstance(event, StreamEnd): full_content = event.full_content or "" return full_content @@ -230,6 +354,7 @@ display_message: str | None = None, files: list[dict] | None = None, is_recall: bool = False, + cwd: str | None = None, ) -> AsyncGenerator[AgentEvent, None]: """ Streaming variant. Yields AgentEvent objects: @@ -254,10 +379,16 @@ mem = await self._ctx_builder._memory_msg(user_id=session.user_id) - # Expose session_id and model to tools via ContextVar - from navi.tools._internal.base import current_session_id as _sid_var, current_model as _model_var + # Expose session_id, model and cwd to tools via ContextVar + from navi.tools._internal.base import ( + current_session_id as _sid_var, + current_model as _model_var, + current_working_directory as _cwd_var, + ) + _sid_token = _sid_var.set(session_id) _model_var.set(profile.model) + _cwd_token = _cwd_var.set(cwd) # Pre-turn compression: if the last turn filled the context past the # threshold, compress NOW before calling the LLM. This prevents the @@ -267,9 +398,15 @@ yield _ev display_text = display_message if display_message is not None else user_message - user_msg_display = Message(role="user", content=display_text, images=images or None, - files=files or None, created_at=datetime.now(timezone.utc), - is_recall=is_recall, is_context=False) + user_msg_display = Message( + role="user", + content=display_text, + images=images or None, + files=files or None, + created_at=datetime.now(timezone.utc), + is_recall=is_recall, + is_context=False, + ) # Image token budgeting: fit as many images as possible into the LLM context. # Overflow images are saved to the session directory so Navi can view them @@ -302,13 +439,17 @@ except Exception: pass if saved_names: - context_content += ( - f"\n\n[Additional images saved to session directory: {', '.join(saved_names)}]" - ) + context_content += f"\n\n[Additional images saved to session directory: {', '.join(saved_names)}]" - user_msg_context = Message(role="user", content=context_content, images=images_for_context or None, - files=files or None, created_at=datetime.now(timezone.utc), - is_recall=is_recall, is_display=False) + user_msg_context = Message( + role="user", + content=context_content, + images=images_for_context or None, + files=files or None, + created_at=datetime.now(timezone.utc), + is_recall=is_recall, + is_display=False, + ) session.messages.append(user_msg_display) session.messages.append(user_msg_context) session.context.append(user_msg_context) @@ -328,8 +469,22 @@ _is_casual = _is_casual_message(context_content) and not profile.planning_mandatory _force_plan = (_is_first_message and not _is_casual) or profile.planning_mandatory if (_is_first_message or profile.planning_enabled) and not _is_casual: - log.debug("agent.planning_enter", session_id=session_id, first_message=_is_first_message, planning_enabled=profile.planning_enabled, force_plan=_force_plan) - async for _ev in self._planning.run(session.context, profile, llm, mem, tool_schemas, messages=session.messages, force_plan=_force_plan): + log.debug( + "agent.planning_enter", + session_id=session_id, + first_message=_is_first_message, + planning_enabled=profile.planning_enabled, + force_plan=_force_plan, + ) + async for _ev in self._planning.run( + session.context, + profile, + llm, + mem, + tool_schemas, + messages=session.messages, + force_plan=_force_plan, + ): if isinstance(_ev, AIHelperTokensUsed): turn_ctx.subagent_tokens += _ev.completion_tokens elif isinstance(_ev, PlanningDebugData): @@ -351,7 +506,11 @@ mem_facts = await mem_facts_task if mem_facts: ctx_injections.append(mem_facts) - log.debug("agent.memory_facts_injected", session_id=session_id, facts_msg_length=len(mem_facts.content or "")) + log.debug( + "agent.memory_facts_injected", + session_id=session_id, + facts_msg_length=len(mem_facts.content or ""), + ) else: log.debug("agent.memory_facts_none", session_id=session_id) @@ -368,22 +527,32 @@ yield StreamStopped() return - async for _ev in self._compression_events_midturn(session, llm, profile, session_id, iteration, ctx_injections, mem): + async for _ev in self._compression_events_midturn( + session, llm, profile, session_id, iteration, ctx_injections, mem + ): yield _ev state = StreamState() - built_ctx = self._ctx_builder.build(session.context, profile, mem, - iteration=iteration, max_iterations=profile.max_iterations, - extra_system=ctx_injections, - session_id=session_id) + built_ctx = self._ctx_builder.build( + session.context, + profile, + mem, + iteration=iteration, + max_iterations=profile.max_iterations, + extra_system=ctx_injections, + session_id=session_id, + session_metadata=session.session_metadata, + ) if ( profile.goal_anchoring_enabled and iteration > 0 and iteration % profile.goal_anchoring_interval == 0 ): - built_ctx.append(await self._ctx_builder._build_goal_anchor(session_id, user_message)) + built_ctx.append( + await self._ctx_builder._build_goal_anchor(session_id, user_message) + ) todo_msg = await _todo_progress_message(session_id, first_iteration=(iteration == 0)) if todo_msg: @@ -433,11 +602,14 @@ # Stopped mid-stream — save partial response and exit if stop_event and stop_event.is_set(): if state.accumulated_text: - session.messages.append(Message( - role="assistant", content=state.accumulated_text, - created_at=datetime.now(timezone.utc), - is_context=False, - )) + session.messages.append( + Message( + role="assistant", + content=state.accumulated_text, + created_at=datetime.now(timezone.utc), + is_context=False, + ) + ) await self._sessions.save(session) yield StreamStopped() return @@ -472,7 +644,9 @@ token_count=_net_tokens if _net_tokens else None, ) - for event in await self._run_workers(session, llm, profile.model, state.context_tokens): + for event in await self._run_workers( + session, llm, profile.model, state.context_tokens + ): yield event return @@ -494,13 +668,19 @@ user_id=session.user_id, user_role=current_user_role.get(), user_info=current_user_info.get(), + cwd=_cwd_var.get(), ) - async for _ev in self._execute_tools_with_sink(turn_tool_calls, tools, turn_ctx, session, stop_event, tool_ctx): + async for _ev in self._execute_tools_with_sink( + turn_tool_calls, tools, turn_ctx, session, stop_event, tool_ctx + ): yield _ev # 6. Cooperative stop: check after tool execution before next LLM call if stop_event and stop_event.is_set(): await self._sessions.save(session) + # Reset context vars before returning so they don't leak. + _cwd_var.reset(_cwd_token) + _sid_var.reset(_sid_token) yield StreamStopped() return @@ -526,6 +706,9 @@ ) await self._sessions.save(session) + # Reset cwd ContextVar so it does not leak into other calls. + _cwd_var.reset(_cwd_token) + _sid_var.reset(_sid_token) raise MaxIterationsReached(profile.max_iterations) # ------------------------------------------------------------------ @@ -557,13 +740,12 @@ result = await worker.run(session, ctx) events.extend(result.events) except Exception: - log.warning("agent.worker_failed", - worker=type(worker).__name__, exc_info=True) + log.warning("agent.worker_failed", worker=type(worker).__name__, exc_info=True) return events def _tool_list( self, - scope: "ToolScopeConfig", + scope: "navi.profiles.base.ToolScopeConfig", # noqa: F821 ) -> list[Tool]: return build_tool_list(scope.native, scope.mcp, self._tools, self._mcp_manager) @@ -598,7 +780,9 @@ if event: yield event - async def _compression_events_midturn(self, session, llm, profile, session_id, iteration, ctx_injections, mem): + async def _compression_events_midturn( + self, session, llm, profile, session_id, iteration, ctx_injections, mem + ): if settings.context_compression_enabled and iteration > 0: preflight_ctx = self._ctx_builder.build( session.context, @@ -632,7 +816,9 @@ if event: yield event - async def _consume_stream(self, stream_gen, stop_event, turn_ctx: AgentTurnContext, state: StreamState): + async def _consume_stream( + self, stream_gen, stop_event, turn_ctx: AgentTurnContext, state: StreamState + ): async for chunk in stream_gen: if stop_event and stop_event.is_set(): if state.thinking_active: @@ -659,7 +845,9 @@ state.thinking_active = False yield ThinkingEnd() - async def _execute_tools_with_sink(self, turn_tool_calls, tools, turn_ctx: AgentTurnContext, session, stop_event, tool_ctx=None): + async def _execute_tools_with_sink( + self, turn_tool_calls, tools, turn_ctx: AgentTurnContext, session, stop_event, tool_ctx=None + ): """Execute tool calls with cooperative stop support. Polls *stop_event* every second while draining the event sink so the @@ -676,7 +864,9 @@ async def _run_with_sentinel(_tc=tc, _holder=result_holder, _sink=sink): try: - _holder.append(await self._tool_executor._run_single_tool(_tc, tool_map, ctx=tool_ctx)) + _holder.append( + await self._tool_executor._run_single_tool(_tc, tool_map, ctx=tool_ctx) + ) except Exception as exc: _holder.append(exc) finally: @@ -713,29 +903,48 @@ # Tool was interrupted by user — record a synthetic cancellation result log.info("agent.tool_stopped", tool=tc.name) yield ToolEvent( - tool_name=tc.name, arguments=tc.arguments, - result="Tool execution was stopped by the user.", success=False, + tool_name=tc.name, + arguments=tc.arguments, + result="Tool execution was stopped by the user.", + success=False, tool_call_id=tc.id, ) - session.messages.append(Message( - role="tool", content="Tool execution was stopped by the user.", - tool_call_id=tc.id, name=tc.name, metadata={}, - is_context=False, - )) + session.messages.append( + Message( + role="tool", + content="Tool execution was stopped by the user.", + tool_call_id=tc.id, + name=tc.name, + metadata={}, + is_context=False, + ) + ) await self._sessions.save(session) return - r = result_holder[0] if result_holder else RuntimeError("tool task produced no result") + r = ( + result_holder[0] + if result_holder + else RuntimeError("tool task produced no result") + ) if isinstance(r, Exception): if isinstance(r, (LLMBackendError, LLMConnectionError)): raise r log.warning("agent.tool_exception", tool=tc.name, error=str(r)) tool_event = ToolEvent( - tool_name=tc.name, arguments=tc.arguments, - result=f"Error: {r}", success=False, + tool_name=tc.name, + arguments=tc.arguments, + result=f"Error: {r}", + success=False, tool_call_id=tc.id, ) - msg = Message(role="tool", content=f"Error: {r}", tool_call_id=tc.id, name=tc.name, metadata={}) + msg = Message( + role="tool", + content=f"Error: {r}", + tool_call_id=tc.id, + name=tc.name, + metadata={}, + ) image_msg = None else: tool_event, msg, image_msg = r @@ -754,4 +963,3 @@ await tool_task except Exception: pass - diff --git a/navi/core/context_builder.py b/navi/core/context_builder.py index 2145d59..3b0b8d8 100644 --- a/navi/core/context_builder.py +++ b/navi/core/context_builder.py @@ -3,7 +3,6 @@ Extracted from agent.py to reduce the Agent class surface area. """ -from datetime import datetime, timezone from typing import TYPE_CHECKING import structlog @@ -24,6 +23,7 @@ """Return a list of formatted todo lines for goal anchoring.""" try: from navi.tools.todo import render_todo_lines as _rtl + return await _rtl(session_id) except Exception: return [] @@ -166,13 +166,21 @@ lines.append("Current todo:") lines.extend(todo_lines) lines.append("Stay on track — complete the remaining pending/in_progress steps.") - lines.append("Use 1-based todo indexes. Mark completed steps done only after verification, with validation.") - lines.append("Before final response, update todo for every completed step, including the final one.") + lines.append( + "Use 1-based todo indexes. Mark completed steps done only after verification, with validation." + ) + lines.append( + "Before final response, update todo for every completed step, including the final one." + ) return Message(role="system", content="\n".join(lines)) def _security_policy_msg(self) -> Message | None: """Build a dynamic security policy system message based on user role.""" - from navi.tools._internal.base import current_user_id as _uid_var, current_user_role as _role_var + from navi.tools._internal.base import ( + current_user_id as _uid_var, + current_user_role as _role_var, + ) + user_id = _uid_var.get(None) role = _role_var.get() if role == "admin": @@ -207,7 +215,11 @@ def _user_context_msg(self) -> Message | None: """Build a [User context] system message from current_user_info ContextVar.""" - from navi.tools._internal.base import current_user_info as _uinfo_var, current_user_role as _role_var + from navi.tools._internal.base import ( + current_user_info as _uinfo_var, + current_user_role as _role_var, + ) + info = _uinfo_var.get(None) if not info: return None @@ -257,7 +269,7 @@ return None lines = ["[MCP servers — external knowledge sources]"] for name, text in instructions.items(): - lines.append(f"") + lines.append("") lines.append(f"## {name}") lines.append(text) return Message(role="system", content="\n".join(lines)) @@ -271,6 +283,7 @@ max_iterations: int | None = None, extra_system: list[Message] | None = None, session_id: str | None = None, + session_metadata: dict | None = None, ) -> list[Message]: system_prompt = self.build_system_prompt(profile) if session_id: @@ -284,6 +297,16 @@ f"(e.g. 'falcon9_rocket.scad') for source_path and output_path. Do NOT include the session_id or " f"the session_files directory in those paths — the MCP server resolves them automatically." ) + cwd = (session_metadata or {}).get("cwd") + if cwd: + system_prompt += ( + f"\n\n---\n\n" + f"[Working directory]\n" + f"The user launched Navi Code from: {cwd}\n" + f"Treat this directory as the project root. Resolve relative paths against it " + f"unless the user provides an absolute path. For filesystem/terminal/code_exec operations, " + f"prefer this directory when no explicit working_dir is given." + ) system_msg = Message(role="system", content=system_prompt) conv = [m for m in session_context if m.role != "system"] result: list[Message] = [system_msg] @@ -309,7 +332,11 @@ result.extend(extra_system) result.extend(conv) - if profile.iteration_budget_enabled and iteration is not None and max_iterations is not None: + if ( + profile.iteration_budget_enabled + and iteration is not None + and max_iterations is not None + ): remaining_after_this = max_iterations - iteration - 1 if remaining_after_this <= 2: urgency = ( @@ -323,12 +350,14 @@ ) else: urgency = "" - result.append(Message( - role="system", - content=( - f"[Iteration {iteration + 1}/{max_iterations} — " - f"{remaining_after_this} iteration(s) after this one.{urgency}]" - ), - )) + result.append( + Message( + role="system", + content=( + f"[Iteration {iteration + 1}/{max_iterations} — " + f"{remaining_after_this} iteration(s) after this one.{urgency}]" + ), + ) + ) return result diff --git a/navi/core/orchestrator.py b/navi/core/orchestrator.py index f08e0a5..8134afe 100644 --- a/navi/core/orchestrator.py +++ b/navi/core/orchestrator.py @@ -6,7 +6,7 @@ import dataclasses from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any import structlog @@ -132,7 +132,12 @@ state = self._sessions.get(session_id) if state is None: return - if state.run is None and state.busy_event is None and not state.websockets and not state.terminals: + if ( + state.run is None + and state.busy_event is None + and not state.websockets + and not state.terminals + ): self._sessions.pop(session_id, None) self._session_locks.pop(session_id, None) @@ -233,6 +238,7 @@ display_content: str | None, files: list[dict] | None, session_store, + cwd: str | None = None, ) -> None: """Execute the agent to completion, broadcasting events to subscribers.""" from navi.tools._internal.base import current_stop_event @@ -241,6 +247,11 @@ run = self._sessions[session_id].run current_stop_event.set(run.stop_event) + session = await session_store.get(session_id) + if session is not None and cwd: + session.session_metadata["cwd"] = cwd + await session_store.save(session) + agent = self._build_agent(session_store) try: @@ -250,6 +261,7 @@ images=raw_images, display_message=display_content, files=files, + cwd=cwd, ): await run.broadcast(("event", event)) except asyncio.CancelledError: @@ -284,8 +296,12 @@ ) await scheduler.reschedule(recall.id, next_trigger) await _publish_recall_update( - recall.session_id, recall.id, recall.call_type, - trigger_at=next_trigger.isoformat(), status="pending", action="rescheduled" + recall.session_id, + recall.id, + recall.call_type, + trigger_at=next_trigger.isoformat(), + status="pending", + action="rescheduled", ) return @@ -293,21 +309,33 @@ if outcome == "success": await scheduler.mark_fired(recall.id) await _publish_recall_update( - recall.session_id, recall.id, recall.call_type, - trigger_at=recall.trigger_at.isoformat(), status="fired", action="fired" + recall.session_id, + recall.id, + recall.call_type, + trigger_at=recall.trigger_at.isoformat(), + status="fired", + action="fired", ) elif outcome == "failed": await scheduler.mark_cancelled(recall.id) await _publish_recall_update( - recall.session_id, recall.id, recall.call_type, - trigger_at=recall.trigger_at.isoformat(), status="cancelled", action="cancelled" + recall.session_id, + recall.id, + recall.call_type, + trigger_at=recall.trigger_at.isoformat(), + status="cancelled", + action="cancelled", ) else: # max_iterations for one-time -> mark_fired (preserving existing behaviour) await scheduler.mark_fired(recall.id) await _publish_recall_update( - recall.session_id, recall.id, recall.call_type, - trigger_at=recall.trigger_at.isoformat(), status="fired", action="fired" + recall.session_id, + recall.id, + recall.call_type, + trigger_at=recall.trigger_at.isoformat(), + status="fired", + action="fired", ) async def run_recall( @@ -317,7 +345,7 @@ store, ) -> None: """Execute a scheduled recall headlessly and notify connected clients.""" - from navi.core.agent import Agent, MaxIterationsReached + from navi.core.agent import MaxIterationsReached from navi.core.event_bus import get_event_bus from navi.core.events import StreamEnd from navi.tools._internal.base import current_stop_event @@ -346,6 +374,7 @@ current_user_role as _role_var, current_user_info as _uinfo_var, ) + if session and session.user_id is not None: uid_token = _uid_var.set(session.user_id) role_token = _role_var.set("user") diff --git a/navi/tools/_internal/base.py b/navi/tools/_internal/base.py index 32d1b54..a8834c8 100644 --- a/navi/tools/_internal/base.py +++ b/navi/tools/_internal/base.py @@ -20,12 +20,16 @@ # Set by run_stream() before executing a tool. run_ephemeral() reads this to forward # sub-agent tool events up to the parent WS stream. -current_event_sink: ContextVar[asyncio.Queue | None] = ContextVar("current_event_sink", default=None) +current_event_sink: ContextVar[asyncio.Queue | None] = ContextVar( + "current_event_sink", default=None +) # Set by _run_agent() before run_stream(). Cooperative stop: when set, the agent # breaks out of LLM loops cleanly (aclose() is called → Ollama stream closes gracefully, # model stays in VRAM). Never use task.cancel() for stopping generation. -current_stop_event: ContextVar[asyncio.Event | None] = ContextVar("current_stop_event", default=None) +current_stop_event: ContextVar[asyncio.Event | None] = ContextVar( + "current_stop_event", default=None +) # Set by run_stream() / run_ephemeral() to expose the current profile's model name # to tools that need to make their own LLM calls (e.g. AIHelper-powered tools). @@ -41,6 +45,13 @@ # ContextBuilder so the LLM receives [User context] with name, email, locale, etc. current_user_info: ContextVar[dict | None] = ContextVar("current_user_info", default=None) +# Set by run_stream() to expose the terminal client's working directory. Tools that +# resolve relative paths (filesystem, terminal, code_exec) can use it as the default +# base in single-user/legacy mode. +current_working_directory: ContextVar[str | None] = ContextVar( + "current_working_directory", default=None +) + @dataclass class ToolContext: @@ -53,6 +64,7 @@ user_id: str | None = None user_role: str = "user" user_info: dict | None = None + cwd: str | None = None @dataclass diff --git a/navi/tools/code_exec.py b/navi/tools/code_exec.py index 211fc61..78671ac 100644 --- a/navi/tools/code_exec.py +++ b/navi/tools/code_exec.py @@ -11,12 +11,24 @@ import tempfile from pathlib import Path -from ._internal.base import Tool, ToolContext, ToolResult, current_user_id, current_user_role +from ._internal.base import ( + Tool, + ToolContext, + ToolResult, + current_user_id, + current_user_role, + current_working_directory, +) _TIMEOUT = 30 -def _resolve_working_dir(working_dir: str | None, user_id: str | None = None, role: str | None = None) -> Path: +def _resolve_working_dir( + working_dir: str | None, + user_id: str | None = None, + role: str | None = None, + cwd: str | None = None, +) -> Path: """Resolve working directory with sandbox enforcement for non-admins.""" user_id = user_id or current_user_id.get(None) role = role or current_user_role.get() @@ -39,6 +51,12 @@ if working_dir: return Path(working_dir).expanduser().resolve() + + # In single-user/legacy/admin mode, default to the client's working directory + # so Python scripts execute in the project the user launched from. + session_cwd = cwd or current_working_directory.get(None) + if session_cwd: + return Path(session_cwd).expanduser().resolve() return Path(tempfile.gettempdir()) @@ -71,7 +89,9 @@ role = ctx.user_role if ctx else current_user_role.get() user_id = ctx.user_id if ctx else current_user_id.get(None) - cwd = _resolve_working_dir(params.get("working_dir"), user_id, role) + cwd = _resolve_working_dir( + params.get("working_dir"), user_id, role, cwd=ctx.cwd if ctx else None + ) if user_id and role != "admin": # Write temp file inside the sandbox so file I/O in user code @@ -105,7 +125,7 @@ if stdout: output_parts.append(stdout.decode(errors="replace")) if stderr: - output_parts.append(f"[stderr]\n{stderr.decode(errors="replace")}") + output_parts.append(f"[stderr]\n{stderr.decode(errors='replace')}") success = proc.returncode == 0 return ToolResult( diff --git a/navi/tools/filesystem.py b/navi/tools/filesystem.py index 0832744..8502d8a 100644 --- a/navi/tools/filesystem.py +++ b/navi/tools/filesystem.py @@ -15,16 +15,23 @@ from navi.config import settings -from ._internal.base import Tool, ToolContext, ToolResult, current_user_id, current_user_role +from ._internal.base import ( + Tool, + ToolContext, + ToolResult, + current_user_id, + current_user_role, + current_working_directory, +) -_READ_WARN_BYTES = 100_000 # 100 KB — add size warning in output -_READ_HARD_BYTES = 1_000_000 # 1 MB — refuse full read without offset/limit +_READ_WARN_BYTES = 100_000 # 100 KB — add size warning in output +_READ_HARD_BYTES = 1_000_000 # 1 MB — refuse full read without offset/limit _LIST_MAX_ENTRIES = 500 _FIND_MAX_RESULTS = 200 # AI actions: ~20k tokens of file content per chunk (4 chars ≈ 1 token) -_AI_CHUNK_CHARS = 80_000 -_AI_OVERLAP_LINES = 30 +_AI_CHUNK_CHARS = 80_000 +_AI_OVERLAP_LINES = 30 # smart_edit: refuse files larger than ~50k tokens (full file must fit in one call) _AI_EDIT_MAX_CHARS = 200_000 @@ -69,7 +76,32 @@ # ── Path helpers ────────────────────────────────────────────────────────────── -def _check_path(path_str: str, user_id: str | None = None, role: str | None = None) -> Path | None: + +def _resolve_relative_path( + p: Path, user_id: str | None, role: str | None, cwd: str | None = None +) -> Path: + """Resolve a relative path against the session cwd when available. + + In single-user/legacy/admin mode, uses the client-provided working directory + so that `filesystem read src/main.py` works from the project directory the + user launched Navi Code in, regardless of the server's cwd. + """ + if user_id and role != "admin": + # Sandbox mode keeps its own resolution rules. + return p + + cwd = cwd or current_working_directory.get(None) + if cwd: + return (Path(cwd) / p).resolve() + return p.resolve() + + +def _check_path( + path_str: str, + user_id: str | None = None, + role: str | None = None, + cwd: str | None = None, +) -> Path | None: """Return resolved Path if access is allowed, else None. When a user_id is active (multi-user mode), all paths are resolved @@ -106,7 +138,10 @@ # Fallback to FS_ALLOWED_PATHS for single-user / legacy mode / admin try: - p = p.expanduser().resolve() + if not p.is_absolute(): + p = _resolve_relative_path(p, user_id, role, cwd=cwd) + else: + p = p.expanduser().resolve() except Exception: return None @@ -124,10 +159,13 @@ def _fmt_size(n: int) -> str: - if n < 1024: return f"{n} B" - if n < 1024 ** 2: return f"{n / 1024:.1f} KB" - if n < 1024 ** 3: return f"{n / 1024 ** 2:.1f} MB" - return f"{n / 1024 ** 3:.1f} GB" + if n < 1024: + return f"{n} B" + if n < 1024**2: + return f"{n / 1024:.1f} KB" + if n < 1024**3: + return f"{n / 1024**2:.1f} MB" + return f"{n / 1024**3:.1f} GB" def _fmt_time(ts: float) -> str: @@ -136,6 +174,7 @@ # ── AI helpers (module-level, no self) ──────────────────────────────────────── + def _number_lines(lines: list[str], start: int = 1) -> str: """Return file lines with 1-based line numbers, right-aligned.""" width = len(str(start + len(lines) - 1)) @@ -150,7 +189,7 @@ if not lines: return [(0, 0)] total = len(lines) - total_chars = sum(len(l) + 1 for l in lines) + total_chars = sum(len(line) + 1 for line in lines) if total_chars <= target_chars: return [(0, total)] @@ -174,10 +213,12 @@ errors: list[str] = [] for i, op in enumerate(ops): if not isinstance(op, dict): - errors.append(f"op[{i}] is not a dict"); continue + errors.append(f"op[{i}] is not a dict") + continue kind = op.get("op") if kind not in ("replace", "delete", "insert"): - errors.append(f"op[{i}] unknown type {kind!r}"); continue + errors.append(f"op[{i}] unknown type {kind!r}") + continue if kind in ("replace", "delete"): s, e = op.get("start"), op.get("end") if not isinstance(s, int) or not isinstance(e, int): @@ -204,8 +245,8 @@ for op in sorted_ops: kind = op["op"] if kind == "replace": - s = op["start"] - 1 # 0-based - e = op["end"] # exclusive (1-based end = exclusive 0-based end) + s = op["start"] - 1 # 0-based + e = op["end"] # exclusive (1-based end = exclusive 0-based end) new = op.get("content", "").split("\n") result[s:e] = new elif kind == "delete": @@ -213,25 +254,28 @@ e = op["end"] del result[s:e] elif kind == "insert": - after = op["after"] # insert after this 1-based line (0 = before line 1) + after = op["after"] # insert after this 1-based line (0 = before line 1) new = op.get("content", "").split("\n") result[after:after] = new return result def _unified_diff(original: list[str], modified: list[str], path: Path) -> str: - diff = list(difflib.unified_diff( - [l + "\n" for l in original], - [l + "\n" for l in modified], - fromfile=f"a/{path.name}", - tofile=f"b/{path.name}", - lineterm="", - )) + diff = list( + difflib.unified_diff( + [line + "\n" for line in original], + [line + "\n" for line in modified], + fromfile=f"a/{path.name}", + tofile=f"b/{path.name}", + lineterm="", + ) + ) return "\n".join(diff) # ── Tool class ──────────────────────────────────────────────────────────────── + class FilesystemTool(Tool): name = "filesystem" description = ( @@ -251,9 +295,24 @@ "action": { "type": "string", "enum": [ - "read", "write", "append", "edit", "edit_lines", "list", "find", "find_up", - "info", "move", "copy", "delete", "exists", "mkdir", - "query", "smart_edit", "grep", "diff", + "read", + "write", + "append", + "edit", + "edit_lines", + "list", + "find", + "find_up", + "info", + "move", + "copy", + "delete", + "exists", + "mkdir", + "query", + "smart_edit", + "grep", + "diff", ], "description": "Operation to perform.", }, @@ -277,7 +336,7 @@ "type": "array", "description": ( "JSON array of line-based edit operations (required for edit_lines). " - "Each op is {\"op\": \"replace\"|\"delete\"|\"insert\", \"start\": int, \"end\": int, \"content\": str}. " + 'Each op is {"op": "replace"|"delete"|"insert", "start": int, "end": int, "content": str}. ' "Line numbers are 1-based and inclusive. Use edit_lines for fast deterministic edits " "when you know the exact lines (e.g. 'change line 15 from X to Y')." ), @@ -351,11 +410,11 @@ self._ai = ai_helper async def execute(self, params: dict, ctx: ToolContext | None = None) -> ToolResult: - action = params.get("action", "") + action = params.get("action", "") raw_path = params.get("path", "") user_id = ctx.user_id if ctx else current_user_id.get(None) role = ctx.user_role if ctx else current_user_role.get() - path = _check_path(raw_path, user_id, role) + path = _check_path(raw_path, user_id, role, cwd=ctx.cwd if ctx else None) if path is None: if not raw_path or raw_path.strip() == "": @@ -376,26 +435,46 @@ try: match action: - case "read": return await asyncio.to_thread(self._read, path, params) - case "write": return await asyncio.to_thread(self._write, path, params) - case "append": return await asyncio.to_thread(self._append, path, params) - case "edit": return await asyncio.to_thread(self._edit, path, params) - case "edit_lines": return await asyncio.to_thread(self._edit_lines, path, params) - case "list": return await asyncio.to_thread(self._list, path, params) - case "find": return await asyncio.to_thread(self._find, path, params) - case "find_up": return await asyncio.to_thread(self._find_up, path, params) - case "info": return await asyncio.to_thread(self._info, path) - case "move": return await asyncio.to_thread(self._move, path, params) - case "copy": return await asyncio.to_thread(self._copy, path, params) - case "delete": return await asyncio.to_thread(self._delete, path) - case "exists": return ToolResult(success=True, output="true" if path.exists() else "false") - case "mkdir": return await asyncio.to_thread(self._mkdir, path) - case "query": return await self._query(path, params) - case "smart_edit": return await self._smart_edit(path, params) - case "grep": return await asyncio.to_thread(self._grep, path, params) - case "diff": return await asyncio.to_thread(self._diff, path, params) + case "read": + return await asyncio.to_thread(self._read, path, params) + case "write": + return await asyncio.to_thread(self._write, path, params) + case "append": + return await asyncio.to_thread(self._append, path, params) + case "edit": + return await asyncio.to_thread(self._edit, path, params) + case "edit_lines": + return await asyncio.to_thread(self._edit_lines, path, params) + case "list": + return await asyncio.to_thread(self._list, path, params) + case "find": + return await asyncio.to_thread(self._find, path, params) + case "find_up": + return await asyncio.to_thread(self._find_up, path, params) + case "info": + return await asyncio.to_thread(self._info, path) + case "move": + return await asyncio.to_thread(self._move, path, params) + case "copy": + return await asyncio.to_thread(self._copy, path, params) + case "delete": + return await asyncio.to_thread(self._delete, path) + case "exists": + return ToolResult(success=True, output="true" if path.exists() else "false") + case "mkdir": + return await asyncio.to_thread(self._mkdir, path) + case "query": + return await self._query(path, params) + case "smart_edit": + return await self._smart_edit(path, params) + case "grep": + return await asyncio.to_thread(self._grep, path, params) + case "diff": + return await asyncio.to_thread(self._diff, path, params) case _: - return ToolResult(success=False, output=f"Unknown action: {action}", error="invalid_action") + return ToolResult( + success=False, output=f"Unknown action: {action}", error="invalid_action" + ) except PermissionError as e: return ToolResult(success=False, output=f"Permission denied: {e}", error=str(e)) @@ -408,11 +487,15 @@ if not path.exists(): return ToolResult(success=False, output=f"File not found: {path}", error="not_found") if path.is_dir(): - return ToolResult(success=False, output=f"Path is a directory, use 'list': {path}", error="is_directory") + return ToolResult( + success=False, + output=f"Path is a directory, use 'list': {path}", + error="is_directory", + ) file_size = path.stat().st_size offset = params.get("offset") - limit = params.get("limit") + limit = params.get("limit") if file_size > _READ_HARD_BYTES and offset is None and limit is None: return ToolResult( @@ -425,13 +508,13 @@ error="file_too_large", ) - text = path.read_text(encoding="utf-8", errors="replace") + text = path.read_text(encoding="utf-8", errors="replace") lines = text.splitlines(keepends=True) total_lines = len(lines) if offset is not None or limit is not None: - start = max(0, (offset or 1) - 1) - end = (start + limit) if limit is not None else total_lines + start = max(0, (offset or 1) - 1) + end = (start + limit) if limit is not None else total_lines selected = lines[start:end] actual_end = min(end, total_lines) header = ( @@ -442,12 +525,16 @@ warn = ( f"⚠ Large file ({_fmt_size(file_size)}) — consider offset/limit next time.\n" - if file_size > _READ_WARN_BYTES else "" + if file_size > _READ_WARN_BYTES + else "" ) numbered = params.get("numbered", False) header = f"[{path} | {total_lines} lines | {_fmt_size(file_size)}]\n" if numbered: - return ToolResult(success=True, output=header + warn + _number_lines(text.splitlines(keepends=False), 1)) + return ToolResult( + success=True, + output=header + warn + _number_lines(text.splitlines(keepends=False), 1), + ) return ToolResult(success=True, output=header + warn + text) def _write(self, path: Path, params: dict) -> ToolResult: @@ -455,30 +542,44 @@ path.parent.mkdir(parents=True, exist_ok=True) path.write_text(content, encoding="utf-8") lines = len(content.splitlines()) - return ToolResult(success=True, output=f"Written {_fmt_size(len(content.encode()))} ({lines} lines) → {path}") + return ToolResult( + success=True, + output=f"Written {_fmt_size(len(content.encode()))} ({lines} lines) → {path}", + ) def _append(self, path: Path, params: dict) -> ToolResult: content = params.get("content", "") if not content: - return ToolResult(success=False, output="'content' is required for append", error="missing_content") + return ToolResult( + success=False, output="'content' is required for append", error="missing_content" + ) path.parent.mkdir(parents=True, exist_ok=True) with path.open("a", encoding="utf-8") as f: f.write(content) - return ToolResult(success=True, output=f"Appended {_fmt_size(len(content.encode()))} to {path} (file now {_fmt_size(path.stat().st_size)})") + return ToolResult( + success=True, + output=f"Appended {_fmt_size(len(content.encode()))} to {path} (file now {_fmt_size(path.stat().st_size)})", + ) def _edit(self, path: Path, params: dict) -> ToolResult: old = params.get("old") new = params.get("new") if old is None: - return ToolResult(success=False, output="'old' is required for edit", error="missing_old") + return ToolResult( + success=False, output="'old' is required for edit", error="missing_old" + ) if new is None: - return ToolResult(success=False, output="'new' is required for edit", error="missing_new") + return ToolResult( + success=False, output="'new' is required for edit", error="missing_new" + ) if old == "": return ToolResult(success=False, output="'old' must not be empty", error="empty_old") if not path.exists(): return ToolResult(success=False, output=f"File not found: {path}", error="not_found") if path.is_dir(): - return ToolResult(success=False, output="edit works on files, not directories.", error="is_directory") + return ToolResult( + success=False, output="edit works on files, not directories.", error="is_directory" + ) text = path.read_text(encoding="utf-8", errors="replace") count = text.count(old) @@ -515,11 +616,19 @@ def _edit_lines(self, path: Path, params: dict) -> ToolResult: ops = params.get("operations") if not ops: - return ToolResult(success=False, output="'operations' is required for edit_lines", error="missing_operations") + return ToolResult( + success=False, + output="'operations' is required for edit_lines", + error="missing_operations", + ) if not path.exists(): return ToolResult(success=False, output=f"File not found: {path}", error="not_found") if path.is_dir(): - return ToolResult(success=False, output="edit_lines works on files, not directories.", error="is_directory") + return ToolResult( + success=False, + output="edit_lines works on files, not directories.", + error="is_directory", + ) text = path.read_text(encoding="utf-8", errors="replace") lines = text.splitlines() @@ -554,16 +663,16 @@ if path.is_file(): return self._info(path) - recursive = params.get("recursive", False) + recursive = params.get("recursive", False) raw_entries = list(path.rglob("*") if recursive else path.iterdir()) raw_entries.sort(key=lambda e: (e.is_file(), str(e).lower())) truncated = len(raw_entries) > _LIST_MAX_ENTRIES - entries = raw_entries[:_LIST_MAX_ENTRIES] - lines = [] + entries = raw_entries[:_LIST_MAX_ENTRIES] + lines = [] for e in entries: try: - s = e.stat() + s = e.stat() rel = e.relative_to(path) if e.is_dir(): if not recursive: @@ -575,18 +684,22 @@ else: lines.append(f"d {rel}/") else: - lines.append(f" {str(rel):<48} {_fmt_size(s.st_size):>10} {_fmt_time(s.st_mtime)}") + lines.append( + f" {str(rel):<48} {_fmt_size(s.st_size):>10} {_fmt_time(s.st_mtime)}" + ) except Exception: lines.append(f"? {e.name}") - note = " ⚠ truncated" if truncated else "" + note = " ⚠ truncated" if truncated else "" header = f"[{path} | {len(entries)} entries{note}]\n" return ToolResult(success=True, output=header + ("\n".join(lines) or "(empty directory)")) def _find(self, path: Path, params: dict) -> ToolResult: pattern = params.get("pattern") if not pattern: - return ToolResult(success=False, output="'pattern' is required for find", error="missing_pattern") + return ToolResult( + success=False, output="'pattern' is required for find", error="missing_pattern" + ) if not path.exists(): return ToolResult(success=False, output=f"Path not found: {path}", error="not_found") @@ -611,7 +724,9 @@ except Exception: lines.append(str(m)) - extra = f" ⚠ showing first {_FIND_MAX_RESULTS}" if len(matches) == _FIND_MAX_RESULTS else "" + extra = ( + f" ⚠ showing first {_FIND_MAX_RESULTS}" if len(matches) == _FIND_MAX_RESULTS else "" + ) header = f"[{len(matches)} matches for '{pattern}' in {path}{extra}]\n" return ToolResult(success=True, output=header + "\n".join(lines)) @@ -628,14 +743,16 @@ return ToolResult(success=True, output=str(target)) parent = current.parent if parent == current: - return ToolResult(success=True, output=f"not found (searched: {', '.join(checked)})") + return ToolResult( + success=True, output=f"not found (searched: {', '.join(checked)})" + ) current = parent def _info(self, path: Path) -> ToolResult: if not path.exists(): return ToolResult(success=False, output=f"Not found: {path}", error="not_found") - s = path.stat() + s = path.stat() kind = "symlink" if path.is_symlink() else ("directory" if path.is_dir() else "file") lines = [ f"path: {path}", @@ -654,7 +771,9 @@ elif path.is_dir(): try: children = list(path.iterdir()) - lines.append(f"contents: {sum(c.is_file() for c in children)} files, {sum(c.is_dir() for c in children)} dirs (top level)") + lines.append( + f"contents: {sum(c.is_file() for c in children)} files, {sum(c.is_dir() for c in children)} dirs (top level)" + ) except Exception: pass return ToolResult(success=True, output="\n".join(lines)) @@ -662,10 +781,18 @@ def _move(self, path: Path, params: dict) -> ToolResult: dest_raw = params.get("destination") if not dest_raw: - return ToolResult(success=False, output="'destination' is required for move", error="missing_destination") + return ToolResult( + success=False, + output="'destination' is required for move", + error="missing_destination", + ) dest = _check_path(dest_raw) if dest is None: - return ToolResult(success=False, output=f"Access denied: destination '{dest_raw}' outside allowed paths.", error="access_denied") + return ToolResult( + success=False, + output=f"Access denied: destination '{dest_raw}' outside allowed paths.", + error="access_denied", + ) if not path.exists(): return ToolResult(success=False, output=f"Not found: {path}", error="not_found") dest.parent.mkdir(parents=True, exist_ok=True) @@ -688,10 +815,18 @@ def _copy(self, path: Path, params: dict) -> ToolResult: dest_raw = params.get("destination") if not dest_raw: - return ToolResult(success=False, output="'destination' is required for copy", error="missing_destination") + return ToolResult( + success=False, + output="'destination' is required for copy", + error="missing_destination", + ) dest = _check_path(dest_raw) if dest is None: - return ToolResult(success=False, output=f"Access denied: destination '{dest_raw}' outside allowed paths.", error="access_denied") + return ToolResult( + success=False, + output=f"Access denied: destination '{dest_raw}' outside allowed paths.", + error="access_denied", + ) if not path.exists(): return ToolResult(success=False, output=f"Not found: {path}", error="not_found") dest.parent.mkdir(parents=True, exist_ok=True) @@ -701,7 +836,9 @@ def _grep(self, path: Path, params: dict) -> ToolResult: pattern = params.get("pattern", "").strip() if not pattern: - return ToolResult(success=False, output="'pattern' is required for grep", error="missing_pattern") + return ToolResult( + success=False, output="'pattern' is required for grep", error="missing_pattern" + ) use_regex = params.get("regex", False) glob_filter = params.get("glob", "") @@ -711,7 +848,9 @@ else: compiled = re.compile(re.escape(pattern), re.IGNORECASE) except re.error as e: - return ToolResult(success=False, output=f"Invalid pattern: {e}", error="invalid_pattern") + return ToolResult( + success=False, output=f"Invalid pattern: {e}", error="invalid_pattern" + ) max_results = 200 matches: list[str] = [] @@ -753,16 +892,26 @@ def _diff(self, path: Path, params: dict) -> ToolResult: dest_raw = params.get("destination") if not dest_raw: - return ToolResult(success=False, output="'destination' is required for diff (path to second file)", error="missing_destination") + return ToolResult( + success=False, + output="'destination' is required for diff (path to second file)", + error="missing_destination", + ) dest = _check_path(dest_raw) if dest is None: - return ToolResult(success=False, output=f"Access denied: destination '{dest_raw}' outside allowed paths.", error="access_denied") + return ToolResult( + success=False, + output=f"Access denied: destination '{dest_raw}' outside allowed paths.", + error="access_denied", + ) if not path.exists(): return ToolResult(success=False, output=f"Not found: {path}", error="not_found") if not dest.exists(): return ToolResult(success=False, output=f"Not found: {dest}", error="not_found") if path.is_dir() or dest.is_dir(): - return ToolResult(success=False, output="diff works on files, not directories.", error="is_directory") + return ToolResult( + success=False, output="diff works on files, not directories.", error="is_directory" + ) try: a_lines = path.read_text(encoding="utf-8", errors="replace").splitlines() @@ -770,13 +919,15 @@ except Exception as e: return ToolResult(success=False, output=f"Read error: {e}", error=str(e)) - diff = list(difflib.unified_diff( - a_lines, - b_lines, - fromfile=f"a/{path.name}", - tofile=f"b/{dest.name}", - lineterm="", - )) + diff = list( + difflib.unified_diff( + a_lines, + b_lines, + fromfile=f"a/{path.name}", + tofile=f"b/{dest.name}", + lineterm="", + ) + ) if not diff: return ToolResult(success=True, output="Files are identical.") return ToolResult(success=True, output="\n".join(diff)) @@ -798,15 +949,19 @@ question = params.get("question", "").strip() if not question: - return ToolResult(success=False, output="'question' is required for query.", error="missing_question") + return ToolResult( + success=False, output="'question' is required for query.", error="missing_question" + ) if not path.exists(): return ToolResult(success=False, output=f"File not found: {path}", error="not_found") if path.is_dir(): - return ToolResult(success=False, output="query works on files, not directories.", error="is_directory") + return ToolResult( + success=False, output="query works on files, not directories.", error="is_directory" + ) - text = await asyncio.to_thread(path.read_text, "utf-8", "replace") - lines = text.splitlines() - total = len(lines) + text = await asyncio.to_thread(path.read_text, "utf-8", "replace") + lines = text.splitlines() + total = len(lines) chunks = _make_chunks(lines, _AI_CHUNK_CHARS, _AI_OVERLAP_LINES) if len(chunks) == 1: @@ -822,7 +977,7 @@ partials: list[str] = [] for s, e in chunks: numbered = _number_lines(lines[s:e], s + 1) - partial = await self._ai.ask( + partial = await self._ai.ask( _QUERY_CHUNK_SYSTEM, f"File: {path} (lines {s + 1}–{e} of {total})\nQuestion: {question}\n\nContent:\n{numbered}", ) @@ -830,7 +985,10 @@ partials.append(f"[lines {s + 1}–{e}] {partial}") if not partials: - return ToolResult(success=True, output=f"No information found in '{path.name}' relevant to: {question}") + return ToolResult( + success=True, + output=f"No information found in '{path.name}' relevant to: {question}", + ) if len(partials) == 1: # Single finding — strip range prefix, return directly @@ -839,7 +997,8 @@ answer = await self._ai.ask( _QUERY_SYNTHESIS_SYSTEM, - f"Question: {question}\n\nFindings from {len(partials)} sections:\n\n" + "\n\n".join(partials), + f"Question: {question}\n\nFindings from {len(partials)} sections:\n\n" + + "\n\n".join(partials), ) return ToolResult(success=True, output=answer) @@ -849,11 +1008,19 @@ instruction = params.get("instruction", "").strip() if not instruction: - return ToolResult(success=False, output="'instruction' is required for smart_edit.", error="missing_instruction") + return ToolResult( + success=False, + output="'instruction' is required for smart_edit.", + error="missing_instruction", + ) if not path.exists(): return ToolResult(success=False, output=f"File not found: {path}", error="not_found") if path.is_dir(): - return ToolResult(success=False, output="smart_edit works on files, not directories.", error="is_directory") + return ToolResult( + success=False, + output="smart_edit works on files, not directories.", + error="is_directory", + ) text = await asyncio.to_thread(path.read_text, "utf-8", "replace") if len(text) > _AI_EDIT_MAX_CHARS: @@ -867,7 +1034,7 @@ error="file_too_large", ) - lines = text.splitlines() + lines = text.splitlines() numbered = _number_lines(lines, 1) raw_ops = await self._ai.ask_json( @@ -897,10 +1064,11 @@ ) new_lines = _apply_ops(lines, raw_ops) - diff = _unified_diff(lines, new_lines, path) + diff = _unified_diff(lines, new_lines, path) # Preserve trailing newline — write atomically to avoid partial writes on failure import os + new_text = "\n".join(new_lines) + ("\n" if text.endswith("\n") else "") tmp = path.with_suffix(path.suffix + ".tmp") diff --git a/navi/tools/terminal.py b/navi/tools/terminal.py index e09436d..9b7972e 100644 --- a/navi/tools/terminal.py +++ b/navi/tools/terminal.py @@ -31,7 +31,15 @@ from navi.config import settings -from ._internal.base import Tool, ToolContext, ToolResult, current_event_sink, current_user_id, current_user_role +from ._internal.base import ( + Tool, + ToolContext, + ToolResult, + current_event_sink, + current_user_id, + current_user_role, + current_working_directory, +) from ..tools._internal.terminal_manager import TerminalManager _DEFAULT_TIMEOUT = 20 @@ -69,7 +77,12 @@ return None -def _resolve_working_dir(working_dir: str | None, user_id: str | None = None, role: str | None = None) -> Path | None: +def _resolve_working_dir( + working_dir: str | None, + user_id: str | None = None, + role: str | None = None, + cwd: str | None = None, +) -> Path | None: """Resolve working directory with sandbox enforcement for non-admins.""" user_id = user_id or current_user_id.get(None) role = role or current_user_role.get() @@ -92,6 +105,12 @@ if working_dir: return Path(working_dir).expanduser().resolve() + + # In single-user/legacy/admin mode, default to the client's working directory + # when available so shell commands run in the project the user launched from. + session_cwd = cwd or current_working_directory.get(None) + if session_cwd: + return Path(session_cwd).expanduser().resolve() return None @@ -151,17 +170,20 @@ async def execute(self, params: dict, ctx: ToolContext | None = None) -> ToolResult: action = params.get("action", "") if not action: - return ToolResult(success=False, output="Missing 'action' parameter.", error="missing_action") + return ToolResult( + success=False, output="Missing 'action' parameter.", error="missing_action" + ) role = ctx.user_role if ctx else current_user_role.get() user_id = ctx.user_id if ctx else current_user_id.get(None) session_id = ctx.session_id if ctx else None + tool_cwd = ctx.cwd if ctx else current_working_directory.get(None) match action: case "run": - return await self._run(params, role, user_id) + return await self._run(params, role, user_id, tool_cwd) case "open": - return await self._open(params, role, user_id, session_id) + return await self._open(params, role, user_id, session_id, tool_cwd) case "close": return await self._close(params, session_id) case "list": @@ -171,11 +193,19 @@ case "send_input": return await self._send_input(params, session_id) case _: - return ToolResult(success=False, output=f"Unknown action: {action}", error="invalid_action") + return ToolResult( + success=False, output=f"Unknown action: {action}", error="invalid_action" + ) # ── Action handlers ────────────────────────────────────────────────────── - async def _run(self, params: dict, role: str | None, user_id: str | None) -> ToolResult: + async def _run( + self, + params: dict, + role: str | None, + user_id: str | None, + tool_cwd: str | None = None, + ) -> ToolResult: """One-shot command execution (original terminal behaviour).""" command = params.get("command", "").strip() if not command: @@ -183,11 +213,19 @@ working_dir = params.get("working_dir") or None raw_timeout = params.get("timeout") - timeout = max(1, min(int(raw_timeout), _MAX_TIMEOUT)) if raw_timeout is not None else _DEFAULT_TIMEOUT + timeout = ( + max(1, min(int(raw_timeout), _MAX_TIMEOUT)) + if raw_timeout is not None + else _DEFAULT_TIMEOUT + ) - cwd = _resolve_working_dir(working_dir, user_id, role) + cwd = _resolve_working_dir(working_dir, user_id, role, cwd=tool_cwd) if cwd is None and working_dir: - return ToolResult(success=False, output="Working directory is outside your sandbox.", error="sandbox_violation") + return ToolResult( + success=False, + output="Working directory is outside your sandbox.", + error="sandbox_violation", + ) # Admin / legacy unrestricted mode if not user_id or role == "admin": @@ -202,11 +240,24 @@ return ToolResult(success=False, output=f"Blocked: {danger}", error="dangerous_command") return await self._run_user_restricted(command, cwd, timeout) - async def _open(self, params: dict, role: str | None, user_id: str | None, session_id: str | None) -> ToolResult: + async def _open( + self, + params: dict, + role: str | None, + user_id: str | None, + session_id: str | None, + tool_cwd: str | None = None, + ) -> ToolResult: if self._tm is None: - return ToolResult(success=False, output="Terminal manager is not available.", error="no_manager") + return ToolResult( + success=False, output="Terminal manager is not available.", error="no_manager" + ) if not session_id: - return ToolResult(success=False, output="Persistent terminals require a session context.", error="no_session") + return ToolResult( + success=False, + output="Persistent terminals require a session context.", + error="no_session", + ) name = params.get("terminal_name", "").strip() description = params.get("description", "").strip() @@ -214,18 +265,32 @@ background = bool(params.get("background")) working_dir = params.get("working_dir") or None raw_timeout = params.get("timeout") - timeout = max(1, min(int(raw_timeout), _MAX_TIMEOUT)) if raw_timeout is not None else _DEFAULT_TIMEOUT + timeout = ( + max(1, min(int(raw_timeout), _MAX_TIMEOUT)) + if raw_timeout is not None + else _DEFAULT_TIMEOUT + ) if not name: - return ToolResult(success=False, output="Missing 'terminal_name' for open.", error="missing_name") + return ToolResult( + success=False, output="Missing 'terminal_name' for open.", error="missing_name" + ) if not description: - return ToolResult(success=False, output="Missing 'description' for open.", error="missing_description") + return ToolResult( + success=False, output="Missing 'description' for open.", error="missing_description" + ) if not command: - return ToolResult(success=False, output="Missing 'command' for open.", error="empty_command") + return ToolResult( + success=False, output="Missing 'command' for open.", error="empty_command" + ) - cwd = _resolve_working_dir(working_dir, user_id, role) + cwd = _resolve_working_dir(working_dir, user_id, role, cwd=tool_cwd) if cwd is None and working_dir: - return ToolResult(success=False, output="Working directory is outside your sandbox.", error="sandbox_violation") + return ToolResult( + success=False, + output="Working directory is outside your sandbox.", + error="sandbox_violation", + ) # Security checks (same as run) if not user_id or role == "admin": @@ -242,7 +307,9 @@ else: danger = _check_dangerous(command) if danger: - return ToolResult(success=False, output=f"Blocked: {danger}", error="dangerous_command") + return ToolResult( + success=False, output=f"Blocked: {danger}", error="dangerous_command" + ) tokens = shlex.split(command) allowed = settings.terminal_user_allowed_commands_list if tokens and tokens[0] not in allowed: @@ -260,13 +327,17 @@ try: exec_tokens = shlex.split(command) except ValueError as e: - return ToolResult(success=False, output=f"Invalid command syntax: {e}", error=str(e)) + return ToolResult( + success=False, output=f"Invalid command syntax: {e}", error=str(e) + ) else: # Non-admin: always use exec to enforce restrictions try: exec_tokens = shlex.split(command) except ValueError as e: - return ToolResult(success=False, output=f"Invalid command syntax: {e}", error=str(e)) + return ToolResult( + success=False, output=f"Invalid command syntax: {e}", error=str(e) + ) try: session = await self._tm.open( @@ -296,7 +367,9 @@ output_parts = list(session.output_buffer) combined = "".join(output_parts) if len(combined) > _MAX_OUTPUT_CHARS: - combined = combined[:_MAX_OUTPUT_CHARS] + f"\n…[truncated — {len(combined)} chars total]" + combined = ( + combined[:_MAX_OUTPUT_CHARS] + f"\n…[truncated — {len(combined)} chars total]" + ) rc = session.proc.returncode if session.proc else None # Close foreground terminal immediately so it doesn't clutter list await self._tm.close(session_id, name) @@ -309,22 +382,30 @@ async def _close(self, params: dict, session_id: str | None) -> ToolResult: if self._tm is None: - return ToolResult(success=False, output="Terminal manager is not available.", error="no_manager") + return ToolResult( + success=False, output="Terminal manager is not available.", error="no_manager" + ) if not session_id: return ToolResult(success=False, output="No session context.", error="no_session") name = params.get("terminal_name", "").strip() if not name: - return ToolResult(success=False, output="Missing 'terminal_name' for close.", error="missing_name") + return ToolResult( + success=False, output="Missing 'terminal_name' for close.", error="missing_name" + ) ok = await self._tm.close(session_id, name) if not ok: - return ToolResult(success=False, output=f"Terminal '{name}' not found.", error="not_found") + return ToolResult( + success=False, output=f"Terminal '{name}' not found.", error="not_found" + ) return ToolResult(success=True, output=f"Terminal '{name}' closed.") def _list(self, session_id: str | None) -> ToolResult: if self._tm is None: - return ToolResult(success=False, output="Terminal manager is not available.", error="no_manager") + return ToolResult( + success=False, output="Terminal manager is not available.", error="no_manager" + ) if not session_id: return ToolResult(success=False, output="No session context.", error="no_session") @@ -344,17 +425,23 @@ def _status(self, params: dict, session_id: str | None) -> ToolResult: if self._tm is None: - return ToolResult(success=False, output="Terminal manager is not available.", error="no_manager") + return ToolResult( + success=False, output="Terminal manager is not available.", error="no_manager" + ) if not session_id: return ToolResult(success=False, output="No session context.", error="no_session") name = params.get("terminal_name", "").strip() if not name: - return ToolResult(success=False, output="Missing 'terminal_name' for status.", error="missing_name") + return ToolResult( + success=False, output="Missing 'terminal_name' for status.", error="missing_name" + ) st = self._tm.status(session_id, name) if st is None: - return ToolResult(success=False, output=f"Terminal '{name}' not found.", error="not_found") + return ToolResult( + success=False, output=f"Terminal '{name}' not found.", error="not_found" + ) lines = [ f"Terminal: {st['name']}", @@ -375,18 +462,28 @@ async def _send_input(self, params: dict, session_id: str | None) -> ToolResult: if self._tm is None: - return ToolResult(success=False, output="Terminal manager is not available.", error="no_manager") + return ToolResult( + success=False, output="Terminal manager is not available.", error="no_manager" + ) if not session_id: return ToolResult(success=False, output="No session context.", error="no_session") name = params.get("terminal_name", "").strip() text = params.get("input", "") if not name: - return ToolResult(success=False, output="Missing 'terminal_name' for send_input.", error="missing_name") + return ToolResult( + success=False, + output="Missing 'terminal_name' for send_input.", + error="missing_name", + ) ok = await self._tm.send_input(session_id, name, text) if not ok: - return ToolResult(success=False, output=f"Cannot send input to '{name}'. Terminal may be closed or not accepting input.", error="send_failed") + return ToolResult( + success=False, + output=f"Cannot send input to '{name}'. Terminal may be closed or not accepting input.", + error="send_failed", + ) return ToolResult(success=True, output=f"Sent input to '{name}'.") # ── Low-level runners (unchanged from original) ───────────────────────── @@ -404,17 +501,21 @@ stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout) except asyncio.TimeoutError: proc.kill() - return ToolResult(success=False, output=f"Timed out after {timeout}s", error="timeout") + return ToolResult( + success=False, output=f"Timed out after {timeout}s", error="timeout" + ) output_parts = [] if stdout: output_parts.append(stdout.decode(errors="replace")) if stderr: - output_parts.append(f"[stderr]\n{stderr.decode(errors="replace")}") + output_parts.append(f"[stderr]\n{stderr.decode(errors='replace')}") combined = "\n".join(output_parts) or "(no output)" if len(combined) > _MAX_OUTPUT_CHARS: - combined = combined[:_MAX_OUTPUT_CHARS] + f"\n…[truncated — {len(combined)} chars total]" + combined = ( + combined[:_MAX_OUTPUT_CHARS] + f"\n…[truncated — {len(combined)} chars total]" + ) success = proc.returncode == 0 return ToolResult( @@ -441,7 +542,7 @@ return ToolResult( success=False, output=f"Command '{tokens[0]}' is not in the allowed list. " - f"Allowed: {allowed}. Set TERMINAL_ALLOWED_COMMANDS=* to allow all.", + f"Allowed: {allowed}. Set TERMINAL_ALLOWED_COMMANDS=* to allow all.", error="not_allowed", ) @@ -456,17 +557,21 @@ stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout) except asyncio.TimeoutError: proc.kill() - return ToolResult(success=False, output=f"Timed out after {timeout}s", error="timeout") + return ToolResult( + success=False, output=f"Timed out after {timeout}s", error="timeout" + ) output_parts = [] if stdout: output_parts.append(stdout.decode(errors="replace")) if stderr: - output_parts.append(f"[stderr]\n{stderr.decode(errors="replace")}") + output_parts.append(f"[stderr]\n{stderr.decode(errors='replace')}") combined = "\n".join(output_parts) or "(no output)" if len(combined) > _MAX_OUTPUT_CHARS: - combined = combined[:_MAX_OUTPUT_CHARS] + f"\n…[truncated — {len(combined)} chars total]" + combined = ( + combined[:_MAX_OUTPUT_CHARS] + f"\n…[truncated — {len(combined)} chars total]" + ) success = proc.returncode == 0 return ToolResult( @@ -476,11 +581,15 @@ error=None if success else f"Exit code {proc.returncode}", ) except FileNotFoundError: - return ToolResult(success=False, output=f"Command not found: {tokens[0]}", error="not_found") + return ToolResult( + success=False, output=f"Command not found: {tokens[0]}", error="not_found" + ) except Exception as e: return ToolResult(success=False, output=f"Execution error: {e}", error=str(e)) - async def _run_user_restricted(self, command: str, cwd: Path | None, timeout: int) -> ToolResult: + async def _run_user_restricted( + self, command: str, cwd: Path | None, timeout: int + ) -> ToolResult: """Run for non-admin users: allowlist + shell features + sandbox cwd.""" try: tokens = shlex.split(command) @@ -495,7 +604,7 @@ return ToolResult( success=False, output=f"Command '{tokens[0]}' is not in the allowed list for non-admin users. " - f"Allowed: {allowed}.", + f"Allowed: {allowed}.", error="not_allowed", ) diff --git a/tests/clients/test_terminal_client.py b/tests/clients/test_terminal_client.py index 0ffcdd0..78fe95a 100644 --- a/tests/clients/test_terminal_client.py +++ b/tests/clients/test_terminal_client.py @@ -94,8 +94,9 @@ """Raw CLI must read session_id/name/preview from the server API.""" class FakeWsClient: - def __init__(self, session_id: str, renderer=None) -> None: + def __init__(self, session_id: str, renderer=None, cwd=None) -> None: self.session_id = session_id + self.cwd = cwd async def run_one_shot(self, prompt: str) -> None: pass diff --git a/tests/clients/test_tui_app.py b/tests/clients/test_tui_app.py index 37ee734..b83ab1c 100644 --- a/tests/clients/test_tui_app.py +++ b/tests/clients/test_tui_app.py @@ -6,6 +6,7 @@ import pytest +from clients.terminal.tui.events import WsEvent from clients.terminal.tui.tui_app import NaviCodeTui @@ -121,3 +122,91 @@ theme_text = str(status._theme.render()) assert "Backend:" in backend_text assert "Theme: gnexus-dark" in theme_text + + +@pytest.mark.anyio +async def test_streaming_state_tracks_ws_events() -> None: + """stream_start enters streaming mode; stream_end/stream_stopped/error leave it.""" + async with NaviCodeTui(new_session=True).run_test() as pilot: + await pilot.pause() + app = pilot.app + input_box = app.query_one("InputBox") + + app.on_ws_event(WsEvent({"type": "stream_start"})) + await pilot.pause() + assert app._streaming is True + assert "Esc to stop" in input_box._input.placeholder + + app.on_ws_event(WsEvent({"type": "stream_delta", "delta": "hi"})) + await pilot.pause() + assert app._streaming is True + + app.on_ws_event(WsEvent({"type": "stream_stopped"})) + await pilot.pause() + assert app._streaming is False + assert input_box._input.placeholder == "Ask anything..." + + +@pytest.mark.anyio +async def test_escape_stops_active_stream(monkeypatch: pytest.MonkeyPatch) -> None: + """Pressing Esc while streaming calls api.stop_session for the current session.""" + stopped: list[str] = [] + + async def fake_stop_session(session_id: str) -> dict: + stopped.append(session_id) + return {"ok": True} + + import clients.terminal.tui.tui_app as tui_app_module + + monkeypatch.setattr(tui_app_module.api, "stop_session", fake_stop_session) + + async with NaviCodeTui(new_session=True).run_test() as pilot: + await pilot.pause() + app = pilot.app + # Simulate an active run for the resolved session. + app.on_ws_event(WsEvent({"type": "stream_start"})) + await pilot.pause() + assert app._streaming is True + current_session = app._ctx.session_id + assert current_session + + await pilot.press("escape") + await pilot.pause(0.1) + assert stopped == [current_session] + + +@pytest.mark.anyio +async def test_stream_stopped_renders_status() -> None: + """stream_stopped resets the assistant buffer and shows a status message.""" + async with NaviCodeTui(new_session=True).run_test() as pilot: + await pilot.pause() + chat = pilot.app.query_one("ChatPanel") + chat.handle_ws_event({"type": "stream_start"}) + chat.handle_ws_event({"type": "stream_delta", "delta": "partial"}) + chat.handle_ws_event({"type": "stream_stopped"}) + await pilot.pause() + assert chat._model._current_assistant is None + assert any( + item.kind == "status" and "stopped" in item.content.lower() + for item in chat._model.items + ) + + +@pytest.mark.anyio +async def test_escape_does_nothing_when_not_streaming(monkeypatch: pytest.MonkeyPatch) -> None: + """Esc without an active stream does not call stop_session.""" + calls: list[str] = [] + + async def fake_stop_session(session_id: str) -> dict: + calls.append(session_id) + return {"ok": True} + + import clients.terminal.tui.tui_app as tui_app_module + + monkeypatch.setattr(tui_app_module.api, "stop_session", fake_stop_session) + + async with NaviCodeTui(new_session=True).run_test() as pilot: + await pilot.pause() + await pilot.press("escape") + await pilot.pause(0.1) + assert calls == [] diff --git a/tests/integration/test_websocket.py b/tests/integration/test_websocket.py index 8970e9c..f9d0653 100644 --- a/tests/integration/test_websocket.py +++ b/tests/integration/test_websocket.py @@ -32,7 +32,15 @@ """Patch orchestrator.run_agent so it broadcasts deterministic events.""" orchestrator = _get_orchestrator(mock_deps) - async def fake_run_agent(session_id, user_content, raw_images, display_content, files, session_store): + async def fake_run_agent( + session_id, + user_content, + raw_images, + display_content, + files, + session_store, + cwd=None, + ): run = orchestrator.get_run(session_id) if run is None: return @@ -150,6 +158,7 @@ # ── Helpers ────────────────────────────────────────────────────────────────── + def _collect_until_done(ws, max_messages: int = 10) -> list[dict]: """Collect websocket messages until stream_end, error, or max messages.""" msgs: list[dict] = []