diff --git a/navi/core/pg_session_store.py b/navi/core/pg_session_store.py index 54dd7fb..c342274 100644 --- a/navi/core/pg_session_store.py +++ b/navi/core/pg_session_store.py @@ -232,8 +232,8 @@ async with pool.acquire() as conn: await conn.execute( "INSERT INTO sessions " - "(id, profile_id, user_id, messages, context, pinned, created_at, last_active, context_token_count) " - "VALUES ($1, $2, $3, '[]', '', FALSE, $4, $5, 0)", + "(id, profile_id, user_id, pinned, created_at, last_active, context_token_count) " + "VALUES ($1, $2, $3, FALSE, $4, $5, 0)", session.id, session.profile_id, session.user_id, session.created_at, session.last_active, ) return session @@ -280,191 +280,10 @@ session.last_active = datetime.now(timezone.utc) pool = await self._get_pool() async with pool.acquire() as conn: - # Legacy JSON columns (dual-write for rollback safety) await conn.execute( - "UPDATE sessions SET profile_id = $1, user_id = $2, messages = $3, context = $4, " - "last_active = $5, context_token_count = $6, planning_logs = $7 WHERE id = $8", - session.profile_id, session.user_id, _serialize(session.messages), _serialize(session.context), - session.last_active, session.context_token_count, - json.dumps(session.planning_logs, ensure_ascii=False), session.id, - ) - - # Normalized rows — write session.messages with their flags. - await conn.execute("DELETE FROM session_messages WHERE session_id = $1", session.id) - - for seq, m in enumerate(session.messages): - await conn.execute( - """ - INSERT INTO session_messages - (session_id, sequence_number, role, content, images, tool_calls, tool_call_id, name, - created_at, is_summary, thinking, is_plan, is_compression, is_context, is_display, - elapsed_seconds, tool_call_count, token_count, files, metadata, is_recall) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21) - """, - session.id, - seq, - m.role, - m.content, - json.dumps(m.images, ensure_ascii=False) if m.images else None, - json.dumps([tc.model_dump(mode="json") for tc in m.tool_calls], ensure_ascii=False) if m.tool_calls else None, - m.tool_call_id, - m.name, - m.created_at, - m.is_summary, - m.thinking, - m.is_plan, - m.is_compression, - m.is_context, - m.is_display, - m.elapsed_seconds, - m.tool_call_count, - m.token_count, - json.dumps(m.files, ensure_ascii=False) if m.files else None, - json.dumps(m.metadata, ensure_ascii=False) if m.metadata else None, - m.is_recall, - ) - - async def set_pinned(self, session_id: str, pinned: bool) -> bool: - pool = await self._get_pool() - async with pool.acquire() as conn: - result = await conn.execute( - "UPDATE sessions SET pinned = $1 WHERE id = $2", - pinned, session_id, - ) - return result == "UPDATE 1" - - async def set_name(self, session_id: str, name: str) -> bool: - pool = await self._get_pool() - async with pool.acquire() as conn: - result = await conn.execute( - "UPDATE sessions SET name = $1 WHERE id = $2", - name, session_id, - ) - return result == "UPDATE 1" - -async def _load_messages_map(conn: asyncpg.Connection, session_ids: list[str]) -> dict[str, list[Message]]: - """Batch-load all messages for the given session IDs.""" - if not session_ids: - return {} - rows = await conn.fetch( - "SELECT * FROM session_messages WHERE session_id = ANY($1) ORDER BY sequence_number", - session_ids, - ) - result: dict[str, list[Message]] = {sid: [] for sid in session_ids} - for row in rows: - sid = row["session_id"] - result[sid].append(_row_to_message(row)) - return result - - -async def _build_sessions( - conn: asyncpg.Connection, - rows: list[asyncpg.Record], -) -> list[Session]: - """Hydrate session rows with messages from the normalized table.""" - session_ids = [r["id"] for r in rows] - messages_map = await _load_messages_map(conn, session_ids) - sessions: list[Session] = [] - for row in rows: - messages = messages_map.get(row["id"], []) - context = [m for m in messages if m.is_context] - planning_logs_raw = row["planning_logs"] - planning_logs = json.loads(planning_logs_raw) if planning_logs_raw else [] - sessions.append( - Session( - id=row["id"], - profile_id=row["profile_id"], - user_id=row["user_id"], - messages=messages, - context=context, - pinned=bool(row["pinned"]), - name=row["name"], - created_at=row["created_at"], - last_active=row["last_active"], - context_token_count=row["context_token_count"] or 0, - planning_logs=planning_logs, - ) - ) - return sessions - - -class PgSessionStore(SessionStore): - def __init__(self, pool: asyncpg.Pool) -> None: - self._pool = pool - self._initialized = False - self._lock = asyncio.Lock() - - async def _get_pool(self) -> asyncpg.Pool: - if not self._initialized: - async with self._lock: - if not self._initialized: - async with self._pool.acquire() as conn: - await conn.execute(_DDL) - await conn.execute(_MIGRATE) - await _ensure_normalized_tables(conn) - await _migrate_to_normalized(conn) - self._initialized = True - return self._pool - - async def create(self, profile_id: str, user_id: str | None = None) -> Session: - session = Session(profile_id=profile_id, user_id=user_id) - pool = await self._get_pool() - async with pool.acquire() as conn: - await conn.execute( - "INSERT INTO sessions " - "(id, profile_id, user_id, messages, context, pinned, created_at, last_active, context_token_count) " - "VALUES ($1, $2, $3, '[]', '', FALSE, $4, $5, 0)", - session.id, session.profile_id, session.user_id, session.created_at, session.last_active, - ) - return session - - async def get(self, session_id: str) -> Session | None: - pool = await self._get_pool() - async with pool.acquire() as conn: - row = await conn.fetchrow( - "SELECT id, profile_id, user_id, pinned, created_at, last_active, context_token_count, name, planning_logs " - "FROM sessions WHERE id = $1", - session_id, - ) - if not row: - return None - - # Load all messages once so messages[] and context[] share - # the same Python objects (id() matching in the agent works). - all_rows = await conn.fetch( - "SELECT * FROM session_messages WHERE session_id = $1 ORDER BY sequence_number", - session_id, - ) - all_messages = [_row_to_message(r) for r in all_rows] - messages = [m for m in all_messages if m.is_display] - context = [m for m in all_messages if m.is_context] - - planning_logs_raw = row["planning_logs"] - planning_logs = json.loads(planning_logs_raw) if planning_logs_raw else [] - - return Session( - id=row["id"], - profile_id=row["profile_id"], - user_id=row["user_id"], - messages=messages, - context=context, - pinned=bool(row["pinned"]), - name=row["name"], - created_at=row["created_at"], - last_active=row["last_active"], - context_token_count=row["context_token_count"] or 0, - planning_logs=planning_logs, - ) - - async def save(self, session: Session) -> None: - session.last_active = datetime.now(timezone.utc) - pool = await self._get_pool() - async with pool.acquire() as conn: - # Legacy JSON columns (dual-write for rollback safety) - await conn.execute( - "UPDATE sessions SET profile_id = $1, user_id = $2, messages = $3, context = $4, " - "last_active = $5, context_token_count = $6, planning_logs = $7 WHERE id = $8", - session.profile_id, session.user_id, _serialize(session.messages), _serialize(session.context), + "UPDATE sessions SET profile_id = $1, user_id = $2, " + "last_active = $3, context_token_count = $4, planning_logs = $5 WHERE id = $6", + session.profile_id, session.user_id, session.last_active, session.context_token_count, json.dumps(session.planning_logs, ensure_ascii=False), session.id, )