"""Unit tests for QueueWorker."""
import asyncio
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from vmk_data_collector.schemas.raw_data import IngestResponse
from vmk_data_collector.services.queue_worker import QueueWorker
def _make_execute_side_effect(raw: Any | None, times: int = 1):
"""Return a side_effect for session.execute that yields *raw* N times."""
count = 0
def side_effect(*_a, **_k):
nonlocal count
count += 1
result = MagicMock()
if count <= times:
result.scalar_one_or_none.return_value = raw
else:
result.scalar_one_or_none.return_value = None
return result
return side_effect
@pytest.fixture
def mock_pipeline() -> AsyncMock:
pipeline = AsyncMock()
pipeline.process.return_value = IngestResponse(
job_id=1,
status="completed",
message="ok",
)
return pipeline
@pytest.fixture
def mock_session(mock_async_session: AsyncMock) -> AsyncMock:
"""Configure the shared mock_async_session for queue tests."""
result_mock = MagicMock()
mock_async_session.execute.return_value = result_mock
return mock_async_session
@pytest.fixture
def mock_session_factory(mock_session: AsyncMock) -> MagicMock:
factory = MagicMock()
factory.return_value = mock_session
return factory
@pytest.fixture
def mock_pipeline_factory(mock_pipeline: AsyncMock) -> MagicMock:
factory = MagicMock()
factory.return_value = mock_pipeline
return factory
@pytest.fixture
def stop_event() -> asyncio.Event:
return asyncio.Event()
@pytest.fixture
def worker(
mock_session_factory: MagicMock,
mock_pipeline_factory: MagicMock,
stop_event: asyncio.Event,
) -> QueueWorker:
return QueueWorker(
session_factory=mock_session_factory,
pipeline_factory=mock_pipeline_factory,
poll_interval=0.05,
stop_event=stop_event,
)
class TestProcessing:
@pytest.mark.asyncio
async def test_processes_pending_job(
self,
worker: QueueWorker,
mock_session: AsyncMock,
mock_pipeline: AsyncMock,
) -> None:
raw = MagicMock()
raw.id = 42
mock_session.execute.side_effect = _make_execute_side_effect(raw, times=1)
stop = asyncio.Event()
worker._stop_event = stop
task = asyncio.create_task(worker.run())
await asyncio.sleep(0.1)
stop.set()
await asyncio.wait_for(task, timeout=1)
mock_pipeline.process.assert_awaited_once_with(42)
mock_session.commit.assert_awaited()
@pytest.mark.asyncio
async def test_invalid_result_status(
self,
worker: QueueWorker,
mock_session: AsyncMock,
mock_pipeline: AsyncMock,
) -> None:
raw = MagicMock()
raw.id = 42
mock_session.execute.side_effect = _make_execute_side_effect(raw, times=1)
mock_pipeline.process.return_value = IngestResponse(
job_id=42,
status="invalid",
reason="not real estate",
message="bad",
)
stop = asyncio.Event()
worker._stop_event = stop
task = asyncio.create_task(worker.run())
await asyncio.sleep(0.1)
stop.set()
await asyncio.wait_for(task, timeout=1)
mock_pipeline.process.assert_awaited_once_with(42)
@pytest.mark.asyncio
async def test_failed_result_status(
self,
worker: QueueWorker,
mock_session: AsyncMock,
mock_pipeline: AsyncMock,
) -> None:
raw = MagicMock()
raw.id = 42
mock_session.execute.side_effect = _make_execute_side_effect(raw, times=1)
mock_pipeline.process.return_value = IngestResponse(
job_id=42,
status="failed",
reason="err",
message="bad",
)
stop = asyncio.Event()
worker._stop_event = stop
task = asyncio.create_task(worker.run())
await asyncio.sleep(0.1)
stop.set()
await asyncio.wait_for(task, timeout=1)
mock_pipeline.process.assert_awaited_once_with(42)
@pytest.mark.asyncio
async def test_pipeline_exception_rollback_and_mark_failed(
self,
worker: QueueWorker,
mock_session: AsyncMock,
mock_pipeline: AsyncMock,
) -> None:
raw = MagicMock()
raw.id = 42
mock_session.execute.side_effect = _make_execute_side_effect(raw, times=1)
mock_pipeline.process.side_effect = RuntimeError("boom")
stop = asyncio.Event()
worker._stop_event = stop
task = asyncio.create_task(worker.run())
await asyncio.sleep(0.1)
stop.set()
await asyncio.wait_for(task, timeout=1)
mock_session.rollback.assert_awaited_once()
mock_session.commit.assert_awaited()
class TestPollingAndShutdown:
@pytest.mark.asyncio
async def test_no_jobs_waits_for_stop(
self,
worker: QueueWorker,
mock_session: AsyncMock,
) -> None:
mock_session.execute.side_effect = _make_execute_side_effect(
None, times=0
)
stop = asyncio.Event()
worker._stop_event = stop
task = asyncio.create_task(worker.run())
await asyncio.sleep(0.08)
assert not task.done()
stop.set()
await asyncio.wait_for(task, timeout=1)
@pytest.mark.asyncio
async def test_graceful_stop_during_processing(
self,
worker: QueueWorker,
mock_session: AsyncMock,
mock_pipeline: AsyncMock,
) -> None:
raw = MagicMock()
raw.id = 42
mock_session.execute.side_effect = _make_execute_side_effect(raw, times=1)
async def slow_process(_id: int) -> IngestResponse:
await asyncio.sleep(0.2)
return IngestResponse(
job_id=_id, status="completed", message="ok"
)
mock_pipeline.process.side_effect = slow_process
stop = asyncio.Event()
worker._stop_event = stop
task = asyncio.create_task(worker.run())
await asyncio.sleep(0.05) # let pipeline start
stop.set()
await asyncio.wait_for(task, timeout=1)
mock_pipeline.process.assert_awaited_once_with(42)