Newer
Older
vmk-360-data_collector / src / vmk_data_collector / services / ollama_client.py
@Eugene Sukhodolskiy Eugene Sukhodolskiy 1 day ago 5 KB fix: code review critical and high issues
import base64
import json
from io import BytesIO
from pathlib import Path
from typing import Any

import httpx
import structlog
from PIL import Image
from tenacity import (
    AsyncRetrying,
    retry_if_exception_type,
    stop_after_attempt,
    wait_exponential,
)

from vmk_data_collector.core.circuit_breaker import CircuitBreaker
from vmk_data_collector.core.exceptions import OllamaFatalError, OllamaRetryableError
from vmk_data_collector.core.metrics import ai_requests_total

logger = structlog.get_logger()


def _classify_httpx_error(
    exc: httpx.HTTPStatusError,
) -> OllamaRetryableError | OllamaFatalError:
    status = exc.response.status_code
    if status >= 500 or status == 429:
        return OllamaRetryableError(f"Ollama returned {status}: {exc}")
    return OllamaFatalError(f"Ollama returned {status}: {exc}")


_RETRY_CONFIG = {
    "stop": stop_after_attempt(3),
    "wait": wait_exponential(min=1, max=10),
    "retry": retry_if_exception_type(OllamaRetryableError),
    "before_sleep": lambda retry_state: logger.warning(
        "ollama_retry", attempt=retry_state.attempt_number
    ),
    "reraise": True,
}


class OllamaClient:
    def __init__(self, base_url: str, timeout: int = 120) -> None:
        self._client = httpx.AsyncClient(
            base_url=base_url,
            timeout=timeout,
        )
        self._circuit_breaker = CircuitBreaker(
            failure_threshold=5,
            recovery_timeout=30.0,
            expected_exception=(OllamaRetryableError,),
        )

    async def chat(
        self,
        model: str,
        messages: list[dict[str, Any]],
        json_mode: bool = False,
    ) -> dict[str, Any]:
        try:
            async for attempt in AsyncRetrying(**_RETRY_CONFIG):
                with attempt:
                    result = await self._circuit_breaker.call(
                        self._chat_raw,
                        model,
                        messages,
                        json_mode,
                    )
            ai_requests_total.labels(model=model, status="success").inc()
            return result
        except Exception:
            ai_requests_total.labels(model=model, status="error").inc()
            raise

    async def _chat_raw(
        self,
        model: str,
        messages: list[dict[str, Any]],
        json_mode: bool = False,
    ) -> dict[str, Any]:
        payload: dict[str, Any] = {
            "model": model,
            "messages": messages,
            "stream": False,
        }
        if json_mode:
            payload["format"] = "json"

        logger.info(
            "ollama_chat_request",
            model=model,
            json_mode=json_mode,
            message_count=len(messages),
        )
        try:
            response = await self._client.post("/api/chat", json=payload)
            response.raise_for_status()
            data = response.json()
        except httpx.ConnectError as exc:
            raise OllamaRetryableError(f"Connection error: {exc}") from exc
        except httpx.TimeoutException as exc:
            raise OllamaRetryableError(f"Timeout: {exc}") from exc
        except httpx.HTTPStatusError as exc:
            raise _classify_httpx_error(exc) from exc
        except json.JSONDecodeError as exc:
            raise OllamaFatalError(f"Invalid JSON response: {exc}") from exc

        logger.info("ollama_chat_response", model=model)
        return data

    async def chat_with_images(
        self,
        model: str,
        messages: list[dict[str, Any]],
        images_base64: list[str],
    ) -> dict[str, Any]:
        try:
            async for attempt in AsyncRetrying(**_RETRY_CONFIG):
                with attempt:
                    result = await self._circuit_breaker.call(
                        self._chat_with_images_raw,
                        model,
                        messages,
                        images_base64,
                    )
            ai_requests_total.labels(model=model, status="success").inc()
            return result
        except Exception:
            ai_requests_total.labels(model=model, status="error").inc()
            raise

    async def _chat_with_images_raw(
        self,
        model: str,
        messages: list[dict[str, Any]],
        images_base64: list[str],
    ) -> dict[str, Any]:
        if messages and images_base64:
            messages[-1]["images"] = images_base64

        logger.info(
            "ollama_vision_request",
            model=model,
            image_count=len(images_base64),
        )
        try:
            response = await self._client.post("/api/chat", json={
                "model": model,
                "messages": messages,
                "stream": False,
            })
            response.raise_for_status()
            data = response.json()
        except httpx.ConnectError as exc:
            raise OllamaRetryableError(f"Connection error: {exc}") from exc
        except httpx.TimeoutException as exc:
            raise OllamaRetryableError(f"Timeout: {exc}") from exc
        except httpx.HTTPStatusError as exc:
            raise _classify_httpx_error(exc) from exc
        except json.JSONDecodeError as exc:
            raise OllamaFatalError(f"Invalid JSON response: {exc}") from exc

        logger.info("ollama_vision_response", model=model)
        return data

    async def close(self) -> None:
        await self._client.aclose()

    @staticmethod
    def image_to_base64(
        image_path: str,
        resize: bool = True,
        max_size: int = 1024,
        quality: int = 85,
    ) -> str:
        with Image.open(image_path) as img:
            if resize and (img.width > max_size or img.height > max_size):
                img.thumbnail((max_size, max_size))
                buffer = BytesIO()
                with img.convert("RGB") as rgb_img:
                    rgb_img.save(buffer, format="JPEG", quality=quality)
                return base64.b64encode(buffer.getvalue()).decode("utf-8")
            with Path(image_path).open("rb") as f:
                return base64.b64encode(f.read()).decode("utf-8")