diff --git a/navi/core/pg_session_store.py b/navi/core/pg_session_store.py index 3aa074c..2051528 100644 --- a/navi/core/pg_session_store.py +++ b/navi/core/pg_session_store.py @@ -1,10 +1,14 @@ """PostgreSQL-backed session store using asyncpg connection pool. Normalized storage: -- session_messages — one row per Message, with is_display / is_context flags -- session_images — base64 images referenced by session_messages (future use) +- session_messages — hot messages (recent N per session) +- session_messages_archive — old messages moved here to keep the hot table small +- session_images — base64 images referenced by session_messages (future use) -All session message I/O goes through the normalized table. The legacy JSON + sessions.next_sequence — monotonic global seq for the session (next free number) + sessions.archive_threshold — all seq < threshold live in the archive table + +All session message I/O goes through the normalized tables. The legacy JSON columns (sessions.messages, sessions.context) remain in the schema for backward compatibility but are no longer read or written. """ @@ -37,7 +41,9 @@ _MIGRATE = """ ALTER TABLE sessions ADD COLUMN IF NOT EXISTS name TEXT; ALTER TABLE sessions ADD COLUMN IF NOT EXISTS planning_logs TEXT NOT NULL DEFAULT '[]'; -ALTER TABLE sessions ADD COLUMN IF NOT EXISTS user_id TEXT REFERENCES navi_users(id) ON DELETE SET NULL +ALTER TABLE sessions ADD COLUMN IF NOT EXISTS user_id TEXT REFERENCES navi_users(id) ON DELETE SET NULL; +ALTER TABLE sessions ADD COLUMN IF NOT EXISTS next_sequence INTEGER NOT NULL DEFAULT 0; +ALTER TABLE sessions ADD COLUMN IF NOT EXISTS archive_threshold INTEGER NOT NULL DEFAULT 0 """ _SESSION_MESSAGES_DDL = """ @@ -84,6 +90,35 @@ CREATE INDEX IF NOT EXISTS idx_session_images_message ON session_images(message_id); """ +_SESSION_ARCHIVE_DDL = """ +CREATE TABLE IF NOT EXISTS session_messages_archive ( + id SERIAL PRIMARY KEY, + session_id TEXT NOT NULL REFERENCES sessions(id) ON DELETE CASCADE, + sequence_number INT NOT NULL, + role TEXT NOT NULL, + content TEXT, + images TEXT, + tool_calls TEXT, + tool_call_id TEXT, + name TEXT, + created_at TIMESTAMPTZ, + is_summary BOOLEAN NOT NULL DEFAULT FALSE, + thinking TEXT, + is_plan BOOLEAN NOT NULL DEFAULT FALSE, + is_compression BOOLEAN NOT NULL DEFAULT FALSE, + is_context BOOLEAN NOT NULL DEFAULT TRUE, + is_display BOOLEAN NOT NULL DEFAULT TRUE, + elapsed_seconds FLOAT, + tool_call_count INT, + token_count INT, + files TEXT, + metadata TEXT, + is_recall BOOLEAN NOT NULL DEFAULT FALSE, + UNIQUE(session_id, sequence_number) +); +CREATE INDEX IF NOT EXISTS idx_session_messages_archive_session_seq ON session_messages_archive(session_id, sequence_number); +""" + def _serialize(messages: list[Message]) -> str: return json.dumps( @@ -149,12 +184,14 @@ files=files, metadata=metadata, is_recall=bool(row.get("is_recall", False)), + sequence_number=row.get("sequence_number", 0), ) async def _ensure_normalized_tables(conn: asyncpg.Connection) -> None: await conn.execute(_SESSION_MESSAGES_DDL) await conn.execute(_SESSION_IMAGES_DDL) + await conn.execute(_SESSION_ARCHIVE_DDL) async def _migrate_to_normalized(conn: asyncpg.Connection) -> None: @@ -227,16 +264,24 @@ conn: asyncpg.Connection, rows: list[asyncpg.Record], ) -> list[Session]: - """Hydrate session rows with messages from the normalized table.""" + """Hydrate session rows with hot (non-archived) messages.""" session_ids = [r["id"] for r in rows] messages_map = await _load_messages_map(conn, session_ids) sessions: list[Session] = [] for row in rows: all_msgs = messages_map.get(row["id"], []) - messages = [m for m in all_msgs if m.is_display] - context = [m for m in all_msgs if m.is_context] + archive_threshold = row.get("archive_threshold", 0) or 0 + hot_msgs = [m for m in all_msgs if m.sequence_number >= archive_threshold] + messages = [m for m in hot_msgs if m.is_display] + context = [m for m in hot_msgs if m.is_context] planning_logs_raw = row["planning_logs"] planning_logs = json.loads(planning_logs_raw) if planning_logs_raw else [] + + next_seq = row.get("next_sequence", 0) or 0 + max_seq = max((m.sequence_number for m in hot_msgs), default=-1) + if next_seq == 0: + next_seq = max_seq + 1 + s = Session( id=row["id"], profile_id=row["profile_id"], @@ -250,7 +295,8 @@ context_token_count=row["context_token_count"] or 0, planning_logs=planning_logs, ) - s.db_message_count = len(all_msgs) + s.db_message_count = len(hot_msgs) + s.db_next_sequence = next_seq sessions.append(s) return sessions @@ -289,18 +335,20 @@ pool = await self._get_pool() async with pool.acquire() as conn: row = await conn.fetchrow( - "SELECT id, profile_id, user_id, pinned, created_at, last_active, context_token_count, name, planning_logs " + "SELECT id, profile_id, user_id, pinned, created_at, last_active, context_token_count, name, planning_logs, next_sequence, archive_threshold " "FROM sessions WHERE id = $1", session_id, ) if not row: return None - # Load all messages once so messages[] and context[] share + archive_threshold = row["archive_threshold"] or 0 + # Load hot (non-archived) messages so messages[] and context[] share # the same Python objects (id() matching in the agent works). all_rows = await conn.fetch( - "SELECT * FROM session_messages WHERE session_id = $1 ORDER BY sequence_number", + "SELECT * FROM session_messages WHERE session_id = $1 AND sequence_number >= $2 ORDER BY sequence_number", session_id, + archive_threshold, ) all_messages = [_row_to_message(r) for r in all_rows] messages = [m for m in all_messages if m.is_display] @@ -309,6 +357,11 @@ planning_logs_raw = row["planning_logs"] planning_logs = json.loads(planning_logs_raw) if planning_logs_raw else [] + next_seq = row["next_sequence"] or 0 + max_seq = max((m.sequence_number for m in all_messages), default=-1) + if next_seq == 0: + next_seq = max_seq + 1 + s = Session( id=row["id"], profile_id=row["profile_id"], @@ -323,6 +376,7 @@ planning_logs=planning_logs, ) s.db_message_count = len(all_messages) + s.db_next_sequence = next_seq return s async def save(self, session: Session) -> None: @@ -337,19 +391,18 @@ json.dumps(session.planning_logs, ensure_ascii=False), session.id, ) - db_count = session.db_message_count + db_next = session.db_next_sequence messages = session.messages - # 1. Update existing rows for mutable flags - if db_count > 0: + # 1. Update mutable flags for already-persisted rows (sequence_number >= 0) + existing = [m for m in messages if m.sequence_number >= 0] + if existing: update_rows = [] - limit = min(db_count, len(messages)) - for seq in range(limit): - m = messages[seq] + for m in existing: update_rows.append( ( session.id, - seq, + m.sequence_number, m.is_context, m.is_display, m.is_summary, @@ -362,38 +415,31 @@ m.token_count, ) ) - if update_rows: - await conn.executemany( - """ - UPDATE session_messages - SET is_context = $3, - is_display = $4, - is_summary = $5, - is_plan = $6, - is_compression = $7, - is_recall = $8, - thinking = $9, - elapsed_seconds = $10, - tool_call_count = $11, - token_count = $12 - WHERE session_id = $1 AND sequence_number = $2 - """, - update_rows, - ) - - # 2. Safety delete in case rows were added concurrently - if db_count > 0: - await conn.execute( - "DELETE FROM session_messages WHERE session_id = $1 AND sequence_number >= $2", - session.id, - db_count, + await conn.executemany( + """ + UPDATE session_messages + SET is_context = $3, + is_display = $4, + is_summary = $5, + is_plan = $6, + is_compression = $7, + is_recall = $8, + thinking = $9, + elapsed_seconds = $10, + tool_call_count = $11, + token_count = $12 + WHERE session_id = $1 AND sequence_number = $2 + """, + update_rows, ) - # 3. Insert only new messages - if len(messages) > db_count: + # 2. Insert new messages (sequence_number < 0 means "not yet persisted") + new_msgs = [m for m in messages if m.sequence_number < 0] + if new_msgs: insert_rows = [] - for seq in range(db_count, len(messages)): - m = messages[seq] + for i, m in enumerate(new_msgs): + seq = db_next + i + m.sequence_number = seq insert_rows.append( ( session.id, @@ -429,9 +475,46 @@ """, insert_rows, ) + new_next = db_next + len(new_msgs) + await conn.execute( + "UPDATE sessions SET next_sequence = $1 WHERE id = $2", + new_next, session.id, + ) + session.db_next_sequence = new_next session.db_message_count = len(messages) + async def archive_old_messages(self, session_id: str, keep_seq_threshold: int) -> int: + """Move messages older than keep_seq_threshold from hot to archive table. + + Returns number of rows archived. + """ + pool = await self._get_pool() + async with pool.acquire() as conn: + # Copy old rows to archive + copied = await conn.execute( + """ + INSERT INTO session_messages_archive + SELECT * FROM session_messages + WHERE session_id = $1 AND sequence_number < $2 + ON CONFLICT (session_id, sequence_number) DO NOTHING + """, + session_id, keep_seq_threshold, + ) + # Delete from hot table + await conn.execute( + "DELETE FROM session_messages WHERE session_id = $1 AND sequence_number < $2", + session_id, keep_seq_threshold, + ) + # Update threshold on session + await conn.execute( + "UPDATE sessions SET archive_threshold = $1 WHERE id = $2", + keep_seq_threshold, session_id, + ) + # asyncpg execute returns 'INSERT 0 N' — extract N + parts = copied.split() + return int(parts[-1]) if len(parts) >= 2 else 0 + async def set_pinned(self, session_id: str, pinned: bool) -> bool: pool = await self._get_pool() async with pool.acquire() as conn: @@ -455,13 +538,13 @@ async with pool.acquire() as conn: if not is_admin and user_id is not None: rows = await conn.fetch( - "SELECT id, profile_id, user_id, pinned, created_at, last_active, context_token_count, name, planning_logs " + "SELECT id, profile_id, user_id, pinned, created_at, last_active, context_token_count, name, planning_logs, next_sequence, archive_threshold " "FROM sessions WHERE user_id = $1 ORDER BY pinned DESC, last_active DESC", user_id, ) else: rows = await conn.fetch( - "SELECT id, profile_id, user_id, pinned, created_at, last_active, context_token_count, name, planning_logs " + "SELECT id, profile_id, user_id, pinned, created_at, last_active, context_token_count, name, planning_logs, next_sequence, archive_threshold " "FROM sessions ORDER BY pinned DESC, last_active DESC" ) return await _build_sessions(conn, rows) diff --git a/navi/core/session.py b/navi/core/session.py index e61681f..f94926b 100644 --- a/navi/core/session.py +++ b/navi/core/session.py @@ -22,6 +22,7 @@ last_active: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) planning_logs: list[dict] = Field(default_factory=list) # raw planning phase outputs per turn db_message_count: int = Field(default=0, exclude=True) # messages already persisted in DB + db_next_sequence: int = Field(default=0, exclude=True) # next global sequence_number for this session class SessionStore(ABC): @@ -74,11 +75,17 @@ async def delete(self, session_id: str) -> bool: ... @abstractmethod + async def delete(self, session_id: str) -> bool: ... + + @abstractmethod async def set_pinned(self, session_id: str, pinned: bool) -> bool: ... @abstractmethod async def set_name(self, session_id: str, name: str) -> bool: ... + @abstractmethod + async def archive_old_messages(self, session_id: str, keep_seq_threshold: int) -> int: ... + class InMemorySessionStore(SessionStore): def __init__(self) -> None: @@ -187,3 +194,7 @@ return False s.name = name return True + + async def archive_old_messages(self, session_id: str, keep_seq_threshold: int) -> int: + # In-memory store: no-op, everything stays in RAM + return 0 diff --git a/navi/llm/base.py b/navi/llm/base.py index dd4aa1b..10fb8b9 100644 --- a/navi/llm/base.py +++ b/navi/llm/base.py @@ -61,6 +61,9 @@ # normalized storage flags — written by the agent, read by PgSessionStore is_context: bool = True # included in LLM context is_display: bool = True # shown in chat UI + # DB sequence number — set by PgSessionStore on load, used for delta-save. + # -1 means "not yet persisted" (new messages created by the agent). + sequence_number: int = Field(default=-1, exclude=True) class LLMResponse(BaseModel):