diff --git a/docs/architecture_weak_spots.md b/docs/architecture_weak_spots.md index 3cbd020..a651db5 100644 --- a/docs/architecture_weak_spots.md +++ b/docs/architecture_weak_spots.md @@ -115,7 +115,7 @@ --- -## 8. Скрытые глобальные зависимости через ContextVar +## 8. Скрытые глобальные зависимости через ContextVar ✅ **Severity:** Medium **Файл:** `navi/tools/_internal/base.py` (строки 19–42) @@ -123,6 +123,15 @@ **Почему блокер:** Инструмент нельзя вызвать вне контекста агента (из CLI, фоновой задачи, теста) без установки всех ContextVar. **Направление:** Передавать контекст выполнения явным параметром в `execute()`. ContextVar оставить как optional fallback. +**Решение 2026-05-24:** +- Создан `ToolContext` dataclass в `navi/tools/_internal/base.py` — явный контейнер для всех 7 значений +- `Tool.execute()` теперь принимает `ctx: ToolContext | None = None` +- `ToolExecutor._execute_one()` собирает `ToolContext` и передаёт его инструменту +- `Agent._execute_tools_with_sink()` и `SubAgentRunner` строят `ToolContext` из значений в scope и передают в цепочку +- Все ~25 инструментов обновлены: читающие ContextVar теперь предпочитают `ctx`, остальные получили только новый параметр +- Все тесты инструментов переведены на явный `ctx=ToolContext(...)` — больше никаких фикстур с `current_session_id.set()` +- ContextVar setters оставлены как fallback для не-инструментных потребителей (`ai_helper.py`, `context_builder.py`, `planning.py`) + --- ## 9. Сессионное состояние в памяти процесса diff --git a/navi/core/agent.py b/navi/core/agent.py index 5b40dd0..cafadfc 100644 --- a/navi/core/agent.py +++ b/navi/core/agent.py @@ -25,7 +25,7 @@ from navi.config import settings from navi.exceptions import ContextTooLargeError, LLMBackendError, LLMConnectionError, MaxIterationsReached, SessionNotFound from navi.llm.base import LLMBackend, Message, ToolCallRequest -from navi.tools._internal.base import Tool, current_event_sink, current_stop_event +from navi.tools._internal.base import Tool, ToolContext, current_event_sink, current_stop_event, current_user_role, current_user_info from .agent_run_context import AgentTurnContext, StreamState from .anti_stall import AntiStallMonitor @@ -379,7 +379,16 @@ session.messages.append(assistant_msg) session.context.append(assistant_msg) - async for _ev in self._execute_tools_with_sink(turn_tool_calls, tools, turn_ctx, session, stop_event): + tool_ctx = ToolContext( + session_id=session_id, + event_sink=None, # set per-tool inside _execute_tools_with_sink + stop_event=stop_event, + model=profile.model, + user_id=session.user_id, + user_role=current_user_role.get(), + user_info=current_user_info.get(), + ) + async for _ev in self._execute_tools_with_sink(turn_tool_calls, tools, turn_ctx, session, stop_event, tool_ctx): yield _ev # 6. Cooperative stop: check after tool execution before next LLM call @@ -582,7 +591,7 @@ state.thinking_active = False yield ThinkingEnd() - async def _execute_tools_with_sink(self, turn_tool_calls, tools, turn_ctx: AgentTurnContext, session, stop_event): + async def _execute_tools_with_sink(self, turn_tool_calls, tools, turn_ctx: AgentTurnContext, session, stop_event, tool_ctx=None): """Execute tool calls with cooperative stop support. Polls *stop_event* every second while draining the event sink so the @@ -599,7 +608,7 @@ async def _run_with_sentinel(_tc=tc, _holder=result_holder, _sink=sink): try: - _holder.append(await self._tool_executor._run_single_tool(_tc, tool_map)) + _holder.append(await self._tool_executor._run_single_tool(_tc, tool_map, ctx=tool_ctx)) except Exception as exc: _holder.append(exc) finally: diff --git a/navi/core/subagent_runner.py b/navi/core/subagent_runner.py index 2cfc397..72f8210 100644 --- a/navi/core/subagent_runner.py +++ b/navi/core/subagent_runner.py @@ -12,7 +12,7 @@ from navi.config import settings from navi.exceptions import ContextTooLargeError from navi.llm.base import LLMBackend, Message, ToolCallRequest -from navi.tools._internal.base import current_event_sink, current_stop_event +from navi.tools._internal.base import ToolContext, current_event_sink, current_stop_event from .events import AIHelperTokensUsed, SubagentComplete, ToolEvent, ToolStarted, TurnThinking from .stream_guard import _iter_stream_guarded @@ -326,6 +326,16 @@ ) ) + tool_ctx = ToolContext( + session_id=tool_session_id, + event_sink=sink, + stop_event=stop_event, + model=profile.model, + user_id=_uid_var.get(None), + user_role=_role_var.get(), + user_info=_uinfo_var.get(None), + ) + for tc in turn_tool_calls: # Cooperative stop: if the user clicked Stop before this tool, # skip remaining tools in this batch. @@ -354,7 +364,7 @@ "tool.execute.subagent", tool=tc.name, args=tc.arguments ) try: - result = await tool.execute(tc.arguments) + result = await tool.execute(tc.arguments, ctx=tool_ctx) content = result.to_message_content() success = result.success metadata = result.metadata or {} diff --git a/navi/core/tool_executor.py b/navi/core/tool_executor.py index c9aac68..4cc77c9 100644 --- a/navi/core/tool_executor.py +++ b/navi/core/tool_executor.py @@ -82,6 +82,7 @@ self, tc: ToolCallRequest, tool_map: dict[str, Tool], + ctx=None, ) -> tuple["ToolEvent", Message, "Message | None"]: """Execute a single tool call and return (ToolEvent, tool_msg, optional_image_msg). @@ -102,7 +103,7 @@ middlewares = getattr(self._tools, "_middlewares", []) for mw in middlewares: await mw.before_execute(resolved_name, tc.arguments) - result = await tool.execute(tc.arguments) + result = await tool.execute(tc.arguments, ctx=ctx) for mw in middlewares: await mw.after_execute(resolved_name, tc.arguments, result) content = result.to_message_content() @@ -126,26 +127,27 @@ self, tc: ToolCallRequest, tool_map: dict[str, Tool], + ctx=None, ) -> tuple["ToolEvent", Message, "Message | None"]: """Execute one tool call and return (ToolEvent, tool_msg, optional_image_msg). Called via asyncio.create_task() from run_stream() so that the parent generator can drain the event sink queue concurrently. """ - return await self._execute_one(tc, tool_map) + return await self._execute_one(tc, tool_map, ctx=ctx) async def _execute_tool_calls( - self, tool_calls: list[ToolCallRequest], tools: list[Tool] + self, tool_calls: list[ToolCallRequest], tools: list[Tool], ctx=None ) -> tuple[list[Message], list[Message]]: tool_map = {t.name: t for t in tools} - pairs = await asyncio.gather(*[self._execute_one(tc, tool_map) for tc in tool_calls]) + pairs = await asyncio.gather(*[self._execute_one(tc, tool_map, ctx=ctx) for tc in tool_calls]) tool_msgs = [p[1] for p in pairs] image_msgs = [p[2] for p in pairs if p[2] is not None] return tool_msgs, image_msgs async def _execute_tool_calls_streaming( - self, tool_calls: list[ToolCallRequest], tools: list[Tool] + self, tool_calls: list[ToolCallRequest], tools: list[Tool], ctx=None ) -> tuple[list[tuple["ToolEvent", Message]], list[Message]]: tool_map = {t.name: t for t in tools} - triples = await asyncio.gather(*[self._execute_one(tc, tool_map) for tc in tool_calls]) + triples = await asyncio.gather(*[self._execute_one(tc, tool_map, ctx=ctx) for tc in tool_calls]) return [(t[0], t[1]) for t in triples], [t[2] for t in triples if t[2] is not None] diff --git a/navi/mcp/tools.py b/navi/mcp/tools.py index 7bd1298..c9e4f7d 100644 --- a/navi/mcp/tools.py +++ b/navi/mcp/tools.py @@ -5,7 +5,7 @@ from pathlib import Path from navi.config import settings -from navi.tools._internal.base import Tool, ToolResult, current_session_id +from navi.tools._internal.base import Tool, ToolContext, ToolResult, current_session_id from .manager import McpManager @@ -77,13 +77,13 @@ return p.name return value - async def execute(self, params: dict[str, Any]) -> ToolResult: + async def execute(self, params: dict[str, Any], ctx: ToolContext | None = None) -> ToolResult: # Defensive copy — never mutate the caller's dict forwarded = dict(params) # 1. Force the real session_id from the agent context so the LLM # cannot hallucinate a wrong UUID (ghost-session bug). - sid = current_session_id.get() + sid = ctx.session_id if ctx else current_session_id.get() if sid is not None: forwarded["session_id"] = sid diff --git a/navi/tools/_internal/base.py b/navi/tools/_internal/base.py index 08d1954..32d1b54 100644 --- a/navi/tools/_internal/base.py +++ b/navi/tools/_internal/base.py @@ -43,6 +43,19 @@ @dataclass +class ToolContext: + """Explicit execution context passed to tools instead of hidden ContextVars.""" + + session_id: str | None = None + event_sink: asyncio.Queue | None = None + stop_event: asyncio.Event | None = None + model: list[str] | str | None = None + user_id: str | None = None + user_role: str = "user" + user_info: dict | None = None + + +@dataclass class ToolResult: success: bool output: str # always a string — LLM consumes this @@ -69,7 +82,7 @@ parameters: dict # JSON Schema object @abstractmethod - async def execute(self, params: dict) -> ToolResult: + async def execute(self, params: dict, ctx: ToolContext | None = None) -> ToolResult: """Execute the tool with given parameters.""" def schema(self) -> ToolSchema: diff --git a/navi/tools/code_exec.py b/navi/tools/code_exec.py index b3499fc..211fc61 100644 --- a/navi/tools/code_exec.py +++ b/navi/tools/code_exec.py @@ -11,15 +11,15 @@ import tempfile from pathlib import Path -from ._internal.base import Tool, ToolResult, current_user_id, current_user_role +from ._internal.base import Tool, ToolContext, ToolResult, current_user_id, current_user_role _TIMEOUT = 30 -def _resolve_working_dir(working_dir: str | None) -> Path: +def _resolve_working_dir(working_dir: str | None, user_id: str | None = None, role: str | None = None) -> Path: """Resolve working directory with sandbox enforcement for non-admins.""" - user_id = current_user_id.get(None) - role = current_user_role.get() + user_id = user_id or current_user_id.get(None) + role = role or current_user_role.get() if user_id and role != "admin": sandbox = Path("user_data") / user_id @@ -65,13 +65,13 @@ "required": ["code"], } - async def execute(self, params: dict) -> ToolResult: + async def execute(self, params: dict, ctx: ToolContext | None = None) -> ToolResult: code = params["code"] - role = current_user_role.get() - user_id = current_user_id.get(None) + role = ctx.user_role if ctx else current_user_role.get() + user_id = ctx.user_id if ctx else current_user_id.get(None) - cwd = _resolve_working_dir(params.get("working_dir")) + cwd = _resolve_working_dir(params.get("working_dir"), user_id, role) if user_id and role != "admin": # Write temp file inside the sandbox so file I/O in user code diff --git a/navi/tools/content_publish.py b/navi/tools/content_publish.py index dd8b1a9..307b72e 100644 --- a/navi/tools/content_publish.py +++ b/navi/tools/content_publish.py @@ -12,7 +12,7 @@ from navi.config import settings from navi.session_files import session_dir -from ._internal.base import Tool, ToolResult, current_session_id +from ._internal.base import Tool, ToolContext, ToolResult, current_session_id class ContentPublishTool(Tool): @@ -72,8 +72,8 @@ "required": ["filename"], } - async def execute(self, params: dict) -> ToolResult: - session_id = current_session_id.get() + async def execute(self, params: dict, ctx: ToolContext | None = None) -> ToolResult: + session_id = ctx.session_id if ctx else current_session_id.get() if not session_id: return ToolResult( success=False, diff --git a/navi/tools/create_mcp_server.py b/navi/tools/create_mcp_server.py index f3af499..a151c74 100644 --- a/navi/tools/create_mcp_server.py +++ b/navi/tools/create_mcp_server.py @@ -8,7 +8,7 @@ from navi.config import settings -from ._internal.base import Tool, ToolResult +from ._internal.base import Tool, ToolContext, ToolResult # Template for pyproject.toml — placeholders {name}, {description}, {deps} _PYPROJECT_TEMPLATE = """[build-system] @@ -94,7 +94,7 @@ "required": ["name", "description"], } - async def execute(self, params: dict) -> ToolResult: + async def execute(self, params: dict, ctx: ToolContext | None = None) -> ToolResult: name = (params.get("name") or "").strip() description = params.get("description", "") dependencies: list[str] = params.get("dependencies") or [] diff --git a/navi/tools/filesystem.py b/navi/tools/filesystem.py index bfcc2f9..0832744 100644 --- a/navi/tools/filesystem.py +++ b/navi/tools/filesystem.py @@ -15,7 +15,7 @@ from navi.config import settings -from ._internal.base import Tool, ToolResult, current_user_id, current_user_role +from ._internal.base import Tool, ToolContext, ToolResult, current_user_id, current_user_role _READ_WARN_BYTES = 100_000 # 100 KB — add size warning in output _READ_HARD_BYTES = 1_000_000 # 1 MB — refuse full read without offset/limit @@ -69,7 +69,7 @@ # ── Path helpers ────────────────────────────────────────────────────────────── -def _check_path(path_str: str) -> Path | None: +def _check_path(path_str: str, user_id: str | None = None, role: str | None = None) -> Path | None: """Return resolved Path if access is allowed, else None. When a user_id is active (multi-user mode), all paths are resolved @@ -83,8 +83,8 @@ except Exception: return None - user_id = current_user_id.get(None) - role = current_user_role.get() + user_id = user_id or current_user_id.get(None) + role = role or current_user_role.get() # Admins bypass sandbox and use FS_ALLOWED_PATHS directly if user_id and role != "admin": @@ -350,10 +350,12 @@ # ai_helper is optional — standard actions work without it self._ai = ai_helper - async def execute(self, params: dict) -> ToolResult: + async def execute(self, params: dict, ctx: ToolContext | None = None) -> ToolResult: action = params.get("action", "") raw_path = params.get("path", "") - path = _check_path(raw_path) + user_id = ctx.user_id if ctx else current_user_id.get(None) + role = ctx.user_role if ctx else current_user_role.get() + path = _check_path(raw_path, user_id, role) if path is None: if not raw_path or raw_path.strip() == "": diff --git a/navi/tools/image_view.py b/navi/tools/image_view.py index 1ce1769..a774c28 100644 --- a/navi/tools/image_view.py +++ b/navi/tools/image_view.py @@ -16,7 +16,7 @@ import httpx from PIL import Image -from ._internal.base import Tool, ToolResult +from ._internal.base import Tool, ToolContext, ToolResult _TIMEOUT = 30 _SUPPORTED = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"} @@ -46,7 +46,7 @@ "required": ["source"], } - async def execute(self, params: dict) -> ToolResult: + async def execute(self, params: dict, ctx: ToolContext | None = None) -> ToolResult: source = params["source"].strip() try: if source.startswith(("http://", "https://")): diff --git a/navi/tools/list_profiles.py b/navi/tools/list_profiles.py index 24c250d..27f16cb 100644 --- a/navi/tools/list_profiles.py +++ b/navi/tools/list_profiles.py @@ -1,6 +1,6 @@ """Built-in tool: list available agent profiles with structured descriptions.""" -from navi.tools._internal.base import Tool, ToolResult +from navi.tools._internal.base import Tool, ToolContext, ToolResult class ListProfilesTool(Tool): @@ -24,7 +24,7 @@ def __init__(self, profile_registry) -> None: self._profiles = profile_registry - async def execute(self, params: dict) -> ToolResult: + async def execute(self, params: dict, ctx: ToolContext | None = None) -> ToolResult: profile_id = (params.get("profile_id") or "").strip() if profile_id: diff --git a/navi/tools/list_tools.py b/navi/tools/list_tools.py index cc47f0f..3fd49f6 100644 --- a/navi/tools/list_tools.py +++ b/navi/tools/list_tools.py @@ -7,7 +7,7 @@ from navi.mcp.tools import build_mcp_name -from ._internal.base import Tool, ToolResult +from ._internal.base import Tool, ToolContext, ToolResult _USER_ENABLED_FILE = Path("tools/enabled.json") @@ -46,7 +46,7 @@ self._profile_registry = profile_registry self._mcp_manager = mcp_manager - async def execute(self, params: dict) -> ToolResult: + async def execute(self, params: dict, ctx: ToolContext | None = None) -> ToolResult: if self._registry is None: return ToolResult(success=False, output="Registry not available.", error="no_registry") diff --git a/navi/tools/manage_recall.py b/navi/tools/manage_recall.py index c988021..5855bcf 100644 --- a/navi/tools/manage_recall.py +++ b/navi/tools/manage_recall.py @@ -2,7 +2,7 @@ from __future__ import annotations -from ._internal.base import Tool, ToolResult, current_session_id, current_user_id, current_user_role +from ._internal.base import Tool, ToolContext, ToolResult, current_session_id, current_user_id, current_user_role class ManageRecallTool(Tool): @@ -42,7 +42,7 @@ def __init__(self, scheduler: "RecallScheduler" | None = None) -> None: self._scheduler = scheduler - async def execute(self, params: dict) -> ToolResult: + async def execute(self, params: dict, ctx: ToolContext | None = None) -> ToolResult: from navi.core.scheduler import RecallScheduler scheduler = self._scheduler @@ -63,7 +63,7 @@ target_session = (params.get("session_id") or "").strip() if not target_session: - target_session = current_session_id.get(None) + target_session = ctx.session_id if ctx else current_session_id.get(None) if not target_session: return ToolResult( success=False, @@ -73,7 +73,7 @@ # cancel / skip need current-session context for ownership check if action in ("cancel", "skip"): - current_sid = current_session_id.get(None) + current_sid = ctx.session_id if ctx else current_session_id.get(None) if not current_sid: return ToolResult( success=False, @@ -116,9 +116,9 @@ ) # list - role = current_user_role.get() + role = ctx.user_role if ctx else current_user_role.get() is_admin = role == "admin" - user_id = current_user_id.get(None) + user_id = ctx.user_id if ctx else current_user_id.get(None) recalls = await scheduler.list_recalls( session_id=target_session, user_id=user_id if not is_admin else None, diff --git a/navi/tools/mcp_status.py b/navi/tools/mcp_status.py index 3cbd42e..624d99e 100644 --- a/navi/tools/mcp_status.py +++ b/navi/tools/mcp_status.py @@ -2,7 +2,7 @@ from navi.mcp import McpManager -from ._internal.base import Tool, ToolResult +from ._internal.base import Tool, ToolContext, ToolResult class McpStatusTool(Tool): @@ -22,7 +22,7 @@ def __init__(self, mcp_manager: McpManager | None = None) -> None: self._mcp_manager = mcp_manager - async def execute(self, params: dict) -> ToolResult: + async def execute(self, params: dict, ctx: ToolContext | None = None) -> ToolResult: manager = self._mcp_manager if manager is None: from navi.api.deps import _mcp_manager as _global_mcp_manager diff --git a/navi/tools/memory.py b/navi/tools/memory.py index 244c29d..5a7a71c 100644 --- a/navi/tools/memory.py +++ b/navi/tools/memory.py @@ -5,7 +5,7 @@ from navi.memory.store import MemoryStore from navi.tools._internal.base import current_session_id, current_user_id -from ._internal.base import Tool, ToolResult +from ._internal.base import Tool, ToolContext, ToolResult _VALID_CATEGORIES = {"profile", "preferences", "technical", "projects", "other"} _VALID_SOURCES = {"conversation", "tool_call", "auto_discovery", "user_explicit"} @@ -85,20 +85,23 @@ def __init__(self, memory_store: MemoryStore) -> None: self._store = memory_store - async def execute(self, params: dict) -> ToolResult: + async def execute(self, params: dict, ctx: ToolContext | None = None) -> ToolResult: action = params.get("action", "") + session_id = ctx.session_id if ctx else current_session_id.get(None) + user_id = ctx.user_id if ctx else current_user_id.get(None) + if action == "save": - return await self._save(params) + return await self._save(params, session_id, user_id) if action == "search": - return await self._search(params) + return await self._search(params, user_id) if action == "forget": - return await self._forget(params) + return await self._forget(params, user_id) if action == "list": - return await self._list() + return await self._list(user_id) return ToolResult(success=False, output=f"Unknown action '{action}'.", error="invalid action") - async def _save(self, params: dict) -> ToolResult: + async def _save(self, params: dict, session_id: str | None = None, user_id: str | None = None) -> ToolResult: category = (params.get("category") or "").strip().lower() key = (params.get("key") or "").strip() value = (params.get("value") or "").strip() @@ -135,8 +138,6 @@ except Exception: pass - session_id = current_session_id.get(None) - user_id = current_user_id.get(None) await self._store.upsert_fact( category=category, key=key, @@ -150,12 +151,11 @@ ) return ToolResult(success=True, output=f"Saved [{category}] {key}: {value}") - async def _search(self, params: dict) -> ToolResult: + async def _search(self, params: dict, user_id: str | None = None) -> ToolResult: query = (params.get("query") or "").strip() if not query: return ToolResult(success=False, output="query is required for search.", error="missing query") - user_id = current_user_id.get(None) facts = await self._store.search_facts(query, user_id=user_id, limit=15) if not facts: return ToolResult(success=True, output="No matching facts found in memory.") @@ -173,14 +173,13 @@ lines.append(f"[{f['category']}] {f['key']}: {f['value']}{prov}") return ToolResult(success=True, output=f"Found {len(facts)} fact(s):\n" + "\n".join(lines)) - async def _forget(self, params: dict) -> ToolResult: + async def _forget(self, params: dict, user_id: str | None = None) -> ToolResult: key = (params.get("key") or "").strip() category = (params.get("category") or "").strip() or None if not key: return ToolResult(success=False, output="key is required for forget.", error="missing key") - user_id = current_user_id.get(None) deleted = await self._store.delete_fact(key, category, user_id=user_id) if deleted == 0: return ToolResult(success=False, output=f"No fact found with key '{key}'.", error="not found") @@ -188,8 +187,7 @@ noun = "fact" if deleted == 1 else "facts" return ToolResult(success=True, output=f"Deleted {deleted} {noun} with key '{key}'.") - async def _list(self) -> ToolResult: - user_id = current_user_id.get(None) + async def _list(self, user_id: str | None = None) -> ToolResult: categories = await self._store.get_categories(user_id=user_id) if not categories: return ToolResult(success=True, output="Memory is empty.") diff --git a/navi/tools/reflect.py b/navi/tools/reflect.py index 286d1c6..1c45055 100644 --- a/navi/tools/reflect.py +++ b/navi/tools/reflect.py @@ -12,7 +12,7 @@ import asyncio -from navi.tools._internal.base import Tool, ToolResult +from navi.tools._internal.base import Tool, ToolContext, ToolResult # ── Advisor system prompts ───────────────────────────────────────────────── @@ -125,7 +125,7 @@ def __init__(self, ai_helper) -> None: self._ai = ai_helper - async def execute(self, params: dict) -> ToolResult: + async def execute(self, params: dict, ctx: ToolContext | None = None) -> ToolResult: situation = (params.get("situation") or "").strip() assumptions = params.get("assumptions") or [] tried = (params.get("tried") or "").strip() or None diff --git a/navi/tools/reload_tools.py b/navi/tools/reload_tools.py index 1b23d0e..afb335b 100644 --- a/navi/tools/reload_tools.py +++ b/navi/tools/reload_tools.py @@ -2,7 +2,7 @@ from navi.config import settings -from ._internal.base import Tool, ToolResult +from ._internal.base import Tool, ToolContext, ToolResult class ReloadToolsTool(Tool): @@ -25,7 +25,7 @@ self._cp_registry = cp_registry self._mcp_manager = mcp_manager - async def execute(self, params: dict) -> ToolResult: + async def execute(self, params: dict, ctx: ToolContext | None = None) -> ToolResult: if self._registry is None: return ToolResult(success=False, output="Tool registry not available.", error="no_registry") diff --git a/navi/tools/schedule_recall.py b/navi/tools/schedule_recall.py index c3c493d..21e3e56 100644 --- a/navi/tools/schedule_recall.py +++ b/navi/tools/schedule_recall.py @@ -4,7 +4,7 @@ from datetime import datetime, timedelta, timezone -from ._internal.base import Tool, ToolResult, current_session_id, current_user_role +from ._internal.base import Tool, ToolContext, ToolResult, current_session_id, current_user_role from ._internal.time_parser import parse_when @@ -95,7 +95,7 @@ def __init__(self, scheduler: "RecallScheduler" | None = None) -> None: self._scheduler = scheduler - async def execute(self, params: dict) -> ToolResult: + async def execute(self, params: dict, ctx: ToolContext | None = None) -> ToolResult: from navi.core.scheduler import RecallExistsError, RecallScheduler scheduler = self._scheduler @@ -106,7 +106,7 @@ error="no scheduler", ) - session_id = current_session_id.get(None) + session_id = ctx.session_id if ctx else current_session_id.get(None) if not session_id: return ToolResult( success=False, diff --git a/navi/tools/scratchpad.py b/navi/tools/scratchpad.py index 0741ee1..401be1a 100644 --- a/navi/tools/scratchpad.py +++ b/navi/tools/scratchpad.py @@ -1,7 +1,7 @@ """Session-scoped scratchpad for capturing working notes during task execution — backed by PostgreSQL KV store.""" from __future__ import annotations -from navi.tools._internal.base import Tool, ToolResult, current_session_id, current_user_id +from navi.tools._internal.base import Tool, ToolContext, ToolResult, current_session_id, current_user_id # Global KV store reference — injected at startup by registry.py _kv_store = None @@ -13,20 +13,20 @@ _kv_store = kv -def _sid() -> str: - return current_session_id.get() or "__default__" +def _sid(explicit: str | None = None) -> str: + return explicit or current_session_id.get() or "__default__" -def _uid() -> str | None: - return current_user_id.get(None) +def _uid(explicit: str | None = None) -> str | None: + return explicit if explicit is not None else current_user_id.get(None) -async def get_section(session_id: str, section: str) -> str: +async def get_section(session_id: str, section: str, user_id: str | None = None) -> str: """Read one scratchpad section for the given session. Returns '' if absent.""" if _kv_store is None: return "" try: - val = await _kv_store.get(_uid(), session_id, "scratchpad", section) + val = await _kv_store.get(_uid(user_id), session_id, "scratchpad", section) return val or "" except Exception: return "" @@ -78,8 +78,8 @@ if kv_store is not None: set_kv_store(kv_store) - async def execute(self, params: dict) -> ToolResult: - sid = _sid() + async def execute(self, params: dict, ctx: ToolContext | None = None) -> ToolResult: + sid = _sid(ctx.session_id if ctx else None) op = params.get("op") section: str | None = params.get("section") or None content: str = params.get("content", "") @@ -89,7 +89,7 @@ return ToolResult(success=False, output="", error="'content' is required for 'write'") key = section or "main" if _kv_store is not None: - await _kv_store.set(_uid(), sid, "scratchpad", key, content) + await _kv_store.set(_uid(ctx.user_id if ctx else None), sid, "scratchpad", key, content) return ToolResult(success=True, output=f"[{key}] written ({len(content)} chars).") if op == "append": @@ -97,9 +97,9 @@ return ToolResult(success=False, output="", error="'content' is required for 'append'") key = section or "main" if _kv_store is not None: - existing = await _kv_store.get(_uid(), sid, "scratchpad", key) or "" + existing = await _kv_store.get(_uid(ctx.user_id if ctx else None), sid, "scratchpad", key) or "" new = (existing + "\n" + content).lstrip("\n") if existing else content - await _kv_store.set(_uid(), sid, "scratchpad", key, new) + await _kv_store.set(_uid(ctx.user_id if ctx else None), sid, "scratchpad", key, new) return ToolResult(success=True, output=f"[{key}] updated ({len(new)} chars total).") return ToolResult(success=True, output=f"[{key}] updated.") @@ -107,12 +107,12 @@ if _kv_store is None: return ToolResult(success=True, output="Scratchpad is empty.") if section is not None: - text = await _kv_store.get(_uid(), sid, "scratchpad", section) + text = await _kv_store.get(_uid(ctx.user_id if ctx else None), sid, "scratchpad", section) if not text: return ToolResult(success=True, output=f"[{section}] is empty.") return ToolResult(success=True, output=f"[{section}]:\n{text}") # No section → read all - all_data = await _kv_store.get_all(_uid(), sid, "scratchpad") + all_data = await _kv_store.get_all(_uid(ctx.user_id if ctx else None), sid, "scratchpad") if not all_data: return ToolResult(success=True, output="Scratchpad is empty.") parts = [f"[{k}]:\n{v}" for k, v in all_data.items()] @@ -122,13 +122,13 @@ if _kv_store is None: return ToolResult(success=True, output="Scratchpad cleared.") if section is not None: - existing = await _kv_store.get(_uid(), sid, "scratchpad", section) - await _kv_store.delete(_uid(), sid, "scratchpad", section) + existing = await _kv_store.get(_uid(ctx.user_id if ctx else None), sid, "scratchpad", section) + await _kv_store.delete(_uid(ctx.user_id if ctx else None), sid, "scratchpad", section) return ToolResult( success=True, output=f"[{section}] cleared." if existing else f"[{section}] was already empty.", ) - await _kv_store.clear_scope(_uid(), sid, "scratchpad") + await _kv_store.clear_scope(_uid(ctx.user_id if ctx else None), sid, "scratchpad") return ToolResult(success=True, output="Scratchpad cleared.") return ToolResult(success=False, output="", error=f"Unknown op: {op!r}") diff --git a/navi/tools/share_file.py b/navi/tools/share_file.py index c1ccb4f..65e38ed 100644 --- a/navi/tools/share_file.py +++ b/navi/tools/share_file.py @@ -8,7 +8,7 @@ from navi.config import settings from navi.session_files import ensure_session_dir -from ._internal.base import Tool, ToolResult, current_session_id, current_user_role, current_user_id +from ._internal.base import Tool, ToolContext, ToolResult, current_session_id, current_user_role, current_user_id def _fmt_size(n: int) -> str: @@ -57,14 +57,14 @@ "required": ["path"], } - async def execute(self, params: dict) -> ToolResult: - session_id = current_session_id.get() + async def execute(self, params: dict, ctx: ToolContext | None = None) -> ToolResult: + session_id = ctx.session_id if ctx else current_session_id.get() if not session_id: return ToolResult(success=False, output="No active session context.", error="no_session") raw_path = Path(params["path"]).expanduser() - user_id = current_user_id.get(None) - role = current_user_role.get() + user_id = ctx.user_id if ctx else current_user_id.get(None) + role = ctx.user_role if ctx else current_user_role.get() if user_id and role != "admin": sandbox = Path("user_data") / user_id diff --git a/navi/tools/spawn_agent.py b/navi/tools/spawn_agent.py index 285449e..14e2828 100644 --- a/navi/tools/spawn_agent.py +++ b/navi/tools/spawn_agent.py @@ -10,7 +10,7 @@ from navi.exceptions import ProfileNotFound -from ._internal.base import Tool, ToolResult, current_session_id +from ._internal.base import Tool, ToolContext, ToolResult, current_session_id log = structlog.get_logger() @@ -104,7 +104,7 @@ self._memory_store = memory_store self._mcp_manager = mcp_manager - async def execute(self, params: dict) -> ToolResult: + async def execute(self, params: dict, ctx: ToolContext | None = None) -> ToolResult: # Import here to avoid module-level circular import from navi.core.agent import Agent from navi.tools.scratchpad import get_section @@ -122,7 +122,7 @@ # Resolve profile: explicit override → parent session's profile → first available profile_id = params.get("profile_id", "").strip() if not profile_id: - profile_id = await self._resolve_parent_profile() + profile_id = await self._resolve_parent_profile(ctx) try: selected_profile = self._profile_registry.get(profile_id) except ProfileNotFound: @@ -137,7 +137,7 @@ ) # Read parent scratchpad context_transfer section and pass it to the sub-agent. - parent_sid = current_session_id.get() + parent_sid = ctx.session_id if ctx else current_session_id.get() context_transfer = (await get_section(parent_sid, "context_transfer")) if parent_sid else "" scope = selected_profile.get_subagent_tools() @@ -197,9 +197,9 @@ log.error("spawn_agent.error", error=str(e), exc_info=True) return ToolResult(success=False, output=f"Sub-agent failed: {e}", error=str(e)) - async def _resolve_parent_profile(self) -> str: + async def _resolve_parent_profile(self, ctx: ToolContext | None = None) -> str: """Return the profile of the current parent session, or fallback to first profile.""" - session_id = current_session_id.get() + session_id = ctx.session_id if ctx else current_session_id.get() if session_id and self._session_store: try: session = await self._session_store.get(session_id) diff --git a/navi/tools/ssh_exec.py b/navi/tools/ssh_exec.py index ca7fd74..3294ad6 100644 --- a/navi/tools/ssh_exec.py +++ b/navi/tools/ssh_exec.py @@ -38,7 +38,7 @@ from navi.config import settings -from ._internal.base import Tool, ToolResult, current_session_id +from ._internal.base import Tool, ToolContext, ToolResult, current_session_id _TIMEOUT = 60 _TTL = 20 * 60 # 20 minutes in seconds @@ -191,7 +191,7 @@ "required": [], } - async def execute(self, params: dict) -> ToolResult: + async def execute(self, params: dict, ctx: ToolContext | None = None) -> ToolResult: action = params.get("action", "exec") timeout = int(params.get("timeout") or _TIMEOUT) @@ -210,7 +210,7 @@ if not command: return ToolResult(success=False, output="'command' is required for exec action.", error="missing_command") - session_id = current_session_id.get() + session_id = ctx.session_id if ctx else current_session_id.get() host = connect_kwargs["host"] port = int(connect_kwargs.get("port", 22)) username = connect_kwargs.get("username", "") diff --git a/navi/tools/switch_profile.py b/navi/tools/switch_profile.py index 16643a2..2baa39c 100644 --- a/navi/tools/switch_profile.py +++ b/navi/tools/switch_profile.py @@ -1,6 +1,6 @@ """Built-in tool for switching the active agent profile mid-session.""" -from navi.tools._internal.base import Tool, ToolResult, current_event_sink, current_session_id +from navi.tools._internal.base import Tool, ToolContext, ToolResult, current_event_sink, current_session_id class SwitchProfileTool(Tool): @@ -27,7 +27,7 @@ self._sessions = session_store self._profiles = profile_registry - async def execute(self, params: dict) -> ToolResult: + async def execute(self, params: dict, ctx: ToolContext | None = None) -> ToolResult: profile_id = (params.get("profile_id") or "").strip() available = ", ".join(p.id for p in self._profiles.all()) @@ -48,7 +48,7 @@ error="subagent_only", ) - sid = current_session_id.get() + sid = ctx.session_id if ctx else current_session_id.get() if not sid: return ToolResult(success=False, output="", error="No active session context.") @@ -66,7 +66,7 @@ await self._sessions.save(session) # Notify the client immediately so it can update the UI. - sink = current_event_sink.get() + sink = ctx.event_sink if ctx else current_event_sink.get() if sink is not None: from navi.core.events import ProfileSwitched await sink.put(ProfileSwitched(profile_id=profile_id, profile_name=profile.name)) diff --git a/navi/tools/terminal.py b/navi/tools/terminal.py index 658b92d..72a8de1 100644 --- a/navi/tools/terminal.py +++ b/navi/tools/terminal.py @@ -21,7 +21,7 @@ from navi.config import settings -from ._internal.base import Tool, ToolResult, current_user_id, current_user_role +from ._internal.base import Tool, ToolContext, ToolResult, current_user_id, current_user_role _DEFAULT_TIMEOUT = 20 _MAX_TIMEOUT = 300 @@ -58,10 +58,10 @@ return None -def _resolve_working_dir(working_dir: str | None) -> Path | None: +def _resolve_working_dir(working_dir: str | None, user_id: str | None = None, role: str | None = None) -> Path | None: """Resolve working directory with sandbox enforcement for non-admins.""" - user_id = current_user_id.get(None) - role = current_user_role.get() + user_id = user_id or current_user_id.get(None) + role = role or current_user_role.get() if user_id and role != "admin": sandbox = Path("user_data") / user_id @@ -111,7 +111,7 @@ "required": ["command"], } - async def execute(self, params: dict) -> ToolResult: + async def execute(self, params: dict, ctx: ToolContext | None = None) -> ToolResult: command = params["command"].strip() working_dir = params.get("working_dir") or None raw_timeout = params.get("timeout") @@ -123,8 +123,8 @@ if not command: return ToolResult(success=False, output="Empty command.", error="empty_command") - role = current_user_role.get() - user_id = current_user_id.get(None) + role = ctx.user_role if ctx else current_user_role.get() + user_id = ctx.user_id if ctx else current_user_id.get(None) # Admins and single-user / legacy mode: use the existing restriction logic if not user_id or role == "admin": @@ -135,7 +135,7 @@ return await self._run_restricted(command, working_dir, timeout) # Non-admin multi-user mode: sandbox + curated allowlist + dangerous-pattern block - cwd = _resolve_working_dir(working_dir) + cwd = _resolve_working_dir(working_dir, user_id, role) if cwd is None and working_dir: return ToolResult( success=False, diff --git a/navi/tools/test_mcp_tool.py b/navi/tools/test_mcp_tool.py index 4115e8b..80eb4d0 100644 --- a/navi/tools/test_mcp_tool.py +++ b/navi/tools/test_mcp_tool.py @@ -2,7 +2,7 @@ import asyncio -from ._internal.base import Tool, ToolResult +from ._internal.base import Tool, ToolContext, ToolResult class TestMcpToolTool(Tool): @@ -37,7 +37,7 @@ def __init__(self, mcp_manager=None) -> None: self._mcp_manager = mcp_manager - async def execute(self, params: dict) -> ToolResult: + async def execute(self, params: dict, ctx: ToolContext | None = None) -> ToolResult: server_name = (params.get("server_name") or "").strip() tool_name = (params.get("tool_name") or "").strip() arguments: dict = params.get("arguments") or {} diff --git a/navi/tools/todo.py b/navi/tools/todo.py index 02abe2b..14a7ab2 100644 --- a/navi/tools/todo.py +++ b/navi/tools/todo.py @@ -4,7 +4,7 @@ import json from dataclasses import dataclass, field -from navi.tools._internal.base import Tool, ToolResult, current_session_id, current_user_id +from navi.tools._internal.base import Tool, ToolContext, ToolResult, current_session_id, current_user_id _STATUS_ICON: dict[str, str] = { "pending": "○", @@ -31,12 +31,12 @@ _kv_store = kv -def _sid() -> str: - return current_session_id.get() or "__default__" +def _sid(explicit: str | None = None) -> str: + return explicit or current_session_id.get() or "__default__" -def _uid() -> str | None: - return current_user_id.get(None) +def _uid(explicit: str | None = None) -> str | None: + return explicit if explicit is not None else current_user_id.get(None) async def _load_tasks(sid: str) -> list[_Task]: @@ -116,8 +116,8 @@ if kv_store is not None: set_kv_store(kv_store) - async def execute(self, params: dict) -> ToolResult: - sid = _sid() + async def execute(self, params: dict, ctx: ToolContext | None = None) -> ToolResult: + sid = _sid(ctx.session_id if ctx else None) op = params.get("op") if op == "set": @@ -179,7 +179,7 @@ if op == "clear": if _kv_store is not None: - await _kv_store.clear_scope(_uid(), sid, "todo") + await _kv_store.clear_scope(_uid(ctx.user_id if ctx else None), sid, "todo") return ToolResult(success=True, output="Plan cleared.") return ToolResult(success=False, output="", error=f"Unknown op: {op!r}") diff --git a/navi/tools/tool_manual.py b/navi/tools/tool_manual.py index 7d0c3e2..07ecfb1 100644 --- a/navi/tools/tool_manual.py +++ b/navi/tools/tool_manual.py @@ -2,7 +2,7 @@ from pathlib import Path -from ._internal.base import Tool, ToolResult +from ._internal.base import Tool, ToolContext, ToolResult MANUALS_DIR = Path(__file__).parent.parent.parent / "manuals" @@ -27,7 +27,7 @@ def __init__(self, registry=None) -> None: self._registry = registry - async def execute(self, params: dict) -> ToolResult: + async def execute(self, params: dict, ctx: ToolContext | None = None) -> ToolResult: tool_name = params["tool_name"].strip() manual_file = MANUALS_DIR / f"{tool_name}.md" diff --git a/tests/conftest_factory.py b/tests/conftest_factory.py index 03a0da4..fd6aa7b 100644 --- a/tests/conftest_factory.py +++ b/tests/conftest_factory.py @@ -127,7 +127,7 @@ }, } - async def execute(self, arguments: dict) -> "FakeToolResult": + async def execute(self, arguments: dict, ctx=None) -> "FakeToolResult": from navi.tools._internal.base import ToolResult return ToolResult(success=True, output=f"executed {self.name}") diff --git a/tests/unit/tools/test_content_publish.py b/tests/unit/tools/test_content_publish.py index 4f42557..980d1f7 100644 --- a/tests/unit/tools/test_content_publish.py +++ b/tests/unit/tools/test_content_publish.py @@ -5,7 +5,7 @@ import navi.content_store as content_store_mod import navi.session_files as session_files_mod import navi.tools.content_publish as content_publish_mod -from navi.tools._internal.base import current_session_id +from navi.tools._internal.base import ToolContext from navi.tools.content_publish import ContentPublishTool @@ -18,24 +18,16 @@ monkeypatch.setattr(content_publish_mod, "settings", _test_settings) monkeypatch.setattr(session_files_mod, "settings", _test_settings) monkeypatch.setattr(content_store_mod, "settings", _test_settings) - token = current_session_id.set("sess-1") - try: - yield ContentPublishTool() - finally: - current_session_id.reset(token) + yield ContentPublishTool() async def test_requires_active_session(self): - token = current_session_id.set(None) - try: - result = await ContentPublishTool().execute({"filename": "logo.svg"}) - finally: - current_session_id.reset(token) + result = await ContentPublishTool().execute({"filename": "logo.svg"}, ctx=ToolContext(session_id=None)) assert not result.success assert result.error == "no_session" async def test_missing_file_reports_session_dir(self, tool, tmp_path): - result = await tool.execute({"filename": "missing.svg"}) + result = await tool.execute({"filename": "missing.svg"}, ctx=ToolContext(session_id="sess-1")) assert not result.success assert result.error == "not_found" @@ -46,7 +38,7 @@ sess_dir = tmp_path / "sessions" / "sess-1" (sess_dir / "folder").mkdir(parents=True) - result = await tool.execute({"filename": "folder"}) + result = await tool.execute({"filename": "folder"}, ctx=ToolContext(session_id="sess-1")) assert not result.success assert result.error == "not_a_file" @@ -73,7 +65,7 @@ "filename": "../logo.svg", "title": "Logo", "content_type": "svg", - }) + }, ctx=ToolContext(session_id="sess-1")) assert result.success assert calls == [{ @@ -109,7 +101,7 @@ result = await tool.execute({ "filename": "model.stl", "source_filename": "model.scad", - }) + }, ctx=ToolContext(session_id="sess-1")) assert result.success assert calls[0]["source_filename"] == "model.scad" diff --git a/tests/unit/tools/test_memory.py b/tests/unit/tools/test_memory.py index 775ba6f..94c7881 100644 --- a/tests/unit/tools/test_memory.py +++ b/tests/unit/tools/test_memory.py @@ -5,7 +5,7 @@ import pytest from navi.tools.memory import MemoryTool -from navi.tools._internal.base import current_session_id, current_user_id +from navi.tools._internal.base import ToolContext @pytest.fixture @@ -18,15 +18,6 @@ return store -@pytest.fixture(autouse=True) -def _fake_ctx(): - token = current_session_id.set("sess1") - token2 = current_user_id.set("user1") - yield - current_session_id.reset(token) - current_user_id.reset(token2) - - @pytest.mark.asyncio async def test_save(fake_store): tool = MemoryTool(fake_store) @@ -37,7 +28,7 @@ "value": "Arch Linux", "source": "tool_call", "confidence": 95, - }) + }, ctx=ToolContext(session_id="sess1", user_id="user1")) assert result.success is True assert "Saved [technical] primary_os" in result.output fake_store.upsert_fact.assert_awaited_once() @@ -50,7 +41,7 @@ "action": "save", "key": "primary_os", "value": "Arch Linux", - }) + }, ctx=ToolContext(session_id="sess1", user_id="user1")) assert result.success is False assert "category" in result.output.lower() @@ -63,7 +54,7 @@ "category": "invalid", "key": "primary_os", "value": "Arch Linux", - }) + }, ctx=ToolContext(session_id="sess1", user_id="user1")) assert result.success is False assert "Invalid category" in result.output @@ -74,7 +65,7 @@ result = await tool.execute({ "action": "search", "query": "nonexistent", - }) + }, ctx=ToolContext(session_id="sess1", user_id="user1")) assert result.success is True assert "No matching facts" in result.output @@ -93,7 +84,7 @@ result = await tool.execute({ "action": "search", "query": "os", - }) + }, ctx=ToolContext(session_id="sess1", user_id="user1")) assert result.success is True assert "Arch Linux" in result.output assert "technical" in result.output @@ -105,7 +96,7 @@ result = await tool.execute({ "action": "forget", "key": "primary_os", - }) + }, ctx=ToolContext(session_id="sess1", user_id="user1")) assert result.success is True assert "Deleted 1 fact" in result.output @@ -117,7 +108,7 @@ result = await tool.execute({ "action": "forget", "key": "missing", - }) + }, ctx=ToolContext(session_id="sess1", user_id="user1")) assert result.success is False assert "No fact found" in result.output @@ -126,7 +117,7 @@ async def test_list_categories(fake_store): fake_store.get_categories = AsyncMock(return_value=["technical", "preferences"]) tool = MemoryTool(fake_store) - result = await tool.execute({"action": "list"}) + result = await tool.execute({"action": "list"}, ctx=ToolContext(session_id="sess1", user_id="user1")) assert result.success is True assert "technical" in result.output assert "preferences" in result.output @@ -135,6 +126,6 @@ @pytest.mark.asyncio async def test_list_empty(fake_store): tool = MemoryTool(fake_store) - result = await tool.execute({"action": "list"}) + result = await tool.execute({"action": "list"}, ctx=ToolContext(session_id="sess1", user_id="user1")) assert result.success is True assert "empty" in result.output.lower() diff --git a/tests/unit/tools/test_recall_tools.py b/tests/unit/tools/test_recall_tools.py index 8fd9a11..823ff31 100644 --- a/tests/unit/tools/test_recall_tools.py +++ b/tests/unit/tools/test_recall_tools.py @@ -8,7 +8,7 @@ from navi.core.scheduler import Recall, RecallExistsError from navi.tools.schedule_recall import ScheduleRecallTool from navi.tools.manage_recall import ManageRecallTool -from navi.tools._internal.base import current_session_id, current_user_role +from navi.tools._internal.base import ToolContext class TestScheduleRecallTool: @@ -22,13 +22,12 @@ updated_at=datetime.now(timezone.utc), ) tool = ScheduleRecallTool(scheduler) - current_session_id.set("s1") result = await tool.execute({ "call_type": "once", "when": "1h", "additional_context_message": "ctx", - }) + }, ctx=ToolContext(session_id="s1")) assert result.success is True assert "r1" in result.output @@ -44,12 +43,11 @@ updated_at=datetime.now(timezone.utc), ) tool = ScheduleRecallTool(scheduler) - current_session_id.set("s1") result = await tool.execute({ "call_type": "immediate", "additional_context_message": "ctx", - }) + }, ctx=ToolContext(session_id="s1")) assert result.success is True assert "immediate" in result.output @@ -64,14 +62,13 @@ updated_at=datetime.now(timezone.utc), ) tool = ScheduleRecallTool(scheduler) - current_session_id.set("s1") result = await tool.execute({ "call_type": "recurring", "when": "1h", "interval_seconds": 3600, "additional_context_message": "ctx", - }) + }, ctx=ToolContext(session_id="s1")) assert result.success is True assert "3600s" in result.output @@ -79,12 +76,11 @@ async def test_missing_context_message(self, monkeypatch): scheduler = AsyncMock() tool = ScheduleRecallTool(scheduler) - current_session_id.set("s1") result = await tool.execute({ "call_type": "once", "when": "1h", - }) + }, ctx=ToolContext(session_id="s1")) assert result.success is False assert "missing context" in result.error @@ -93,13 +89,12 @@ scheduler = AsyncMock() scheduler.schedule_recall.side_effect = RecallExistsError("exists") tool = ScheduleRecallTool(scheduler) - current_session_id.set("s1") result = await tool.execute({ "call_type": "once", "when": "1h", "additional_context_message": "ctx", - }) + }, ctx=ToolContext(session_id="s1")) assert result.success is False assert result.error == "recall_exists" @@ -107,13 +102,12 @@ async def test_invalid_call_type(self, monkeypatch): scheduler = AsyncMock() tool = ScheduleRecallTool(scheduler) - current_session_id.set("s1") result = await tool.execute({ "call_type": "invalid", "when": "1h", "additional_context_message": "ctx", - }) + }, ctx=ToolContext(session_id="s1")) assert result.success is False assert "bad_call_type" in result.error @@ -124,9 +118,8 @@ scheduler = AsyncMock() scheduler.cancel_recall.return_value = True tool = ManageRecallTool(scheduler) - current_session_id.set("s1") - result = await tool.execute({"action": "cancel"}) + result = await tool.execute({"action": "cancel"}, ctx=ToolContext(session_id="s1")) assert result.success is True assert "cancelled" in result.output @@ -136,9 +129,8 @@ scheduler = AsyncMock() scheduler.skip_next_recall.return_value = True tool = ManageRecallTool(scheduler) - current_session_id.set("s1") - result = await tool.execute({"action": "skip"}) + result = await tool.execute({"action": "skip"}, ctx=ToolContext(session_id="s1")) assert result.success is True assert "skipped" in result.output @@ -156,10 +148,8 @@ ) ] tool = ManageRecallTool(scheduler) - current_session_id.set("s1") - current_user_role.set("admin") - result = await tool.execute({"action": "list"}) + result = await tool.execute({"action": "list"}, ctx=ToolContext(session_id="s1", user_role="admin")) assert result.success is True assert "Recalls for session" in result.output @@ -168,9 +158,8 @@ async def test_bad_action(self, monkeypatch): scheduler = AsyncMock() tool = ManageRecallTool(scheduler) - current_session_id.set("s1") - result = await tool.execute({"action": "unknown"}) + result = await tool.execute({"action": "unknown"}, ctx=ToolContext(session_id="s1")) assert result.success is False assert result.error == "bad_action" @@ -179,9 +168,8 @@ scheduler = AsyncMock() scheduler.cancel_recall.return_value = False tool = ManageRecallTool(scheduler) - current_session_id.set("s1") - result = await tool.execute({"action": "cancel"}) + result = await tool.execute({"action": "cancel"}, ctx=ToolContext(session_id="s1")) assert result.success is False assert "no_pending_recall" in result.error diff --git a/tests/unit/tools/test_scratchpad.py b/tests/unit/tools/test_scratchpad.py index efbc560..26ad10a 100644 --- a/tests/unit/tools/test_scratchpad.py +++ b/tests/unit/tools/test_scratchpad.py @@ -3,7 +3,7 @@ import pytest from navi.tools.scratchpad import ScratchpadTool, get_section, _kv_store -from navi.tools._internal.base import current_session_id, current_user_id +from navi.tools._internal.base import ToolContext from navi.store import KvStore from tests.conftest_factory import FakePool @@ -46,22 +46,13 @@ _mod._kv_store = None -@pytest.fixture(autouse=True) -def _fake_ctx(): - token = current_session_id.set("sess1") - token2 = current_user_id.set("user1") - yield - current_session_id.reset(token) - current_user_id.reset(token2) - - @pytest.mark.asyncio async def test_write_then_read(_fake_kv): tool = ScratchpadTool() - result = await tool.execute({"op": "write", "section": "findings", "content": "found X"}) + result = await tool.execute({"op": "write", "section": "findings", "content": "found X"}, ctx=ToolContext(session_id="sess1", user_id="user1")) assert result.success is True - result = await tool.execute({"op": "read", "section": "findings"}) + result = await tool.execute({"op": "read", "section": "findings"}, ctx=ToolContext(session_id="sess1", user_id="user1")) assert result.success is True assert "found X" in result.output @@ -69,11 +60,11 @@ @pytest.mark.asyncio async def test_append(_fake_kv): tool = ScratchpadTool() - await tool.execute({"op": "write", "section": "findings", "content": "line1"}) - result = await tool.execute({"op": "append", "section": "findings", "content": "line2"}) + await tool.execute({"op": "write", "section": "findings", "content": "line1"}, ctx=ToolContext(session_id="sess1", user_id="user1")) + result = await tool.execute({"op": "append", "section": "findings", "content": "line2"}, ctx=ToolContext(session_id="sess1", user_id="user1")) assert result.success is True - result = await tool.execute({"op": "read", "section": "findings"}) + result = await tool.execute({"op": "read", "section": "findings"}, ctx=ToolContext(session_id="sess1", user_id="user1")) assert "line1" in result.output assert "line2" in result.output @@ -81,9 +72,9 @@ @pytest.mark.asyncio async def test_read_all_sections(_fake_kv): tool = ScratchpadTool() - await tool.execute({"op": "write", "section": "goal", "content": "do thing"}) - await tool.execute({"op": "write", "section": "findings", "content": "found"}) - result = await tool.execute({"op": "read"}) + await tool.execute({"op": "write", "section": "goal", "content": "do thing"}, ctx=ToolContext(session_id="sess1", user_id="user1")) + await tool.execute({"op": "write", "section": "findings", "content": "found"}, ctx=ToolContext(session_id="sess1", user_id="user1")) + result = await tool.execute({"op": "read"}, ctx=ToolContext(session_id="sess1", user_id="user1")) assert result.success is True assert "goal" in result.output assert "findings" in result.output @@ -92,27 +83,26 @@ @pytest.mark.asyncio async def test_clear_section(_fake_kv): tool = ScratchpadTool() - await tool.execute({"op": "write", "section": "goal", "content": "do thing"}) - result = await tool.execute({"op": "clear", "section": "goal"}) + await tool.execute({"op": "write", "section": "goal", "content": "do thing"}, ctx=ToolContext(session_id="sess1", user_id="user1")) + result = await tool.execute({"op": "clear", "section": "goal"}, ctx=ToolContext(session_id="sess1", user_id="user1")) assert result.success is True - result = await tool.execute({"op": "read", "section": "goal"}) + result = await tool.execute({"op": "read", "section": "goal"}, ctx=ToolContext(session_id="sess1", user_id="user1")) assert "empty" in result.output.lower() @pytest.mark.asyncio async def test_scope_isolation(_fake_kv): tool = ScratchpadTool() - await tool.execute({"op": "write", "section": "main", "content": "sess1 data"}) + await tool.execute({"op": "write", "section": "main", "content": "sess1 data"}, ctx=ToolContext(session_id="sess1", user_id="user1")) - current_session_id.set("sess2") - result = await tool.execute({"op": "read", "section": "main"}) + result = await tool.execute({"op": "read", "section": "main"}, ctx=ToolContext(session_id="sess2", user_id="user1")) assert "empty" in result.output.lower() @pytest.mark.asyncio async def test_get_section(_fake_kv): tool = ScratchpadTool() - await tool.execute({"op": "write", "section": "artifacts", "content": "path/to/file"}) - text = await get_section("sess1", "artifacts") + await tool.execute({"op": "write", "section": "artifacts", "content": "path/to/file"}, ctx=ToolContext(session_id="sess1", user_id="user1")) + text = await get_section("sess1", "artifacts", user_id="user1") assert text == "path/to/file" diff --git a/tests/unit/tools/test_share_file.py b/tests/unit/tools/test_share_file.py index bac9422..8dd6422 100644 --- a/tests/unit/tools/test_share_file.py +++ b/tests/unit/tools/test_share_file.py @@ -7,7 +7,7 @@ import navi.session_files as session_files_mod import navi.tools.share_file as share_file_mod from navi.config import Settings -from navi.tools._internal.base import current_session_id +from navi.tools._internal.base import ToolContext from navi.tools.share_file import ShareFileTool @@ -25,14 +25,10 @@ ) monkeypatch.setattr(share_file_mod, "settings", _test_settings) monkeypatch.setattr(session_files_mod, "settings", _test_settings) - token = current_session_id.set("sess 1") - try: - yield ShareFileTool() - finally: - current_session_id.reset(token) + yield ShareFileTool() async def test_rejects_relative_path(self, tool): - result = await tool.execute({"path": "workspace/report.txt"}) + result = await tool.execute({"path": "workspace/report.txt"}, ctx=ToolContext(session_id="sess 1")) assert not result.success assert result.error == "path_not_absolute" @@ -41,7 +37,7 @@ src = tmp_path / "report.txt" src.write_text("hello") - result = await tool.execute({"path": str(src), "filename": "clean report.txt"}) + result = await tool.execute({"path": str(src), "filename": "clean report.txt"}, ctx=ToolContext(session_id="sess 1")) assert result.success dest = tmp_path / "sessions" / "sess 1" / "clean report.txt" @@ -59,7 +55,7 @@ src = tmp_path / "too_large.bin" src.write_bytes(b"x") - result = await tool.execute({"path": str(src)}) + result = await tool.execute({"path": str(src)}, ctx=ToolContext(session_id="sess 1")) assert not result.success assert result.error == "file_too_large" @@ -68,7 +64,7 @@ src = tmp_path / "source.txt" src.write_text("hello") - result = await tool.execute({"path": str(src), "filename": "отчёт #1.txt"}) + result = await tool.execute({"path": str(src), "filename": "отчёт #1.txt"}, ctx=ToolContext(session_id="sess 1")) assert result.success parsed = urlparse(result.metadata["url"]) @@ -84,7 +80,7 @@ existing.parent.mkdir(parents=True) existing.write_text("old") - result = await tool.execute({"path": str(src), "filename": "report.txt"}) + result = await tool.execute({"path": str(src), "filename": "report.txt"}, ctx=ToolContext(session_id="sess 1")) assert result.success assert existing.read_text() == "old" diff --git a/tests/unit/tools/test_spawn_agent.py b/tests/unit/tools/test_spawn_agent.py index 1c5fa70..b542312 100644 --- a/tests/unit/tools/test_spawn_agent.py +++ b/tests/unit/tools/test_spawn_agent.py @@ -1,7 +1,7 @@ import pytest from navi.core.session import InMemorySessionStore -from navi.tools._internal.base import current_session_id +from navi.tools._internal.base import ToolContext from navi.tools.spawn_agent import SpawnAgentTool from tests.conftest_factory import ( FakeLLMBackend, @@ -36,7 +36,7 @@ result = await tool.execute({ "task": "inspect code", "profile_id": "developer", - }) + }, ctx=ToolContext()) assert result.success is True assert captured["profile_id"] == "developer" @@ -47,7 +47,6 @@ async def test_spawn_agent_defaults_to_parent_profile(monkeypatch, spawn_tool): tool, _, store = spawn_tool session = await store.create("secretary") - token = current_session_id.set(session.id) captured = {} async def fake_run_ephemeral(self, **kwargs): @@ -56,10 +55,7 @@ monkeypatch.setattr("navi.core.agent.Agent.run_ephemeral", fake_run_ephemeral) - try: - result = await tool.execute({"task": "research this"}) - finally: - current_session_id.reset(token) + result = await tool.execute({"task": "research this"}, ctx=ToolContext(session_id=session.id)) assert result.success is True assert captured["profile_id"] == "secretary" @@ -72,7 +68,7 @@ result = await tool.execute({ "task": "do work", "profile_id": "missing_profile", - }) + }, ctx=ToolContext()) assert result.success is False assert result.error == "unknown_profile:missing_profile" diff --git a/tests/unit/tools/test_todo.py b/tests/unit/tools/test_todo.py index 79de193..c271f89 100644 --- a/tests/unit/tools/test_todo.py +++ b/tests/unit/tools/test_todo.py @@ -13,7 +13,7 @@ set_tasks, _kv_store, ) -from navi.tools._internal.base import current_session_id, current_user_id +from navi.tools._internal.base import ToolContext from navi.store import KvStore from tests.conftest_factory import FakeConnection, FakePool @@ -58,22 +58,13 @@ _mod._kv_store = None -@pytest.fixture(autouse=True) -def _fake_ctx(): - token = current_session_id.set("sess1") - token2 = current_user_id.set("user1") - yield - current_session_id.reset(token) - current_user_id.reset(token2) - - # ── TodoTool execute tests ──────────────────────────────────────────────────── @pytest.mark.asyncio async def test_set_tasks(_fake_kv): tool = TodoTool() - result = await tool.execute({"op": "set", "tasks": ["task A", "task B"]}) + result = await tool.execute({"op": "set", "tasks": ["task A", "task B"]}, ctx=ToolContext(session_id="sess1", user_id="user1")) assert result.success is True assert "task A" in result.output assert "task B" in result.output @@ -82,7 +73,7 @@ @pytest.mark.asyncio async def test_view_empty(_fake_kv): tool = TodoTool() - result = await tool.execute({"op": "view"}) + result = await tool.execute({"op": "view"}, ctx=ToolContext(session_id="sess1", user_id="user1")) assert result.success is True assert "No plan set" in result.output @@ -90,8 +81,8 @@ @pytest.mark.asyncio async def test_update_status(_fake_kv): tool = TodoTool() - await tool.execute({"op": "set", "tasks": ["task A"]}) - result = await tool.execute({"op": "update", "index": 1, "status": "done", "validation": "tested"}) + await tool.execute({"op": "set", "tasks": ["task A"]}, ctx=ToolContext(session_id="sess1", user_id="user1")) + result = await tool.execute({"op": "update", "index": 1, "status": "done", "validation": "tested"}, ctx=ToolContext(session_id="sess1", user_id="user1")) assert result.success is True assert "done" in result.output @@ -99,8 +90,8 @@ @pytest.mark.asyncio async def test_done_requires_validation(_fake_kv): tool = TodoTool() - await tool.execute({"op": "set", "tasks": ["task A"]}) - result = await tool.execute({"op": "update", "index": 1, "status": "done"}) + await tool.execute({"op": "set", "tasks": ["task A"]}, ctx=ToolContext(session_id="sess1", user_id="user1")) + result = await tool.execute({"op": "update", "index": 1, "status": "done"}, ctx=ToolContext(session_id="sess1", user_id="user1")) assert result.success is False assert "validation" in result.error.lower() @@ -108,8 +99,8 @@ @pytest.mark.asyncio async def test_failed_without_validation_warns(_fake_kv): tool = TodoTool() - await tool.execute({"op": "set", "tasks": ["task A"]}) - result = await tool.execute({"op": "update", "index": 1, "status": "failed"}) + await tool.execute({"op": "set", "tasks": ["task A"]}, ctx=ToolContext(session_id="sess1", user_id="user1")) + result = await tool.execute({"op": "update", "index": 1, "status": "failed"}, ctx=ToolContext(session_id="sess1", user_id="user1")) assert result.success is True assert "Tip" in result.output @@ -117,8 +108,8 @@ @pytest.mark.asyncio async def test_clear(_fake_kv): tool = TodoTool() - await tool.execute({"op": "set", "tasks": ["task A"]}) - result = await tool.execute({"op": "clear"}) + await tool.execute({"op": "set", "tasks": ["task A"]}, ctx=ToolContext(session_id="sess1", user_id="user1")) + result = await tool.execute({"op": "clear"}, ctx=ToolContext(session_id="sess1", user_id="user1")) assert result.success is True assert "cleared" in result.output.lower()