Newer
Older
vmk-360_data_mcp / src / vmk_data_mcp / db.py
"""Асинхронный пул PostgreSQL с read-only защитой и таймаутами."""

from contextlib import asynccontextmanager

import asyncpg

from vmk_data_mcp.config import settings

_pool: asyncpg.Pool | None = None

# Разрешённые команды начинаются только с SELECT / WITH / VALUES / EXPLAIN
_SAFE_PREFIXES = ("select", "with", "values", "explain")

# Колонки, доступные для выборки и фильтрации (белый список)
USER_COLUMNS = frozenset(
    {
        "id",
        "title",
        "description",
        "generated_description",
        "price",
        "currency",
        "deal_type",
        "city",
        "district",
        "rooms_count",
        "total_area",
        "living_area",
        "kitchen_area",
        "floor",
        "floors_count",
        "building_type",
        "building_year",
        "renovation_status",
        "balcony_count",
        "bathroom_type",
        "parking_type",
        "heating_type",
        "layout_type",
        "window_view",
        "metro_station",
        "metro_distance_type",
        "metro_distance_meters",
        "url_source",
        "publish_date",
        "images_count",
        "contact_phone",
        "listing_status",
        "archived_at",
        "created_at",
        "updated_at",
        "search_vector",
        "embedding",
    }
)


def _is_safe_query(sql: str) -> bool:
    """Проверяет, что запрос начинается с безопасного префикса и не содержит
    дополнительных команд после точки с запятой."""
    cleaned = sql.strip().lower()
    # Запрещаем точку с запятой внутри запроса (multi-statement)
    if ";" in cleaned.rstrip(";"):
        return False
    return any(cleaned.startswith(prefix) for prefix in _SAFE_PREFIXES)


async def _init_conn(conn: asyncpg.Connection) -> None:
    """Инициализатор нового соединения: включает read-only по умолчанию."""
    await conn.execute("SET default_transaction_read_only = on")


async def init_pool() -> asyncpg.Pool:
    """Инициализирует пул соединений с PostgreSQL."""
    global _pool
    if _pool is not None:
        return _pool

    _pool = await asyncpg.create_pool(
        settings.database_url,
        min_size=settings.db_pool_min_size,
        max_size=settings.db_pool_max_size,
        command_timeout=settings.db_query_timeout,
        init=_init_conn,
    )
    return _pool


async def close_pool() -> None:
    """Закрывает пул соединений."""
    global _pool
    if _pool is not None:
        await _pool.close()
        _pool = None


@asynccontextmanager
async def get_connection():
    """Контекстный менеджер для получения соединения из пула."""
    pool = await init_pool()
    async with pool.acquire() as conn:
        yield conn


async def fetch(sql: str, *args) -> list[asyncpg.Record]:
    """Выполняет SELECT-запрос и возвращает строки."""
    if not _is_safe_query(sql):
        raise ValueError("Only read-only queries are allowed")

    async with get_connection() as conn:
        return await conn.fetch(sql, *args)


async def fetchrow(sql: str, *args) -> asyncpg.Record | None:
    """Выполняет SELECT-запрос и возвращает одну строку."""
    if not _is_safe_query(sql):
        raise ValueError("Only read-only queries are allowed")

    async with get_connection() as conn:
        return await conn.fetchrow(sql, *args)