Newer
Older
vmk-360-data_collector / tests / unit / test_circuit_breaker.py
@Eugene Sukhodolskiy Eugene Sukhodolskiy 1 day ago 3 KB fix: code review critical and high issues
"""Unit tests for the in-memory circuit breaker."""

import asyncio
from unittest.mock import AsyncMock

import pytest

from vmk_data_collector.core.circuit_breaker import (
    CircuitBreaker,
    CircuitBreakerOpenError,
    CircuitState,
)


@pytest.fixture
def cb() -> CircuitBreaker:
    return CircuitBreaker(
        failure_threshold=3,
        recovery_timeout=0.1,
        expected_exception=(ValueError,),
    )


class TestCircuitBreakerStates:
    def test_initial_state_is_closed(self, cb: CircuitBreaker) -> None:
        assert cb.state == CircuitState.CLOSED

    @pytest.mark.asyncio
    async def test_success_keeps_closed(self, cb: CircuitBreaker) -> None:
        func = AsyncMock(return_value="ok")
        result = await cb.call(func)
        assert result == "ok"
        assert cb.state == CircuitState.CLOSED

    @pytest.mark.asyncio
    async def test_expected_failure_increments_count(
        self, cb: CircuitBreaker
    ) -> None:
        func = AsyncMock(side_effect=ValueError("boom"))
        with pytest.raises(ValueError):
            await cb.call(func)
        assert cb._failure_count == 1
        assert cb.state == CircuitState.CLOSED

    @pytest.mark.asyncio
    async def test_failure_threshold_opens_circuit(
        self, cb: CircuitBreaker
    ) -> None:
        func = AsyncMock(side_effect=ValueError("boom"))
        for _ in range(3):
            with pytest.raises(ValueError):
                await cb.call(func)
        assert cb.state == CircuitState.OPEN

    @pytest.mark.asyncio
    async def test_open_raises_without_calling(
        self, cb: CircuitBreaker
    ) -> None:
        func = AsyncMock(side_effect=ValueError("boom"))
        for _ in range(3):
            with pytest.raises(ValueError):
                await cb.call(func)
        with pytest.raises(CircuitBreakerOpenError):
            await cb.call(func)
        assert func.call_count == 3

    @pytest.mark.asyncio
    async def test_recovery_to_half_open(self, cb: CircuitBreaker) -> None:
        func = AsyncMock(side_effect=ValueError("boom"))
        for _ in range(3):
            with pytest.raises(ValueError):
                await cb.call(func)
        assert cb.state == CircuitState.OPEN
        await asyncio.sleep(0.15)
        assert cb.state == CircuitState.HALF_OPEN

    @pytest.mark.asyncio
    async def test_half_open_success_closes(self, cb: CircuitBreaker) -> None:
        func = AsyncMock(side_effect=ValueError("boom"))
        for _ in range(3):
            with pytest.raises(ValueError):
                await cb.call(func)
        await asyncio.sleep(0.15)
        func = AsyncMock(return_value="ok")
        result = await cb.call(func)
        assert result == "ok"
        assert cb.state == CircuitState.CLOSED

    @pytest.mark.asyncio
    async def test_half_open_failure_reopens(self, cb: CircuitBreaker) -> None:
        func = AsyncMock(side_effect=ValueError("boom"))
        for _ in range(3):
            with pytest.raises(ValueError):
                await cb.call(func)
        await asyncio.sleep(0.15)
        with pytest.raises(ValueError):
            await cb.call(func)
        assert cb.state == CircuitState.OPEN

    @pytest.mark.asyncio
    async def test_unexpected_exception_not_counted(
        self, cb: CircuitBreaker
    ) -> None:
        func = AsyncMock(side_effect=RuntimeError("unexpected"))
        with pytest.raises(RuntimeError):
            await cb.call(func)
        assert cb._failure_count == 0


class TestCircuitBreakerOpenError:
    def test_default_message(self) -> None:
        exc = CircuitBreakerOpenError()
        assert "OPEN" in str(exc)

    def test_custom_message(self) -> None:
        exc = CircuitBreakerOpenError("custom")
        assert str(exc) == "custom"