diff --git a/navi/core/agent.py b/navi/core/agent.py index 4fd372c..e44743a 100644 --- a/navi/core/agent.py +++ b/navi/core/agent.py @@ -207,7 +207,7 @@ display_text = display_message if display_message is not None else user_message user_msg_display = Message(role="user", content=display_text, images=images or None, files=files or None, created_at=datetime.now(timezone.utc), - is_recall=is_recall) + is_recall=is_recall, is_context=False) # Image token budgeting: fit as many images as possible into the LLM context. # Overflow images are saved to the session directory so Navi can view them @@ -246,8 +246,9 @@ user_msg_context = Message(role="user", content=context_content, images=images_for_context or None, files=files or None, created_at=datetime.now(timezone.utc), - is_recall=is_recall) + is_recall=is_recall, is_display=False) session.messages.append(user_msg_display) + session.messages.append(user_msg_context) session.context.append(user_msg_context) # Persist user message immediately so it survives a client disconnect # before the assistant reply is ready. @@ -370,6 +371,7 @@ session.messages.append(Message( role="assistant", content=state.accumulated_text, created_at=datetime.now(timezone.utc), + is_context=False, )) await self._sessions.save(session) yield StreamStopped() @@ -516,13 +518,30 @@ if result is None: return None new_context, summary_text = result - session.context = new_context - session.context_token_count = self._compressor.estimate_context_tokens(new_context) + + # Mark messages that are no longer part of the LLM context + new_context_ids = {id(m) for m in new_context} + for msg in session.messages: + if id(msg) not in new_context_ids and msg.role != "system": + msg.is_context = False + + # The summary returned by the compressor must also live in messages so + # save() writes it to the normalized table, but it is not displayed. + summary_msg = next((m for m in new_context if m.is_summary), None) + if summary_msg and summary_msg not in session.messages: + summary_msg.is_display = False + session.messages.append(summary_msg) + + # UI marker showing that compression happened session.messages.append(Message( role="system", is_compression=True, + is_context=False, content=summary_text, )) + + session.context = new_context + session.context_token_count = self._compressor.estimate_context_tokens(new_context) await self._sessions.save(session) log.info( @@ -692,6 +711,7 @@ session.messages.append(Message( role="tool", content="Tool execution was stopped by the user.", tool_call_id=tc.id, name=tc.name, metadata={}, + is_context=False, )) await self._sessions.save(session) yield StreamStopped() diff --git a/navi/core/pg_session_store.py b/navi/core/pg_session_store.py index e17170d..d793530 100644 --- a/navi/core/pg_session_store.py +++ b/navi/core/pg_session_store.py @@ -1,4 +1,13 @@ -"""PostgreSQL-backed session store using asyncpg connection pool.""" +"""PostgreSQL-backed session store using asyncpg connection pool. + +Phase-1 migration to normalized tables: +- session_messages — one row per Message, with is_display / is_context flags +- session_images — base64 images referenced by session_messages (future use) + +Dual-write: every save() writes to both JSON columns (legacy) and normalized +rows (new). get() prefers normalized rows when they exist, falling back to JSON. +This lets us migrate incrementally without breaking existing sessions. +""" import asyncio import json @@ -6,7 +15,7 @@ import asyncpg -from navi.llm.base import Message +from navi.llm.base import Message, ToolCallRequest from .session import Session, SessionStore @@ -31,6 +40,50 @@ ALTER TABLE sessions ADD COLUMN IF NOT EXISTS user_id TEXT REFERENCES navi_users(id) ON DELETE SET NULL """ +_SESSION_MESSAGES_DDL = """ +CREATE TABLE IF NOT EXISTS session_messages ( + id SERIAL PRIMARY KEY, + session_id TEXT NOT NULL REFERENCES sessions(id) ON DELETE CASCADE, + sequence_number INT NOT NULL, + role TEXT NOT NULL, + content TEXT, + images TEXT, -- JSON array of base64 strings + tool_calls TEXT, -- JSON + tool_call_id TEXT, + name TEXT, + created_at TIMESTAMPTZ, + is_summary BOOLEAN NOT NULL DEFAULT FALSE, + thinking TEXT, + is_plan BOOLEAN NOT NULL DEFAULT FALSE, + is_compression BOOLEAN NOT NULL DEFAULT FALSE, + is_context BOOLEAN NOT NULL DEFAULT TRUE, + is_display BOOLEAN NOT NULL DEFAULT TRUE, + elapsed_seconds FLOAT, + tool_call_count INT, + token_count INT, + files TEXT, -- JSON + metadata TEXT, -- JSON + is_recall BOOLEAN NOT NULL DEFAULT FALSE, + UNIQUE(session_id, sequence_number) +); + +CREATE INDEX IF NOT EXISTS idx_session_messages_session_seq ON session_messages(session_id, sequence_number); +CREATE INDEX IF NOT EXISTS idx_session_messages_context ON session_messages(session_id, is_context, sequence_number); +""" + +_SESSION_IMAGES_DDL = """ +CREATE TABLE IF NOT EXISTS session_images ( + id SERIAL PRIMARY KEY, + session_id TEXT NOT NULL REFERENCES sessions(id) ON DELETE CASCADE, + message_id INT REFERENCES session_messages(id) ON DELETE CASCADE, + base64 TEXT NOT NULL, + filename TEXT +); + +CREATE INDEX IF NOT EXISTS idx_session_images_session ON session_images(session_id); +CREATE INDEX IF NOT EXISTS idx_session_images_message ON session_images(message_id); +""" + def _serialize(messages: list[Message]) -> str: return json.dumps( @@ -45,6 +98,116 @@ return [Message.model_validate(m) for m in json.loads(raw)] +def _message_key(m: Message) -> tuple: + """Stable key for matching a message between messages[] and context[]. + + Used only by the one-shot boot migration for legacy JSON rows. + """ + return ( + m.role, + m.content, + m.tool_call_id, + m.name, + m.is_summary, + m.is_plan, + m.is_compression, + m.is_recall, + m.thinking, + m.created_at.isoformat() if m.created_at else None, + json.dumps([tc.model_dump(mode="json") for tc in m.tool_calls], ensure_ascii=False) if m.tool_calls else None, + json.dumps(m.files, ensure_ascii=False) if m.files else None, + json.dumps(m.metadata, ensure_ascii=False) if m.metadata else None, + ) + + +def _row_to_message(row: asyncpg.Record) -> Message: + images = json.loads(row["images"]) if row.get("images") else None + tool_calls = None + raw_tool_calls = row.get("tool_calls") + if raw_tool_calls: + tool_calls = [ToolCallRequest.model_validate(tc) for tc in json.loads(raw_tool_calls)] + files = json.loads(row["files"]) if row.get("files") else None + metadata = json.loads(row["metadata"]) if row.get("metadata") else {} + + return Message( + role=row["role"], + content=row["content"], + images=images, + tool_calls=tool_calls, + tool_call_id=row.get("tool_call_id"), + name=row.get("name"), + created_at=row.get("created_at"), + is_summary=bool(row.get("is_summary", False)), + thinking=row.get("thinking"), + is_plan=bool(row.get("is_plan", False)), + is_compression=bool(row.get("is_compression", False)), + is_context=bool(row.get("is_context", True)), + is_display=bool(row.get("is_display", True)), + elapsed_seconds=row.get("elapsed_seconds"), + tool_call_count=row.get("tool_call_count"), + token_count=row.get("token_count"), + files=files, + metadata=metadata, + is_recall=bool(row.get("is_recall", False)), + ) + + +async def _ensure_normalized_tables(conn: asyncpg.Connection) -> None: + await conn.execute(_SESSION_MESSAGES_DDL) + await conn.execute(_SESSION_IMAGES_DDL) + + +async def _migrate_to_normalized(conn: asyncpg.Connection) -> None: + """One-shot migration: copy existing JSON messages/context into session_messages.""" + migrated = await conn.fetchval("SELECT COUNT(*) FROM session_messages") + if migrated: + return + + rows = await conn.fetch("SELECT id, messages, context FROM sessions") + for row in rows: + session_id = row["id"] + messages = _deserialize(row["messages"] or "[]") + context = _deserialize(row["context"] or "") + if not context and messages: + context = list(messages) + + context_set = {_message_key(m) for m in context} + + for seq, m in enumerate(messages): + is_ctx = _message_key(m) in context_set + await conn.execute( + """ + INSERT INTO session_messages + (session_id, sequence_number, role, content, images, tool_calls, tool_call_id, name, + created_at, is_summary, thinking, is_plan, is_compression, is_context, is_display, + elapsed_seconds, tool_call_count, token_count, files, metadata, is_recall) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21) + ON CONFLICT (session_id, sequence_number) DO NOTHING + """, + session_id, + seq, + m.role, + m.content, + json.dumps(m.images, ensure_ascii=False) if m.images else None, + json.dumps([tc.model_dump(mode="json") for tc in m.tool_calls], ensure_ascii=False) if m.tool_calls else None, + m.tool_call_id, + m.name, + m.created_at, + m.is_summary, + m.thinking, + m.is_plan, + m.is_compression, + is_ctx, + True, + m.elapsed_seconds, + m.tool_call_count, + m.token_count, + json.dumps(m.files, ensure_ascii=False) if m.files else None, + json.dumps(m.metadata, ensure_ascii=False) if m.metadata else None, + m.is_recall, + ) + + class PgSessionStore(SessionStore): def __init__(self, pool: asyncpg.Pool) -> None: self._pool = pool @@ -58,6 +221,8 @@ async with self._pool.acquire() as conn: await conn.execute(_DDL) await conn.execute(_MIGRATE) + await _ensure_normalized_tables(conn) + await _migrate_to_normalized(conn) self._initialized = True return self._pool @@ -81,12 +246,47 @@ "FROM sessions WHERE id = $1", session_id, ) - return self._row_to_session(row) if row else None + if not row: + return None + + msg_rows = await conn.fetch( + "SELECT * FROM session_messages WHERE session_id = $1 AND is_display = true ORDER BY sequence_number", + session_id, + ) + if msg_rows: + messages = [_row_to_message(r) for r in msg_rows] + ctx_rows = await conn.fetch( + "SELECT * FROM session_messages WHERE session_id = $1 AND is_context = true ORDER BY sequence_number", + session_id, + ) + context = [_row_to_message(r) for r in ctx_rows] + else: + messages = _deserialize(row["messages"]) + context_json = row["context"] + context = _deserialize(context_json) if context_json else list(messages) + + planning_logs_raw = row["planning_logs"] + planning_logs = json.loads(planning_logs_raw) if planning_logs_raw else [] + + return Session( + id=row["id"], + profile_id=row["profile_id"], + user_id=row["user_id"], + messages=messages, + context=context, + pinned=bool(row["pinned"]), + name=row["name"], + created_at=row["created_at"], + last_active=row["last_active"], + context_token_count=row["context_token_count"] or 0, + planning_logs=planning_logs, + ) async def save(self, session: Session) -> None: session.last_active = datetime.now(timezone.utc) pool = await self._get_pool() async with pool.acquire() as conn: + # Legacy JSON columns (dual-write for rollback safety) await conn.execute( "UPDATE sessions SET profile_id = $1, user_id = $2, messages = $3, context = $4, " "last_active = $5, context_token_count = $6, planning_logs = $7 WHERE id = $8", @@ -95,6 +295,41 @@ json.dumps(session.planning_logs, ensure_ascii=False), session.id, ) + # Normalized rows — write session.messages with their flags. + await conn.execute("DELETE FROM session_messages WHERE session_id = $1", session.id) + + for seq, m in enumerate(session.messages): + await conn.execute( + """ + INSERT INTO session_messages + (session_id, sequence_number, role, content, images, tool_calls, tool_call_id, name, + created_at, is_summary, thinking, is_plan, is_compression, is_context, is_display, + elapsed_seconds, tool_call_count, token_count, files, metadata, is_recall) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21) + """, + session.id, + seq, + m.role, + m.content, + json.dumps(m.images, ensure_ascii=False) if m.images else None, + json.dumps([tc.model_dump(mode="json") for tc in m.tool_calls], ensure_ascii=False) if m.tool_calls else None, + m.tool_call_id, + m.name, + m.created_at, + m.is_summary, + m.thinking, + m.is_plan, + m.is_compression, + m.is_context, + m.is_display, + m.elapsed_seconds, + m.tool_call_count, + m.token_count, + json.dumps(m.files, ensure_ascii=False) if m.files else None, + json.dumps(m.metadata, ensure_ascii=False) if m.metadata else None, + m.is_recall, + ) + async def set_pinned(self, session_id: str, pinned: bool) -> bool: pool = await self._get_pool() async with pool.acquire() as conn: @@ -249,6 +484,10 @@ return [self._row_to_session(r) for r in rows] def _row_to_session(self, row: asyncpg.Record) -> Session: + # Prefer normalized rows; fallback to legacy JSON columns + session_id = row["id"] + # We can't easily do another query here (no async in sync method), + # so rely on the caller path (get) or legacy JSON (list paths). messages = _deserialize(row["messages"]) context_json = row["context"] context = _deserialize(context_json) if context_json else list(messages) diff --git a/navi/core/planning.py b/navi/core/planning.py index 67164e3..5cc6a48 100644 --- a/navi/core/planning.py +++ b/navi/core/planning.py @@ -407,14 +407,20 @@ if not re.search(r"(TOOL:|AGENT:|→\s*SELF)", plan_text): log.warning("agent.planning_no_executors", hint="plan lacks TOOL/AGENT/SELF assignments") - context.append(Message(role="assistant", content=plan_text)) + plan_ctx_msg = Message(role="assistant", content=plan_text, is_display=False) + context.append(plan_ctx_msg) if messages is not None: - messages.append(Message(role="assistant", content=plan_text, is_plan=True)) + messages.append(plan_ctx_msg) + messages.append(Message(role="assistant", content=plan_text, is_plan=True, is_context=False)) - context.append(Message( + prompt_msg = Message( role="user", content="Plan is ready. Execute it now step by step, starting with step 1. Use the todo tool to track progress.", - )) + is_display=False, + ) + context.append(prompt_msg) + if messages is not None: + messages.append(prompt_msg) _todo_steps = _parse_plan_steps(plan_text) if _todo_steps: diff --git a/navi/llm/base.py b/navi/llm/base.py index a2186c9..dd4aa1b 100644 --- a/navi/llm/base.py +++ b/navi/llm/base.py @@ -58,6 +58,9 @@ metadata: dict = Field(default_factory=dict) # marks a scheduled-recall trigger message (display styling + LLM context hint) is_recall: bool = False + # normalized storage flags — written by the agent, read by PgSessionStore + is_context: bool = True # included in LLM context + is_display: bool = True # shown in chat UI class LLMResponse(BaseModel): diff --git a/tests/unit/core/test_agent.py b/tests/unit/core/test_agent.py index 022e44d..f7b35b1 100644 --- a/tests/unit/core/test_agent.py +++ b/tests/unit/core/test_agent.py @@ -63,10 +63,12 @@ result = await agent.run(session.id, "hi") assert result == "hello" saved = await agent._sessions.get(session.id) - assert len(saved.messages) == 2 # user + assistant + # user display + user context + assistant + assert len(saved.messages) == 3 assert saved.messages[0].role == "user" - assert saved.messages[1].role == "assistant" - assert saved.messages[1].content == "hello" + assert saved.messages[1].role == "user" + assert saved.messages[2].role == "assistant" + assert saved.messages[2].content == "hello" @pytest.mark.asyncio async def test_run_session_not_found(self, agent): @@ -88,10 +90,10 @@ result = await agent.run(session.id, "do something") assert result == "done" saved = await agent._sessions.get(session.id) - # user + assistant(tool) + tool_result + assistant(final) - assert len(saved.messages) == 4 - assert saved.messages[2].role == "tool" - assert saved.messages[3].content == "done" + # user display + user context + assistant(tool) + tool_result + assistant(final) + assert len(saved.messages) == 5 + assert saved.messages[3].role == "tool" + assert saved.messages[4].content == "done" @pytest.mark.asyncio async def test_run_token_accumulation(self, agent, session):