Newer
Older
vmk-360-data_collector / scripts / backfill_embeddings.py
@Eugene Sukhodolskiy Eugene Sukhodolskiy 1 day ago 4 KB feat: add pgvector semantic search
#!/usr/bin/env python3
"""Backfill embeddings for existing property_listings.

Usage:
    python -m scripts.backfill_embeddings --batch-size 10 --limit 1000
"""

import argparse
import asyncio

import structlog
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine

from vmk_data_collector.core.config import settings
from vmk_data_collector.db.repositories.property import PropertyRepository
from vmk_data_collector.models.property_listing import PropertyListing
from vmk_data_collector.services.ollama_client import OllamaClient
from vmk_data_collector.services.property_pipeline import PropertyPipeline

logger = structlog.get_logger()


def _build_embedding_text(listing: PropertyListing) -> str:
    """Reuses pipeline text builder logic."""
    parts: list[str] = []

    if listing.deal_type:
        parts.append(f"{listing.deal_type}")
    if listing.title:
        parts.append(f"{listing.title}")
    if listing.description:
        parts.append(f"{listing.description}")
    if listing.generated_description:
        parts.append(f"{listing.generated_description}")

    location_parts: list[str] = []
    if listing.city:
        location_parts.append(f"город {listing.city}")
    if listing.district:
        location_parts.append(f"район {listing.district}")
    if listing.micro_district:
        location_parts.append(f"микрорайон {listing.micro_district}")
    if listing.street:
        location_parts.append(f"улица {listing.street}")
    if location_parts:
        parts.append(", ".join(location_parts))

    if listing.rooms_count is not None:
        parts.append(f"{listing.rooms_count} комнат")
    if listing.total_area:
        parts.append(f"площадь {listing.total_area} м²")
    if listing.floor and listing.floors_total:
        parts.append(f"этаж {listing.floor} из {listing.floors_total}")
    elif listing.floor:
        parts.append(f"этаж {listing.floor}")
    if listing.building_type:
        parts.append(f"тип дома {listing.building_type}")
    if listing.renovation_status:
        parts.append(f"ремонт {listing.renovation_status}")
    if listing.price:
        currency = listing.currency or ""
        parts.append(f"цена {listing.price} {currency}")

    return ". ".join(parts)


async def backfill(
    batch_size: int = 10,
    limit: int | None = None,
) -> None:
    engine = create_async_engine(settings.database_url_async, echo=False)
    async_session = async_sessionmaker(engine, expire_on_commit=False)
    client = OllamaClient(base_url=settings.ollama_base_url)

    async with async_session() as session:
        from sqlalchemy import select

        stmt = select(PropertyListing).where(
            PropertyListing.embedding.is_(None)
        )
        if limit:
            stmt = stmt.limit(limit)

        result = await session.execute(stmt)
        listings = result.scalars().all()
        total = len(listings)
        logger.info("backfill_start", total=total, batch_size=batch_size)

        processed = 0
        for i in range(0, total, batch_size):
            batch = listings[i : i + batch_size]
            texts = []
            for listing in batch:
                text = _build_embedding_text(listing)
                texts.append(text)

            try:
                embeddings = await client.embed(
                    model=settings.ollama_embedding_model,
                    texts=texts,
                )
            except Exception as exc:
                logger.error(
                    "backfill_batch_failed",
                    batch_start=i,
                    error=str(exc),
                )
                continue

            repo = PropertyRepository(session)
            for listing, vector in zip(batch, embeddings):
                if vector:
                    await repo.update_embedding(listing.id, vector)
                    processed += 1

            await session.commit()
            logger.info(
                "backfill_batch_done",
                processed=processed,
                total=total,
                batch_start=i,
                batch_end=min(i + batch_size, total),
            )

    await client.close()
    await engine.dispose()
    logger.info("backfill_complete", processed=processed, total=total)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Backfill embeddings for existing listings"
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=10,
        help="How many listings to embed per Ollama request",
    )
    parser.add_argument(
        "--limit",
        type=int,
        default=None,
        help="Maximum listings to process (default: all)",
    )
    args = parser.parse_args()
    asyncio.run(backfill(batch_size=args.batch_size, limit=args.limit))