Newer
Older
navi-1 / clients / terminal / tui / file_refs.py
"""Resolve @path references inside user input.

Supported forms:
    @path/to/file.py     → file content wrapped in code fence
    @dir/                → list of files in directory (recursive if trailing /)
    @tests/**/*.py       → glob expansion, files only

Size limits apply per-file and in total to avoid flooding the LLM context.
"""

from __future__ import annotations

import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Iterable

from clients.terminal.tui.renderers.language import guess_language


MAX_FILE_BYTES = 64_000
MAX_TOTAL_BYTES = 128_000
TRUNCATED_NOTICE = "\n... [truncated by Navi Code]"

# Sensitive paths/patterns that should never be attached automatically.
SENSITIVE_NAMES: set[str] = {
    ".env",
    ".env.local",
    ".env.production",
    ".env.staging",
    ".git",
    ".gitignore",
    ".ssh",
    ".aws",
    ".docker",
    ".npmrc",
    ".pypirc",
    ".netrc",
    ".pgpass",
    "id_rsa",
    "id_rsa.pub",
    "id_dsa",
    "id_dsa.pub",
    "id_ecdsa",
    "id_ecdsa.pub",
    "id_ed25519",
    "id_ed25519.pub",
    ".DS_Store",
    "Thumbs.db",
}

SENSITIVE_SUFFIXES: tuple[str, ...] = (
    ".pem",
    ".key",
    ".crt",
    ".p12",
    ".pfx",
    ".keystore",
    ".jks",
    ".pyc",
    ".pyo",
)

SENSITIVE_DIR_NAMES: set[str] = {
    ".git",
    ".ssh",
    ".aws",
    ".venv",
    "venv",
    "node_modules",
    "__pycache__",
    ".tox",
    ".pytest_cache",
    ".mypy_cache",
    ".egg-info",
    "dist",
    "build",
}


@dataclass
class ResolvedFile:
    """A file resolved from an @ reference."""

    path: Path
    display_path: str
    content: str
    truncated: bool = False


@dataclass
class FileRefResult:
    """Result of resolving @ references in a prompt."""

    prompt: str  # user-visible prompt (with @ markers replaced by file list)
    attachments: list[ResolvedFile] = field(default_factory=list)
    errors: list[str] = field(default_factory=list)
    total_bytes: int = 0

    def is_empty(self) -> bool:
        return not self.attachments and not self.errors

    def to_message(self) -> str:
        """Build the full message to send to the backend."""
        if not self.attachments and not self.errors:
            return self.prompt

        parts = [self.prompt]
        if self.attachments:
            parts.append("")
            parts.append("--- attached files ---")
            for f in self.attachments:
                lang = guess_language(f.path)
                label = f"file: {f.display_path}"
                if f.truncated:
                    label += " (truncated)"
                parts.append(f"```{lang} {label}")
                parts.append(f.content)
                parts.append("```")
        if self.errors:
            parts.append("")
            parts.append("--- attachment errors ---")
            for err in self.errors:
                parts.append(f"- {err}")
        return "\n".join(parts)


_ref_pattern = re.compile(r"@((?:[A-Za-z0-9_\-\.~/$*?\[\]\\]|\\\s)+)")


def find_refs(text: str) -> list[str]:
    """Return all @path tokens found in text, in order, without duplicates."""
    seen: set[str] = set()
    refs: list[str] = []
    for raw in _ref_pattern.findall(text):
        # Un-escape backslash-space inside the token.
        ref = raw.replace("\\ ", " ")
        if ref not in seen:
            seen.add(ref)
            refs.append(ref)
    return refs


class FileRefResolver:
    """Resolve @ references relative to a base directory."""

    def __init__(self, base_dir: Path | str | None = None) -> None:
        self.base_dir = Path(base_dir or Path.cwd()).expanduser().resolve()
        self._home_dir = Path.home().expanduser().resolve()

    def resolve(self, text: str) -> FileRefResult:
        refs = find_refs(text)
        if not refs:
            return FileRefResult(prompt=text)

        result = FileRefResult(prompt=text)
        for ref in refs:
            self._resolve_ref(ref, result)
            if result.total_bytes >= MAX_TOTAL_BYTES:
                result.errors.append("total attachment size limit reached; remaining files skipped")
                break
        return result

    def _resolve_ref(self, ref: str, result: FileRefResult) -> None:
        path = self._expand_path(ref)
        if path is None:
            result.errors.append(f"could not resolve {ref!r}")
            return

        if path.exists():
            if path.is_dir():
                files = sorted(_collect_files(path, recursive=ref.endswith("/")))
                if not files:
                    result.errors.append(f"no files found in {ref}")
                    return
                for file_path in files:
                    self._attach_file(file_path, result, root_dir=path)
                    if result.total_bytes >= MAX_TOTAL_BYTES:
                        return
                return

            if path.is_file():
                self._attach_file(path, result)
                return

            result.errors.append(f"not a file or directory: {ref}")
            return

        # Non-existent path: try glob expansion if it looks like a pattern.
        if _is_glob(ref):
            matches = sorted(self.base_dir.glob(ref))
            if not matches:
                result.errors.append(f"no matches for {ref}")
                return
            for file_path in matches:
                if not file_path.is_file():
                    continue
                self._attach_file(file_path, result)
                if result.total_bytes >= MAX_TOTAL_BYTES:
                    return
            return

        result.errors.append(f"not found: {ref}")

    def _attach_file(self, path: Path, result: FileRefResult, root_dir: Path | None = None) -> None:
        if _is_sensitive_path(path):
            display = _display_path(path, self.base_dir, root_dir)
            result.errors.append(f"skipped sensitive file: {display}")
            return

        display = _display_path(path, self.base_dir, root_dir)
        try:
            data = path.read_bytes()
        except Exception as exc:
            result.errors.append(f"failed to read {path}: {exc}")
            return

        if b"\x00" in data:
            result.errors.append(f"skipped binary file: {display}")
            return

        truncated = False
        if len(data) > MAX_FILE_BYTES:
            data = data[:MAX_FILE_BYTES]
            truncated = True

        remaining = MAX_TOTAL_BYTES - result.total_bytes
        if remaining <= 0:
            return
        if len(data) > remaining:
            data = data[:remaining]
            truncated = True

        try:
            text = data.decode("utf-8", errors="replace")
        except Exception as exc:
            result.errors.append(f"failed to decode {path}: {exc}")
            return

        if truncated:
            text += TRUNCATED_NOTICE
            result.total_bytes += len(TRUNCATED_NOTICE.encode("utf-8"))
        result.attachments.append(ResolvedFile(path=path, display_path=display, content=text, truncated=truncated))
        result.total_bytes += len(data)

    def _expand_path(self, ref: str) -> Path | None:
        """Expand a raw @ reference into an absolute, validated Path.

        Paths are restricted to the resolver's base directory. The only
        exception is explicit ``~`` expansion, which is allowed inside the
        user's home directory.
        """
        # Strip any trailing slash for expansion, but keep the flag later.
        clean = ref.rstrip("/")
        if not clean:
            return None
        if clean.startswith("~"):
            candidate = Path(clean).expanduser().resolve()
            allowed_root = self._home_dir
        else:
            candidate = (self.base_dir / clean).resolve()
            allowed_root = self.base_dir

        if not _is_under_root(candidate, allowed_root):
            return None
        return candidate


def _is_glob(ref: str) -> bool:
    """Return True if ref contains glob metacharacters."""
    return "*" in ref or "?" in ref or "[" in ref


def _collect_files(path: Path, recursive: bool = False) -> Iterable[Path]:
    """Yield non-sensitive files inside a directory."""
    iterator = path.rglob("*") if recursive else path.iterdir()
    for p in sorted(iterator):
        if p.is_file() and not _is_inside_sensitive_dir(p):
            yield p


def _display_path(path: Path, base: Path, root_dir: Path | None = None) -> str:
    for candidate in (root_dir, base):
        if candidate is not None:
            try:
                return str(path.relative_to(candidate))
            except ValueError:
                continue
    return str(path)


def _is_under_root(candidate: Path, root: Path) -> bool:
    """Return True if candidate stays inside root (after symlink resolution)."""
    try:
        candidate.relative_to(root)
        return True
    except ValueError:
        return False


def _is_inside_sensitive_dir(path: Path) -> bool:
    """Return True if path lives inside a directory that should be skipped."""
    for part in path.parts:
        if part in SENSITIVE_DIR_NAMES:
            return True
    return False


def _is_sensitive_path(path: Path) -> bool:
    """Return True if path matches a sensitive file pattern."""
    if _is_inside_sensitive_dir(path):
        return True
    if path.name in SENSITIVE_NAMES:
        return True
    if path.name.lower().endswith(SENSITIVE_SUFFIXES):
        return True
    return False