diff --git a/navi/api/deps.py b/navi/api/deps.py index 0881e00..1381243 100644 --- a/navi/api/deps.py +++ b/navi/api/deps.py @@ -15,6 +15,7 @@ ToolRegistry, build_default_registries, ) +from navi.workers import Worker, build_default_workers @lru_cache @@ -41,10 +42,16 @@ return _session_store +@lru_cache +def get_workers() -> list[Worker]: + return build_default_workers() + + def get_agent( session_store: Annotated[SessionStore, Depends(get_session_store)], profile_registry: Annotated[ProfileRegistry, Depends(get_profile_registry)], tool_registry: Annotated[ToolRegistry, Depends(get_tool_registry)], backend_registry: Annotated[BackendRegistry, Depends(get_backend_registry)], ) -> Agent: - return Agent(session_store, profile_registry, tool_registry, backend_registry) + return Agent(session_store, profile_registry, tool_registry, backend_registry, + workers=get_workers()) diff --git a/navi/api/websocket.py b/navi/api/websocket.py index 2721d0e..8a32e8d 100644 --- a/navi/api/websocket.py +++ b/navi/api/websocket.py @@ -41,9 +41,9 @@ log.info("ws.connected", session_id=session_id) # Build agent (can't use FastAPI Depends inside WebSocket directly) - from navi.api.deps import get_registries + from navi.api.deps import get_registries, get_workers tools, profiles, backends = get_registries() - agent = Agent(session_store, profiles, tools, backends) + agent = Agent(session_store, profiles, tools, backends, workers=get_workers()) try: while True: diff --git a/navi/config.py b/navi/config.py index 6c676fc..6826cc1 100644 --- a/navi/config.py +++ b/navi/config.py @@ -33,7 +33,7 @@ # Context compression context_compression_enabled: bool = True context_compression_threshold: float = 0.80 # trigger at 80% of ollama_num_ctx - context_keep_recent: int = 6 # conversational turns to keep verbatim + context_keep_recent: int = 10 # conversational turns to keep verbatim context_summary_temperature: float = 0.3 # Global personality prompt prepended to every agent's system prompt. diff --git a/navi/core/agent.py b/navi/core/agent.py index 60547b4..74ac68c 100644 --- a/navi/core/agent.py +++ b/navi/core/agent.py @@ -5,13 +5,14 @@ 1. Receive user message, load session + profile 2. Build tool schemas from profile's enabled_tools 3. Loop (up to max_iterations): - a. Call LLM with current messages + tool schemas - b. If finish_reason == "stop" -> done, return content - c. If finish_reason == "tool_calls" -> execute tools concurrently, append results, continue -4. Final streaming path: use llm.stream() to yield text deltas to WebSocket clients + a. Call LLM with session.context (may be compressed) + tool schemas + b. If finish_reason == "stop" -> stream final response + c. If finish_reason == "tool_calls" -> execute tools, append to both + session.messages (display history) and session.context (LLM context) +4. After StreamEnd: run workers sequentially (e.g. context compression) -For multi-agent extension: instantiate multiple Agent objects with different profiles. -An Orchestrator (core/orchestrator.py) dispatches tasks to worker agents via asyncio Queues. +session.messages — full display history, never compressed +session.context — what the LLM sees; workers may compress this """ import asyncio @@ -19,7 +20,7 @@ from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path -from typing import AsyncGenerator +from typing import TYPE_CHECKING, AsyncGenerator import structlog @@ -28,10 +29,12 @@ from navi.llm.base import LLMBackend, Message, ToolCallRequest from navi.tools.base import Tool -from .compressor import compress_session, should_compress from .registry import BackendRegistry, ProfileRegistry, ToolRegistry from .session import SessionStore +if TYPE_CHECKING: + from navi.workers.base import Worker, WorkerContext + _USER_ENABLED_FILE = Path(settings.tools_dir) / "enabled.json" @@ -41,6 +44,7 @@ except Exception: return [] + log = structlog.get_logger() @@ -100,11 +104,13 @@ profile_registry: ProfileRegistry, tool_registry: ToolRegistry, backend_registry: BackendRegistry, + workers: list["Worker"] | None = None, ) -> None: self._sessions = session_store self._profiles = profile_registry self._tools = tool_registry self._backends = backend_registry + self._workers: list["Worker"] = workers or [] # ------------------------------------------------------------------ # Public interface @@ -121,16 +127,22 @@ tool_schemas = [t.schema() for t in tools] llm = self._get_backend(profile.llm_backend) - # Inject system prompt on first message - if not session.messages: - session.messages.append(Message(role="system", content=self._build_system_prompt(profile.system_prompt))) + # System prompt only goes into context (not display history) + if not session.context: + session.context.append(Message( + role="system", + content=self._build_system_prompt(profile.system_prompt), + )) - session.messages.append(Message(role="user", content=user_message, images=images or None, created_at=datetime.now(timezone.utc))) + user_msg = Message(role="user", content=user_message, images=images or None, + created_at=datetime.now(timezone.utc)) + session.messages.append(user_msg) + session.context.append(user_msg) for iteration in range(profile.max_iterations): log.debug("agent.iteration", session_id=session_id, iteration=iteration) response = await llm.complete( - session.messages, + session.context, tools=tool_schemas if tools else None, temperature=profile.temperature, model=profile.model, @@ -138,21 +150,27 @@ if response.finish_reason == "stop" or not response.tool_calls: content = response.content or "" - session.messages.append(Message(role="assistant", content=content)) + assistant_msg = Message(role="assistant", content=content, + created_at=datetime.now(timezone.utc)) + session.messages.append(assistant_msg) + session.context.append(assistant_msg) await self._sessions.save(session) return content - # Tool calls turn + # Tool calls turn — append to both messages and context assistant_msg = Message( role="assistant", content=response.content, tool_calls=response.tool_calls, ) session.messages.append(assistant_msg) + session.context.append(assistant_msg) tool_results, image_injections = await self._execute_tool_calls(response.tool_calls, tools) session.messages.extend(tool_results) - session.messages.extend(image_injections) + session.context.extend(tool_results) + # Image injections are synthetic LLM helpers — context only + session.context.extend(image_injections) await self._sessions.save(session) raise MaxIterationsReached(profile.max_iterations) @@ -162,9 +180,10 @@ ) -> AsyncGenerator[AgentEvent, None]: """ Streaming variant. Yields AgentEvent objects: - - ToolEvent: when a tool is called and its result arrives - - TextDelta: each text chunk from the final LLM response - - StreamEnd: final event with the full assembled content + - ThinkingDelta / ThinkingEnd: reasoning chunks + - ToolEvent: tool call + result + - TextDelta / StreamEnd: final streamed response + - ContextCompressed: emitted by workers after compression """ session = await self._sessions.get(session_id) if session is None: @@ -175,28 +194,36 @@ tool_schemas = [t.schema() for t in tools] llm = self._get_backend(profile.llm_backend) - if not session.messages: - session.messages.append(Message(role="system", content=self._build_system_prompt(profile.system_prompt))) + # System prompt only goes into context (not display history) + if not session.context: + session.context.append(Message( + role="system", + content=self._build_system_prompt(profile.system_prompt), + )) - session.messages.append(Message(role="user", content=user_message, images=images or None, created_at=datetime.now(timezone.utc))) + user_msg = Message(role="user", content=user_message, images=images or None, + created_at=datetime.now(timezone.utc)) + session.messages.append(user_msg) + session.context.append(user_msg) # Tool-calling loop (non-streaming) for iteration in range(profile.max_iterations): response = await llm.complete( - session.messages, + session.context, tools=tool_schemas if tools else None, temperature=profile.temperature, model=profile.model, ) if response.finish_reason == "stop" or not response.tool_calls: - # Switch to streaming for the final text response - final_messages = session.messages.copy() + # Stream the final response accumulated = "" thinking_active = False context_tokens: int | None = None - async for chunk in llm.stream(final_messages, temperature=profile.temperature, model=profile.model): + async for chunk in llm.stream( + session.context.copy(), temperature=profile.temperature, model=profile.model + ): if chunk.prompt_tokens is not None or chunk.completion_tokens is not None: context_tokens = (chunk.prompt_tokens or 0) + (chunk.completion_tokens or 0) if chunk.thinking: @@ -213,7 +240,10 @@ if thinking_active: yield ThinkingEnd() - session.messages.append(Message(role="assistant", content=accumulated, created_at=datetime.now(timezone.utc))) + assistant_msg = Message(role="assistant", content=accumulated, + created_at=datetime.now(timezone.utc)) + session.messages.append(assistant_msg) + session.context.append(assistant_msg) await self._sessions.save(session) yield StreamEnd( @@ -222,46 +252,19 @@ max_context_tokens=settings.ollama_num_ctx, ) - # Post-response compression — runs after client receives StreamEnd - if ( - settings.context_compression_enabled - and context_tokens is not None - and should_compress(context_tokens, settings.ollama_num_ctx, settings.context_compression_threshold) - ): - count_before = len(session.messages) - try: - new_messages = await compress_session( - messages=session.messages, - llm=llm, - model=profile.model, - temperature=settings.context_summary_temperature, - keep_recent=settings.context_keep_recent, - ) - if new_messages is not None: - session.messages = new_messages - await self._sessions.save(session) - log.info( - "agent.compressed", - session_id=session_id, - before=count_before, - after=len(session.messages), - ) - yield ContextCompressed( - messages_before=count_before, - messages_after=len(session.messages), - ) - except Exception: - log.warning("agent.compress_failed", session_id=session_id, exc_info=True) - + # Run post-response workers (e.g. context compression) + for event in await self._run_workers(session, llm, profile.model, context_tokens): + yield event return - # Tool calls: emit events, execute, continue loop + # Tool calls: emit events, append to both messages and context assistant_msg = Message( role="assistant", content=response.content, tool_calls=response.tool_calls, ) session.messages.append(assistant_msg) + session.context.append(assistant_msg) tool_results_msgs, image_injections = await self._execute_tool_calls_streaming( response.tool_calls, tools @@ -269,7 +272,9 @@ for event, msg in tool_results_msgs: yield event session.messages.append(msg) - session.messages.extend(image_injections) + session.context.append(msg) + # Image injections are synthetic — context only + session.context.extend(image_injections) await self._sessions.save(session) raise MaxIterationsReached(profile.max_iterations) @@ -278,6 +283,35 @@ # Internal helpers # ------------------------------------------------------------------ + async def _run_workers( + self, + session, + llm: LLMBackend, + model: str, + context_tokens: int | None, + ) -> list[AgentEvent]: + """Run all workers sequentially; collect their events.""" + from navi.workers.base import WorkerContext + + ctx = WorkerContext( + session_id=session.id, + context_tokens=context_tokens, + max_context_tokens=settings.ollama_num_ctx, + llm=llm, + model=model, + temperature=settings.context_summary_temperature, + session_store=self._sessions, + ) + events: list[AgentEvent] = [] + for worker in self._workers: + try: + result = await worker.run(session, ctx) + events.extend(result.events) + except Exception: + log.warning("agent.worker_failed", + worker=type(worker).__name__, exc_info=True) + return events + def _build_system_prompt(self, profile_prompt: str) -> str: persona = settings.navi_persona.strip() if persona: @@ -286,12 +320,10 @@ def _tool_list(self, enabled: list[str]) -> list[Tool]: names = list(enabled) - # Merge in user-created tools from tools/enabled.json extra = _load_user_enabled_tools() for name in extra: if name not in names: names.append(name) - # Silently skip any names not registered (e.g. tool was deleted) result = [] for name in names: try: diff --git a/navi/core/compressor.py b/navi/core/compressor.py index 492dced..7f9f501 100644 --- a/navi/core/compressor.py +++ b/navi/core/compressor.py @@ -103,21 +103,21 @@ return "\n".join(lines) -async def compress_session( - messages: list[Message], +async def compress_context( + context: list[Message], llm: LLMBackend, model: str, temperature: float, keep_recent: int, ) -> list[Message] | None: """ - Summarize old messages and return a new (shorter) message list. - Returns None if there is nothing to compress or if the LLM call fails. - + Summarize old messages in the LLM context and return a shorter context list. + Only operates on `context` — the full display history (session.messages) is never touched. + Returns None if there is nothing to compress. Raises LLMBackendError on LLM failure — caller decides how to handle. """ - system_msgs = [m for m in messages if m.role == "system"] - to_summarize, to_keep = partition_messages(messages, keep_recent) + system_msgs = [m for m in context if m.role == "system"] + to_summarize, to_keep = partition_messages(context, keep_recent) if len(to_summarize) < 2: return None # nothing substantial to compress diff --git a/navi/core/session.py b/navi/core/session.py index 4588e22..4b5a7e8 100644 --- a/navi/core/session.py +++ b/navi/core/session.py @@ -12,7 +12,8 @@ class Session(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) profile_id: str - messages: list[Message] = Field(default_factory=list) + messages: list[Message] = Field(default_factory=list) # full display history (never compressed) + context: list[Message] = Field(default_factory=list) # LLM context (may be compressed) pinned: bool = False created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) last_active: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) diff --git a/navi/core/sqlite_session_store.py b/navi/core/sqlite_session_store.py index a28022f..8583c53 100644 --- a/navi/core/sqlite_session_store.py +++ b/navi/core/sqlite_session_store.py @@ -15,6 +15,7 @@ id TEXT PRIMARY KEY, profile_id TEXT NOT NULL, messages TEXT NOT NULL DEFAULT '[]', + context TEXT NOT NULL DEFAULT '', pinned INTEGER NOT NULL DEFAULT 0, created_at TEXT NOT NULL, last_active TEXT NOT NULL @@ -22,24 +23,40 @@ """ +def _serialize(messages: list[Message]) -> str: + return json.dumps( + [m.model_dump(mode='json', exclude_none=True) for m in messages], + ensure_ascii=False, + ) + + +def _deserialize(raw: str) -> list[Message]: + if not raw: + return [] + return [Message.model_validate(m) for m in json.loads(raw)] + + class SqliteSessionStore(SessionStore): def __init__(self, db_path: str = "navi.db") -> None: self._db_path = db_path with sqlite3.connect(db_path) as conn: conn.execute(_CREATE_TABLE) - # Migrate: add pinned column to existing tables that don't have it - try: - conn.execute("ALTER TABLE sessions ADD COLUMN pinned INTEGER NOT NULL DEFAULT 0") - except sqlite3.OperationalError: - pass # column already exists + for migration in [ + "ALTER TABLE sessions ADD COLUMN pinned INTEGER NOT NULL DEFAULT 0", + "ALTER TABLE sessions ADD COLUMN context TEXT NOT NULL DEFAULT ''", + ]: + try: + conn.execute(migration) + except sqlite3.OperationalError: + pass # column already exists conn.commit() async def create(self, profile_id: str) -> Session: session = Session(profile_id=profile_id) async with aiosqlite.connect(self._db_path) as db: await db.execute( - "INSERT INTO sessions (id, profile_id, messages, pinned, created_at, last_active) " - "VALUES (?, ?, '[]', 0, ?, ?)", + "INSERT INTO sessions (id, profile_id, messages, context, pinned, created_at, last_active) " + "VALUES (?, ?, '[]', '', 0, ?, ?)", (session.id, session.profile_id, session.created_at.isoformat(), session.last_active.isoformat()), ) @@ -49,7 +66,7 @@ async def get(self, session_id: str) -> Session | None: async with aiosqlite.connect(self._db_path) as db: async with db.execute( - "SELECT id, profile_id, messages, pinned, created_at, last_active " + "SELECT id, profile_id, messages, context, pinned, created_at, last_active " "FROM sessions WHERE id = ?", (session_id,), ) as cur: @@ -58,14 +75,11 @@ async def save(self, session: Session) -> None: session.last_active = datetime.now(timezone.utc) - messages_json = json.dumps( - [m.model_dump(mode='json', exclude_none=True) for m in session.messages], - ensure_ascii=False, - ) async with aiosqlite.connect(self._db_path) as db: await db.execute( - "UPDATE sessions SET messages = ?, last_active = ? WHERE id = ?", - (messages_json, session.last_active.isoformat(), session.id), + "UPDATE sessions SET messages = ?, context = ?, last_active = ? WHERE id = ?", + (_serialize(session.messages), _serialize(session.context), + session.last_active.isoformat(), session.id), ) await db.commit() @@ -81,7 +95,7 @@ async def list_all(self) -> list[Session]: async with aiosqlite.connect(self._db_path) as db: async with db.execute( - "SELECT id, profile_id, messages, pinned, created_at, last_active " + "SELECT id, profile_id, messages, context, pinned, created_at, last_active " "FROM sessions ORDER BY pinned DESC, last_active DESC" ) as cur: rows = await cur.fetchall() @@ -94,12 +108,16 @@ return cur.rowcount > 0 def _row_to_session(self, row: tuple) -> Session: - id_, profile_id, messages_json, pinned, created_at, last_active = row - messages = [Message.model_validate(m) for m in json.loads(messages_json)] + id_, profile_id, messages_json, context_json, pinned, created_at, last_active = row + messages = _deserialize(messages_json) + # Backward compat: existing sessions have empty context column — + # initialize context from messages so they work without re-compression. + context = _deserialize(context_json) if context_json else list(messages) return Session( id=id_, profile_id=profile_id, messages=messages, + context=context, pinned=bool(pinned), created_at=datetime.fromisoformat(created_at), last_active=datetime.fromisoformat(last_active), diff --git a/navi/workers/__init__.py b/navi/workers/__init__.py new file mode 100644 index 0000000..f4003b1 --- /dev/null +++ b/navi/workers/__init__.py @@ -0,0 +1,15 @@ +from .base import Worker, WorkerContext, WorkerResult +from .compressor import CompressionWorker + + +def build_default_workers() -> list[Worker]: + return [CompressionWorker()] + + +__all__ = [ + "Worker", + "WorkerContext", + "WorkerResult", + "CompressionWorker", + "build_default_workers", +] diff --git a/navi/workers/base.py b/navi/workers/base.py new file mode 100644 index 0000000..658f17e --- /dev/null +++ b/navi/workers/base.py @@ -0,0 +1,46 @@ +"""Base classes for post-response background workers.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from navi.core.session import Session, SessionStore + from navi.llm.base import LLMBackend + + +@dataclass +class WorkerContext: + """Runtime data passed to every worker after a response completes.""" + + session_id: str + context_tokens: int | None # tokens used in last turn (from Ollama) + max_context_tokens: int # ollama_num_ctx + llm: LLMBackend + model: str + temperature: float + session_store: SessionStore + + +@dataclass +class WorkerResult: + """Returned by a worker. `events` will be yielded to the WebSocket client.""" + + events: list[Any] = field(default_factory=list) # list[AgentEvent] + + +class Worker(ABC): + """ + A post-response background task. + + Workers run sequentially after `StreamEnd` is yielded. + Each worker may modify `session` (e.g. compress context) and return + events to forward to the client. Failures are isolated — one broken + worker does not block others. + """ + + @abstractmethod + async def run(self, session: Session, ctx: WorkerContext) -> WorkerResult: + """Execute the worker. May mutate session and save via ctx.session_store.""" diff --git a/navi/workers/compressor.py b/navi/workers/compressor.py new file mode 100644 index 0000000..abd299e --- /dev/null +++ b/navi/workers/compressor.py @@ -0,0 +1,59 @@ +"""Context compression worker.""" + +import structlog + +from navi.config import settings +from navi.core.compressor import compress_context, should_compress + +from .base import Worker, WorkerContext, WorkerResult + +log = structlog.get_logger() + + +class CompressionWorker(Worker): + """ + Compresses session.context when it approaches the token limit. + session.messages (full display history) is never modified. + """ + + async def run(self, session, ctx: WorkerContext) -> WorkerResult: + if not settings.context_compression_enabled: + return WorkerResult() + if ctx.context_tokens is None: + return WorkerResult() + if not should_compress(ctx.context_tokens, ctx.max_context_tokens, + settings.context_compression_threshold): + return WorkerResult() + + count_before = len(session.context) + try: + new_context = await compress_context( + context=session.context, + llm=ctx.llm, + model=ctx.model, + temperature=settings.context_summary_temperature, + keep_recent=settings.context_keep_recent, + ) + except Exception: + log.warning("compression_worker.llm_failed", session_id=ctx.session_id, exc_info=True) + return WorkerResult() + + if new_context is None: + return WorkerResult() + + session.context = new_context + await ctx.session_store.save(session) + + log.info( + "compression_worker.done", + session_id=ctx.session_id, + before=count_before, + after=len(session.context), + ) + + # Import here to avoid circular dependency + from navi.core.agent import ContextCompressed + return WorkerResult(events=[ContextCompressed( + messages_before=count_before, + messages_after=len(session.context), + )])