diff --git a/clients/terminal/tui/file_refs.py b/clients/terminal/tui/file_refs.py new file mode 100644 index 0000000..4871e2d --- /dev/null +++ b/clients/terminal/tui/file_refs.py @@ -0,0 +1,250 @@ +"""Resolve @path references inside user input. + +Supported forms: + @path/to/file.py → file content wrapped in code fence + @dir/ → list of files in directory (recursive if trailing /) + @tests/**/*.py → glob expansion, files only + +Size limits apply per-file and in total to avoid flooding the LLM context. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Iterable + + +MAX_FILE_BYTES = 64_000 +MAX_TOTAL_BYTES = 128_000 +TRUNCATED_NOTICE = "\n... [truncated by Navi Code]" + + +@dataclass +class ResolvedFile: + """A file resolved from an @ reference.""" + + path: Path + display_path: str + content: str + truncated: bool = False + + +@dataclass +class FileRefResult: + """Result of resolving @ references in a prompt.""" + + prompt: str # user-visible prompt (with @ markers replaced by file list) + attachments: list[ResolvedFile] = field(default_factory=list) + errors: list[str] = field(default_factory=list) + total_bytes: int = 0 + + def is_empty(self) -> bool: + return not self.attachments and not self.errors + + def to_message(self) -> str: + """Build the full message to send to the backend.""" + if not self.attachments and not self.errors: + return self.prompt + + parts = [self.prompt] + if self.attachments: + parts.append("") + parts.append("--- attached files ---") + for f in self.attachments: + lang = _guess_language(f.path) + label = f"file: {f.display_path}" + if f.truncated: + label += " (truncated)" + parts.append(f"```{lang} {label}") + parts.append(f.content) + parts.append("```") + if self.errors: + parts.append("") + parts.append("--- attachment errors ---") + for err in self.errors: + parts.append(f"- {err}") + return "\n".join(parts) + + +_ref_pattern = re.compile(r"@((?:[A-Za-z0-9_\-\.~/$*?]|\\\s)+)") + + +def find_refs(text: str) -> list[str]: + """Return all @path tokens found in text, in order, without duplicates.""" + seen: set[str] = set() + refs: list[str] = OrderedRefs() + for raw in _ref_pattern.findall(text): + # Un-escape backslash-space inside the token. + ref = raw.replace("\\ ", " ") + if ref not in seen: + seen.add(ref) + refs.append(ref) + return refs + + +class OrderedRefs(list): + """Stub kept for type clarity; plain list suffices.""" + + +class FileRefResolver: + """Resolve @ references relative to a base directory.""" + + def __init__(self, base_dir: Path | str | None = None) -> None: + self.base_dir = Path(base_dir or Path.cwd()).expanduser().resolve() + + def resolve(self, text: str) -> FileRefResult: + refs = find_refs(text) + if not refs: + return FileRefResult(prompt=text) + + result = FileRefResult(prompt=text) + for ref in refs: + self._resolve_ref(ref, result) + if result.total_bytes >= MAX_TOTAL_BYTES: + result.errors.append("total attachment size limit reached; remaining files skipped") + break + return result + + def _resolve_ref(self, ref: str, result: FileRefResult) -> None: + path = self._expand_path(ref) + if path is None: + result.errors.append(f"could not resolve {ref!r}") + return + + if path.exists(): + if path.is_dir(): + files = sorted(_collect_files(path, recursive=ref.endswith("/"))) + if not files: + result.errors.append(f"no files found in {ref}") + return + for file_path in files: + self._attach_file(file_path, result, root_dir=path) + if result.total_bytes >= MAX_TOTAL_BYTES: + return + return + + if path.is_file(): + self._attach_file(path, result) + return + + result.errors.append(f"not a file or directory: {ref}") + return + + # Non-existent path: try glob expansion if it looks like a pattern. + if _is_glob(ref): + matches = sorted(self.base_dir.glob(ref)) + if not matches: + result.errors.append(f"no matches for {ref}") + return + for file_path in matches: + if not file_path.is_file(): + continue + self._attach_file(file_path, result) + if result.total_bytes >= MAX_TOTAL_BYTES: + return + return + + result.errors.append(f"not found: {ref}") + + def _attach_file(self, path: Path, result: FileRefResult, root_dir: Path | None = None) -> None: + display = _relative_or_absolute(path, self.base_dir) + if root_dir is not None: + display = _relative_or_absolute(path, root_dir) + try: + data = path.read_bytes() + except Exception as exc: + result.errors.append(f"failed to read {path}: {exc}") + return + + truncated = False + if len(data) > MAX_FILE_BYTES: + data = data[:MAX_FILE_BYTES] + truncated = True + + remaining = MAX_TOTAL_BYTES - result.total_bytes + if remaining <= 0: + return + if len(data) > remaining: + data = data[:remaining] + truncated = True + + try: + text = data.decode("utf-8", errors="replace") + except Exception as exc: + result.errors.append(f"failed to decode {path}: {exc}") + return + + if truncated: + text += TRUNCATED_NOTICE + result.attachments.append(ResolvedFile(path=path, display_path=display, content=text, truncated=truncated)) + result.total_bytes += len(data) + + def _expand_path(self, ref: str) -> Path | None: + # Strip any trailing slash for expansion, but keep the flag later. + clean = ref.rstrip("/") + if not clean: + return None + if clean.startswith("~"): + return Path(clean).expanduser() + return (self.base_dir / clean).resolve() + + +def _is_glob(ref: str) -> bool: + """Return True if ref contains glob metacharacters.""" + return "*" in ref or "?" in ref or "[" in ref + + +def _collect_files(path: Path, recursive: bool = False) -> Iterable[Path]: + """Yield files inside a directory.""" + if recursive: + for p in sorted(path.rglob("*")): + if p.is_file(): + yield p + else: + for p in sorted(path.iterdir()): + if p.is_file(): + yield p + + +def _relative_or_absolute(path: Path, base: Path) -> str: + try: + return str(path.relative_to(base)) + except ValueError: + return str(path) + + +def _guess_language(path: Path) -> str: + """Best-effort language tag for markdown code fence.""" + mapping = { + ".py": "python", + ".js": "javascript", + ".ts": "typescript", + ".tsx": "tsx", + ".jsx": "jsx", + ".go": "go", + ".rs": "rust", + ".c": "c", + ".cpp": "cpp", + ".h": "c", + ".java": "java", + ".kt": "kotlin", + ".sh": "bash", + ".zsh": "bash", + ".bash": "bash", + ".md": "markdown", + ".json": "json", + ".yaml": "yaml", + ".yml": "yaml", + ".toml": "toml", + ".html": "html", + ".css": "css", + ".scss": "scss", + ".sql": "sql", + ".dockerfile": "dockerfile", + ".lock": "text", + ".txt": "text", + ".env": "bash", + } + return mapping.get(path.suffix.lower(), "text") diff --git a/clients/terminal/tui/permissions.py b/clients/terminal/tui/permissions.py index 8f2a809..a6bbe35 100644 --- a/clients/terminal/tui/permissions.py +++ b/clients/terminal/tui/permissions.py @@ -105,4 +105,8 @@ return args.get("path", "") or args.get("destination", "") if tool == "terminal": return args.get("command", "") or args.get("action", "") + if tool == "code_exec": + return args.get("language", "") or args.get("code", "")[:40] + if tool == "ssh_exec": + return args.get("host", "") or args.get("command", "") return "" diff --git a/clients/terminal/tui/renderers/__init__.py b/clients/terminal/tui/renderers/__init__.py index 7cabd58..fc30a6e 100644 --- a/clients/terminal/tui/renderers/__init__.py +++ b/clients/terminal/tui/renderers/__init__.py @@ -4,7 +4,7 @@ from .base import ContentRenderer from .registry import RendererRegistry -from . import message, tool, thinking, error, markdown_content, plain +from . import message, tool, thinking, error, markdown_content, plain, diff, artifact def default_registry() -> RendererRegistry: @@ -17,6 +17,8 @@ reg.register(tool.ToolResultRenderer()) reg.register(error.ErrorRenderer()) reg.register(markdown_content.MarkdownRenderer()) + reg.register(diff.DiffRenderer()) + reg.register(artifact.ArtifactRenderer()) reg.register(plain.PlainRenderer()) return reg diff --git a/clients/terminal/tui/renderers/artifact.py b/clients/terminal/tui/renderers/artifact.py new file mode 100644 index 0000000..a054d48 --- /dev/null +++ b/clients/terminal/tui/renderers/artifact.py @@ -0,0 +1,92 @@ +"""Renderer for code artifact messages.""" + +from __future__ import annotations + +from rich.console import RenderableType +from rich.panel import Panel +from rich.syntax import Syntax +from rich.text import Text + +from clients.terminal.tui.themes import get_active_theme + +from .base import ContentRenderer + + +def _theme_aware_code_theme(theme_name: str) -> str: + """Pick a Pygments code theme that matches the Navi theme brightness.""" + return "dracula" if theme_name == "gnexus-dark" else "github-light" + + +class ArtifactRenderer(ContentRenderer): + """Render a file-like artifact with syntax highlighting.""" + + def accepts(self, msg: dict) -> bool: + return msg.get("type") == "artifact" + + def render(self, msg: dict) -> RenderableType: + theme = get_active_theme() + path = msg.get("path", "artifact") + language = msg.get("language") or _guess_language(path) + content = msg.get("content", "") + + if not content.strip(): + return Panel( + Text("(empty artifact)", style=theme.text_dim.hex), + title=path, + title_align="left", + border_style=theme.border.hex, + ) + + code_theme = _theme_aware_code_theme(theme.name) + syntax = Syntax( + content, + language, + theme=code_theme, + background_color=theme.surface.hex, + line_numbers=True, + word_wrap=True, + ) + return Panel( + syntax, + title=f"{path} [{language}]", + title_align="left", + border_style=theme.tool_border.hex, + ) + + +def _guess_language(path: str) -> str: + """Best-effort language tag from a file path.""" + mapping = { + ".py": "python", + ".js": "javascript", + ".ts": "typescript", + ".tsx": "tsx", + ".jsx": "jsx", + ".go": "go", + ".rs": "rust", + ".c": "c", + ".cpp": "cpp", + ".h": "c", + ".java": "java", + ".kt": "kotlin", + ".sh": "bash", + ".zsh": "bash", + ".bash": "bash", + ".md": "markdown", + ".json": "json", + ".yaml": "yaml", + ".yml": "yaml", + ".toml": "toml", + ".html": "html", + ".css": "css", + ".scss": "scss", + ".sql": "sql", + ".dockerfile": "dockerfile", + ".txt": "text", + ".env": "bash", + } + lower = path.lower() + for ext, lang in mapping.items(): + if lower.endswith(ext): + return lang + return "text" diff --git a/clients/terminal/tui/renderers/diff.py b/clients/terminal/tui/renderers/diff.py new file mode 100644 index 0000000..7f99592 --- /dev/null +++ b/clients/terminal/tui/renderers/diff.py @@ -0,0 +1,45 @@ +"""Renderer for unified diff messages.""" + +from __future__ import annotations + +from rich.console import RenderableType +from rich.panel import Panel +from rich.text import Text + +from clients.terminal.tui.themes import get_active_theme + +from .base import ContentRenderer + + +class DiffRenderer(ContentRenderer): + """Render a unified diff with added/removed line highlighting.""" + + def accepts(self, msg: dict) -> bool: + return msg.get("type") == "diff" + + def render(self, msg: dict) -> RenderableType: + theme = get_active_theme() + content = msg.get("content", "") + old_label = msg.get("old_label", "---") + new_label = msg.get("new_label", "+++") + + lines = content.splitlines() + highlighted = Text() + for idx, line in enumerate(lines): + if idx: + highlighted.append("\n") + if line.startswith("+") and not line.startswith("+++"): + highlighted.append(line, style=theme.success.hex) + elif line.startswith("-") and not line.startswith("---"): + highlighted.append(line, style=theme.error.hex) + elif line.startswith("@@"): + highlighted.append(line, style=theme.text_dim.hex) + else: + highlighted.append(line, style=theme.text.hex) + + return Panel( + highlighted, + title=f"diff: {old_label} → {new_label}", + title_align="left", + border_style=theme.border.hex, + ) diff --git a/clients/terminal/tui/screens/permission_dialog.py b/clients/terminal/tui/screens/permission_dialog.py new file mode 100644 index 0000000..d5b79ad --- /dev/null +++ b/clients/terminal/tui/screens/permission_dialog.py @@ -0,0 +1,80 @@ +"""Modal permission dialog for destructive tool operations.""" + +from __future__ import annotations + +from textual.app import ComposeResult +from textual.containers import Grid, Horizontal +from textual.screen import ModalScreen +from textual.widgets import Button, Static + + +class PermissionDialogScreen(ModalScreen[bool | str]): + """Ask user whether to allow a potentially destructive tool call. + + Returns: + "allow_once" | "allow_always" | "deny_once" | "deny_always" | None (dismissed) + """ + + DEFAULT_CSS = """ + PermissionDialogScreen { align: center middle; } + PermissionDialogScreen > Grid { + grid-size: 1; + grid-gutter: 1 2; + padding: 1 2; + border: thick $tui-error; + background: $tui-surface; + width: 60; + height: auto; + } + PermissionDialogScreen > Grid > Static { + width: 100%; + color: $tui-text; + } + PermissionDialogScreen .tool-name { + color: $tui-error; + text-style: bold; + } + PermissionDialogScreen .details { + color: $tui-text-dim; + } + PermissionDialogScreen .buttons { height: auto; } + PermissionDialogScreen Button { + margin: 0 1; + } + """ + + def __init__( + self, + tool: str, + action: str, + target: str, + details: str, + ) -> None: + super().__init__() + self._tool = tool + self._action = action + self._target = target + self._details = details + + def compose(self) -> ComposeResult: + with Grid(): + yield Static("Permission required", classes="tool-name") + yield Static(f"Tool: {self._tool}", classes="details") + if self._action: + yield Static(f"Action: {self._action}", classes="details") + if self._target: + yield Static(f"Target: {self._target}", classes="details") + if self._details: + yield Static(self._details, classes="details") + with Horizontal(classes="buttons"): + yield Button("Allow once", id="allow_once", variant="primary") + yield Button("Always allow", id="allow_always", variant="primary") + yield Button("Deny", id="deny_once", variant="error") + yield Button("Always deny", id="deny_always", variant="error") + + def on_button_pressed(self, event: Button.Pressed) -> None: + self.dismiss(event.button.id) + + def on_key(self, event) -> None: + if event.key == "escape": + self.dismiss("deny_once") diff --git a/clients/terminal/tui/shell_runner.py b/clients/terminal/tui/shell_runner.py new file mode 100644 index 0000000..7a12f5d --- /dev/null +++ b/clients/terminal/tui/shell_runner.py @@ -0,0 +1,74 @@ +"""Run local shell commands from !input in the TUI.""" + +from __future__ import annotations + +import subprocess +from dataclasses import dataclass +from pathlib import Path + + +DEFAULT_TIMEOUT = 30.0 +MAX_OUTPUT_LINES = 200 + + +@dataclass +class ShellResult: + """Result of running a shell command.""" + + command: str + returncode: int + stdout: str + stderr: str + truncated: bool = False + + def summary(self) -> str: + """Return a short, chat-friendly summary.""" + marker = "✓" if self.returncode == 0 else "✗" + lines = [f"{marker} $ {self.command}", ""] + if self.stdout: + lines.append(self.stdout) + if self.stderr: + if self.stdout: + lines.append("") + lines.append(f"--- stderr ---\n{self.stderr}") + lines.append("") + lines.append(f"exit code: {self.returncode}") + return "\n".join(lines) + + +def run_shell_command(raw: str, cwd: Path | str | None = None, timeout: float = DEFAULT_TIMEOUT) -> ShellResult: + """Run a shell command from user input (without the leading !). + + The command is passed to a real shell so pipes, redirections and globs work. + """ + command = raw[1:] if raw.startswith("!") else raw + command = command.strip() + if not command: + return ShellResult(command="", returncode=1, stdout="", stderr="empty command") + + work_dir = Path(cwd or Path.cwd()).expanduser().resolve() + try: + proc = subprocess.run( + command, + shell=True, + cwd=work_dir, + capture_output=True, + text=True, + timeout=timeout, + ) + stdout = _truncate(proc.stdout) + stderr = _truncate(proc.stderr) + return ShellResult(command=command, returncode=proc.returncode, stdout=stdout, stderr=stderr) + except subprocess.TimeoutExpired: + return ShellResult(command=command, returncode=124, stdout="", stderr=f"timed out after {timeout}s") + except Exception as exc: + return ShellResult(command=command, returncode=1, stdout="", stderr=str(exc)) + + +def _truncate(text: str) -> str: + """Limit output to the last MAX_OUTPUT_LINES lines to avoid flooding the UI.""" + lines = text.splitlines() + if len(lines) <= MAX_OUTPUT_LINES: + return text + truncated = lines[-MAX_OUTPUT_LINES:] + return f"... [{len(lines) - MAX_OUTPUT_LINES} lines truncated]\n" + "\n".join(truncated) diff --git a/clients/terminal/tui/tui_app.py b/clients/terminal/tui/tui_app.py index 90b6188..0a19b60 100644 --- a/clients/terminal/tui/tui_app.py +++ b/clients/terminal/tui/tui_app.py @@ -16,10 +16,13 @@ UserSubmitted, WsEvent, ) +from clients.terminal.tui.file_refs import FileRefResolver from clients.terminal.tui.permissions import PermissionEngine +from clients.terminal.tui.shell_runner import run_shell_command from clients.terminal.tui.themes import ThemeRegistry, set_active_theme from clients.terminal.tui.commands.registry import get_registry from clients.terminal.tui.screens.command_palette import CommandPaletteScreen +from clients.terminal.tui.screens.permission_dialog import PermissionDialogScreen from clients.terminal.tui.widgets import ChatPanel, InputBox, StatusPanel from clients.terminal.tui.ws_bridge import WsBridge @@ -152,13 +155,30 @@ self._run_command(text) return - # @ file references and ! shell commands can be parsed here in Phase 4. - self._chat_panel.add_user_message(text) + if text.startswith("!"): + self._run_shell_command(text) + return + + resolved = FileRefResolver().resolve(text) + self._chat_panel.add_user_message(resolved.prompt) + if resolved.attachments: + names = ", ".join(a.display_path + (" (truncated)" if a.truncated else "") for a in resolved.attachments) + self._chat_panel.handle_ws_event({"type": "status", "content": f"Attached: {names}"}) + for err in resolved.errors: + self._chat_panel.handle_ws_event({"type": "error", "message": err}) + if self._bridge and self._bridge.connected: - self._bridge.client.enqueue(text) + self._bridge.client.enqueue(resolved.to_message()) else: self._chat_panel.handle_ws_event({"type": "error", "message": "Not connected to a session"}) + def _run_shell_command(self, text: str) -> None: + self.run_worker(self._shell_worker(text)) + + async def _shell_worker(self, text: str) -> None: + result = run_shell_command(text) + self._chat_panel.handle_ws_event({"type": "status", "content": result.summary()}) + def _run_command(self, text: str) -> None: parts = text[1:].split(None, 1) name = parts[0].lower() @@ -174,19 +194,68 @@ await cmd.execute(self._ctx, args) def on_ws_event(self, event: WsEvent) -> None: - self._chat_panel.handle_ws_event(event.payload) + payload = event.payload + if payload.get("type") == "tool_started": + rule = self._permission_engine.check(payload.get("tool", ""), payload.get("args") or {}) + if rule is not None: + self._show_permission_dialog(payload, rule) + return + self._chat_panel.handle_ws_event(payload) + + def _show_permission_dialog(self, payload: dict, rule) -> None: + tool = payload.get("tool", "?") + args = payload.get("args") or {} + action = args.get("action", "") + target = self._permission_engine._extract_target(tool, args) + + def on_decision(choice: str | None) -> None: + if choice == "allow_once": + self._chat_panel.handle_ws_event(payload) + elif choice == "allow_always": + self._permission_engine.set_always_allow(tool, args) + self._chat_panel.handle_ws_event(payload) + elif choice == "deny_once": + self._deny_tool(tool, args) + elif choice == "deny_always": + self._permission_engine.set_always_deny(tool, args) + self._deny_tool(tool, args) + else: + # Dismissed / escape — treat as deny once. + self._deny_tool(tool, args) + + self.push_screen( + PermissionDialogScreen( + tool=tool, + action=action, + target=target, + details=rule.message, + ), + callback=on_decision, + ) + + def _deny_tool(self, tool: str, args: dict) -> None: + self._chat_panel.handle_ws_event({"type": "error", "message": f"Denied: {tool} {args}"}) + if self._ctx.session_id: + self.run_worker(self._stop_session_worker(self._ctx.session_id)) + + async def _stop_session_worker(self, session_id: str) -> None: + try: + api.stop_session(session_id) + except Exception as exc: + self._chat_panel.handle_ws_event({"type": "error", "message": f"Failed to stop session: {exc}"}) + if self._bridge: + await self._bridge.stop() + self._status_panel.set_connection(False, "permission denied") def on_connection_status_changed(self, event: ConnectionStatusChanged) -> None: self._status_panel.set_connection(event.connected, event.detail) def on_permission_request(self, event: PermissionRequest) -> None: - self._pending_permission = event - self._chat_panel.handle_ws_event( - { - "type": "status", - "content": f"Permission required: {event.details}\nAllow once (y) / always (a) / reject (n)", - } - ) + """Manual PermissionRequest event (fallback from components).""" + tool = event.tool + args = {"action": event.action, "command": event.details} + payload = {"type": "tool_started", "tool": tool, "args": args} + self._show_permission_dialog(payload, type("R", (), {"message": event.details})()) def action_command_palette(self) -> None: registry = get_registry() diff --git a/tests/clients/test_diff_artifact_renderers.py b/tests/clients/test_diff_artifact_renderers.py new file mode 100644 index 0000000..75b8d51 --- /dev/null +++ b/tests/clients/test_diff_artifact_renderers.py @@ -0,0 +1,48 @@ +"""Tests for diff and artifact renderers.""" + +from __future__ import annotations + +from clients.terminal.tui.renderers.artifact import ArtifactRenderer +from clients.terminal.tui.renderers.diff import DiffRenderer +from clients.terminal.tui.themes import set_active_theme + + +def test_diff_renderer_accepts_and_highlights() -> None: + set_active_theme("gnexus-dark") + renderer = DiffRenderer() + msg = { + "type": "diff", + "old_label": "a.py", + "new_label": "b.py", + "content": "--- a.py\n+++ b.py\n@@ -1,2 +1,2 @@\n-old\n+new\n unchanged\n", + } + assert renderer.accepts(msg) + renderable = renderer.render(msg) + text = str(renderable.renderable) + assert "old" in text + assert "new" in text + + +def test_artifact_renderer_accepts_and_renders() -> None: + set_active_theme("gnexus-dark") + renderer = ArtifactRenderer() + msg = { + "type": "artifact", + "path": "src/main.py", + "content": "def main():\n pass\n", + } + assert renderer.accepts(msg) + renderable = renderer.render(msg) + assert renderable.renderable is not None + assert "src/main.py" in renderable.title + + +def test_artifact_guesses_language_from_path() -> None: + renderer = ArtifactRenderer() + msg = { + "type": "artifact", + "path": "config.yaml", + "content": "key: value\n", + } + renderable = renderer.render(msg) + assert "yaml" in renderable.title diff --git a/tests/clients/test_file_refs.py b/tests/clients/test_file_refs.py new file mode 100644 index 0000000..560f0a2 --- /dev/null +++ b/tests/clients/test_file_refs.py @@ -0,0 +1,108 @@ +"""Tests for @ file reference resolver.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from clients.terminal.tui.file_refs import FileRefResolver, MAX_FILE_BYTES, MAX_TOTAL_BYTES + + +@pytest.fixture +def sample_dir(tmp_path: Path) -> Path: + d = tmp_path / "project" + d.mkdir() + (d / "main.py").write_text("def main():\n pass\n", encoding="utf-8") + (d / "readme.md").write_text("# Hello\n", encoding="utf-8") + (d / "sub").mkdir() + (d / "sub" / "util.py").write_text("def util():\n return 1\n", encoding="utf-8") + return d + + +def test_resolve_single_file(sample_dir: Path) -> None: + resolver = FileRefResolver(sample_dir) + result = resolver.resolve("check @main.py") + assert not result.errors + assert len(result.attachments) == 1 + assert result.attachments[0].display_path == "main.py" + assert "def main():" in result.to_message() + assert "```python file: main.py" in result.to_message() + + +def test_resolve_directory_without_recursive(sample_dir: Path) -> None: + resolver = FileRefResolver(sample_dir) + result = resolver.resolve("look @.") + assert not result.errors + paths = {a.display_path for a in result.attachments} + assert "main.py" in paths + assert "readme.md" in paths + assert "sub/util.py" not in paths + + +def test_resolve_directory_recursive(sample_dir: Path) -> None: + resolver = FileRefResolver(sample_dir) + result = resolver.resolve("look @./") + paths = {a.display_path for a in result.attachments} + assert "main.py" in paths + assert "sub/util.py" in paths + + +def test_resolve_missing_file(sample_dir: Path) -> None: + resolver = FileRefResolver(sample_dir) + result = resolver.resolve("check @missing.py") + assert not result.attachments + assert any("not found" in e for e in result.errors) + + +def test_resolve_glob(sample_dir: Path) -> None: + resolver = FileRefResolver(sample_dir) + result = resolver.resolve("check @**/*.py") + paths = {a.display_path for a in result.attachments} + assert "main.py" in paths + assert "sub/util.py" in paths + assert "readme.md" not in paths + + +def test_resolve_tilde_expands(tmp_path: Path, monkeypatch) -> None: + home = tmp_path / "home" + home.mkdir() + (home / "note.txt").write_text("hello", encoding="utf-8") + monkeypatch.setenv("HOME", str(home)) + resolver = FileRefResolver(base_dir=tmp_path) + result = resolver.resolve("read @~/note.txt") + assert len(result.attachments) == 1 + assert result.attachments[0].content == "hello" + + +def test_no_refs_returns_unchanged() -> None: + resolver = FileRefResolver(Path.cwd()) + text = "just a regular message" + result = resolver.resolve(text) + assert result.prompt == text + assert result.is_empty() + + +def test_size_limit_per_file(sample_dir: Path) -> None: + big = sample_dir / "big.txt" + big.write_bytes(b"x" * (MAX_FILE_BYTES + 100)) + resolver = FileRefResolver(sample_dir) + result = resolver.resolve("read @big.txt") + assert len(result.attachments) == 1 + assert result.attachments[0].truncated + assert len(result.attachments[0].content) <= MAX_FILE_BYTES + 100 + + +def test_total_size_limit_stops_processing(sample_dir: Path) -> None: + # Each file is exactly half the total limit; a third file should be skipped. + big1 = sample_dir / "big1.txt" + big2 = sample_dir / "big2.txt" + big3 = sample_dir / "big3.txt" + big1.write_bytes(b"x" * MAX_FILE_BYTES) + big2.write_bytes(b"x" * MAX_FILE_BYTES) + big3.write_bytes(b"x" * 100) + resolver = FileRefResolver(sample_dir) + result = resolver.resolve("read @big1.txt @big2.txt @big3.txt") + assert result.total_bytes <= MAX_TOTAL_BYTES + assert len(result.attachments) == 2 + assert any("limit reached" in e for e in result.errors) diff --git a/tests/clients/test_permission_dialog.py b/tests/clients/test_permission_dialog.py new file mode 100644 index 0000000..3a9bae3 --- /dev/null +++ b/tests/clients/test_permission_dialog.py @@ -0,0 +1,48 @@ +"""Tests for the inline permission dialog.""" + +from __future__ import annotations + +import pytest + +from clients.terminal.tui.screens.permission_dialog import PermissionDialogScreen +from clients.terminal.tui.events import WsEvent +from clients.terminal.tui.tui_app import NaviCodeTui + + +@pytest.mark.anyio +async def test_tool_started_triggers_permission_dialog() -> None: + """A destructive tool_started event opens the permission dialog.""" + async with NaviCodeTui(new_session=True).run_test() as pilot: + await pilot.pause() + chat = pilot.app.query_one("ChatPanel") + pilot.app.post_message(WsEvent({"type": "tool_started", "tool": "filesystem", "args": {"action": "delete", "path": "/tmp/x"}})) + await pilot.pause() + assert pilot.app.screen_stack[-1].__class__.__name__ == "PermissionDialogScreen" + + +@pytest.mark.anyio +async def test_allow_once_passes_tool_to_chat() -> None: + async with NaviCodeTui(new_session=True).run_test() as pilot: + await pilot.pause() + chat = pilot.app.query_one("ChatPanel") + pilot.app.post_message(WsEvent({"type": "tool_started", "tool": "filesystem", "args": {"action": "delete", "path": "/tmp/x"}})) + await pilot.pause() + dialog = pilot.app.screen_stack[-1] + assert isinstance(dialog, PermissionDialogScreen) + await pilot.click("#allow_once") + await pilot.pause() + assert any(item.kind == "tool_started" for item in chat._model.items) + + +@pytest.mark.anyio +async def test_deny_once_adds_error_and_stops() -> None: + async with NaviCodeTui(new_session=True).run_test() as pilot: + await pilot.pause() + chat = pilot.app.query_one("ChatPanel") + pilot.app.post_message(WsEvent({"type": "tool_started", "tool": "filesystem", "args": {"action": "delete", "path": "/tmp/x"}})) + await pilot.pause() + dialog = pilot.app.screen_stack[-1] + assert isinstance(dialog, PermissionDialogScreen) + await pilot.click("#deny_once") + await pilot.pause() + assert any(item.kind == "error" for item in chat._model.items) diff --git a/tests/clients/test_shell_runner.py b/tests/clients/test_shell_runner.py new file mode 100644 index 0000000..0228c4c --- /dev/null +++ b/tests/clients/test_shell_runner.py @@ -0,0 +1,60 @@ +"""Tests for ! shell command runner.""" + +from __future__ import annotations + +from pathlib import Path + + +from clients.terminal.tui.shell_runner import run_shell_command, MAX_OUTPUT_LINES + + +def test_run_simple_command() -> None: + result = run_shell_command("!echo hello") + assert result.returncode == 0 + assert "hello" in result.stdout + assert result.stderr == "" + + +def test_run_with_stderr() -> None: + result = run_shell_command("!echo error >&2") + assert result.returncode == 0 + assert "error" in result.stderr + + +def test_run_failing_command() -> None: + result = run_shell_command("!false") + assert result.returncode == 1 + assert "✗" in result.summary() + + +def test_run_piped_command() -> None: + result = run_shell_command("!echo hi | tr a-z A-Z") + assert result.returncode == 0 + assert "HI" in result.stdout + + +def test_run_timeout() -> None: + result = run_shell_command("!sleep 5", timeout=0.1) + assert result.returncode == 124 + assert "timed out" in result.stderr + + +def test_empty_command() -> None: + result = run_shell_command("!") + assert result.returncode == 1 + + +def test_truncate_long_output() -> None: + result = run_shell_command(f"!seq 1 {MAX_OUTPUT_LINES + 50}") + assert result.truncated is False + assert "lines truncated" in result.stdout + assert len(result.stdout.splitlines()) <= MAX_OUTPUT_LINES + 1 + + +def test_run_in_cwd(tmp_path: Path) -> None: + sub = tmp_path / "sub" + sub.mkdir() + (sub / "file.txt").write_text("data", encoding="utf-8") + result = run_shell_command("!cat file.txt", cwd=sub) + assert result.returncode == 0 + assert result.stdout.strip() == "data"