Newer
Older
vmk-360-data_collector / tests / unit / test_ollama_client.py
@Eugene Sukhodolskiy Eugene Sukhodolskiy 1 day ago 10 KB fix: code review critical and high issues
"""Unit tests for OllamaClient."""

import base64
import json
from io import BytesIO
from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch

import httpx
import pytest
from PIL import Image

from vmk_data_collector.core.exceptions import OllamaFatalError, OllamaRetryableError
from vmk_data_collector.services.ollama_client import OllamaClient


@pytest.fixture(autouse=True)
def _patch_before_sleep(monkeypatch: pytest.MonkeyPatch) -> None:
    """Disable tenacity before_sleep logging that breaks with structlog."""
    import vmk_data_collector.services.ollama_client as _oc_mod

    monkeypatch.setitem(
        _oc_mod._RETRY_CONFIG,
        "before_sleep",
        lambda retry_state: None,
    )


@pytest.fixture
def client() -> OllamaClient:
    return OllamaClient(base_url="http://localhost:11434", timeout=30)


class TestChat:
    @pytest.mark.asyncio
    async def test_chat_success(self, client: OllamaClient) -> None:
        mock_response = MagicMock()
        mock_response.raise_for_status = lambda: None
        mock_response.json.return_value = {
            "message": {"content": '{"key": "value"}'}
        }
        client._client.post = AsyncMock(return_value=mock_response)

        result = await client.chat(
            model="llama3",
            messages=[{"role": "user", "content": "hello"}],
            json_mode=True,
        )

        assert result == {"message": {"content": '{"key": "value"}'}}
        client._client.post.assert_awaited_once()
        payload = client._client.post.call_args.kwargs["json"]
        assert payload["model"] == "llama3"
        assert payload["format"] == "json"

    @pytest.mark.asyncio
    async def test_connect_error_is_retryable(self, client: OllamaClient) -> None:
        client._client.post = AsyncMock(
            side_effect=httpx.ConnectError("connection refused")
        )

        with pytest.raises(OllamaRetryableError, match="Connection error"):
            await client.chat(
                model="llama3",
                messages=[{"role": "user", "content": "hello"}],
            )

        # Retried up to 3 attempts
        assert client._client.post.await_count == 3

    @pytest.mark.asyncio
    async def test_timeout_is_retryable(self, client: OllamaClient) -> None:
        client._client.post = AsyncMock(
            side_effect=httpx.TimeoutException("timed out")
        )

        with pytest.raises(OllamaRetryableError, match="Timeout"):
            await client.chat(
                model="llama3",
                messages=[{"role": "user", "content": "hello"}],
            )
        assert client._client.post.await_count == 3

    @pytest.mark.asyncio
    async def test_500_is_retryable(self, client: OllamaClient) -> None:
        mock_response = MagicMock()
        mock_response.status_code = 500
        exc = httpx.HTTPStatusError(
            "server error",
            request=MagicMock(),
            response=mock_response,
        )
        client._client.post = AsyncMock(side_effect=exc)

        with pytest.raises(OllamaRetryableError, match="Ollama returned 500"):
            await client.chat(
                model="llama3",
                messages=[{"role": "user", "content": "hello"}],
            )
        assert client._client.post.await_count == 3

    @pytest.mark.asyncio
    async def test_429_is_retryable(self, client: OllamaClient) -> None:
        mock_response = MagicMock()
        mock_response.status_code = 429
        exc = httpx.HTTPStatusError(
            "rate limited",
            request=MagicMock(),
            response=mock_response,
        )
        client._client.post = AsyncMock(side_effect=exc)

        with pytest.raises(OllamaRetryableError, match="Ollama returned 429"):
            await client.chat(
                model="llama3",
                messages=[{"role": "user", "content": "hello"}],
            )
        assert client._client.post.await_count == 3

    @pytest.mark.asyncio
    async def test_400_is_fatal(self, client: OllamaClient) -> None:
        mock_response = MagicMock()
        mock_response.status_code = 400
        exc = httpx.HTTPStatusError(
            "bad request",
            request=MagicMock(),
            response=mock_response,
        )
        client._client.post = AsyncMock(side_effect=exc)

        with pytest.raises(OllamaFatalError, match="Ollama returned 400"):
            await client.chat(
                model="llama3",
                messages=[{"role": "user", "content": "hello"}],
            )
        # No retries for fatal error
        assert client._client.post.await_count == 1

    @pytest.mark.asyncio
    async def test_invalid_json_is_fatal(self, client: OllamaClient) -> None:
        mock_response = MagicMock()
        mock_response.raise_for_status = lambda: None
        mock_response.json.side_effect = json.JSONDecodeError(
            "not json", doc="", pos=0
        )
        client._client.post = AsyncMock(return_value=mock_response)

        with pytest.raises(OllamaFatalError, match="Invalid JSON response"):
            await client.chat(
                model="llama3",
                messages=[{"role": "user", "content": "hello"}],
            )


class TestChatWithImages:
    @pytest.mark.asyncio
    async def test_appends_images_to_last_message(self, client: OllamaClient) -> None:
        mock_response = MagicMock()
        mock_response.raise_for_status = lambda: None
        mock_response.json.return_value = {
            "message": {"content": '{"overall_condition": "good"}'}
        }
        client._client.post = AsyncMock(return_value=mock_response)

        result = await client.chat_with_images(
            model="vision",
            messages=[{"role": "user", "content": "describe"}],
            images_base64=["data1", "data2"],
        )

        payload = client._client.post.call_args.kwargs["json"]
        assert payload["messages"][-1]["images"] == ["data1", "data2"]
        assert result == {"message": {"content": '{"overall_condition": "good"}'}}


class TestImageToBase64:
    def test_reads_small_image_without_resize(self, tmp_path: Path) -> None:
        img_path = tmp_path / "small.png"
        img = Image.new("RGB", (100, 100), color=(255, 0, 0))
        img.save(img_path)

        b64 = OllamaClient.image_to_base64(str(img_path), resize=True, max_size=1024)
        decoded = base64.b64decode(b64)
        restored = Image.open(BytesIO(decoded))
        assert restored.size == (100, 100)

    def test_resizes_large_image(self, tmp_path: Path) -> None:
        img_path = tmp_path / "large.png"
        img = Image.new("RGB", (2000, 2000), color=(0, 255, 0))
        img.save(img_path)

        b64 = OllamaClient.image_to_base64(
            str(img_path), resize=True, max_size=512, quality=80
        )
        decoded = base64.b64decode(b64)
        restored = Image.open(BytesIO(decoded))
        assert restored.width <= 512
        assert restored.height <= 512

    def test_image_is_closed_after_use(self, tmp_path: Path) -> None:
        """Regression test for memory leak: Image.open must be closed."""
        img_path = tmp_path / "test.png"
        Image.new("RGB", (100, 100)).save(img_path)

        # Track open images via a spy
        opened: list[Any] = []
        original_open = Image.open

        def spy_open(path):
            img = original_open(path)
            opened.append(img)
            return img

        with patch("PIL.Image.open", spy_open):
            OllamaClient.image_to_base64(str(img_path))

        for img in opened:
            # After with-block __exit__ should have called close
            # Note: closed attribute exists on PIL.Image but may not be public
            # We verify by trying to access size after close (should error)
            assert img.fp is None or img.fp.closed


class TestCircuitBreakerIntegration:
    @pytest.mark.asyncio
    async def test_opens_after_5_retryable_failures(self, client: OllamaClient) -> None:
        from vmk_data_collector.core.circuit_breaker import CircuitBreakerOpenError

        client._client.post = AsyncMock(
            side_effect=httpx.ConnectError("connection refused")
        )

        # 1st call: 3 retries = 3 failures, circuit stays closed
        with pytest.raises(OllamaRetryableError):
            await client.chat(
                model="llama3",
                messages=[{"role": "user", "content": "hello"}],
            )

        # 2nd call: 2 more failures (total=5) open circuit on 5th failure.
        # 3rd attempt gets CircuitBreakerOpenError (not retried).
        with pytest.raises(CircuitBreakerOpenError):
            await client.chat(
                model="llama3",
                messages=[{"role": "user", "content": "hello"}],
            )

    @pytest.mark.asyncio
    async def test_success_resets_failure_count(self, client: OllamaClient) -> None:
        from vmk_data_collector.core.circuit_breaker import CircuitBreakerOpenError

        # 1st call: 3 failures (counter=3), circuit closed
        client._client.post = AsyncMock(
            side_effect=httpx.ConnectError("connection refused")
        )
        with pytest.raises(OllamaRetryableError):
            await client.chat(
                model="llama3",
                messages=[{"role": "user", "content": "hello"}],
            )

        # 1 success resets counter to 0
        mock_response = MagicMock()
        mock_response.raise_for_status = lambda: None
        mock_response.json.return_value = {"message": {"content": "ok"}}
        client._client.post = AsyncMock(return_value=mock_response)
        result = await client.chat(
            model="llama3",
            messages=[{"role": "user", "content": "hello"}],
        )
        assert result == {"message": {"content": "ok"}}

        # Need 5 failures again to open circuit
        client._client.post = AsyncMock(
            side_effect=httpx.ConnectError("connection refused")
        )
        # 1st call after reset: 3 failures (counter=3)
        with pytest.raises(OllamaRetryableError):
            await client.chat(
                model="llama3",
                messages=[{"role": "user", "content": "hello"}],
            )

        # 2nd call: 2 more failures (counter=5), circuit opens
        with pytest.raises(CircuitBreakerOpenError):
            await client.chat(
                model="llama3",
                messages=[{"role": "user", "content": "hello"}],
            )