Newer
Older
navi-1 / navi / core / agent.py
"""
Agent: the tool-calling loop.

Flow:
1. Receive user message, load session + profile
2. Build tool schemas from profile's enabled_tools
3. Loop (up to max_iterations):
   a. Call LLM with session.context (may be compressed) + tool schemas
   b. If finish_reason == "stop"  -> stream final response
   c. If finish_reason == "tool_calls" -> execute tools, append to both
      session.messages (display history) and session.context (LLM context)
4. After StreamEnd: run workers sequentially (e.g. context compression)

session.messages — full display history, never compressed
session.context  — what the LLM sees; workers may compress this
"""

import asyncio
import json
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, AsyncGenerator

import structlog

from navi.config import settings
from navi.exceptions import MaxIterationsReached, SessionNotFound
from navi.llm.base import LLMBackend, Message, ToolCallRequest
from navi.tools.base import Tool, current_event_sink

from .compressor import compress_context, should_compress
from .events import (
    AgentEvent,
    ContextCompressed,
    StreamEnd,
    TextDelta,
    ThinkingDelta,
    ThinkingEnd,
    ToolEvent,
    ToolStarted,
    TurnThinking,
)
from .registry import BackendRegistry, ProfileRegistry, ToolRegistry
from .session import SessionStore

if TYPE_CHECKING:
    from navi.memory.store import MemoryStore
    from navi.workers.base import Worker, WorkerContext

_USER_ENABLED_FILE = Path(settings.tools_dir) / "enabled.json"


def _load_user_enabled_tools() -> list[str]:
    try:
        return json.loads(_USER_ENABLED_FILE.read_text())
    except Exception:
        return []


log = structlog.get_logger()

# Sentinel: placed in the event sink by the tool wrapper to signal completion.
_TOOL_DONE = object()


class Agent:
    def __init__(
        self,
        session_store: "SessionStore | None",
        profile_registry: ProfileRegistry,
        tool_registry: ToolRegistry,
        backend_registry: BackendRegistry,
        workers: list["Worker"] | None = None,
        memory_store: "MemoryStore | None" = None,
    ) -> None:
        self._sessions = session_store
        self._profiles = profile_registry
        self._tools = tool_registry
        self._backends = backend_registry
        self._workers: list["Worker"] = workers or []
        self._memory = memory_store

    # ------------------------------------------------------------------
    # Public interface
    # ------------------------------------------------------------------

    async def run(self, session_id: str, user_message: str, images: list[str] | None = None) -> str:
        """Non-streaming: run the full tool-calling loop and return the final text."""
        session = await self._sessions.get(session_id)
        if session is None:
            raise SessionNotFound(session_id)

        profile = self._profiles.get(session.profile_id)
        tools = self._tool_list(profile.enabled_tools)
        tool_schemas = [t.schema() for t in tools]
        llm = self._get_backend(profile.llm_backend)

        # System prompt only goes into context (not display history).
        # Use role check rather than empty check: backward-compat sessions have
        # context initialised from messages (which never contain a system message).
        if not any(m.role == "system" for m in session.context):
            session.context.insert(0, Message(
                role="system",
                content=self._build_system_prompt(profile.system_prompt),
            ))

        mem = await self._memory_msg()

        # Expose session_id to tools (e.g. SSH connection pool) via ContextVar
        from navi.tools.base import current_session_id as _sid_var
        _sid_var.set(session_id)

        user_msg = Message(role="user", content=user_message, images=images or None,
                           created_at=datetime.now(timezone.utc))
        session.messages.append(user_msg)
        session.context.append(user_msg)
        await self._sessions.save(session)

        for iteration in range(profile.max_iterations):
            log.debug("agent.iteration", session_id=session_id, iteration=iteration)
            response = await llm.complete(
                self._with_memory(session.context, mem),
                tools=tool_schemas if tools else None,
                temperature=profile.temperature,
                model=profile.model,
            )

            if response.finish_reason == "stop" or not response.tool_calls:
                content = response.content or ""
                assistant_msg = Message(role="assistant", content=content,
                                        created_at=datetime.now(timezone.utc))
                session.messages.append(assistant_msg)
                session.context.append(assistant_msg)
                await self._sessions.save(session)
                return content

            # Tool calls turn — append to both messages and context
            assistant_msg = Message(
                role="assistant",
                content=response.content,
                tool_calls=response.tool_calls,
            )
            session.messages.append(assistant_msg)
            session.context.append(assistant_msg)

            tool_results, image_injections = await self._execute_tool_calls(response.tool_calls, tools)
            session.messages.extend(tool_results)
            session.context.extend(tool_results)
            # Image injections are synthetic LLM helpers — context only
            session.context.extend(image_injections)

        await self._sessions.save(session)
        raise MaxIterationsReached(profile.max_iterations)

    async def run_ephemeral(
        self,
        user_message: str,
        profile_id: str,
        max_iterations: int = 20,
        exclude_tools: list[str] | None = None,
    ) -> str:
        """
        Run a sub-agent loop without a persistent session.

        Intended for spawning from tools (e.g. SpawnAgentTool).
        No DB reads/writes — uses a temporary in-memory context.
        Tools listed in exclude_tools are stripped from the tool list
        (use this to prevent recursion: exclude 'spawn_agent').
        """
        profile = self._profiles.get(profile_id)
        exclude = set(exclude_tools or [])
        tools = [t for t in self._tool_list(profile.enabled_tools) if t.name not in exclude]
        tool_schemas = [t.schema() for t in tools]
        llm = self._get_backend(profile.llm_backend)

        mem = await self._memory_msg()

        context: list[Message] = [
            Message(role="system", content=self._build_system_prompt(profile.system_prompt))
        ]
        context.append(Message(role="user", content=user_message,
                               created_at=datetime.now(timezone.utc)))

        # Read the event sink set by the parent run_stream() for this tool call.
        # If None (e.g. called from run(), not run_stream()), events are silently dropped.
        sink = current_event_sink.get()

        log.info("agent.subagent.start", profile_id=profile_id, max_iterations=max_iterations)

        tool_map = {t.name: t for t in tools}

        for iteration in range(max_iterations):
            log.debug("agent.subagent.iteration", iteration=iteration)

            accumulated_text = ""
            accumulated_thinking = ""
            turn_tool_calls: list[ToolCallRequest] | None = None

            async for chunk in llm.stream_complete(
                self._with_memory(context, mem),
                tools=tool_schemas if tools else None,
                temperature=profile.temperature,
                model=profile.model,
            ):
                if chunk.thinking:
                    accumulated_thinking += chunk.thinking
                if chunk.delta:
                    accumulated_text += chunk.delta
                if chunk.tool_calls:
                    turn_tool_calls = chunk.tool_calls

            if not turn_tool_calls:
                log.info("agent.subagent.complete", iterations=iteration + 1,
                         result_len=len(accumulated_text))
                return accumulated_text

            # Emit accumulated thinking before tool calls
            if accumulated_thinking and sink is not None:
                log.debug("agent.subagent.turn_thinking", length=len(accumulated_thinking))
                await sink.put(TurnThinking(thinking=accumulated_thinking, is_subagent=True))

            context.append(Message(
                role="assistant",
                content=accumulated_text or None,
                tool_calls=turn_tool_calls,
            ))

            # Execute each tool call sequentially, emitting events to parent sink
            for tc in turn_tool_calls:
                if sink is not None:
                    await sink.put(ToolStarted(
                        tool_name=tc.name, arguments=tc.arguments, is_subagent=True
                    ))

                tool = tool_map.get(tc.name)
                image_msg = None
                if tool is None:
                    content = f"Error: tool '{tc.name}' not found."
                    success = False
                else:
                    log.info("tool.execute.subagent", tool=tc.name, args=tc.arguments)
                    result = await tool.execute(tc.arguments)
                    content = result.to_message_content()
                    success = result.success
                    if result.success and result.metadata and result.metadata.get("is_image"):
                        b64 = result.metadata.get("base64")
                        if b64:
                            image_msg = Message(
                                role="user",
                                content=f"[Image loaded via {tc.name} — analyse it]",
                                images=[b64],
                            )

                if sink is not None:
                    await sink.put(ToolEvent(
                        tool_name=tc.name, arguments=tc.arguments,
                        result=content, success=success, is_subagent=True,
                    ))

                context.append(Message(role="tool", content=content,
                                       tool_call_id=tc.id, name=tc.name))
                if image_msg:
                    context.append(image_msg)

        log.warning("agent.subagent.max_iterations", max_iterations=max_iterations)
        return "[Sub-agent reached iteration limit without a final answer]"

    async def run_stream(
        self, session_id: str, user_message: str, images: list[str] | None = None
    ) -> AsyncGenerator[AgentEvent, None]:
        """
        Streaming variant. Yields AgentEvent objects:
        - ThinkingDelta / ThinkingEnd: reasoning chunks
        - ToolEvent: tool call + result
        - TextDelta / StreamEnd: final streamed response
        - ContextCompressed: emitted by workers after compression
        """
        session = await self._sessions.get(session_id)
        if session is None:
            raise SessionNotFound(session_id)

        profile = self._profiles.get(session.profile_id)
        tools = self._tool_list(profile.enabled_tools)
        tool_schemas = [t.schema() for t in tools]
        llm = self._get_backend(profile.llm_backend)

        # System prompt only goes into context (not display history).
        # Use role check rather than empty check: backward-compat sessions have
        # context initialised from messages (which never contain a system message).
        if not any(m.role == "system" for m in session.context):
            session.context.insert(0, Message(
                role="system",
                content=self._build_system_prompt(profile.system_prompt),
            ))

        mem = await self._memory_msg()

        # Expose session_id to tools (e.g. SSH connection pool) via ContextVar
        from navi.tools.base import current_session_id as _sid_var
        _sid_token = _sid_var.set(session_id)

        # Pre-turn compression: if the last turn filled the context past the
        # threshold, compress NOW before calling the LLM.  This prevents the
        # model from seeing an over-full context and generating gibberish
        # (e.g. a "summary of the conversation" instead of a real answer).
        if (
            settings.context_compression_enabled
            and session.context_token_count > 0
            and should_compress(
                session.context_token_count,
                settings.ollama_num_ctx,
                settings.context_compression_threshold,
            )
        ):
            try:
                new_context = await compress_context(
                    context=session.context,
                    llm=llm,
                    model=profile.model,
                    temperature=settings.context_summary_temperature,
                    keep_recent=settings.context_keep_recent,
                )
                if new_context is not None:
                    count_before = len(session.context)
                    session.context = new_context
                    session.context_token_count = 0
                    log.info(
                        "agent.preturn_compress",
                        session_id=session_id,
                        before=count_before,
                        after=len(new_context),
                    )
            except Exception:
                log.warning("agent.preturn_compress_failed", session_id=session_id, exc_info=True)

        user_msg = Message(role="user", content=user_message, images=images or None,
                           created_at=datetime.now(timezone.utc))
        session.messages.append(user_msg)
        session.context.append(user_msg)
        # Persist user message immediately so it survives a client disconnect
        # before the assistant reply is ready.
        await self._sessions.save(session)

        # Tool-calling loop — uses stream_complete() for every turn so thinking
        # is captured in real-time via ThinkingDelta/ThinkingEnd events.
        for iteration in range(profile.max_iterations):
            accumulated_text = ""
            turn_tool_calls: list[ToolCallRequest] | None = None
            thinking_active = False
            context_tokens: int | None = None

            async for chunk in llm.stream_complete(
                self._with_memory(session.context, mem),
                tools=tool_schemas if tools else None,
                temperature=profile.temperature,
                model=profile.model,
            ):
                if chunk.prompt_tokens is not None or chunk.completion_tokens is not None:
                    context_tokens = (chunk.prompt_tokens or 0) + (chunk.completion_tokens or 0)
                if chunk.thinking:
                    if not thinking_active:
                        thinking_active = True
                    yield ThinkingDelta(delta=chunk.thinking)
                elif chunk.delta:
                    if thinking_active:
                        thinking_active = False
                        yield ThinkingEnd()
                    accumulated_text += chunk.delta
                    yield TextDelta(delta=chunk.delta)
                if chunk.tool_calls:
                    turn_tool_calls = chunk.tool_calls
                if chunk.finish_reason and thinking_active:
                    thinking_active = False
                    yield ThinkingEnd()

            if not turn_tool_calls:
                # Final response — text already streamed above
                assistant_msg = Message(role="assistant", content=accumulated_text,
                                        created_at=datetime.now(timezone.utc))
                session.messages.append(assistant_msg)
                session.context.append(assistant_msg)
                session.context_token_count = context_tokens or 0
                await self._sessions.save(session)

                yield StreamEnd(
                    full_content=accumulated_text,
                    context_tokens=context_tokens,
                    max_context_tokens=settings.ollama_num_ctx,
                )

                for event in await self._run_workers(session, llm, profile.model, context_tokens):
                    yield event
                return

            # Tool calls turn — record to session and execute
            assistant_msg = Message(
                role="assistant",
                content=accumulated_text or None,
                tool_calls=turn_tool_calls,
            )
            session.messages.append(assistant_msg)
            session.context.append(assistant_msg)

            tool_map = {t.name: t for t in tools}
            for tc in turn_tool_calls:
                # 1. Announce immediately so the UI shows a pending card
                yield ToolStarted(tool_name=tc.name, arguments=tc.arguments)

                # 2. Create a sink queue for sub-agent events from this tool call.
                #    create_task() snapshots the current ContextVar values, so the
                #    task will inherit current_event_sink = sink.
                sink: asyncio.Queue = asyncio.Queue()
                sink_token = current_event_sink.set(sink)
                result_holder: list = []

                async def _run_with_sentinel(_tc=tc, _holder=result_holder, _sink=sink):
                    try:
                        _holder.append(await self._run_single_tool(_tc, tool_map))
                    except BaseException as exc:
                        _holder.append(exc)
                    finally:
                        await _sink.put(_TOOL_DONE)

                asyncio.create_task(_run_with_sentinel())
                current_event_sink.reset(sink_token)  # outer ctx restored; task has its own copy

                # 3. Block on the sink until the sentinel arrives.
                #    Sub-agent ToolStarted/ToolEvent objects come through here in real time.
                while True:
                    item = await sink.get()
                    if item is _TOOL_DONE:
                        break
                    yield item

                # 4. Re-raise tool exception or unpack result
                r = result_holder[0] if result_holder else RuntimeError("tool task produced no result")
                if isinstance(r, BaseException):
                    raise r
                tool_event, msg, image_msg = r

                # 5. Yield the completed ToolEvent and record in session
                yield tool_event
                session.messages.append(msg)
                session.context.append(msg)
                if image_msg:
                    session.context.append(image_msg)

        await self._sessions.save(session)
        raise MaxIterationsReached(profile.max_iterations)

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------

    async def _run_workers(
        self,
        session,
        llm: LLMBackend,
        model: str,
        context_tokens: int | None,
    ) -> list[AgentEvent]:
        """Run all workers sequentially; collect their events."""
        from navi.workers.base import WorkerContext

        ctx = WorkerContext(
            session_id=session.id,
            context_tokens=context_tokens,
            max_context_tokens=settings.ollama_num_ctx,
            llm=llm,
            model=model,
            temperature=settings.context_summary_temperature,
            session_store=self._sessions,
        )
        events: list[AgentEvent] = []
        for worker in self._workers:
            try:
                result = await worker.run(session, ctx)
                events.extend(result.events)
            except Exception:
                log.warning("agent.worker_failed",
                            worker=type(worker).__name__, exc_info=True)
        return events

    async def _memory_msg(self) -> "Message | None":
        """Return an ephemeral system message with the user memory summary, or None."""
        if not self._memory:
            return None
        summary = await self._memory.get_summary()
        if not summary:
            return None
        return Message(role="system", content=f"## What I remember about the user\n\n{summary}")

    def _with_memory(self, ctx: list[Message], mem: "Message | None") -> list[Message]:
        """Inject memory message after the first system message without mutating ctx."""
        if mem is None:
            return ctx
        if ctx and ctx[0].role == "system":
            return [ctx[0], mem] + ctx[1:]
        return [mem] + ctx

    def _build_system_prompt(self, profile_prompt: str) -> str:
        persona = settings.navi_persona.strip()
        if persona:
            return f"{persona}\n\n---\n\n{profile_prompt}"
        return profile_prompt

    def _tool_list(self, enabled: list[str]) -> list[Tool]:
        names = list(enabled)
        extra = _load_user_enabled_tools()
        for name in extra:
            if name not in names:
                names.append(name)
        result = []
        for name in names:
            try:
                result.append(self._tools.get(name))
            except Exception:
                pass
        return result

    def _get_backend(self, backend_key: str) -> LLMBackend:
        return self._backends.get(backend_key)

    async def _run_single_tool(
        self,
        tc: ToolCallRequest,
        tool_map: dict[str, Tool],
    ) -> tuple[ToolEvent, Message, "Message | None"]:
        """Execute one tool call and return (ToolEvent, tool_msg, optional_image_msg).

        Called via asyncio.create_task() from run_stream() so that the parent
        generator can drain the event sink queue concurrently.
        """
        tool = tool_map.get(tc.name)
        image_msg = None
        if tool is None:
            content = f"Error: tool '{tc.name}' not found."
            event = ToolEvent(tool_name=tc.name, arguments=tc.arguments,
                              result=content, success=False)
        else:
            log.info("tool.execute", tool=tc.name, args=tc.arguments)
            result = await tool.execute(tc.arguments)
            content = result.to_message_content()
            event = ToolEvent(tool_name=tc.name, arguments=tc.arguments,
                              result=content, success=result.success)
            if result.success and result.metadata and result.metadata.get("is_image"):
                b64 = result.metadata.get("base64")
                if b64:
                    image_msg = Message(
                        role="user",
                        content=f"[Image loaded via {tc.name} — analyse it]",
                        images=[b64],
                    )
        msg = Message(role="tool", content=content, tool_call_id=tc.id, name=tc.name)
        return event, msg, image_msg

    async def _execute_tool_calls(
        self, tool_calls: list[ToolCallRequest], tools: list[Tool]
    ) -> tuple[list[Message], list[Message]]:
        tool_map = {t.name: t for t in tools}

        async def _run_one(tc: ToolCallRequest) -> tuple[Message, Message | None]:
            tool = tool_map.get(tc.name)
            image_msg = None
            if tool is None:
                content = f"Error: tool '{tc.name}' not found."
            else:
                log.info("tool.execute", tool=tc.name, args=tc.arguments)
                result = await tool.execute(tc.arguments)
                content = result.to_message_content()
                if result.success and result.metadata and result.metadata.get("is_image"):
                    b64 = result.metadata.get("base64")
                    if b64:
                        image_msg = Message(
                            role="user",
                            content=f"[Image loaded via {tc.name} — analyse it]",
                            images=[b64],
                        )
            tool_msg = Message(role="tool", content=content, tool_call_id=tc.id, name=tc.name)
            return tool_msg, image_msg

        pairs = await asyncio.gather(*[_run_one(tc) for tc in tool_calls])
        tool_msgs = [p[0] for p in pairs]
        image_msgs = [p[1] for p in pairs if p[1] is not None]
        return tool_msgs, image_msgs

    async def _execute_tool_calls_streaming(
        self, tool_calls: list[ToolCallRequest], tools: list[Tool]
    ) -> tuple[list[tuple[ToolEvent, Message]], list[Message]]:
        tool_map = {t.name: t for t in tools}

        async def _run_one(tc: ToolCallRequest) -> tuple[ToolEvent, Message, Message | None]:
            tool = tool_map.get(tc.name)
            image_msg = None
            if tool is None:
                content = f"Error: tool '{tc.name}' not found."
                event = ToolEvent(
                    tool_name=tc.name, arguments=tc.arguments, result=content, success=False
                )
            else:
                log.info("tool.execute", tool=tc.name, args=tc.arguments)
                result = await tool.execute(tc.arguments)
                content = result.to_message_content()
                event = ToolEvent(
                    tool_name=tc.name,
                    arguments=tc.arguments,
                    result=content,
                    success=result.success,
                )
                if result.success and result.metadata and result.metadata.get("is_image"):
                    b64 = result.metadata.get("base64")
                    if b64:
                        image_msg = Message(
                            role="user",
                            content=f"[Image loaded via {tc.name} — analyse it]",
                            images=[b64],
                        )
            msg = Message(role="tool", content=content, tool_call_id=tc.id, name=tc.name)
            return event, msg, image_msg

        triples = await asyncio.gather(*[_run_one(tc) for tc in tool_calls])
        pairs = [(t[0], t[1]) for t in triples]
        image_msgs = [t[2] for t in triples if t[2] is not None]
        return pairs, image_msgs