import base64
import json
from io import BytesIO
from pathlib import Path
from typing import Any
import httpx
import structlog
from PIL import Image
from tenacity import (
AsyncRetrying,
before_sleep_log,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from vmk_data_collector.core.circuit_breaker import CircuitBreaker
from vmk_data_collector.core.exceptions import OllamaFatalError, OllamaRetryableError
logger = structlog.get_logger()
def _classify_httpx_error(
exc: httpx.HTTPStatusError,
) -> OllamaRetryableError | OllamaFatalError:
status = exc.response.status_code
if status >= 500 or status == 429:
return OllamaRetryableError(f"Ollama returned {status}: {exc}")
return OllamaFatalError(f"Ollama returned {status}: {exc}")
_RETRY_CONFIG = {
"stop": stop_after_attempt(3),
"wait": wait_exponential(min=1, max=10),
"retry": retry_if_exception_type(OllamaRetryableError),
"before_sleep": before_sleep_log(logger, "warning"),
"reraise": True,
}
class OllamaClient:
def __init__(self, base_url: str, timeout: int = 120) -> None:
self._client = httpx.AsyncClient(
base_url=base_url,
timeout=timeout,
)
self._circuit_breaker = CircuitBreaker(
failure_threshold=5,
recovery_timeout=30.0,
expected_exception=(OllamaRetryableError,),
)
async def chat(
self,
model: str,
messages: list[dict[str, Any]],
json_mode: bool = False,
) -> dict[str, Any]:
async for attempt in AsyncRetrying(**_RETRY_CONFIG):
with attempt:
return await self._circuit_breaker.call(
self._chat_raw,
model,
messages,
json_mode,
)
async def _chat_raw(
self,
model: str,
messages: list[dict[str, Any]],
json_mode: bool = False,
) -> dict[str, Any]:
payload: dict[str, Any] = {
"model": model,
"messages": messages,
"stream": False,
}
if json_mode:
payload["format"] = "json"
logger.info(
"ollama_chat_request",
model=model,
json_mode=json_mode,
message_count=len(messages),
)
try:
response = await self._client.post("/api/chat", json=payload)
response.raise_for_status()
data = response.json()
except httpx.ConnectError as exc:
raise OllamaRetryableError(f"Connection error: {exc}") from exc
except httpx.TimeoutException as exc:
raise OllamaRetryableError(f"Timeout: {exc}") from exc
except httpx.HTTPStatusError as exc:
raise _classify_httpx_error(exc) from exc
except json.JSONDecodeError as exc:
raise OllamaFatalError(f"Invalid JSON response: {exc}") from exc
logger.info("ollama_chat_response", model=model)
return data
async def chat_with_images(
self,
model: str,
messages: list[dict[str, Any]],
images_base64: list[str],
) -> dict[str, Any]:
async for attempt in AsyncRetrying(**_RETRY_CONFIG):
with attempt:
return await self._circuit_breaker.call(
self._chat_with_images_raw,
model,
messages,
images_base64,
)
async def _chat_with_images_raw(
self,
model: str,
messages: list[dict[str, Any]],
images_base64: list[str],
) -> dict[str, Any]:
if messages and images_base64:
messages[-1]["images"] = images_base64
logger.info(
"ollama_vision_request",
model=model,
image_count=len(images_base64),
)
try:
response = await self._client.post("/api/chat", json={
"model": model,
"messages": messages,
"stream": False,
})
response.raise_for_status()
data = response.json()
except httpx.ConnectError as exc:
raise OllamaRetryableError(f"Connection error: {exc}") from exc
except httpx.TimeoutException as exc:
raise OllamaRetryableError(f"Timeout: {exc}") from exc
except httpx.HTTPStatusError as exc:
raise _classify_httpx_error(exc) from exc
except json.JSONDecodeError as exc:
raise OllamaFatalError(f"Invalid JSON response: {exc}") from exc
logger.info("ollama_vision_response", model=model)
return data
async def close(self) -> None:
await self._client.aclose()
@staticmethod
def image_to_base64(
image_path: str,
resize: bool = True,
max_size: int = 1024,
quality: int = 85,
) -> str:
img = Image.open(image_path)
if resize and (img.width > max_size or img.height > max_size):
img.thumbnail((max_size, max_size))
buffer = BytesIO()
img = img.convert("RGB")
img.save(buffer, format="JPEG", quality=quality)
return base64.b64encode(buffer.getvalue()).decode("utf-8")
with Path(image_path).open("rb") as f:
return base64.b64encode(f.read()).decode("utf-8")