diff --git a/clients/terminal/tui/events.py b/clients/terminal/tui/events.py index 7800458..b3c72c2 100644 --- a/clients/terminal/tui/events.py +++ b/clients/terminal/tui/events.py @@ -42,13 +42,12 @@ class PermissionRequest(Message): - """Tool call requires destructive-operation confirmation.""" + """Tool call requires destructive-operation confirmation (component fallback).""" - def __init__(self, tool: str, action: str, details: str, callback) -> None: + def __init__(self, tool: str, args: dict, message: str) -> None: self.tool = tool - self.action = action - self.details = details - self.callback = callback + self.args = args + self.message = message super().__init__() diff --git a/clients/terminal/tui/file_refs.py b/clients/terminal/tui/file_refs.py index 4871e2d..34ca6f5 100644 --- a/clients/terminal/tui/file_refs.py +++ b/clients/terminal/tui/file_refs.py @@ -15,11 +15,68 @@ from pathlib import Path from typing import Iterable +from clients.terminal.tui.renderers.language import guess_language + MAX_FILE_BYTES = 64_000 MAX_TOTAL_BYTES = 128_000 TRUNCATED_NOTICE = "\n... [truncated by Navi Code]" +# Sensitive paths/patterns that should never be attached automatically. +SENSITIVE_NAMES: set[str] = { + ".env", + ".env.local", + ".env.production", + ".env.staging", + ".git", + ".gitignore", + ".ssh", + ".aws", + ".docker", + ".npmrc", + ".pypirc", + ".netrc", + ".pgpass", + "id_rsa", + "id_rsa.pub", + "id_dsa", + "id_dsa.pub", + "id_ecdsa", + "id_ecdsa.pub", + "id_ed25519", + "id_ed25519.pub", + ".DS_Store", + "Thumbs.db", +} + +SENSITIVE_SUFFIXES: tuple[str, ...] = ( + ".pem", + ".key", + ".crt", + ".p12", + ".pfx", + ".keystore", + ".jks", + ".pyc", + ".pyo", +) + +SENSITIVE_DIR_NAMES: set[str] = { + ".git", + ".ssh", + ".aws", + ".venv", + "venv", + "node_modules", + "__pycache__", + ".tox", + ".pytest_cache", + ".mypy_cache", + ".egg-info", + "dist", + "build", +} + @dataclass class ResolvedFile: @@ -53,7 +110,7 @@ parts.append("") parts.append("--- attached files ---") for f in self.attachments: - lang = _guess_language(f.path) + lang = guess_language(f.path) label = f"file: {f.display_path}" if f.truncated: label += " (truncated)" @@ -68,13 +125,13 @@ return "\n".join(parts) -_ref_pattern = re.compile(r"@((?:[A-Za-z0-9_\-\.~/$*?]|\\\s)+)") +_ref_pattern = re.compile(r"@((?:[A-Za-z0-9_\-\.~/$*?\[\]\\]|\\\s)+)") def find_refs(text: str) -> list[str]: """Return all @path tokens found in text, in order, without duplicates.""" seen: set[str] = set() - refs: list[str] = OrderedRefs() + refs: list[str] = [] for raw in _ref_pattern.findall(text): # Un-escape backslash-space inside the token. ref = raw.replace("\\ ", " ") @@ -84,15 +141,12 @@ return refs -class OrderedRefs(list): - """Stub kept for type clarity; plain list suffices.""" - - class FileRefResolver: """Resolve @ references relative to a base directory.""" def __init__(self, base_dir: Path | str | None = None) -> None: self.base_dir = Path(base_dir or Path.cwd()).expanduser().resolve() + self._home_dir = Path.home().expanduser().resolve() def resolve(self, text: str) -> FileRefResult: refs = find_refs(text) @@ -149,15 +203,22 @@ result.errors.append(f"not found: {ref}") def _attach_file(self, path: Path, result: FileRefResult, root_dir: Path | None = None) -> None: - display = _relative_or_absolute(path, self.base_dir) - if root_dir is not None: - display = _relative_or_absolute(path, root_dir) + if _is_sensitive_path(path): + display = _display_path(path, self.base_dir, root_dir) + result.errors.append(f"skipped sensitive file: {display}") + return + + display = _display_path(path, self.base_dir, root_dir) try: data = path.read_bytes() except Exception as exc: result.errors.append(f"failed to read {path}: {exc}") return + if b"\x00" in data: + result.errors.append(f"skipped binary file: {display}") + return + truncated = False if len(data) > MAX_FILE_BYTES: data = data[:MAX_FILE_BYTES] @@ -178,17 +239,31 @@ if truncated: text += TRUNCATED_NOTICE + result.total_bytes += len(TRUNCATED_NOTICE.encode("utf-8")) result.attachments.append(ResolvedFile(path=path, display_path=display, content=text, truncated=truncated)) result.total_bytes += len(data) def _expand_path(self, ref: str) -> Path | None: + """Expand a raw @ reference into an absolute, validated Path. + + Paths are restricted to the resolver's base directory. The only + exception is explicit ``~`` expansion, which is allowed inside the + user's home directory. + """ # Strip any trailing slash for expansion, but keep the flag later. clean = ref.rstrip("/") if not clean: return None if clean.startswith("~"): - return Path(clean).expanduser() - return (self.base_dir / clean).resolve() + candidate = Path(clean).expanduser().resolve() + allowed_root = self._home_dir + else: + candidate = (self.base_dir / clean).resolve() + allowed_root = self.base_dir + + if not _is_under_root(candidate, allowed_root): + return None + return candidate def _is_glob(ref: str) -> bool: @@ -197,54 +272,46 @@ def _collect_files(path: Path, recursive: bool = False) -> Iterable[Path]: - """Yield files inside a directory.""" - if recursive: - for p in sorted(path.rglob("*")): - if p.is_file(): - yield p - else: - for p in sorted(path.iterdir()): - if p.is_file(): - yield p + """Yield non-sensitive files inside a directory.""" + iterator = path.rglob("*") if recursive else path.iterdir() + for p in sorted(iterator): + if p.is_file() and not _is_inside_sensitive_dir(p): + yield p -def _relative_or_absolute(path: Path, base: Path) -> str: +def _display_path(path: Path, base: Path, root_dir: Path | None = None) -> str: + for candidate in (root_dir, base): + if candidate is not None: + try: + return str(path.relative_to(candidate)) + except ValueError: + continue + return str(path) + + +def _is_under_root(candidate: Path, root: Path) -> bool: + """Return True if candidate stays inside root (after symlink resolution).""" try: - return str(path.relative_to(base)) + candidate.relative_to(root) + return True except ValueError: - return str(path) + return False -def _guess_language(path: Path) -> str: - """Best-effort language tag for markdown code fence.""" - mapping = { - ".py": "python", - ".js": "javascript", - ".ts": "typescript", - ".tsx": "tsx", - ".jsx": "jsx", - ".go": "go", - ".rs": "rust", - ".c": "c", - ".cpp": "cpp", - ".h": "c", - ".java": "java", - ".kt": "kotlin", - ".sh": "bash", - ".zsh": "bash", - ".bash": "bash", - ".md": "markdown", - ".json": "json", - ".yaml": "yaml", - ".yml": "yaml", - ".toml": "toml", - ".html": "html", - ".css": "css", - ".scss": "scss", - ".sql": "sql", - ".dockerfile": "dockerfile", - ".lock": "text", - ".txt": "text", - ".env": "bash", - } - return mapping.get(path.suffix.lower(), "text") +def _is_inside_sensitive_dir(path: Path) -> bool: + """Return True if path lives inside a directory that should be skipped.""" + for part in path.parts: + if part in SENSITIVE_DIR_NAMES: + return True + return False + + +def _is_sensitive_path(path: Path) -> bool: + """Return True if path matches a sensitive file pattern.""" + if _is_inside_sensitive_dir(path): + return True + if path.name in SENSITIVE_NAMES: + return True + if path.name.lower().endswith(SENSITIVE_SUFFIXES): + return True + return False diff --git a/clients/terminal/tui/permissions.py b/clients/terminal/tui/permissions.py index a6bbe35..7250e55 100644 --- a/clients/terminal/tui/permissions.py +++ b/clients/terminal/tui/permissions.py @@ -27,6 +27,9 @@ PermissionRule(tool="terminal", pattern="rm *", message="Remove files/directories"), PermissionRule(tool="terminal", pattern="*format*", message="Format operation"), PermissionRule(tool="terminal", pattern="*drop*", message="Drop database/table"), + PermissionRule(tool="code_exec", message="Execute arbitrary code"), + PermissionRule(tool="ssh_exec", message="Execute remote command"), + PermissionRule(tool="shell", action="run", message="Local shell command"), ] @@ -60,13 +63,15 @@ self._store_path.write_text(json.dumps(data, indent=2)) def check(self, tool: str, args: dict) -> PermissionRule | None: - """Return matching rule if confirmation is needed, else None.""" + """Return matching rule if confirmation is needed, else None. + + Always-allow and always-deny entries both bypass the confirmation + dialog. Callers that need to actively reject always-deny matches can + use :meth:`is_always_deny`. + """ rule_key = self._rule_key(tool, args) - if rule_key in self._always_allow: + if rule_key in self._always_allow or rule_key in self._always_deny: return None - if rule_key in self._always_deny: - # Always denied — treat as matched, the caller will reject. - return PermissionRule(tool=tool, message="always denied by user policy") for rule in self._rules: if rule.tool != tool: @@ -75,12 +80,20 @@ if args.get("action") != rule.action: continue if rule.pattern is not None: - target = self._extract_target(tool, args) + target = self.extract_target(tool, args) if target is None or not fnmatch.fnmatch(target, rule.pattern): continue return rule return None + def is_always_deny(self, tool: str, args: dict) -> bool: + """Return True if this tool call was permanently denied by the user.""" + return self._rule_key(tool, args) in self._always_deny + + def extract_target(self, tool: str, args: dict) -> str: + """Public helper for extracting the human-readable target of a tool call.""" + return self._extract_target(tool, args) + def set_always_allow(self, tool: str, args: dict) -> None: self._always_allow.add(self._rule_key(tool, args)) self._always_deny.discard(self._rule_key(tool, args)) @@ -91,10 +104,9 @@ self._always_allow.discard(self._rule_key(tool, args)) self._save() - @staticmethod - def _rule_key(tool: str, args: dict) -> str: + def _rule_key(self, tool: str, args: dict) -> str: action = args.get("action", "") - target = PermissionEngine._extract_target(tool, args) + target = self.extract_target(tool, args) if target: return f"{tool}:{action}:{target}" return f"{tool}:{action}" diff --git a/clients/terminal/tui/renderers/artifact.py b/clients/terminal/tui/renderers/artifact.py index a054d48..0f5db13 100644 --- a/clients/terminal/tui/renderers/artifact.py +++ b/clients/terminal/tui/renderers/artifact.py @@ -10,6 +10,7 @@ from clients.terminal.tui.themes import get_active_theme from .base import ContentRenderer +from .language import guess_language def _theme_aware_code_theme(theme_name: str) -> str: @@ -26,7 +27,7 @@ def render(self, msg: dict) -> RenderableType: theme = get_active_theme() path = msg.get("path", "artifact") - language = msg.get("language") or _guess_language(path) + language = msg.get("language") or guess_language(path) content = msg.get("content", "") if not content.strip(): @@ -52,41 +53,3 @@ title_align="left", border_style=theme.tool_border.hex, ) - - -def _guess_language(path: str) -> str: - """Best-effort language tag from a file path.""" - mapping = { - ".py": "python", - ".js": "javascript", - ".ts": "typescript", - ".tsx": "tsx", - ".jsx": "jsx", - ".go": "go", - ".rs": "rust", - ".c": "c", - ".cpp": "cpp", - ".h": "c", - ".java": "java", - ".kt": "kotlin", - ".sh": "bash", - ".zsh": "bash", - ".bash": "bash", - ".md": "markdown", - ".json": "json", - ".yaml": "yaml", - ".yml": "yaml", - ".toml": "toml", - ".html": "html", - ".css": "css", - ".scss": "scss", - ".sql": "sql", - ".dockerfile": "dockerfile", - ".txt": "text", - ".env": "bash", - } - lower = path.lower() - for ext, lang in mapping.items(): - if lower.endswith(ext): - return lang - return "text" diff --git a/clients/terminal/tui/renderers/language.py b/clients/terminal/tui/renderers/language.py new file mode 100644 index 0000000..2cfe919 --- /dev/null +++ b/clients/terminal/tui/renderers/language.py @@ -0,0 +1,47 @@ +"""Shared language-guessing helpers for code fences and syntax highlighting.""" + +from __future__ import annotations + +from pathlib import Path + + +LANGUAGE_MAPPING: dict[str, str] = { + ".py": "python", + ".js": "javascript", + ".ts": "typescript", + ".tsx": "tsx", + ".jsx": "jsx", + ".go": "go", + ".rs": "rust", + ".c": "c", + ".cpp": "cpp", + ".h": "c", + ".java": "java", + ".kt": "kotlin", + ".sh": "bash", + ".zsh": "bash", + ".bash": "bash", + ".md": "markdown", + ".json": "json", + ".yaml": "yaml", + ".yml": "yaml", + ".toml": "toml", + ".html": "html", + ".css": "css", + ".scss": "scss", + ".sql": "sql", + ".dockerfile": "dockerfile", + ".lock": "text", + ".txt": "text", + ".env": "bash", +} + + +def guess_language(path: Path | str) -> str: + """Best-effort language tag for a file path or artifact name.""" + name = str(path) + lower_name = name.lower() + for ext, lang in LANGUAGE_MAPPING.items(): + if lower_name.endswith(ext): + return lang + return "text" diff --git a/clients/terminal/tui/screens/permission_dialog.py b/clients/terminal/tui/screens/permission_dialog.py index d5b79ad..1d95024 100644 --- a/clients/terminal/tui/screens/permission_dialog.py +++ b/clients/terminal/tui/screens/permission_dialog.py @@ -8,7 +8,7 @@ from textual.widgets import Button, Static -class PermissionDialogScreen(ModalScreen[bool | str]): +class PermissionDialogScreen(ModalScreen[str | None]): """Ask user whether to allow a potentially destructive tool call. Returns: diff --git a/clients/terminal/tui/shell_runner.py b/clients/terminal/tui/shell_runner.py index 7a12f5d..0d0329f 100644 --- a/clients/terminal/tui/shell_runner.py +++ b/clients/terminal/tui/shell_runner.py @@ -56,19 +56,29 @@ text=True, timeout=timeout, ) - stdout = _truncate(proc.stdout) - stderr = _truncate(proc.stderr) - return ShellResult(command=command, returncode=proc.returncode, stdout=stdout, stderr=stderr) + stdout, stdout_truncated = _truncate(proc.stdout) + stderr, stderr_truncated = _truncate(proc.stderr) + return ShellResult( + command=command, + returncode=proc.returncode, + stdout=stdout, + stderr=stderr, + truncated=stdout_truncated or stderr_truncated, + ) except subprocess.TimeoutExpired: - return ShellResult(command=command, returncode=124, stdout="", stderr=f"timed out after {timeout}s") + return ShellResult(command=command, returncode=124, stdout="", stderr=f"timed out after {timeout}s", truncated=False) except Exception as exc: - return ShellResult(command=command, returncode=1, stdout="", stderr=str(exc)) + return ShellResult(command=command, returncode=1, stdout="", stderr=str(exc), truncated=False) -def _truncate(text: str) -> str: - """Limit output to the last MAX_OUTPUT_LINES lines to avoid flooding the UI.""" +def _truncate(text: str) -> tuple[str, bool]: + """Limit output to the last MAX_OUTPUT_LINES lines to avoid flooding the UI. + + Returns the possibly-truncated text and a flag indicating whether truncation + actually happened. + """ lines = text.splitlines() if len(lines) <= MAX_OUTPUT_LINES: - return text + return text, False truncated = lines[-MAX_OUTPUT_LINES:] - return f"... [{len(lines) - MAX_OUTPUT_LINES} lines truncated]\n" + "\n".join(truncated) + return f"... [{len(lines) - MAX_OUTPUT_LINES} lines truncated]\n" + "\n".join(truncated), True diff --git a/clients/terminal/tui/tui_app.py b/clients/terminal/tui/tui_app.py index 0a19b60..4debd62 100644 --- a/clients/terminal/tui/tui_app.py +++ b/clients/terminal/tui/tui_app.py @@ -17,7 +17,7 @@ WsEvent, ) from clients.terminal.tui.file_refs import FileRefResolver -from clients.terminal.tui.permissions import PermissionEngine +from clients.terminal.tui.permissions import PermissionEngine, PermissionRule from clients.terminal.tui.shell_runner import run_shell_command from clients.terminal.tui.themes import ThemeRegistry, set_active_theme from clients.terminal.tui.commands.registry import get_registry @@ -173,7 +173,42 @@ self._chat_panel.handle_ws_event({"type": "error", "message": "Not connected to a session"}) def _run_shell_command(self, text: str) -> None: - self.run_worker(self._shell_worker(text)) + command = text[1:].strip() + args = {"action": "run", "command": command} + if self._permission_engine.is_always_deny("shell", args): + self._chat_panel.handle_ws_event({"type": "error", "message": f"Shell command denied by policy: {command}"}) + return + if self._permission_engine.check("shell", args) is None: + self.run_worker(self._shell_worker(text)) + return + self._confirm_shell_command(text) + + def _confirm_shell_command(self, text: str) -> None: + command = text[1:].strip() + + def on_decision(choice: str | None) -> None: + if choice == "allow_once": + self.run_worker(self._shell_worker(text)) + elif choice == "allow_always": + self._permission_engine.set_always_allow("shell", {"action": "run", "command": command}) + self.run_worker(self._shell_worker(text)) + elif choice == "deny_once": + self._chat_panel.handle_ws_event({"type": "error", "message": f"Shell command cancelled: {command}"}) + elif choice == "deny_always": + self._permission_engine.set_always_deny("shell", {"action": "run", "command": command}) + self._chat_panel.handle_ws_event({"type": "error", "message": f"Shell command cancelled: {command}"}) + else: + self._chat_panel.handle_ws_event({"type": "error", "message": f"Shell command cancelled: {command}"}) + + self.push_screen( + PermissionDialogScreen( + tool="shell", + action="run", + target=command, + details="Local shell command", + ), + callback=on_decision, + ) async def _shell_worker(self, text: str) -> None: result = run_shell_command(text) @@ -196,7 +231,12 @@ def on_ws_event(self, event: WsEvent) -> None: payload = event.payload if payload.get("type") == "tool_started": - rule = self._permission_engine.check(payload.get("tool", ""), payload.get("args") or {}) + tool = payload.get("tool", "") + args = payload.get("args") or {} + if self._permission_engine.is_always_deny(tool, args): + self._deny_tool(tool, args) + return + rule = self._permission_engine.check(tool, args) if rule is not None: self._show_permission_dialog(payload, rule) return @@ -206,7 +246,7 @@ tool = payload.get("tool", "?") args = payload.get("args") or {} action = args.get("action", "") - target = self._permission_engine._extract_target(tool, args) + target = self._permission_engine.extract_target(tool, args) def on_decision(choice: str | None) -> None: if choice == "allow_once": @@ -234,7 +274,18 @@ ) def _deny_tool(self, tool: str, args: dict) -> None: - self._chat_panel.handle_ws_event({"type": "error", "message": f"Denied: {tool} {args}"}) + # Render a synthetic tool result so the user sees the denial, then stop + # the session because the backend is already executing the tool and the + # TUI cannot inject a result into the running tool-call loop. + self._chat_panel.handle_ws_event( + { + "type": "tool_call", + "tool": tool, + "args": args, + "success": False, + "result": "permission denied by user", + } + ) if self._ctx.session_id: self.run_worker(self._stop_session_worker(self._ctx.session_id)) @@ -253,9 +304,10 @@ def on_permission_request(self, event: PermissionRequest) -> None: """Manual PermissionRequest event (fallback from components).""" tool = event.tool - args = {"action": event.action, "command": event.details} + args = event.args payload = {"type": "tool_started", "tool": tool, "args": args} - self._show_permission_dialog(payload, type("R", (), {"message": event.details})()) + rule = PermissionRule(tool=tool, message=event.message) + self._show_permission_dialog(payload, rule) def action_command_palette(self) -> None: registry = get_registry() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/__init__.py diff --git a/tests/clients/__init__.py b/tests/clients/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/clients/__init__.py diff --git a/tests/clients/test_file_refs.py b/tests/clients/test_file_refs.py index 560f0a2..4611e67 100644 --- a/tests/clients/test_file_refs.py +++ b/tests/clients/test_file_refs.py @@ -106,3 +106,36 @@ assert result.total_bytes <= MAX_TOTAL_BYTES assert len(result.attachments) == 2 assert any("limit reached" in e for e in result.errors) + + +def test_resolve_absolute_path_outside_base_rejected(sample_dir: Path) -> None: + resolver = FileRefResolver(sample_dir) + result = resolver.resolve("check @/etc/passwd") + assert not result.attachments + assert any("not found" in e or "could not resolve" in e for e in result.errors) + + +def test_resolve_sensitive_file_skipped(sample_dir: Path) -> None: + (sample_dir / ".env").write_text("SECRET=123\n", encoding="utf-8") + resolver = FileRefResolver(sample_dir) + result = resolver.resolve("read @.env") + assert not result.attachments + assert any("skipped sensitive file" in e for e in result.errors) + + +def test_resolve_binary_file_skipped(sample_dir: Path) -> None: + (sample_dir / "binary.bin").write_bytes(b"\x00\x01\x02\x03") + resolver = FileRefResolver(sample_dir) + result = resolver.resolve("read @binary.bin") + assert not result.attachments + assert any("skipped binary file" in e for e in result.errors) + + +def test_resolve_glob_with_brackets(sample_dir: Path) -> None: + (sample_dir / "a.py").write_text("a", encoding="utf-8") + (sample_dir / "b.py").write_text("b", encoding="utf-8") + resolver = FileRefResolver(sample_dir) + result = resolver.resolve("check @[ab].py") + paths = {a.display_path for a in result.attachments} + assert "a.py" in paths + assert "b.py" in paths diff --git a/tests/clients/test_permission_dialog.py b/tests/clients/test_permission_dialog.py index 3a9bae3..43b2604 100644 --- a/tests/clients/test_permission_dialog.py +++ b/tests/clients/test_permission_dialog.py @@ -15,6 +15,7 @@ async with NaviCodeTui(new_session=True).run_test() as pilot: await pilot.pause() chat = pilot.app.query_one("ChatPanel") + assert chat is not None pilot.app.post_message(WsEvent({"type": "tool_started", "tool": "filesystem", "args": {"action": "delete", "path": "/tmp/x"}})) await pilot.pause() assert pilot.app.screen_stack[-1].__class__.__name__ == "PermissionDialogScreen" @@ -25,6 +26,7 @@ async with NaviCodeTui(new_session=True).run_test() as pilot: await pilot.pause() chat = pilot.app.query_one("ChatPanel") + assert chat is not None pilot.app.post_message(WsEvent({"type": "tool_started", "tool": "filesystem", "args": {"action": "delete", "path": "/tmp/x"}})) await pilot.pause() dialog = pilot.app.screen_stack[-1] @@ -35,14 +37,15 @@ @pytest.mark.anyio -async def test_deny_once_adds_error_and_stops() -> None: +async def test_deny_once_adds_synthetic_tool_call() -> None: async with NaviCodeTui(new_session=True).run_test() as pilot: await pilot.pause() chat = pilot.app.query_one("ChatPanel") + assert chat is not None pilot.app.post_message(WsEvent({"type": "tool_started", "tool": "filesystem", "args": {"action": "delete", "path": "/tmp/x"}})) await pilot.pause() dialog = pilot.app.screen_stack[-1] assert isinstance(dialog, PermissionDialogScreen) await pilot.click("#deny_once") await pilot.pause() - assert any(item.kind == "error" for item in chat._model.items) + assert any(item.kind == "tool_call" for item in chat._model.items) diff --git a/tests/clients/test_permissions.py b/tests/clients/test_permissions.py new file mode 100644 index 0000000..e6c3711 --- /dev/null +++ b/tests/clients/test_permissions.py @@ -0,0 +1,53 @@ +"""Tests for the permission engine.""" + +from __future__ import annotations + +from pathlib import Path + +from clients.terminal.tui.permissions import PermissionEngine + + +def test_default_rules_match_destructive_filesystem() -> None: + engine = PermissionEngine(store_path=Path("/dev/null")) + assert engine.check("filesystem", {"action": "delete", "path": "x.txt"}) is not None + assert engine.check("filesystem", {"action": "read", "path": "x.txt"}) is None + + +def test_default_rules_match_terminal_rm() -> None: + engine = PermissionEngine(store_path=Path("/dev/null")) + assert engine.check("terminal", {"command": "rm -rf /tmp/x"}) is not None + assert engine.check("terminal", {"command": "ls /tmp"}) is None + + +def test_default_rules_match_code_exec_and_ssh_exec() -> None: + engine = PermissionEngine(store_path=Path("/dev/null")) + assert engine.check("code_exec", {"language": "python", "code": "print(1)"}) is not None + assert engine.check("ssh_exec", {"host": "server", "command": "uptime"}) is not None + + +def test_default_rules_match_shell_command() -> None: + engine = PermissionEngine(store_path=Path("/dev/null")) + assert engine.check("shell", {"action": "run", "command": "ls"}) is not None + + +def test_always_allow_bypasses_confirmation(tmp_path: Path) -> None: + store = tmp_path / "permissions.json" + engine = PermissionEngine(store_path=store) + engine.set_always_allow("terminal", {"command": "rm -rf /tmp/x"}) + assert engine.check("terminal", {"command": "rm -rf /tmp/x"}) is None + + +def test_always_deny_is_detected_without_rule(tmp_path: Path) -> None: + store = tmp_path / "permissions.json" + engine = PermissionEngine(store_path=store) + engine.set_always_deny("terminal", {"command": "rm -rf /tmp/x"}) + assert engine.is_always_deny("terminal", {"command": "rm -rf /tmp/x"}) is True + assert engine.check("terminal", {"command": "rm -rf /tmp/x"}) is None + + +def test_extract_target_for_tools() -> None: + engine = PermissionEngine(store_path=Path("/dev/null")) + assert engine.extract_target("filesystem", {"path": "/tmp/x"}) == "/tmp/x" + assert engine.extract_target("filesystem", {"destination": "/tmp/y"}) == "/tmp/y" + assert engine.extract_target("terminal", {"command": "ls"}) == "ls" + assert engine.extract_target("ssh_exec", {"host": "h1"}) == "h1" diff --git a/tests/clients/test_shell_runner.py b/tests/clients/test_shell_runner.py index 0228c4c..9d138d1 100644 --- a/tests/clients/test_shell_runner.py +++ b/tests/clients/test_shell_runner.py @@ -46,7 +46,7 @@ def test_truncate_long_output() -> None: result = run_shell_command(f"!seq 1 {MAX_OUTPUT_LINES + 50}") - assert result.truncated is False + assert result.truncated is True assert "lines truncated" in result.stdout assert len(result.stdout.splitlines()) <= MAX_OUTPUT_LINES + 1 diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/integration/__init__.py diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 29db170..7b9e693 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,15 +1,11 @@ """Integration test fixtures — FastAPI app with mocked dependencies.""" -from typing import AsyncGenerator - import pytest from fastapi.testclient import TestClient from navi.auth import User -from navi.core.events import StreamEnd, TextDelta from navi.core.registry import BackendRegistry -from navi.core.session import InMemorySessionStore, Session -from navi.llm.base import Message +from navi.core.session import InMemorySessionStore from tests.conftest_factory import FakeLLMBackend, make_profile_registry, make_registry_with_tools @@ -66,6 +62,8 @@ # Build container directly — no more module-level singletons from navi.core.container import AppContainer + from navi.core.orchestrator import AgentSessionOrchestrator + container = AppContainer( database=None, memory_store=None, @@ -79,6 +77,7 @@ workers=[], mcp_manager=None, ) + container.orchestrator = AgentSessionOrchestrator(container) fake_agent = FakeAgent() container._agent = fake_agent diff --git a/tests/integration/test_websocket.py b/tests/integration/test_websocket.py index b701f3b..8970e9c 100644 --- a/tests/integration/test_websocket.py +++ b/tests/integration/test_websocket.py @@ -4,48 +4,43 @@ import json import pytest -from fastapi.testclient import TestClient from navi.core.events import StreamEnd, TextDelta -from navi.llm.base import Message -class FakeAgent: - """Deterministic agent for WebSocket tests.""" +def _get_orchestrator(mock_deps): + from navi.main import app - def __init__(self, stream_events=None, run_response="Hello") -> None: - self._stream_events = stream_events or [] - self._run_response = run_response - - async def run(self, session_id: str, user_message: str, images=None) -> str: - return self._run_response - - async def run_stream(self, session_id, user_message, images=None, display_message=None): - for ev in self._stream_events: - yield ev + return app.state.container.orchestrator @pytest.fixture(autouse=True) -def _clear_runs(monkeypatch): - """Clear the module-level _runs dict before every WS test.""" - import navi.api.websocket as ws_mod - - ws_mod._runs.clear() +def _clear_runs(mock_deps): + """Clear orchestrator state before every WS test.""" + orchestrator = _get_orchestrator(mock_deps) + for session_id in list(orchestrator._sessions.keys()): + state = orchestrator._sessions.get(session_id) + if state and state.run and state.run.task: + state.run.task.cancel() + orchestrator._sessions.clear() + orchestrator._session_locks.clear() yield @pytest.fixture def fake_agent_ws(monkeypatch, mock_deps): - """Patch Agent in websocket module so handlers use FakeAgent.""" - import navi.api.websocket as ws_mod + """Patch orchestrator.run_agent so it broadcasts deterministic events.""" + orchestrator = _get_orchestrator(mock_deps) - events = [ - TextDelta(delta="Hello"), - StreamEnd(full_content="Hello"), - ] - fake = FakeAgent(stream_events=events) - monkeypatch.setattr(ws_mod, "Agent", lambda *a, **kw: fake) - return fake + async def fake_run_agent(session_id, user_content, raw_images, display_content, files, session_store): + run = orchestrator.get_run(session_id) + if run is None: + return + await run.broadcast(("event", TextDelta(delta="Hello"))) + await run.broadcast(("event", StreamEnd(full_content="Hello"))) + + monkeypatch.setattr(orchestrator, "run_agent", fake_run_agent) + return fake_run_agent class TestWebSocketConnect: @@ -75,19 +70,17 @@ assert any(m.get("type") == "stream_end" for m in msgs) @pytest.mark.anyio - async def test_reconnect_replay(self, client, make_session, monkeypatch): + async def test_reconnect_replay(self, client, make_session, mock_deps): """Reconnect while a run is active — replay buffer should emit past events.""" - import navi.api.websocket as ws_mod - + orchestrator = _get_orchestrator(mock_deps) session = await make_session("secretary") # Inject an active run with buffered events - run = ws_mod._AgentRun() + run = orchestrator.create_run(session.id) run.events = [ {"type": "stream_start"}, {"type": "stream_delta", "delta": "hello"}, ] - ws_mod._runs[session.id] = run with client.websocket_connect(f"/ws/sessions/{session.id}") as ws: msgs = _collect_until_done(ws, max_messages=5) @@ -99,9 +92,7 @@ assert "replay_end" in types # Clean up injected run - ws_mod._runs.pop(session.id, None) - if run.task: - run.task.cancel() + orchestrator._sessions.pop(session.id, None) @pytest.mark.anyio async def test_invalid_json(self, client, make_session): @@ -128,7 +119,7 @@ class TestStopSession: @pytest.mark.anyio - async def test_stop_no_active_run(self, client, make_session): + async def test_stop_no_active_run(self, client, make_session, mock_deps): session = await make_session("secretary") response = client.post(f"/sessions/{session.id}/stop") assert response.status_code == 200 @@ -136,15 +127,13 @@ assert data["ok"] is False @pytest.mark.anyio - async def test_stop_active_run(self, client, make_session, monkeypatch): - import navi.api.websocket as ws_mod - + async def test_stop_active_run(self, client, make_session, mock_deps): + orchestrator = _get_orchestrator(mock_deps) session = await make_session("secretary") # Start a long-running agent task in background - run = ws_mod._AgentRun() + run = orchestrator.create_run(session.id) run.task = asyncio.create_task(asyncio.sleep(10)) - ws_mod._runs[session.id] = run response = client.post(f"/sessions/{session.id}/stop") assert response.status_code == 200 @@ -156,6 +145,7 @@ await run.task except asyncio.CancelledError: pass + orchestrator._sessions.pop(session.id, None) # ── Helpers ────────────────────────────────────────────────────────────────── diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/unit/__init__.py diff --git a/tests/unit/api/__init__.py b/tests/unit/api/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/unit/api/__init__.py