"""PostgreSQL-backed session store using asyncpg connection pool.
Normalized storage:
- session_messages — one row per Message, with is_display / is_context flags
- session_images — base64 images referenced by session_messages (future use)
All session message I/O goes through the normalized table. The legacy JSON
columns (sessions.messages, sessions.context) remain in the schema for backward
compatibility but are no longer read or written.
"""
import asyncio
import json
from datetime import datetime, timezone
import asyncpg
from navi.llm.base import Message, ToolCallRequest
from .session import Session, SessionStore
_DDL = """
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
profile_id TEXT NOT NULL,
user_id TEXT REFERENCES navi_users(id) ON DELETE SET NULL,
messages TEXT NOT NULL DEFAULT '[]',
context TEXT NOT NULL DEFAULT '',
pinned BOOLEAN NOT NULL DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL,
last_active TIMESTAMPTZ NOT NULL,
context_token_count INTEGER NOT NULL DEFAULT 0,
name TEXT
)
"""
_MIGRATE = """
ALTER TABLE sessions ADD COLUMN IF NOT EXISTS name TEXT;
ALTER TABLE sessions ADD COLUMN IF NOT EXISTS planning_logs TEXT NOT NULL DEFAULT '[]';
ALTER TABLE sessions ADD COLUMN IF NOT EXISTS user_id TEXT REFERENCES navi_users(id) ON DELETE SET NULL
"""
_SESSION_MESSAGES_DDL = """
CREATE TABLE IF NOT EXISTS session_messages (
id SERIAL PRIMARY KEY,
session_id TEXT NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
sequence_number INT NOT NULL,
role TEXT NOT NULL,
content TEXT,
images TEXT, -- JSON array of base64 strings
tool_calls TEXT, -- JSON
tool_call_id TEXT,
name TEXT,
created_at TIMESTAMPTZ,
is_summary BOOLEAN NOT NULL DEFAULT FALSE,
thinking TEXT,
is_plan BOOLEAN NOT NULL DEFAULT FALSE,
is_compression BOOLEAN NOT NULL DEFAULT FALSE,
is_context BOOLEAN NOT NULL DEFAULT TRUE,
is_display BOOLEAN NOT NULL DEFAULT TRUE,
elapsed_seconds FLOAT,
tool_call_count INT,
token_count INT,
files TEXT, -- JSON
metadata TEXT, -- JSON
is_recall BOOLEAN NOT NULL DEFAULT FALSE,
UNIQUE(session_id, sequence_number)
);
CREATE INDEX IF NOT EXISTS idx_session_messages_session_seq ON session_messages(session_id, sequence_number);
CREATE INDEX IF NOT EXISTS idx_session_messages_context ON session_messages(session_id, is_context, sequence_number);
"""
_SESSION_IMAGES_DDL = """
CREATE TABLE IF NOT EXISTS session_images (
id SERIAL PRIMARY KEY,
session_id TEXT NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
message_id INT REFERENCES session_messages(id) ON DELETE CASCADE,
base64 TEXT NOT NULL,
filename TEXT
);
CREATE INDEX IF NOT EXISTS idx_session_images_session ON session_images(session_id);
CREATE INDEX IF NOT EXISTS idx_session_images_message ON session_images(message_id);
"""
def _serialize(messages: list[Message]) -> str:
return json.dumps(
[m.model_dump(mode="json", exclude_none=True) for m in messages],
ensure_ascii=False,
)
def _deserialize(raw: str) -> list[Message]:
if not raw:
return []
return [Message.model_validate(m) for m in json.loads(raw)]
def _message_key(m: Message) -> tuple:
"""Stable key for matching a message between messages[] and context[].
Used only by the one-shot boot migration for legacy JSON rows.
"""
return (
m.role,
m.content,
m.tool_call_id,
m.name,
m.is_summary,
m.is_plan,
m.is_compression,
m.is_recall,
m.thinking,
m.created_at.isoformat() if m.created_at else None,
json.dumps([tc.model_dump(mode="json") for tc in m.tool_calls], ensure_ascii=False) if m.tool_calls else None,
json.dumps(m.files, ensure_ascii=False) if m.files else None,
json.dumps(m.metadata, ensure_ascii=False) if m.metadata else None,
)
def _row_to_message(row: asyncpg.Record) -> Message:
images = json.loads(row["images"]) if row.get("images") else None
tool_calls = None
raw_tool_calls = row.get("tool_calls")
if raw_tool_calls:
tool_calls = [ToolCallRequest.model_validate(tc) for tc in json.loads(raw_tool_calls)]
files = json.loads(row["files"]) if row.get("files") else None
metadata = json.loads(row["metadata"]) if row.get("metadata") else {}
return Message(
role=row["role"],
content=row["content"],
images=images,
tool_calls=tool_calls,
tool_call_id=row.get("tool_call_id"),
name=row.get("name"),
created_at=row.get("created_at"),
is_summary=bool(row.get("is_summary", False)),
thinking=row.get("thinking"),
is_plan=bool(row.get("is_plan", False)),
is_compression=bool(row.get("is_compression", False)),
is_context=bool(row.get("is_context", True)),
is_display=bool(row.get("is_display", True)),
elapsed_seconds=row.get("elapsed_seconds"),
tool_call_count=row.get("tool_call_count"),
token_count=row.get("token_count"),
files=files,
metadata=metadata,
is_recall=bool(row.get("is_recall", False)),
)
async def _ensure_normalized_tables(conn: asyncpg.Connection) -> None:
await conn.execute(_SESSION_MESSAGES_DDL)
await conn.execute(_SESSION_IMAGES_DDL)
async def _migrate_to_normalized(conn: asyncpg.Connection) -> None:
"""One-shot migration: copy existing JSON messages/context into session_messages."""
migrated = await conn.fetchval("SELECT COUNT(*) FROM session_messages")
if migrated:
return
rows = await conn.fetch("SELECT id, messages, context FROM sessions")
for row in rows:
session_id = row["id"]
messages = _deserialize(row["messages"] or "[]")
context = _deserialize(row["context"] or "")
if not context and messages:
context = list(messages)
context_set = {_message_key(m) for m in context}
for seq, m in enumerate(messages):
is_ctx = _message_key(m) in context_set
await conn.execute(
"""
INSERT INTO session_messages
(session_id, sequence_number, role, content, images, tool_calls, tool_call_id, name,
created_at, is_summary, thinking, is_plan, is_compression, is_context, is_display,
elapsed_seconds, tool_call_count, token_count, files, metadata, is_recall)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21)
ON CONFLICT (session_id, sequence_number) DO NOTHING
""",
session_id,
seq,
m.role,
m.content,
json.dumps(m.images, ensure_ascii=False) if m.images else None,
json.dumps([tc.model_dump(mode="json") for tc in m.tool_calls], ensure_ascii=False) if m.tool_calls else None,
m.tool_call_id,
m.name,
m.created_at,
m.is_summary,
m.thinking,
m.is_plan,
m.is_compression,
is_ctx,
True,
m.elapsed_seconds,
m.tool_call_count,
m.token_count,
json.dumps(m.files, ensure_ascii=False) if m.files else None,
json.dumps(m.metadata, ensure_ascii=False) if m.metadata else None,
m.is_recall,
)
async def _load_messages_map(conn: asyncpg.Connection, session_ids: list[str]) -> dict[str, list[Message]]:
"""Batch-load all messages for the given session IDs."""
if not session_ids:
return {}
rows = await conn.fetch(
"SELECT * FROM session_messages WHERE session_id = ANY($1) ORDER BY sequence_number",
session_ids,
)
result: dict[str, list[Message]] = {sid: [] for sid in session_ids}
for row in rows:
sid = row["session_id"]
result[sid].append(_row_to_message(row))
return result
async def _build_sessions(
conn: asyncpg.Connection,
rows: list[asyncpg.Record],
) -> list[Session]:
"""Hydrate session rows with messages from the normalized table."""
session_ids = [r["id"] for r in rows]
messages_map = await _load_messages_map(conn, session_ids)
sessions: list[Session] = []
for row in rows:
all_msgs = messages_map.get(row["id"], [])
messages = [m for m in all_msgs if m.is_display]
context = [m for m in all_msgs if m.is_context]
planning_logs_raw = row["planning_logs"]
planning_logs = json.loads(planning_logs_raw) if planning_logs_raw else []
s = Session(
id=row["id"],
profile_id=row["profile_id"],
user_id=row["user_id"],
messages=messages,
context=context,
pinned=bool(row["pinned"]),
name=row["name"],
created_at=row["created_at"],
last_active=row["last_active"],
context_token_count=row["context_token_count"] or 0,
planning_logs=planning_logs,
)
s.db_message_count = len(all_msgs)
sessions.append(s)
return sessions
class PgSessionStore(SessionStore):
def __init__(self, pool: asyncpg.Pool) -> None:
self._pool = pool
self._initialized = False
self._lock = asyncio.Lock()
async def _get_pool(self) -> asyncpg.Pool:
if not self._initialized:
async with self._lock:
if not self._initialized:
async with self._pool.acquire() as conn:
await conn.execute(_DDL)
await conn.execute(_MIGRATE)
await _ensure_normalized_tables(conn)
await _migrate_to_normalized(conn)
self._initialized = True
return self._pool
async def create(self, profile_id: str, user_id: str | None = None) -> Session:
session = Session(profile_id=profile_id, user_id=user_id)
pool = await self._get_pool()
async with pool.acquire() as conn:
await conn.execute(
"INSERT INTO sessions "
"(id, profile_id, user_id, pinned, created_at, last_active, context_token_count) "
"VALUES ($1, $2, $3, FALSE, $4, $5, 0)",
session.id, session.profile_id, session.user_id, session.created_at, session.last_active,
)
return session
async def get(self, session_id: str) -> Session | None:
pool = await self._get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT id, profile_id, user_id, pinned, created_at, last_active, context_token_count, name, planning_logs "
"FROM sessions WHERE id = $1",
session_id,
)
if not row:
return None
# Load all messages once so messages[] and context[] share
# the same Python objects (id() matching in the agent works).
all_rows = await conn.fetch(
"SELECT * FROM session_messages WHERE session_id = $1 ORDER BY sequence_number",
session_id,
)
all_messages = [_row_to_message(r) for r in all_rows]
messages = [m for m in all_messages if m.is_display]
context = [m for m in all_messages if m.is_context]
planning_logs_raw = row["planning_logs"]
planning_logs = json.loads(planning_logs_raw) if planning_logs_raw else []
s = Session(
id=row["id"],
profile_id=row["profile_id"],
user_id=row["user_id"],
messages=messages,
context=context,
pinned=bool(row["pinned"]),
name=row["name"],
created_at=row["created_at"],
last_active=row["last_active"],
context_token_count=row["context_token_count"] or 0,
planning_logs=planning_logs,
)
s.db_message_count = len(all_messages)
return s
async def save(self, session: Session) -> None:
session.last_active = datetime.now(timezone.utc)
pool = await self._get_pool()
async with pool.acquire() as conn:
await conn.execute(
"UPDATE sessions SET profile_id = $1, user_id = $2, "
"last_active = $3, context_token_count = $4, planning_logs = $5 WHERE id = $6",
session.profile_id, session.user_id,
session.last_active, session.context_token_count,
json.dumps(session.planning_logs, ensure_ascii=False), session.id,
)
db_count = session.db_message_count
messages = session.messages
# 1. Update existing rows for mutable flags
if db_count > 0:
update_rows = []
limit = min(db_count, len(messages))
for seq in range(limit):
m = messages[seq]
update_rows.append(
(
session.id,
seq,
m.is_context,
m.is_display,
m.is_summary,
m.is_plan,
m.is_compression,
m.is_recall,
m.thinking,
m.elapsed_seconds,
m.tool_call_count,
m.token_count,
)
)
if update_rows:
await conn.executemany(
"""
UPDATE session_messages
SET is_context = $3,
is_display = $4,
is_summary = $5,
is_plan = $6,
is_compression = $7,
is_recall = $8,
thinking = $9,
elapsed_seconds = $10,
tool_call_count = $11,
token_count = $12
WHERE session_id = $1 AND sequence_number = $2
""",
update_rows,
)
# 2. Safety delete in case rows were added concurrently
if db_count > 0:
await conn.execute(
"DELETE FROM session_messages WHERE session_id = $1 AND sequence_number >= $2",
session.id,
db_count,
)
# 3. Insert only new messages
if len(messages) > db_count:
insert_rows = []
for seq in range(db_count, len(messages)):
m = messages[seq]
insert_rows.append(
(
session.id,
seq,
m.role,
m.content,
json.dumps(m.images, ensure_ascii=False) if m.images else None,
json.dumps([tc.model_dump(mode="json") for tc in m.tool_calls], ensure_ascii=False) if m.tool_calls else None,
m.tool_call_id,
m.name,
m.created_at,
m.is_summary,
m.thinking,
m.is_plan,
m.is_compression,
m.is_context,
m.is_display,
m.elapsed_seconds,
m.tool_call_count,
m.token_count,
json.dumps(m.files, ensure_ascii=False) if m.files else None,
json.dumps(m.metadata, ensure_ascii=False) if m.metadata else None,
m.is_recall,
)
)
await conn.executemany(
"""
INSERT INTO session_messages
(session_id, sequence_number, role, content, images, tool_calls, tool_call_id, name,
created_at, is_summary, thinking, is_plan, is_compression, is_context, is_display,
elapsed_seconds, tool_call_count, token_count, files, metadata, is_recall)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21)
""",
insert_rows,
)
session.db_message_count = len(messages)
async def set_pinned(self, session_id: str, pinned: bool) -> bool:
pool = await self._get_pool()
async with pool.acquire() as conn:
result = await conn.execute(
"UPDATE sessions SET pinned = $1 WHERE id = $2",
pinned, session_id,
)
return result == "UPDATE 1"
async def set_name(self, session_id: str, name: str) -> bool:
pool = await self._get_pool()
async with pool.acquire() as conn:
result = await conn.execute(
"UPDATE sessions SET name = $1 WHERE id = $2",
name, session_id,
)
return result == "UPDATE 1"
async def list_all(self, user_id: str | None = None, is_admin: bool = False) -> list[Session]:
pool = await self._get_pool()
async with pool.acquire() as conn:
if not is_admin and user_id is not None:
rows = await conn.fetch(
"SELECT id, profile_id, user_id, pinned, created_at, last_active, context_token_count, name, planning_logs "
"FROM sessions WHERE user_id = $1 ORDER BY pinned DESC, last_active DESC",
user_id,
)
else:
rows = await conn.fetch(
"SELECT id, profile_id, user_id, pinned, created_at, last_active, context_token_count, name, planning_logs "
"FROM sessions ORDER BY pinned DESC, last_active DESC"
)
return await _build_sessions(conn, rows)
async def list_page(
self,
*,
limit: int,
offset: int,
profile_id: str | None = None,
user_id: str | None = None,
is_admin: bool = False,
) -> list[Session]:
pool = await self._get_pool()
async with pool.acquire() as conn:
conditions = []
params: list = []
param_idx = 0
def add_param(value):
nonlocal param_idx
param_idx += 1
params.append(value)
return f"${param_idx}"
if not is_admin and user_id is not None:
conditions.append(f"user_id = {add_param(user_id)}")
if profile_id:
conditions.append(f"profile_id = {add_param(profile_id)}")
where = "WHERE " + " AND ".join(conditions) if conditions else ""
order_limit = f"ORDER BY pinned DESC, last_active DESC LIMIT {add_param(limit)} OFFSET {add_param(offset)}"
rows = await conn.fetch(
"SELECT id, profile_id, user_id, pinned, created_at, last_active, context_token_count, name, planning_logs "
f"FROM sessions {where} {order_limit}",
*params,
)
return await _build_sessions(conn, rows)
async def delete(self, session_id: str) -> bool:
pool = await self._get_pool()
async with pool.acquire() as conn:
result = await conn.execute("DELETE FROM sessions WHERE id = $1", session_id)
return result == "DELETE 1"
async def count_all(
self,
*,
user_id: str | None = None,
is_admin: bool = False,
search: str | None = None,
) -> int:
pool = await self._get_pool()
async with pool.acquire() as conn:
conditions = []
params: list = []
param_idx = 0
def add_param(value):
nonlocal param_idx
param_idx += 1
params.append(value)
return f"${param_idx}"
if not is_admin and user_id is not None:
conditions.append(f"user_id = {add_param(user_id)}")
if search:
like = f"%{search}%"
conditions.append(
f"(id ILIKE {add_param(like)} OR name ILIKE {add_param(like)} OR user_id ILIKE {add_param(like)} OR profile_id ILIKE {add_param(like)} OR EXISTS (SELECT 1 FROM session_messages m WHERE m.session_id = sessions.id AND m.is_display = true AND m.content ILIKE {add_param(like)}))"
)
where = "WHERE " + " AND ".join(conditions) if conditions else ""
row = await conn.fetchrow(f"SELECT COUNT(*) FROM sessions {where}", *params)
return row["count"] if row else 0
async def search_list(
self,
*,
limit: int,
offset: int,
user_id: str | None = None,
is_admin: bool = False,
search: str | None = None,
sort_by: str = "last_active",
sort_order: str = "desc",
) -> list[Session]:
pool = await self._get_pool()
async with pool.acquire() as conn:
conditions = []
params: list = []
param_idx = 0
def add_param(value):
nonlocal param_idx
param_idx += 1
params.append(value)
return f"${param_idx}"
if not is_admin and user_id is not None:
conditions.append(f"user_id = {add_param(user_id)}")
if search:
like = f"%{search}%"
conditions.append(
f"(id ILIKE {add_param(like)} OR name ILIKE {add_param(like)} OR user_id ILIKE {add_param(like)} OR profile_id ILIKE {add_param(like)} OR EXISTS (SELECT 1 FROM session_messages m WHERE m.session_id = sessions.id AND m.is_display = true AND m.content ILIKE {add_param(like)}))"
)
where = "WHERE " + " AND ".join(conditions) if conditions else ""
allowed_cols = {"created_at", "last_active", "name", "profile_id", "user_id", "pinned"}
col = sort_by if sort_by in allowed_cols else "last_active"
order = "DESC" if sort_order == "desc" else "ASC"
# secondary sort by pinned DESC for stability
order_clause = f"ORDER BY pinned DESC, {col} {order} LIMIT {add_param(limit)} OFFSET {add_param(offset)}"
rows = await conn.fetch(
"SELECT id, profile_id, user_id, pinned, created_at, last_active, context_token_count, name, planning_logs "
f"FROM sessions {where} {order_clause}",
*params,
)
return await _build_sessions(conn, rows)