Newer
Older
vmk-360-data_collector / tests / unit / test_queue_worker.py
@Eugene Sukhodolskiy Eugene Sukhodolskiy 1 day ago 6 KB fix: code review critical and high issues
"""Unit tests for QueueWorker."""

import asyncio
from typing import Any
from unittest.mock import AsyncMock, MagicMock

import pytest

from vmk_data_collector.schemas.raw_data import IngestResponse
from vmk_data_collector.services.queue_worker import QueueWorker


def _make_execute_side_effect(raw: Any | None, times: int = 1):
    """Return a side_effect for session.execute that yields *raw* N times."""
    count = 0

    def side_effect(*_a, **_k):
        nonlocal count
        count += 1
        result = MagicMock()
        if count <= times:
            result.scalar_one_or_none.return_value = raw
        else:
            result.scalar_one_or_none.return_value = None
        return result

    return side_effect


@pytest.fixture
def mock_pipeline() -> AsyncMock:
    pipeline = AsyncMock()
    pipeline.process.return_value = IngestResponse(
        job_id=1,
        status="completed",
        message="ok",
    )
    return pipeline


@pytest.fixture
def mock_session(mock_async_session: AsyncMock) -> AsyncMock:
    """Configure the shared mock_async_session for queue tests."""
    result_mock = MagicMock()
    mock_async_session.execute.return_value = result_mock
    return mock_async_session


@pytest.fixture
def mock_session_factory(mock_session: AsyncMock) -> MagicMock:
    factory = MagicMock()
    factory.return_value = mock_session
    return factory


@pytest.fixture
def mock_pipeline_factory(mock_pipeline: AsyncMock) -> MagicMock:
    factory = MagicMock()
    factory.return_value = mock_pipeline
    return factory


@pytest.fixture
def stop_event() -> asyncio.Event:
    return asyncio.Event()


@pytest.fixture
def worker(
    mock_session_factory: MagicMock,
    mock_pipeline_factory: MagicMock,
    stop_event: asyncio.Event,
) -> QueueWorker:
    return QueueWorker(
        session_factory=mock_session_factory,
        pipeline_factory=mock_pipeline_factory,
        poll_interval=0.05,
        stop_event=stop_event,
    )


class TestProcessing:
    @pytest.mark.asyncio
    async def test_processes_pending_job(
        self,
        worker: QueueWorker,
        mock_session: AsyncMock,
        mock_pipeline: AsyncMock,
    ) -> None:
        raw = MagicMock()
        raw.id = 42
        mock_session.execute.side_effect = _make_execute_side_effect(raw, times=1)

        stop = asyncio.Event()
        worker._stop_event = stop

        task = asyncio.create_task(worker.run())
        await asyncio.sleep(0.1)
        stop.set()
        await asyncio.wait_for(task, timeout=1)

        mock_pipeline.process.assert_awaited_once_with(42)
        mock_session.commit.assert_awaited()

    @pytest.mark.asyncio
    async def test_invalid_result_status(
        self,
        worker: QueueWorker,
        mock_session: AsyncMock,
        mock_pipeline: AsyncMock,
    ) -> None:
        raw = MagicMock()
        raw.id = 42
        mock_session.execute.side_effect = _make_execute_side_effect(raw, times=1)
        mock_pipeline.process.return_value = IngestResponse(
            job_id=42,
            status="invalid",
            reason="not real estate",
            message="bad",
        )

        stop = asyncio.Event()
        worker._stop_event = stop

        task = asyncio.create_task(worker.run())
        await asyncio.sleep(0.1)
        stop.set()
        await asyncio.wait_for(task, timeout=1)

        mock_pipeline.process.assert_awaited_once_with(42)

    @pytest.mark.asyncio
    async def test_failed_result_status(
        self,
        worker: QueueWorker,
        mock_session: AsyncMock,
        mock_pipeline: AsyncMock,
    ) -> None:
        raw = MagicMock()
        raw.id = 42
        mock_session.execute.side_effect = _make_execute_side_effect(raw, times=1)
        mock_pipeline.process.return_value = IngestResponse(
            job_id=42,
            status="failed",
            reason="err",
            message="bad",
        )

        stop = asyncio.Event()
        worker._stop_event = stop

        task = asyncio.create_task(worker.run())
        await asyncio.sleep(0.1)
        stop.set()
        await asyncio.wait_for(task, timeout=1)

        mock_pipeline.process.assert_awaited_once_with(42)

    @pytest.mark.asyncio
    async def test_pipeline_exception_rollback_and_mark_failed(
        self,
        worker: QueueWorker,
        mock_session: AsyncMock,
        mock_pipeline: AsyncMock,
    ) -> None:
        raw = MagicMock()
        raw.id = 42
        mock_session.execute.side_effect = _make_execute_side_effect(raw, times=1)
        mock_pipeline.process.side_effect = RuntimeError("boom")

        stop = asyncio.Event()
        worker._stop_event = stop

        task = asyncio.create_task(worker.run())
        await asyncio.sleep(0.1)
        stop.set()
        await asyncio.wait_for(task, timeout=1)

        mock_session.rollback.assert_awaited_once()
        mock_session.commit.assert_awaited()


class TestPollingAndShutdown:
    @pytest.mark.asyncio
    async def test_no_jobs_waits_for_stop(
        self,
        worker: QueueWorker,
        mock_session: AsyncMock,
    ) -> None:
        mock_session.execute.side_effect = _make_execute_side_effect(
            None, times=0
        )

        stop = asyncio.Event()
        worker._stop_event = stop

        task = asyncio.create_task(worker.run())
        await asyncio.sleep(0.08)
        assert not task.done()
        stop.set()
        await asyncio.wait_for(task, timeout=1)

    @pytest.mark.asyncio
    async def test_graceful_stop_during_processing(
        self,
        worker: QueueWorker,
        mock_session: AsyncMock,
        mock_pipeline: AsyncMock,
    ) -> None:
        raw = MagicMock()
        raw.id = 42
        mock_session.execute.side_effect = _make_execute_side_effect(raw, times=1)

        async def slow_process(_id: int) -> IngestResponse:
            await asyncio.sleep(0.2)
            return IngestResponse(
                job_id=_id, status="completed", message="ok"
            )

        mock_pipeline.process.side_effect = slow_process

        stop = asyncio.Event()
        worker._stop_event = stop

        task = asyncio.create_task(worker.run())
        await asyncio.sleep(0.05)  # let pipeline start
        stop.set()
        await asyncio.wait_for(task, timeout=1)

        mock_pipeline.process.assert_awaited_once_with(42)