Newer
Older
vmk-360-data_collector / src / vmk_data_collector / services / property_pipeline.py
@Eugene Sukhodolskiy Eugene Sukhodolskiy 1 day ago 17 KB feat: implement review items 8-14
import time
from dataclasses import dataclass, field
from typing import Any

import structlog
from sqlalchemy import inspect

from vmk_data_collector.core.circuit_breaker import CircuitBreakerOpenError
from vmk_data_collector.core.metrics import (
    pipeline_duration_seconds,
    pipeline_results_total,
)
from vmk_data_collector.db.repositories.ai_enrichment import (
    AiEnrichmentRepository,
)
from vmk_data_collector.db.repositories.custom_field import (
    CustomFieldRepository,
)
from vmk_data_collector.db.repositories.data_source import (
    DataSourceRepository,
)
from vmk_data_collector.db.repositories.image import ImageRepository
from vmk_data_collector.db.repositories.property import PropertyRepository
from vmk_data_collector.db.repositories.property_type import (
    PropertyTypeRepository,
)
from vmk_data_collector.db.repositories.raw_data import RawDataRepository
from vmk_data_collector.db.repositories.snapshot import SnapshotRepository
from vmk_data_collector.domain.entities import NormalizedProperty
from vmk_data_collector.domain.enums import RawDataStatus
from vmk_data_collector.schemas.raw_data import IngestResponse
from vmk_data_collector.services.ai_enricher import AiEnricher
from vmk_data_collector.services.ai_image_analyzer import AiImageAnalyzer
from vmk_data_collector.services.ai_normalizer import AiNormalizer
from vmk_data_collector.services.image_downloader import ImageDownloader
from vmk_data_collector.services.ollama_client import OllamaClient

logger = structlog.get_logger()


@dataclass
class PipelineContext:
    raw: Any | None = None
    norm_response: Any | None = None
    normalized: NormalizedProperty | None = None
    data_source: Any | None = None
    existing_listing: Any | None = None
    property_id: int | None = None
    snapshot_id: int | None = None
    aggregated_analysis: dict[str, Any] = field(default_factory=dict)
    enrichment: Any | None = None


class PropertyPipeline:
    def __init__(
        self,
        raw_repo: RawDataRepository,
        property_repo: PropertyRepository,
        image_repo: ImageRepository,
        custom_field_repo: CustomFieldRepository,
        snapshot_repo: SnapshotRepository,
        enrichment_repo: AiEnrichmentRepository,
        data_source_repo: DataSourceRepository,
        property_type_repo: PropertyTypeRepository,
        normalizer: AiNormalizer,
        image_downloader: ImageDownloader,
        image_analyzer: AiImageAnalyzer,
        enricher: AiEnricher,
        active_jobs: set[int] | None = None,
    ) -> None:
        self._raw_repo = raw_repo
        self._property_repo = property_repo
        self._image_repo = image_repo
        self._custom_field_repo = custom_field_repo
        self._snapshot_repo = snapshot_repo
        self._enrichment_repo = enrichment_repo
        self._data_source_repo = data_source_repo
        self._property_type_repo = property_type_repo
        self._active_jobs = active_jobs
        self._normalizer = normalizer
        self._image_downloader = image_downloader
        self._image_analyzer = image_analyzer
        self._enricher = enricher

    async def process(self, raw_data_id: int) -> IngestResponse:
        if self._active_jobs is not None:
            self._active_jobs.add(raw_data_id)
        start = time.perf_counter()
        context = PipelineContext()

        try:
            # Stage 1: Load raw data
            context.raw = await self._stage_load_raw(raw_data_id)

            # Stage 2: Normalize via AI
            context.norm_response = await self._stage_normalize(
                raw_data_id, context.raw
            )
            if not context.norm_response.is_real_estate:
                pipeline_results_total.labels(status="invalid").inc()
                return IngestResponse(
                    job_id=raw_data_id,
                    status="invalid",
                    reason=context.norm_response.reason,
                    message="Payload is not real estate",
                )

            context.normalized = NormalizedProperty(
                **(context.norm_response.normalized or {})
            )

            # Stage 3: Resolve source and existing listing
            (
                context.data_source,
                context.existing_listing,
            ) = await self._stage_resolve_source_and_listing(
                raw_data_id, context.raw, context.normalized
            )

            # Stage 4: Persist listing (create or update)
            (
                context.property_id,
                context.snapshot_id,
            ) = await self._stage_persist_listing(
                raw_data_id,
                context.raw,
                context.normalized,
                context.data_source,
                context.existing_listing,
            )

            # Stage 5: Custom fields
            await self._stage_persist_custom_fields(
                context.property_id,
                context.normalized.custom_fields,
            )

            # Stage 6: Download and analyze images
            context.aggregated_analysis = await self._stage_process_images(
                context.property_id,
                context.normalized.images,
            )

            # Stage 7: AI enrichment
            context.enrichment = await self._stage_enrich(
                context.property_id,
                context.normalized,
                context.aggregated_analysis,
            )

            # Stage 8: Finalize
            result = await self._stage_finalize(
                raw_data_id,
                context.property_id,
                context.snapshot_id,
            )
            pipeline_results_total.labels(status="completed").inc()
            return result

        except CircuitBreakerOpenError as exc:
            pipeline_results_total.labels(status="failed").inc()
            logger.warning(
                "pipeline_circuit_breaker_open",
                raw_id=raw_data_id,
                error=str(exc),
            )
            await self._raw_repo.update_status(
                raw_data_id,
                RawDataStatus.failed,
                error_message=f"Circuit breaker open: {exc}",
            )
            return IngestResponse(
                job_id=raw_data_id,
                status="failed",
                reason="circuit_breaker_open",
                message=f"Ollama circuit breaker is open: {exc}",
            )
        except Exception as exc:
            pipeline_results_total.labels(status="failed").inc()
            logger.error(
                "pipeline_unhandled_error",
                raw_id=raw_data_id,
                error=str(exc),
                exc_info=True,
            )
            await self._raw_repo.update_status(
                raw_data_id,
                RawDataStatus.failed,
                error_message=str(exc),
            )
            return IngestResponse(
                job_id=raw_data_id,
                status="failed",
                reason="pipeline_error",
                message=f"Pipeline error: {exc}",
            )
        finally:
            pipeline_duration_seconds.observe(time.perf_counter() - start)
            if self._active_jobs is not None:
                self._active_jobs.discard(raw_data_id)

    # ------------------------------------------------------------------
    # Stages
    # ------------------------------------------------------------------

    async def _stage_load_raw(self, raw_data_id: int) -> Any:
        raw = await self._raw_repo.get_by_id(raw_data_id)
        if raw is None:
            raise ValueError(f"Raw data {raw_data_id} not found")
        await self._raw_repo.update_status(
            raw_data_id, RawDataStatus.processing
        )
        return raw

    async def _stage_normalize(
        self, raw_data_id: int, raw: Any
    ) -> Any:
        try:
            return await self._normalizer.normalize(raw.payload)
        except CircuitBreakerOpenError:
            raise
        except Exception as exc:
            logger.error(
                "pipeline_normalizer_error",
                raw_id=raw_data_id,
                error=str(exc),
            )
            await self._raw_repo.update_status(
                raw_data_id,
                RawDataStatus.failed,
                error_message=str(exc),
            )
            raise

    async def _stage_resolve_source_and_listing(
        self,
        raw_data_id: int,
        raw: Any,
        normalized: NormalizedProperty,
    ) -> tuple[Any, Any]:
        source_slug = raw.payload.get("source_slug", "unknown")
        data_source = await self._data_source_repo.get_or_create_by_slug(
            source_slug, name=source_slug
        )
        raw.source_id = data_source.id

        existing = await self._property_repo.get_by_source_and_external(
            data_source.id, raw.external_id
        )
        return data_source, existing

    async def _stage_persist_listing(
        self,
        raw_data_id: int,
        raw: Any,
        normalized: NormalizedProperty,
        data_source: Any,
        existing: Any,
    ) -> tuple[int, int | None]:
        listing_kwargs = self._normalized_to_kwargs(normalized)
        if normalized.property_type:
            property_type = await self._property_type_repo.get_or_create_by_slug(
                normalized.property_type,
                name=normalized.property_type,
            )
            listing_kwargs["property_type_id"] = property_type.id
        listing_kwargs.pop("property_type", None)
        listing_kwargs["source_id"] = data_source.id
        listing_kwargs["external_id"] = raw.external_id
        listing_kwargs["raw_data_id"] = raw_data_id

        snapshot_id: int | None = None
        if existing:
            snapshot_data = self._listing_to_dict(existing)
            changed_fields = self._compute_changed_fields(
                snapshot_data, listing_kwargs
            )
            snapshot = await self._snapshot_repo.create(
                existing.id, snapshot_data, changed_fields
            )
            snapshot_id = snapshot.id
            await self._property_repo.update(existing.id, **listing_kwargs)
            property_id = existing.id
            await self._custom_field_repo.delete_by_property(property_id)
        else:
            property_obj = await self._property_repo.create(
                **listing_kwargs
            )
            property_id = property_obj.id

        return property_id, snapshot_id

    async def _stage_persist_custom_fields(
        self,
        property_id: int,
        custom_fields: dict[str, Any],
    ) -> None:
        if not custom_fields:
            return
        fields = [
            {
                "field_name": k,
                "field_value": str(v),
                "field_type": self._infer_field_type(v),
            }
            for k, v in custom_fields.items()
        ]
        await self._custom_field_repo.bulk_create(property_id, fields)

    async def _stage_process_images(
        self,
        property_id: int,
        images: list[str],
    ) -> dict[str, Any]:
        aggregated_analysis: dict[str, Any] = {}
        for order, url in enumerate(images):
            try:
                result = await self._image_downloader.download(
                    property_id, url, order
                )
            except Exception as exc:
                logger.error(
                    "image_download_failed",
                    property_id=property_id,
                    url=url,
                    error=str(exc),
                )
                continue

            duplicate = await self._image_repo.get_by_hash(
                property_id, result.image_hash
            )
            if duplicate:
                logger.info(
                    "image_duplicate_skipped",
                    property_id=property_id,
                    hash=result.image_hash,
                )
                continue

            image = await self._image_repo.create(
                property_id, url, order
            )
            await self._image_repo.update_downloaded(
                image.id,
                result.local_path,
                result.file_size,
                result.width,
                result.height,
                result.image_hash,
            )

            if result.local_path:
                try:
                    b64 = OllamaClient.image_to_base64(
                        result.local_path
                    )
                    analysis = await self._image_analyzer.analyze(
                        b64
                    )
                    await self._image_repo.update_analysis(
                        image.id,
                        analysis.overall_condition or "",
                    )
                    aggregated_analysis[url] = (
                        analysis.model_dump()
                    )
                except Exception as exc:
                    logger.error(
                        "image_analysis_failed",
                        property_id=property_id,
                        image_id=image.id,
                        error=str(exc),
                    )

        all_images = await self._image_repo.get_by_property(property_id)
        await self._property_repo.update(
            property_id, images_count=len(all_images)
        )

        return aggregated_analysis

    async def _stage_enrich(
        self,
        property_id: int,
        normalized: NormalizedProperty,
        aggregated_analysis: dict[str, Any],
    ) -> Any:
        try:
            enrichment = await self._enricher.enrich(
                normalized, aggregated_analysis
            )
        except CircuitBreakerOpenError:
            raise
        except Exception as exc:
            logger.error(
                "pipeline_enricher_error",
                property_id=property_id,
                error=str(exc),
            )
            return None

        if enrichment:
            await self._enrichment_repo.delete_by_property(property_id)
            await self._enrichment_repo.create(
                property_id,
                extracted_features=enrichment.extracted_features,
                price_assessment=enrichment.price_assessment,
                listing_quality_score=enrichment.listing_quality_score,
                reliability_rating=enrichment.reliability_rating,
                sentiment_score=enrichment.sentiment_score,
                classification=enrichment.classification,
                image_analysis_results=aggregated_analysis,
                generated_description=enrichment.generated_description,
                summary=enrichment.summary,
                model_version=enrichment.model_version,
                processing_time_ms=enrichment.processing_time_ms,
            )
            await self._property_repo.update(
                property_id,
                listing_quality_score=enrichment.listing_quality_score,
                reliability_rating=enrichment.reliability_rating,
                sentiment_score=enrichment.sentiment_score,
                generated_description=enrichment.generated_description,
            )

        return enrichment

    async def _stage_finalize(
        self,
        raw_data_id: int,
        property_id: int,
        snapshot_id: int | None,
    ) -> IngestResponse:
        await self._raw_repo.set_processed(raw_data_id)
        return IngestResponse(
            job_id=raw_data_id,
            property_id=property_id,
            status="completed",
            message="Property ingested successfully",
            snapshot_id=snapshot_id,
        )

    # ------------------------------------------------------------------
    # Helpers
    # ------------------------------------------------------------------

    @staticmethod
    def _normalized_to_kwargs(
        normalized: NormalizedProperty,
    ) -> dict[str, Any]:
        return {
            k: v
            for k, v in normalized.__dict__.items()
            if k not in ("images", "custom_fields")
        }

    @staticmethod
    def _listing_to_dict(listing: Any) -> dict[str, Any]:
        mapper = inspect(listing).mapper
        return {
            col.name: PropertyPipeline._serialize_value(
                getattr(listing, col.name)
            )
            for col in mapper.columns
        }

    @staticmethod
    def _serialize_value(value: Any) -> Any:
        from datetime import datetime
        from decimal import Decimal
        from enum import Enum

        if value is None:
            return None
        if isinstance(value, Decimal):
            return float(value)
        if isinstance(value, datetime):
            return value.isoformat()
        if isinstance(value, Enum):
            return value.value
        return value

    @staticmethod
    def _compute_changed_fields(
        old: dict[str, Any],
        new: dict[str, Any],
    ) -> dict[str, Any]:
        changed: dict[str, Any] = {}
        for key, new_val in new.items():
            old_val = old.get(key)
            if old_val != new_val:
                changed[key] = {"old": old_val, "new": new_val}
        return changed

    @staticmethod
    def _infer_field_type(value: Any) -> str:
        if isinstance(value, bool):
            return "bool"
        if isinstance(value, int):
            return "int"
        if isinstance(value, float):
            return "float"
        return "str"