diff --git a/navi/core/pg_session_store.py b/navi/core/pg_session_store.py index 660add9..3aa074c 100644 --- a/navi/core/pg_session_store.py +++ b/navi/core/pg_session_store.py @@ -237,21 +237,21 @@ context = [m for m in all_msgs if m.is_context] planning_logs_raw = row["planning_logs"] planning_logs = json.loads(planning_logs_raw) if planning_logs_raw else [] - sessions.append( - Session( - id=row["id"], - profile_id=row["profile_id"], - user_id=row["user_id"], - messages=messages, - context=context, - pinned=bool(row["pinned"]), - name=row["name"], - created_at=row["created_at"], - last_active=row["last_active"], - context_token_count=row["context_token_count"] or 0, - planning_logs=planning_logs, - ) + s = Session( + id=row["id"], + profile_id=row["profile_id"], + user_id=row["user_id"], + messages=messages, + context=context, + pinned=bool(row["pinned"]), + name=row["name"], + created_at=row["created_at"], + last_active=row["last_active"], + context_token_count=row["context_token_count"] or 0, + planning_logs=planning_logs, ) + s.db_message_count = len(all_msgs) + sessions.append(s) return sessions @@ -309,7 +309,7 @@ planning_logs_raw = row["planning_logs"] planning_logs = json.loads(planning_logs_raw) if planning_logs_raw else [] - return Session( + s = Session( id=row["id"], profile_id=row["profile_id"], user_id=row["user_id"], @@ -322,6 +322,8 @@ context_token_count=row["context_token_count"] or 0, planning_logs=planning_logs, ) + s.db_message_count = len(all_messages) + return s async def save(self, session: Session) -> None: session.last_active = datetime.now(timezone.utc) @@ -335,11 +337,89 @@ json.dumps(session.planning_logs, ensure_ascii=False), session.id, ) - # Normalized rows — write session.messages with their flags. - await conn.execute("DELETE FROM session_messages WHERE session_id = $1", session.id) + db_count = session.db_message_count + messages = session.messages - for seq, m in enumerate(session.messages): + # 1. Update existing rows for mutable flags + if db_count > 0: + update_rows = [] + limit = min(db_count, len(messages)) + for seq in range(limit): + m = messages[seq] + update_rows.append( + ( + session.id, + seq, + m.is_context, + m.is_display, + m.is_summary, + m.is_plan, + m.is_compression, + m.is_recall, + m.thinking, + m.elapsed_seconds, + m.tool_call_count, + m.token_count, + ) + ) + if update_rows: + await conn.executemany( + """ + UPDATE session_messages + SET is_context = $3, + is_display = $4, + is_summary = $5, + is_plan = $6, + is_compression = $7, + is_recall = $8, + thinking = $9, + elapsed_seconds = $10, + tool_call_count = $11, + token_count = $12 + WHERE session_id = $1 AND sequence_number = $2 + """, + update_rows, + ) + + # 2. Safety delete in case rows were added concurrently + if db_count > 0: await conn.execute( + "DELETE FROM session_messages WHERE session_id = $1 AND sequence_number >= $2", + session.id, + db_count, + ) + + # 3. Insert only new messages + if len(messages) > db_count: + insert_rows = [] + for seq in range(db_count, len(messages)): + m = messages[seq] + insert_rows.append( + ( + session.id, + seq, + m.role, + m.content, + json.dumps(m.images, ensure_ascii=False) if m.images else None, + json.dumps([tc.model_dump(mode="json") for tc in m.tool_calls], ensure_ascii=False) if m.tool_calls else None, + m.tool_call_id, + m.name, + m.created_at, + m.is_summary, + m.thinking, + m.is_plan, + m.is_compression, + m.is_context, + m.is_display, + m.elapsed_seconds, + m.tool_call_count, + m.token_count, + json.dumps(m.files, ensure_ascii=False) if m.files else None, + json.dumps(m.metadata, ensure_ascii=False) if m.metadata else None, + m.is_recall, + ) + ) + await conn.executemany( """ INSERT INTO session_messages (session_id, sequence_number, role, content, images, tool_calls, tool_call_id, name, @@ -347,29 +427,11 @@ elapsed_seconds, tool_call_count, token_count, files, metadata, is_recall) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21) """, - session.id, - seq, - m.role, - m.content, - json.dumps(m.images, ensure_ascii=False) if m.images else None, - json.dumps([tc.model_dump(mode="json") for tc in m.tool_calls], ensure_ascii=False) if m.tool_calls else None, - m.tool_call_id, - m.name, - m.created_at, - m.is_summary, - m.thinking, - m.is_plan, - m.is_compression, - m.is_context, - m.is_display, - m.elapsed_seconds, - m.tool_call_count, - m.token_count, - json.dumps(m.files, ensure_ascii=False) if m.files else None, - json.dumps(m.metadata, ensure_ascii=False) if m.metadata else None, - m.is_recall, + insert_rows, ) + session.db_message_count = len(messages) + async def set_pinned(self, session_id: str, pinned: bool) -> bool: pool = await self._get_pool() async with pool.acquire() as conn: diff --git a/navi/core/session.py b/navi/core/session.py index e0e263e..e61681f 100644 --- a/navi/core/session.py +++ b/navi/core/session.py @@ -21,6 +21,7 @@ created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) last_active: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) planning_logs: list[dict] = Field(default_factory=list) # raw planning phase outputs per turn + db_message_count: int = Field(default=0, exclude=True) # messages already persisted in DB class SessionStore(ABC):