Newer
Older
navi-1 / navi / core / agent.py
"""
Agent: the tool-calling loop.

Flow:
1. Receive user message, load session + profile
2. Build tool schemas from profile's enabled_tools
3. Loop (up to max_iterations):
   a. Call LLM with session.context (may be compressed) + tool schemas
   b. If finish_reason == "stop"  -> stream final response
   c. If finish_reason == "tool_calls" -> execute tools, append to both
      session.messages (display history) and session.context (LLM context)
4. After StreamEnd: run workers sequentially (e.g. context compression)

session.messages — full display history, never compressed
session.context  — what the LLM sees; workers may compress this
"""

import asyncio
import json
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, AsyncGenerator

import structlog

from navi.config import settings
from navi.exceptions import MaxIterationsReached, SessionNotFound
from navi.llm.base import LLMBackend, Message, ToolCallRequest
from navi.tools.base import Tool, current_event_sink

from .compressor import compress_context, should_compress
from .events import (
    AgentEvent,
    ContextCompressed,
    StreamEnd,
    TextDelta,
    ThinkingDelta,
    ThinkingEnd,
    ToolEvent,
    ToolStarted,
    TurnThinking,
)
from .registry import BackendRegistry, ProfileRegistry, ToolRegistry
from .session import SessionStore

if TYPE_CHECKING:
    from navi.memory.store import MemoryStore
    from navi.workers.base import Worker, WorkerContext

_USER_ENABLED_FILE = Path(settings.tools_dir) / "enabled.json"


def _load_user_enabled_tools() -> list[str]:
    try:
        return json.loads(_USER_ENABLED_FILE.read_text())
    except Exception:
        return []


log = structlog.get_logger()

# Sentinel: placed in the event sink by the tool wrapper to signal completion.
_TOOL_DONE = object()


class Agent:
    def __init__(
        self,
        session_store: "SessionStore | None",
        profile_registry: ProfileRegistry,
        tool_registry: ToolRegistry,
        backend_registry: BackendRegistry,
        workers: list["Worker"] | None = None,
        memory_store: "MemoryStore | None" = None,
    ) -> None:
        self._sessions = session_store
        self._profiles = profile_registry
        self._tools = tool_registry
        self._backends = backend_registry
        self._workers: list["Worker"] = workers or []
        self._memory = memory_store

    # ------------------------------------------------------------------
    # Public interface
    # ------------------------------------------------------------------

    async def run(self, session_id: str, user_message: str, images: list[str] | None = None) -> str:
        """Non-streaming: run the full tool-calling loop and return the final text."""
        session = await self._sessions.get(session_id)
        if session is None:
            raise SessionNotFound(session_id)

        profile = self._profiles.get(session.profile_id)
        tools = self._tool_list(profile.enabled_tools)
        tool_schemas = [t.schema() for t in tools]
        llm = self._get_backend(profile.llm_backend)

        mem = await self._memory_msg()

        # Expose session_id to tools (e.g. SSH connection pool) via ContextVar
        from navi.tools.base import current_session_id as _sid_var
        _sid_var.set(session_id)

        user_msg = Message(role="user", content=user_message, images=images or None,
                           created_at=datetime.now(timezone.utc))
        session.messages.append(user_msg)
        session.context.append(user_msg)
        await self._sessions.save(session)

        for iteration in range(profile.max_iterations):
            log.debug("agent.iteration", session_id=session_id, iteration=iteration)
            response = await llm.complete(
                self._build_context(session.context, profile, mem),
                tools=tool_schemas if tools else None,
                temperature=profile.temperature,
                model=profile.model,
            )

            if response.finish_reason == "stop" or not response.tool_calls:
                content = response.content or ""
                assistant_msg = Message(role="assistant", content=content,
                                        created_at=datetime.now(timezone.utc))
                session.messages.append(assistant_msg)
                session.context.append(assistant_msg)
                await self._sessions.save(session)
                return content

            # Tool calls turn — append to both messages and context
            assistant_msg = Message(
                role="assistant",
                content=response.content,
                tool_calls=response.tool_calls,
            )
            session.messages.append(assistant_msg)
            session.context.append(assistant_msg)

            tool_results, image_injections = await self._execute_tool_calls(response.tool_calls, tools)
            session.messages.extend(tool_results)
            session.context.extend(tool_results)
            # Image injections are synthetic LLM helpers — context only
            session.context.extend(image_injections)

        await self._sessions.save(session)
        raise MaxIterationsReached(profile.max_iterations)

    async def run_ephemeral(
        self,
        user_message: str,
        profile_id: str,
        max_iterations: int = 20,
        exclude_tools: list[str] | None = None,
    ) -> str:
        """
        Run a sub-agent loop without a persistent session.

        Intended for spawning from tools (e.g. SpawnAgentTool).
        No DB reads/writes — uses a temporary in-memory context.
        Tools listed in exclude_tools are stripped from the tool list
        (use this to prevent recursion: exclude 'spawn_agent').
        """
        profile = self._profiles.get(profile_id)
        exclude = set(exclude_tools or [])
        tools = [t for t in self._tool_list(profile.enabled_tools) if t.name not in exclude]
        tool_schemas = [t.schema() for t in tools]
        llm = self._get_backend(profile.llm_backend)

        mem = await self._memory_msg()

        # Sub-agent context: only user/assistant/tool messages — system is injected dynamically.
        context: list[Message] = [
            Message(role="user", content=user_message, created_at=datetime.now(timezone.utc))
        ]

        # Read the event sink set by the parent run_stream() for this tool call.
        # If None (e.g. called from run(), not run_stream()), events are silently dropped.
        sink = current_event_sink.get()

        log.info("agent.subagent.start", profile_id=profile_id, max_iterations=max_iterations)

        tool_map = {t.name: t for t in tools}

        for iteration in range(max_iterations):
            log.debug("agent.subagent.iteration", iteration=iteration)

            accumulated_text = ""
            accumulated_thinking = ""
            turn_tool_calls: list[ToolCallRequest] | None = None

            async for chunk in llm.stream_complete(
                self._build_context(context, profile, mem),
                tools=tool_schemas if tools else None,
                temperature=profile.temperature,
                model=profile.model,
            ):
                if chunk.thinking:
                    accumulated_thinking += chunk.thinking
                if chunk.delta:
                    accumulated_text += chunk.delta
                if chunk.tool_calls:
                    turn_tool_calls = chunk.tool_calls

            if not turn_tool_calls:
                log.info("agent.subagent.complete", iterations=iteration + 1,
                         result_len=len(accumulated_text))
                return accumulated_text

            # Emit accumulated thinking before tool calls
            if accumulated_thinking and sink is not None:
                log.debug("agent.subagent.turn_thinking", length=len(accumulated_thinking))
                await sink.put(TurnThinking(thinking=accumulated_thinking, is_subagent=True))

            context.append(Message(
                role="assistant",
                content=accumulated_text or None,
                tool_calls=turn_tool_calls,
            ))

            # Execute each tool call sequentially, emitting events to parent sink
            for tc in turn_tool_calls:
                if sink is not None:
                    await sink.put(ToolStarted(
                        tool_name=tc.name, arguments=tc.arguments, is_subagent=True
                    ))

                tool = tool_map.get(tc.name)
                image_msg = None
                if tool is None:
                    content = f"Error: tool '{tc.name}' not found."
                    success = False
                else:
                    log.info("tool.execute.subagent", tool=tc.name, args=tc.arguments)
                    result = await tool.execute(tc.arguments)
                    content = result.to_message_content()
                    success = result.success
                    if result.success and result.metadata and result.metadata.get("is_image"):
                        b64 = result.metadata.get("base64")
                        if b64:
                            image_msg = Message(
                                role="user",
                                content=f"[Image loaded via {tc.name} — analyse it]",
                                images=[b64],
                            )

                if sink is not None:
                    await sink.put(ToolEvent(
                        tool_name=tc.name, arguments=tc.arguments,
                        result=content, success=success, is_subagent=True,
                    ))

                context.append(Message(role="tool", content=content,
                                       tool_call_id=tc.id, name=tc.name))
                if image_msg:
                    context.append(image_msg)

        log.warning("agent.subagent.max_iterations", max_iterations=max_iterations)
        return "[Sub-agent reached iteration limit without a final answer]"

    async def run_stream(
        self, session_id: str, user_message: str, images: list[str] | None = None
    ) -> AsyncGenerator[AgentEvent, None]:
        """
        Streaming variant. Yields AgentEvent objects:
        - ThinkingDelta / ThinkingEnd: reasoning chunks
        - ToolEvent: tool call + result
        - TextDelta / StreamEnd: final streamed response
        - ContextCompressed: emitted by workers after compression
        """
        session = await self._sessions.get(session_id)
        if session is None:
            raise SessionNotFound(session_id)

        profile = self._profiles.get(session.profile_id)
        tools = self._tool_list(profile.enabled_tools)
        tool_schemas = [t.schema() for t in tools]
        llm = self._get_backend(profile.llm_backend)

        mem = await self._memory_msg()

        # Expose session_id to tools (e.g. SSH connection pool) via ContextVar
        from navi.tools.base import current_session_id as _sid_var
        _sid_token = _sid_var.set(session_id)

        # Pre-turn compression: if the last turn filled the context past the
        # threshold, compress NOW before calling the LLM.  This prevents the
        # model from seeing an over-full context and generating gibberish
        # (e.g. a "summary of the conversation" instead of a real answer).
        if (
            settings.context_compression_enabled
            and session.context_token_count > 0
            and should_compress(
                session.context_token_count,
                settings.ollama_num_ctx,
                settings.context_compression_threshold,
            )
        ):
            try:
                new_context = await compress_context(
                    context=session.context,
                    llm=llm,
                    model=profile.model,
                    temperature=settings.context_summary_temperature,
                    keep_recent=settings.context_keep_recent,
                )
                if new_context is not None:
                    count_before = len(session.context)
                    session.context = new_context
                    session.context_token_count = 0
                    log.info(
                        "agent.preturn_compress",
                        session_id=session_id,
                        before=count_before,
                        after=len(new_context),
                    )
            except Exception:
                log.warning("agent.preturn_compress_failed", session_id=session_id, exc_info=True)

        user_msg = Message(role="user", content=user_message, images=images or None,
                           created_at=datetime.now(timezone.utc))
        session.messages.append(user_msg)
        session.context.append(user_msg)
        # Persist user message immediately so it survives a client disconnect
        # before the assistant reply is ready.
        await self._sessions.save(session)

        # Tool-calling loop — uses stream_complete() for every turn so thinking
        # is captured in real-time via ThinkingDelta/ThinkingEnd events.
        for iteration in range(profile.max_iterations):
            accumulated_text = ""
            turn_tool_calls: list[ToolCallRequest] | None = None
            thinking_active = False
            context_tokens: int | None = None

            async for chunk in llm.stream_complete(
                self._build_context(session.context, profile, mem),
                tools=tool_schemas if tools else None,
                temperature=profile.temperature,
                model=profile.model,
            ):
                if chunk.prompt_tokens is not None or chunk.completion_tokens is not None:
                    context_tokens = (chunk.prompt_tokens or 0) + (chunk.completion_tokens or 0)
                if chunk.thinking:
                    if not thinking_active:
                        thinking_active = True
                    yield ThinkingDelta(delta=chunk.thinking)
                elif chunk.delta:
                    if thinking_active:
                        thinking_active = False
                        yield ThinkingEnd()
                    accumulated_text += chunk.delta
                    yield TextDelta(delta=chunk.delta)
                if chunk.tool_calls:
                    turn_tool_calls = chunk.tool_calls
                if chunk.finish_reason and thinking_active:
                    thinking_active = False
                    yield ThinkingEnd()

            if not turn_tool_calls:
                # Final response — text already streamed above
                assistant_msg = Message(role="assistant", content=accumulated_text,
                                        created_at=datetime.now(timezone.utc))
                session.messages.append(assistant_msg)
                session.context.append(assistant_msg)
                session.context_token_count = context_tokens or 0
                await self._sessions.save(session)

                yield StreamEnd(
                    full_content=accumulated_text,
                    context_tokens=context_tokens,
                    max_context_tokens=settings.ollama_num_ctx,
                )

                for event in await self._run_workers(session, llm, profile.model, context_tokens):
                    yield event
                return

            # Tool calls turn — record to session and execute
            assistant_msg = Message(
                role="assistant",
                content=accumulated_text or None,
                tool_calls=turn_tool_calls,
            )
            session.messages.append(assistant_msg)
            session.context.append(assistant_msg)

            tool_map = {t.name: t for t in tools}
            for tc in turn_tool_calls:
                # 1. Announce immediately so the UI shows a pending card
                yield ToolStarted(tool_name=tc.name, arguments=tc.arguments)

                # 2. Create a sink queue for sub-agent events from this tool call.
                #    create_task() snapshots the current ContextVar values, so the
                #    task will inherit current_event_sink = sink.
                sink: asyncio.Queue = asyncio.Queue()
                sink_token = current_event_sink.set(sink)
                result_holder: list = []

                async def _run_with_sentinel(_tc=tc, _holder=result_holder, _sink=sink):
                    try:
                        _holder.append(await self._run_single_tool(_tc, tool_map))
                    except BaseException as exc:
                        _holder.append(exc)
                    finally:
                        await _sink.put(_TOOL_DONE)

                asyncio.create_task(_run_with_sentinel())
                current_event_sink.reset(sink_token)  # outer ctx restored; task has its own copy

                # 3. Block on the sink until the sentinel arrives.
                #    Sub-agent ToolStarted/ToolEvent objects come through here in real time.
                while True:
                    item = await sink.get()
                    if item is _TOOL_DONE:
                        break
                    yield item

                # 4. Re-raise tool exception or unpack result
                r = result_holder[0] if result_holder else RuntimeError("tool task produced no result")
                if isinstance(r, BaseException):
                    raise r
                tool_event, msg, image_msg = r

                # 5. Yield the completed ToolEvent and record in session
                yield tool_event
                session.messages.append(msg)
                session.context.append(msg)
                if image_msg:
                    session.context.append(image_msg)

            # 6. If switch_profile was called this iteration, reload profile + tools.
            #    switch_profile updates the DB but run_stream() holds a local session
            #    object — without this check the final save would overwrite the change
            #    and the next LLM call would still use the old tool schemas.
            fresh = await self._sessions.get(session_id)
            if fresh and fresh.profile_id != session.profile_id:
                session.profile_id = fresh.profile_id
                profile = self._profiles.get(session.profile_id)
                tools = self._tool_list(profile.enabled_tools)
                tool_schemas = [t.schema() for t in tools]
                llm = self._get_backend(profile.llm_backend)
                log.info(
                    "agent.profile_reloaded",
                    session_id=session_id,
                    new_profile=session.profile_id,
                )

        await self._sessions.save(session)
        raise MaxIterationsReached(profile.max_iterations)

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------

    async def _run_workers(
        self,
        session,
        llm: LLMBackend,
        model: str,
        context_tokens: int | None,
    ) -> list[AgentEvent]:
        """Run all workers sequentially; collect their events."""
        from navi.workers.base import WorkerContext

        ctx = WorkerContext(
            session_id=session.id,
            context_tokens=context_tokens,
            max_context_tokens=settings.ollama_num_ctx,
            llm=llm,
            model=model,
            temperature=settings.context_summary_temperature,
            session_store=self._sessions,
        )
        events: list[AgentEvent] = []
        for worker in self._workers:
            try:
                result = await worker.run(session, ctx)
                events.extend(result.events)
            except Exception:
                log.warning("agent.worker_failed",
                            worker=type(worker).__name__, exc_info=True)
        return events

    async def _memory_msg(self) -> "Message | None":
        """Return an ephemeral system message with the user memory summary, or None."""
        if not self._memory:
            return None
        summary = await self._memory.get_summary()
        if not summary:
            return None
        return Message(role="system", content=f"## What I remember about the user\n\n{summary}")

    def _build_context(
        self,
        session_context: list[Message],
        profile: "AgentProfile",
        mem: "Message | None",
    ) -> list[Message]:
        """Build the full LLM context for one call.

        System prompt is injected fresh from the current profile every time —
        it is NOT stored in session.context so that profile switches take
        effect immediately without touching stored history.
        Memory (if any) is placed right after the system message.
        Any system messages already in session.context are stripped (migration safety).
        """
        system_msg = Message(
            role="system",
            content=self._build_system_prompt(profile.system_prompt),
        )
        conv = [m for m in session_context if m.role != "system"]
        result: list[Message] = [system_msg]
        if mem:
            result.append(mem)
        result.extend(conv)
        return result

    def _build_system_prompt(self, profile_prompt: str) -> str:
        persona = settings.navi_persona.strip()
        if persona:
            return f"{persona}\n\n---\n\n{profile_prompt}"
        return profile_prompt

    def _tool_list(self, enabled: list[str]) -> list[Tool]:
        names = list(enabled)
        extra = _load_user_enabled_tools()
        for name in extra:
            if name not in names:
                names.append(name)
        result = []
        for name in names:
            try:
                result.append(self._tools.get(name))
            except Exception:
                pass
        return result

    def _get_backend(self, backend_key: str) -> LLMBackend:
        return self._backends.get(backend_key)

    async def _run_single_tool(
        self,
        tc: ToolCallRequest,
        tool_map: dict[str, Tool],
    ) -> tuple[ToolEvent, Message, "Message | None"]:
        """Execute one tool call and return (ToolEvent, tool_msg, optional_image_msg).

        Called via asyncio.create_task() from run_stream() so that the parent
        generator can drain the event sink queue concurrently.
        """
        tool = tool_map.get(tc.name)
        image_msg = None
        if tool is None:
            content = f"Error: tool '{tc.name}' not found."
            event = ToolEvent(tool_name=tc.name, arguments=tc.arguments,
                              result=content, success=False)
        else:
            log.info("tool.execute", tool=tc.name, args=tc.arguments)
            result = await tool.execute(tc.arguments)
            content = result.to_message_content()
            event = ToolEvent(tool_name=tc.name, arguments=tc.arguments,
                              result=content, success=result.success)
            if result.success and result.metadata and result.metadata.get("is_image"):
                b64 = result.metadata.get("base64")
                if b64:
                    image_msg = Message(
                        role="user",
                        content=f"[Image loaded via {tc.name} — analyse it]",
                        images=[b64],
                    )
        msg = Message(role="tool", content=content, tool_call_id=tc.id, name=tc.name)
        return event, msg, image_msg

    async def _execute_tool_calls(
        self, tool_calls: list[ToolCallRequest], tools: list[Tool]
    ) -> tuple[list[Message], list[Message]]:
        tool_map = {t.name: t for t in tools}

        async def _run_one(tc: ToolCallRequest) -> tuple[Message, Message | None]:
            tool = tool_map.get(tc.name)
            image_msg = None
            if tool is None:
                content = f"Error: tool '{tc.name}' not found."
            else:
                log.info("tool.execute", tool=tc.name, args=tc.arguments)
                result = await tool.execute(tc.arguments)
                content = result.to_message_content()
                if result.success and result.metadata and result.metadata.get("is_image"):
                    b64 = result.metadata.get("base64")
                    if b64:
                        image_msg = Message(
                            role="user",
                            content=f"[Image loaded via {tc.name} — analyse it]",
                            images=[b64],
                        )
            tool_msg = Message(role="tool", content=content, tool_call_id=tc.id, name=tc.name)
            return tool_msg, image_msg

        pairs = await asyncio.gather(*[_run_one(tc) for tc in tool_calls])
        tool_msgs = [p[0] for p in pairs]
        image_msgs = [p[1] for p in pairs if p[1] is not None]
        return tool_msgs, image_msgs

    async def _execute_tool_calls_streaming(
        self, tool_calls: list[ToolCallRequest], tools: list[Tool]
    ) -> tuple[list[tuple[ToolEvent, Message]], list[Message]]:
        tool_map = {t.name: t for t in tools}

        async def _run_one(tc: ToolCallRequest) -> tuple[ToolEvent, Message, Message | None]:
            tool = tool_map.get(tc.name)
            image_msg = None
            if tool is None:
                content = f"Error: tool '{tc.name}' not found."
                event = ToolEvent(
                    tool_name=tc.name, arguments=tc.arguments, result=content, success=False
                )
            else:
                log.info("tool.execute", tool=tc.name, args=tc.arguments)
                result = await tool.execute(tc.arguments)
                content = result.to_message_content()
                event = ToolEvent(
                    tool_name=tc.name,
                    arguments=tc.arguments,
                    result=content,
                    success=result.success,
                )
                if result.success and result.metadata and result.metadata.get("is_image"):
                    b64 = result.metadata.get("base64")
                    if b64:
                        image_msg = Message(
                            role="user",
                            content=f"[Image loaded via {tc.name} — analyse it]",
                            images=[b64],
                        )
            msg = Message(role="tool", content=content, tool_call_id=tc.id, name=tc.name)
            return event, msg, image_msg

        triples = await asyncio.gather(*[_run_one(tc) for tc in tool_calls])
        pairs = [(t[0], t[1]) for t in triples]
        image_msgs = [t[2] for t in triples if t[2] is not None]
        return pairs, image_msgs