Newer
Older
navi-1 / navi / core / agent.py
"""
Agent: the tool-calling loop.

Flow:
1. Receive user message, load session + profile
2. Build tool schemas from profile's enabled_tools
3. Loop (up to max_iterations):
   a. Call LLM with current messages + tool schemas
   b. If finish_reason == "stop"  -> done, return content
   c. If finish_reason == "tool_calls" -> execute tools concurrently, append results, continue
4. Final streaming path: use llm.stream() to yield text deltas to WebSocket clients

For multi-agent extension: instantiate multiple Agent objects with different profiles.
An Orchestrator (core/orchestrator.py) dispatches tasks to worker agents via asyncio Queues.
"""

import asyncio
import json
from dataclasses import dataclass
from typing import AsyncGenerator

import structlog

from navi.exceptions import MaxIterationsReached, SessionNotFound
from navi.llm.base import LLMBackend, Message, ToolCallRequest
from navi.tools.base import Tool

from .registry import BackendRegistry, ProfileRegistry, ToolRegistry
from .session import SessionStore

log = structlog.get_logger()


@dataclass
class ToolEvent:
    """Emitted during streaming to inform the client about tool activity."""

    tool_name: str
    arguments: dict
    result: str
    success: bool


@dataclass
class TextDelta:
    """A chunk of text from the streaming LLM response."""

    delta: str


@dataclass
class StreamEnd:
    """Marks the end of the streaming response."""

    full_content: str


AgentEvent = ToolEvent | TextDelta | StreamEnd


class Agent:
    def __init__(
        self,
        session_store: SessionStore,
        profile_registry: ProfileRegistry,
        tool_registry: ToolRegistry,
        backend_registry: BackendRegistry,
    ) -> None:
        self._sessions = session_store
        self._profiles = profile_registry
        self._tools = tool_registry
        self._backends = backend_registry

    # ------------------------------------------------------------------
    # Public interface
    # ------------------------------------------------------------------

    async def run(self, session_id: str, user_message: str) -> str:
        """Non-streaming: run the full tool-calling loop and return the final text."""
        session = await self._sessions.get(session_id)
        if session is None:
            raise SessionNotFound(session_id)

        profile = self._profiles.get(session.profile_id)
        tools = self._tool_list(profile.enabled_tools)
        tool_schemas = [t.schema() for t in tools]
        llm = self._get_backend(profile.llm_backend, profile.model)

        # Inject system prompt on first message
        if not session.messages:
            session.messages.append(Message(role="system", content=profile.system_prompt))

        session.messages.append(Message(role="user", content=user_message))

        for iteration in range(profile.max_iterations):
            log.debug("agent.iteration", session_id=session_id, iteration=iteration)
            response = await llm.complete(
                session.messages,
                tools=tool_schemas if tools else None,
                temperature=profile.temperature,
            )

            if response.finish_reason == "stop" or not response.tool_calls:
                content = response.content or ""
                session.messages.append(Message(role="assistant", content=content))
                await self._sessions.save(session)
                return content

            # Tool calls turn
            assistant_msg = Message(
                role="assistant",
                content=response.content,
                tool_calls=response.tool_calls,
            )
            session.messages.append(assistant_msg)

            tool_results = await self._execute_tool_calls(response.tool_calls, tools)
            session.messages.extend(tool_results)

        await self._sessions.save(session)
        raise MaxIterationsReached(profile.max_iterations)

    async def run_stream(
        self, session_id: str, user_message: str
    ) -> AsyncGenerator[AgentEvent, None]:
        """
        Streaming variant. Yields AgentEvent objects:
        - ToolEvent: when a tool is called and its result arrives
        - TextDelta: each text chunk from the final LLM response
        - StreamEnd: final event with the full assembled content
        """
        session = await self._sessions.get(session_id)
        if session is None:
            raise SessionNotFound(session_id)

        profile = self._profiles.get(session.profile_id)
        tools = self._tool_list(profile.enabled_tools)
        tool_schemas = [t.schema() for t in tools]
        llm = self._get_backend(profile.llm_backend, profile.model)

        if not session.messages:
            session.messages.append(Message(role="system", content=profile.system_prompt))

        session.messages.append(Message(role="user", content=user_message))

        # Tool-calling loop (non-streaming)
        for iteration in range(profile.max_iterations):
            response = await llm.complete(
                session.messages,
                tools=tool_schemas if tools else None,
                temperature=profile.temperature,
            )

            if response.finish_reason == "stop" or not response.tool_calls:
                # Switch to streaming for the final text response
                # Re-use the already-received content, stream it as one delta
                final_messages = session.messages.copy()
                accumulated = ""

                async for chunk in llm.stream(final_messages, temperature=profile.temperature):
                    if chunk.delta:
                        accumulated += chunk.delta
                        yield TextDelta(delta=chunk.delta)

                session.messages.append(Message(role="assistant", content=accumulated))
                await self._sessions.save(session)
                yield StreamEnd(full_content=accumulated)
                return

            # Tool calls: emit events, execute, continue loop
            assistant_msg = Message(
                role="assistant",
                content=response.content,
                tool_calls=response.tool_calls,
            )
            session.messages.append(assistant_msg)

            tool_results_msgs = await self._execute_tool_calls_streaming(
                response.tool_calls, tools
            )
            for event, msg in tool_results_msgs:
                yield event
                session.messages.append(msg)

        await self._sessions.save(session)
        raise MaxIterationsReached(profile.max_iterations)

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------

    def _tool_list(self, enabled: list[str]) -> list[Tool]:
        return self._tools.resolve(enabled)

    def _get_backend(self, backend_key: str, model: str) -> LLMBackend:
        return self._backends.get(backend_key, model)

    async def _execute_tool_calls(
        self, tool_calls: list[ToolCallRequest], tools: list[Tool]
    ) -> list[Message]:
        tool_map = {t.name: t for t in tools}

        async def _run_one(tc: ToolCallRequest) -> Message:
            tool = tool_map.get(tc.name)
            if tool is None:
                content = f"Error: tool '{tc.name}' not found."
            else:
                log.info("tool.execute", tool=tc.name, args=tc.arguments)
                result = await tool.execute(tc.arguments)
                content = result.to_message_content()
            return Message(
                role="tool",
                content=content,
                tool_call_id=tc.id,
                name=tc.name,
            )

        return await asyncio.gather(*[_run_one(tc) for tc in tool_calls])

    async def _execute_tool_calls_streaming(
        self, tool_calls: list[ToolCallRequest], tools: list[Tool]
    ) -> list[tuple[ToolEvent, Message]]:
        tool_map = {t.name: t for t in tools}

        async def _run_one(tc: ToolCallRequest) -> tuple[ToolEvent, Message]:
            tool = tool_map.get(tc.name)
            if tool is None:
                content = f"Error: tool '{tc.name}' not found."
                event = ToolEvent(
                    tool_name=tc.name, arguments=tc.arguments, result=content, success=False
                )
            else:
                log.info("tool.execute", tool=tc.name, args=tc.arguments)
                result = await tool.execute(tc.arguments)
                content = result.to_message_content()
                event = ToolEvent(
                    tool_name=tc.name,
                    arguments=tc.arguments,
                    result=content,
                    success=result.success,
                )
            msg = Message(role="tool", content=content, tool_call_id=tc.id, name=tc.name)
            return event, msg

        return await asyncio.gather(*[_run_one(tc) for tc in tool_calls])