diff --git a/navi/core/pg_session_store.py b/navi/core/pg_session_store.py index a1d55ca..54dd7fb 100644 --- a/navi/core/pg_session_store.py +++ b/navi/core/pg_session_store.py @@ -342,21 +342,201 @@ ) return result == "UPDATE 1" +async def _load_messages_map(conn: asyncpg.Connection, session_ids: list[str]) -> dict[str, list[Message]]: + """Batch-load all messages for the given session IDs.""" + if not session_ids: + return {} + rows = await conn.fetch( + "SELECT * FROM session_messages WHERE session_id = ANY($1) ORDER BY sequence_number", + session_ids, + ) + result: dict[str, list[Message]] = {sid: [] for sid in session_ids} + for row in rows: + sid = row["session_id"] + result[sid].append(_row_to_message(row)) + return result + + +async def _build_sessions( + conn: asyncpg.Connection, + rows: list[asyncpg.Record], +) -> list[Session]: + """Hydrate session rows with messages from the normalized table.""" + session_ids = [r["id"] for r in rows] + messages_map = await _load_messages_map(conn, session_ids) + sessions: list[Session] = [] + for row in rows: + messages = messages_map.get(row["id"], []) + context = [m for m in messages if m.is_context] + planning_logs_raw = row["planning_logs"] + planning_logs = json.loads(planning_logs_raw) if planning_logs_raw else [] + sessions.append( + Session( + id=row["id"], + profile_id=row["profile_id"], + user_id=row["user_id"], + messages=messages, + context=context, + pinned=bool(row["pinned"]), + name=row["name"], + created_at=row["created_at"], + last_active=row["last_active"], + context_token_count=row["context_token_count"] or 0, + planning_logs=planning_logs, + ) + ) + return sessions + + +class PgSessionStore(SessionStore): + def __init__(self, pool: asyncpg.Pool) -> None: + self._pool = pool + self._initialized = False + self._lock = asyncio.Lock() + + async def _get_pool(self) -> asyncpg.Pool: + if not self._initialized: + async with self._lock: + if not self._initialized: + async with self._pool.acquire() as conn: + await conn.execute(_DDL) + await conn.execute(_MIGRATE) + await _ensure_normalized_tables(conn) + await _migrate_to_normalized(conn) + self._initialized = True + return self._pool + + async def create(self, profile_id: str, user_id: str | None = None) -> Session: + session = Session(profile_id=profile_id, user_id=user_id) + pool = await self._get_pool() + async with pool.acquire() as conn: + await conn.execute( + "INSERT INTO sessions " + "(id, profile_id, user_id, messages, context, pinned, created_at, last_active, context_token_count) " + "VALUES ($1, $2, $3, '[]', '', FALSE, $4, $5, 0)", + session.id, session.profile_id, session.user_id, session.created_at, session.last_active, + ) + return session + + async def get(self, session_id: str) -> Session | None: + pool = await self._get_pool() + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT id, profile_id, user_id, pinned, created_at, last_active, context_token_count, name, planning_logs " + "FROM sessions WHERE id = $1", + session_id, + ) + if not row: + return None + + # Load all messages once so messages[] and context[] share + # the same Python objects (id() matching in the agent works). + all_rows = await conn.fetch( + "SELECT * FROM session_messages WHERE session_id = $1 ORDER BY sequence_number", + session_id, + ) + all_messages = [_row_to_message(r) for r in all_rows] + messages = [m for m in all_messages if m.is_display] + context = [m for m in all_messages if m.is_context] + + planning_logs_raw = row["planning_logs"] + planning_logs = json.loads(planning_logs_raw) if planning_logs_raw else [] + + return Session( + id=row["id"], + profile_id=row["profile_id"], + user_id=row["user_id"], + messages=messages, + context=context, + pinned=bool(row["pinned"]), + name=row["name"], + created_at=row["created_at"], + last_active=row["last_active"], + context_token_count=row["context_token_count"] or 0, + planning_logs=planning_logs, + ) + + async def save(self, session: Session) -> None: + session.last_active = datetime.now(timezone.utc) + pool = await self._get_pool() + async with pool.acquire() as conn: + # Legacy JSON columns (dual-write for rollback safety) + await conn.execute( + "UPDATE sessions SET profile_id = $1, user_id = $2, messages = $3, context = $4, " + "last_active = $5, context_token_count = $6, planning_logs = $7 WHERE id = $8", + session.profile_id, session.user_id, _serialize(session.messages), _serialize(session.context), + session.last_active, session.context_token_count, + json.dumps(session.planning_logs, ensure_ascii=False), session.id, + ) + + # Normalized rows — write session.messages with their flags. + await conn.execute("DELETE FROM session_messages WHERE session_id = $1", session.id) + + for seq, m in enumerate(session.messages): + await conn.execute( + """ + INSERT INTO session_messages + (session_id, sequence_number, role, content, images, tool_calls, tool_call_id, name, + created_at, is_summary, thinking, is_plan, is_compression, is_context, is_display, + elapsed_seconds, tool_call_count, token_count, files, metadata, is_recall) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21) + """, + session.id, + seq, + m.role, + m.content, + json.dumps(m.images, ensure_ascii=False) if m.images else None, + json.dumps([tc.model_dump(mode="json") for tc in m.tool_calls], ensure_ascii=False) if m.tool_calls else None, + m.tool_call_id, + m.name, + m.created_at, + m.is_summary, + m.thinking, + m.is_plan, + m.is_compression, + m.is_context, + m.is_display, + m.elapsed_seconds, + m.tool_call_count, + m.token_count, + json.dumps(m.files, ensure_ascii=False) if m.files else None, + json.dumps(m.metadata, ensure_ascii=False) if m.metadata else None, + m.is_recall, + ) + + async def set_pinned(self, session_id: str, pinned: bool) -> bool: + pool = await self._get_pool() + async with pool.acquire() as conn: + result = await conn.execute( + "UPDATE sessions SET pinned = $1 WHERE id = $2", + pinned, session_id, + ) + return result == "UPDATE 1" + + async def set_name(self, session_id: str, name: str) -> bool: + pool = await self._get_pool() + async with pool.acquire() as conn: + result = await conn.execute( + "UPDATE sessions SET name = $1 WHERE id = $2", + name, session_id, + ) + return result == "UPDATE 1" + async def list_all(self, user_id: str | None = None, is_admin: bool = False) -> list[Session]: pool = await self._get_pool() async with pool.acquire() as conn: if not is_admin and user_id is not None: rows = await conn.fetch( - "SELECT id, profile_id, user_id, messages, context, pinned, created_at, last_active, context_token_count, name, planning_logs " + "SELECT id, profile_id, user_id, pinned, created_at, last_active, context_token_count, name, planning_logs " "FROM sessions WHERE user_id = $1 ORDER BY pinned DESC, last_active DESC", user_id, ) else: rows = await conn.fetch( - "SELECT id, profile_id, user_id, messages, context, pinned, created_at, last_active, context_token_count, name, planning_logs " + "SELECT id, profile_id, user_id, pinned, created_at, last_active, context_token_count, name, planning_logs " "FROM sessions ORDER BY pinned DESC, last_active DESC" ) - return [self._row_to_session(r) for r in rows] + return await _build_sessions(conn, rows) async def list_page( self, @@ -388,11 +568,11 @@ order_limit = f"ORDER BY pinned DESC, last_active DESC LIMIT {add_param(limit)} OFFSET {add_param(offset)}" rows = await conn.fetch( - "SELECT id, profile_id, user_id, messages, context, pinned, created_at, last_active, context_token_count, name, planning_logs " + "SELECT id, profile_id, user_id, pinned, created_at, last_active, context_token_count, name, planning_logs " f"FROM sessions {where} {order_limit}", *params, ) - return [self._row_to_session(r) for r in rows] + return await _build_sessions(conn, rows) async def delete(self, session_id: str) -> bool: pool = await self._get_pool() @@ -424,7 +604,7 @@ if search: like = f"%{search}%" conditions.append( - f"(id ILIKE {add_param(like)} OR name ILIKE {add_param(like)} OR user_id ILIKE {add_param(like)} OR profile_id ILIKE {add_param(like)} OR messages ILIKE {add_param(like)})" + f"(id ILIKE {add_param(like)} OR name ILIKE {add_param(like)} OR user_id ILIKE {add_param(like)} OR profile_id ILIKE {add_param(like)} OR EXISTS (SELECT 1 FROM session_messages m WHERE m.session_id = sessions.id AND m.content ILIKE {add_param(like)}))" ) where = "WHERE " + " AND ".join(conditions) if conditions else "" @@ -459,7 +639,7 @@ if search: like = f"%{search}%" conditions.append( - f"(id ILIKE {add_param(like)} OR name ILIKE {add_param(like)} OR user_id ILIKE {add_param(like)} OR profile_id ILIKE {add_param(like)} OR messages ILIKE {add_param(like)})" + f"(id ILIKE {add_param(like)} OR name ILIKE {add_param(like)} OR user_id ILIKE {add_param(like)} OR profile_id ILIKE {add_param(like)} OR EXISTS (SELECT 1 FROM session_messages m WHERE m.session_id = sessions.id AND m.content ILIKE {add_param(like)}))" ) where = "WHERE " + " AND ".join(conditions) if conditions else "" @@ -471,32 +651,8 @@ order_clause = f"ORDER BY pinned DESC, {col} {order} LIMIT {add_param(limit)} OFFSET {add_param(offset)}" rows = await conn.fetch( - "SELECT id, profile_id, user_id, messages, context, pinned, created_at, last_active, context_token_count, name, planning_logs " + "SELECT id, profile_id, user_id, pinned, created_at, last_active, context_token_count, name, planning_logs " f"FROM sessions {where} {order_clause}", *params, ) - return [self._row_to_session(r) for r in rows] - - def _row_to_session(self, row: asyncpg.Record) -> Session: - # Prefer normalized rows; fallback to legacy JSON columns - session_id = row["id"] - # We can't easily do another query here (no async in sync method), - # so rely on the caller path (get) or legacy JSON (list paths). - messages = _deserialize(row["messages"]) - context_json = row["context"] - context = _deserialize(context_json) if context_json else list(messages) - planning_logs_raw = row["planning_logs"] - planning_logs = json.loads(planning_logs_raw) if planning_logs_raw else [] - return Session( - id=row["id"], - profile_id=row["profile_id"], - user_id=row["user_id"], - messages=messages, - context=context, - pinned=bool(row["pinned"]), - name=row["name"], - created_at=row["created_at"], - last_active=row["last_active"], - context_token_count=row["context_token_count"] or 0, - planning_logs=planning_logs, - ) + return await _build_sessions(conn, rows)