"""Unit tests for the in-memory circuit breaker."""
import asyncio
from unittest.mock import AsyncMock
import pytest
from vmk_data_collector.core.circuit_breaker import (
CircuitBreaker,
CircuitBreakerOpenError,
CircuitState,
)
@pytest.fixture
def cb() -> CircuitBreaker:
return CircuitBreaker(
failure_threshold=3,
recovery_timeout=0.1,
expected_exception=(ValueError,),
)
class TestCircuitBreakerStates:
def test_initial_state_is_closed(self, cb: CircuitBreaker) -> None:
assert cb.state == CircuitState.CLOSED
@pytest.mark.asyncio
async def test_success_keeps_closed(self, cb: CircuitBreaker) -> None:
func = AsyncMock(return_value="ok")
result = await cb.call(func)
assert result == "ok"
assert cb.state == CircuitState.CLOSED
@pytest.mark.asyncio
async def test_expected_failure_increments_count(
self, cb: CircuitBreaker
) -> None:
func = AsyncMock(side_effect=ValueError("boom"))
with pytest.raises(ValueError):
await cb.call(func)
assert cb._failure_count == 1
assert cb.state == CircuitState.CLOSED
@pytest.mark.asyncio
async def test_failure_threshold_opens_circuit(
self, cb: CircuitBreaker
) -> None:
func = AsyncMock(side_effect=ValueError("boom"))
for _ in range(3):
with pytest.raises(ValueError):
await cb.call(func)
assert cb.state == CircuitState.OPEN
@pytest.mark.asyncio
async def test_open_raises_without_calling(
self, cb: CircuitBreaker
) -> None:
func = AsyncMock(side_effect=ValueError("boom"))
for _ in range(3):
with pytest.raises(ValueError):
await cb.call(func)
with pytest.raises(CircuitBreakerOpenError):
await cb.call(func)
assert func.call_count == 3
@pytest.mark.asyncio
async def test_recovery_to_half_open(self, cb: CircuitBreaker) -> None:
func = AsyncMock(side_effect=ValueError("boom"))
for _ in range(3):
with pytest.raises(ValueError):
await cb.call(func)
assert cb.state == CircuitState.OPEN
await asyncio.sleep(0.15)
assert cb.state == CircuitState.HALF_OPEN
@pytest.mark.asyncio
async def test_half_open_success_closes(self, cb: CircuitBreaker) -> None:
func = AsyncMock(side_effect=ValueError("boom"))
for _ in range(3):
with pytest.raises(ValueError):
await cb.call(func)
await asyncio.sleep(0.15)
func = AsyncMock(return_value="ok")
result = await cb.call(func)
assert result == "ok"
assert cb.state == CircuitState.CLOSED
@pytest.mark.asyncio
async def test_half_open_failure_reopens(self, cb: CircuitBreaker) -> None:
func = AsyncMock(side_effect=ValueError("boom"))
for _ in range(3):
with pytest.raises(ValueError):
await cb.call(func)
await asyncio.sleep(0.15)
with pytest.raises(ValueError):
await cb.call(func)
assert cb.state == CircuitState.OPEN
@pytest.mark.asyncio
async def test_unexpected_exception_not_counted(
self, cb: CircuitBreaker
) -> None:
func = AsyncMock(side_effect=RuntimeError("unexpected"))
with pytest.raises(RuntimeError):
await cb.call(func)
assert cb._failure_count == 0
class TestCircuitBreakerOpenError:
def test_default_message(self) -> None:
exc = CircuitBreakerOpenError()
assert "OPEN" in str(exc)
def test_custom_message(self) -> None:
exc = CircuitBreakerOpenError("custom")
assert str(exc) == "custom"