Newer
Older
navi-1 / navi / core / agent.py
"""
Agent: the tool-calling loop.

Flow:
1. Receive user message, load session + profile
2. Build tool schemas from profile's enabled_tools
3. Loop (up to max_iterations):
   a. Call LLM with session.context (may be compressed) + tool schemas
   b. If finish_reason == "stop"  -> stream final response
   c. If finish_reason == "tool_calls" -> execute tools, append to both
      session.messages (display history) and session.context (LLM context)
4. After StreamEnd: run workers sequentially (e.g. context compression)

session.messages — full display history, never compressed
session.context  — what the LLM sees; workers may compress this
"""

import asyncio
import json
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, AsyncGenerator

import structlog

from navi.config import settings
from navi.exceptions import MaxIterationsReached, SessionNotFound
from navi.llm.base import LLMBackend, Message, ToolCallRequest
from navi.tools.base import Tool

from .events import (
    AgentEvent,
    StreamEnd,
    TextDelta,
    ThinkingDelta,
    ThinkingEnd,
    ToolEvent,
)
from .registry import BackendRegistry, ProfileRegistry, ToolRegistry
from .session import SessionStore

if TYPE_CHECKING:
    from navi.workers.base import Worker, WorkerContext

_USER_ENABLED_FILE = Path(settings.tools_dir) / "enabled.json"


def _load_user_enabled_tools() -> list[str]:
    try:
        return json.loads(_USER_ENABLED_FILE.read_text())
    except Exception:
        return []


log = structlog.get_logger()


class Agent:
    def __init__(
        self,
        session_store: SessionStore,
        profile_registry: ProfileRegistry,
        tool_registry: ToolRegistry,
        backend_registry: BackendRegistry,
        workers: list["Worker"] | None = None,
    ) -> None:
        self._sessions = session_store
        self._profiles = profile_registry
        self._tools = tool_registry
        self._backends = backend_registry
        self._workers: list["Worker"] = workers or []

    # ------------------------------------------------------------------
    # Public interface
    # ------------------------------------------------------------------

    async def run(self, session_id: str, user_message: str, images: list[str] | None = None) -> str:
        """Non-streaming: run the full tool-calling loop and return the final text."""
        session = await self._sessions.get(session_id)
        if session is None:
            raise SessionNotFound(session_id)

        profile = self._profiles.get(session.profile_id)
        tools = self._tool_list(profile.enabled_tools)
        tool_schemas = [t.schema() for t in tools]
        llm = self._get_backend(profile.llm_backend)

        # System prompt only goes into context (not display history).
        # Use role check rather than empty check: backward-compat sessions have
        # context initialised from messages (which never contain a system message).
        if not any(m.role == "system" for m in session.context):
            session.context.insert(0, Message(
                role="system",
                content=self._build_system_prompt(profile.system_prompt),
            ))

        user_msg = Message(role="user", content=user_message, images=images or None,
                           created_at=datetime.now(timezone.utc))
        session.messages.append(user_msg)
        session.context.append(user_msg)

        for iteration in range(profile.max_iterations):
            log.debug("agent.iteration", session_id=session_id, iteration=iteration)
            response = await llm.complete(
                session.context,
                tools=tool_schemas if tools else None,
                temperature=profile.temperature,
                model=profile.model,
            )

            if response.finish_reason == "stop" or not response.tool_calls:
                content = response.content or ""
                assistant_msg = Message(role="assistant", content=content,
                                        created_at=datetime.now(timezone.utc))
                session.messages.append(assistant_msg)
                session.context.append(assistant_msg)
                await self._sessions.save(session)
                return content

            # Tool calls turn — append to both messages and context
            assistant_msg = Message(
                role="assistant",
                content=response.content,
                tool_calls=response.tool_calls,
            )
            session.messages.append(assistant_msg)
            session.context.append(assistant_msg)

            tool_results, image_injections = await self._execute_tool_calls(response.tool_calls, tools)
            session.messages.extend(tool_results)
            session.context.extend(tool_results)
            # Image injections are synthetic LLM helpers — context only
            session.context.extend(image_injections)

        await self._sessions.save(session)
        raise MaxIterationsReached(profile.max_iterations)

    async def run_stream(
        self, session_id: str, user_message: str, images: list[str] | None = None
    ) -> AsyncGenerator[AgentEvent, None]:
        """
        Streaming variant. Yields AgentEvent objects:
        - ThinkingDelta / ThinkingEnd: reasoning chunks
        - ToolEvent: tool call + result
        - TextDelta / StreamEnd: final streamed response
        - ContextCompressed: emitted by workers after compression
        """
        session = await self._sessions.get(session_id)
        if session is None:
            raise SessionNotFound(session_id)

        profile = self._profiles.get(session.profile_id)
        tools = self._tool_list(profile.enabled_tools)
        tool_schemas = [t.schema() for t in tools]
        llm = self._get_backend(profile.llm_backend)

        # System prompt only goes into context (not display history).
        # Use role check rather than empty check: backward-compat sessions have
        # context initialised from messages (which never contain a system message).
        if not any(m.role == "system" for m in session.context):
            session.context.insert(0, Message(
                role="system",
                content=self._build_system_prompt(profile.system_prompt),
            ))

        user_msg = Message(role="user", content=user_message, images=images or None,
                           created_at=datetime.now(timezone.utc))
        session.messages.append(user_msg)
        session.context.append(user_msg)

        # Tool-calling loop (non-streaming)
        for iteration in range(profile.max_iterations):
            response = await llm.complete(
                session.context,
                tools=tool_schemas if tools else None,
                temperature=profile.temperature,
                model=profile.model,
            )

            if response.finish_reason == "stop" or not response.tool_calls:
                # Stream the final response
                accumulated = ""
                thinking_active = False
                context_tokens: int | None = None

                async for chunk in llm.stream(
                    session.context.copy(), temperature=profile.temperature, model=profile.model
                ):
                    if chunk.prompt_tokens is not None or chunk.completion_tokens is not None:
                        context_tokens = (chunk.prompt_tokens or 0) + (chunk.completion_tokens or 0)
                    if chunk.thinking:
                        if not thinking_active:
                            thinking_active = True
                        yield ThinkingDelta(delta=chunk.thinking)
                    elif chunk.delta:
                        if thinking_active:
                            thinking_active = False
                            yield ThinkingEnd()
                        accumulated += chunk.delta
                        yield TextDelta(delta=chunk.delta)

                if thinking_active:
                    yield ThinkingEnd()

                assistant_msg = Message(role="assistant", content=accumulated,
                                        created_at=datetime.now(timezone.utc))
                session.messages.append(assistant_msg)
                session.context.append(assistant_msg)
                await self._sessions.save(session)

                yield StreamEnd(
                    full_content=accumulated,
                    context_tokens=context_tokens,
                    max_context_tokens=settings.ollama_num_ctx,
                )

                # Run post-response workers (e.g. context compression)
                for event in await self._run_workers(session, llm, profile.model, context_tokens):
                    yield event
                return

            # Tool calls: emit events, append to both messages and context
            assistant_msg = Message(
                role="assistant",
                content=response.content,
                tool_calls=response.tool_calls,
            )
            session.messages.append(assistant_msg)
            session.context.append(assistant_msg)

            tool_results_msgs, image_injections = await self._execute_tool_calls_streaming(
                response.tool_calls, tools
            )
            for event, msg in tool_results_msgs:
                yield event
                session.messages.append(msg)
                session.context.append(msg)
            # Image injections are synthetic — context only
            session.context.extend(image_injections)

        await self._sessions.save(session)
        raise MaxIterationsReached(profile.max_iterations)

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------

    async def _run_workers(
        self,
        session,
        llm: LLMBackend,
        model: str,
        context_tokens: int | None,
    ) -> list[AgentEvent]:
        """Run all workers sequentially; collect their events."""
        from navi.workers.base import WorkerContext

        ctx = WorkerContext(
            session_id=session.id,
            context_tokens=context_tokens,
            max_context_tokens=settings.ollama_num_ctx,
            llm=llm,
            model=model,
            temperature=settings.context_summary_temperature,
            session_store=self._sessions,
        )
        events: list[AgentEvent] = []
        for worker in self._workers:
            try:
                result = await worker.run(session, ctx)
                events.extend(result.events)
            except Exception:
                log.warning("agent.worker_failed",
                            worker=type(worker).__name__, exc_info=True)
        return events

    def _build_system_prompt(self, profile_prompt: str) -> str:
        persona = settings.navi_persona.strip()
        if persona:
            return f"{persona}\n\n---\n\n{profile_prompt}"
        return profile_prompt

    def _tool_list(self, enabled: list[str]) -> list[Tool]:
        names = list(enabled)
        extra = _load_user_enabled_tools()
        for name in extra:
            if name not in names:
                names.append(name)
        result = []
        for name in names:
            try:
                result.append(self._tools.get(name))
            except Exception:
                pass
        return result

    def _get_backend(self, backend_key: str) -> LLMBackend:
        return self._backends.get(backend_key)

    async def _execute_tool_calls(
        self, tool_calls: list[ToolCallRequest], tools: list[Tool]
    ) -> tuple[list[Message], list[Message]]:
        tool_map = {t.name: t for t in tools}

        async def _run_one(tc: ToolCallRequest) -> tuple[Message, Message | None]:
            tool = tool_map.get(tc.name)
            image_msg = None
            if tool is None:
                content = f"Error: tool '{tc.name}' not found."
            else:
                log.info("tool.execute", tool=tc.name, args=tc.arguments)
                result = await tool.execute(tc.arguments)
                content = result.to_message_content()
                if result.success and result.metadata and result.metadata.get("is_image"):
                    b64 = result.metadata.get("base64")
                    if b64:
                        image_msg = Message(
                            role="user",
                            content=f"[Image loaded via {tc.name} — analyse it]",
                            images=[b64],
                        )
            tool_msg = Message(role="tool", content=content, tool_call_id=tc.id, name=tc.name)
            return tool_msg, image_msg

        pairs = await asyncio.gather(*[_run_one(tc) for tc in tool_calls])
        tool_msgs = [p[0] for p in pairs]
        image_msgs = [p[1] for p in pairs if p[1] is not None]
        return tool_msgs, image_msgs

    async def _execute_tool_calls_streaming(
        self, tool_calls: list[ToolCallRequest], tools: list[Tool]
    ) -> tuple[list[tuple[ToolEvent, Message]], list[Message]]:
        tool_map = {t.name: t for t in tools}

        async def _run_one(tc: ToolCallRequest) -> tuple[ToolEvent, Message, Message | None]:
            tool = tool_map.get(tc.name)
            image_msg = None
            if tool is None:
                content = f"Error: tool '{tc.name}' not found."
                event = ToolEvent(
                    tool_name=tc.name, arguments=tc.arguments, result=content, success=False
                )
            else:
                log.info("tool.execute", tool=tc.name, args=tc.arguments)
                result = await tool.execute(tc.arguments)
                content = result.to_message_content()
                event = ToolEvent(
                    tool_name=tc.name,
                    arguments=tc.arguments,
                    result=content,
                    success=result.success,
                )
                if result.success and result.metadata and result.metadata.get("is_image"):
                    b64 = result.metadata.get("base64")
                    if b64:
                        image_msg = Message(
                            role="user",
                            content=f"[Image loaded via {tc.name} — analyse it]",
                            images=[b64],
                        )
            msg = Message(role="tool", content=content, tool_call_id=tc.id, name=tc.name)
            return event, msg, image_msg

        triples = await asyncio.gather(*[_run_one(tc) for tc in tool_calls])
        pairs = [(t[0], t[1]) for t in triples]
        image_msgs = [t[2] for t in triples if t[2] is not None]
        return pairs, image_msgs