import base64
import json
from io import BytesIO
from pathlib import Path
from typing import Any
import httpx
import structlog
from PIL import Image
from tenacity import (
AsyncRetrying,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from vmk_data_collector.core.circuit_breaker import CircuitBreaker
from vmk_data_collector.core.exceptions import OllamaFatalError, OllamaRetryableError
from vmk_data_collector.core.metrics import ai_requests_total
logger = structlog.get_logger()
def _classify_httpx_error(
exc: httpx.HTTPStatusError,
) -> OllamaRetryableError | OllamaFatalError:
status = exc.response.status_code
if status >= 500 or status == 429:
return OllamaRetryableError(f"Ollama returned {status}: {exc}")
return OllamaFatalError(f"Ollama returned {status}: {exc}")
_RETRY_CONFIG = {
"stop": stop_after_attempt(3),
"wait": wait_exponential(min=1, max=10),
"retry": retry_if_exception_type(OllamaRetryableError),
"before_sleep": lambda retry_state: logger.warning(
"ollama_retry", attempt=retry_state.attempt_number
),
"reraise": True,
}
class OllamaClient:
def __init__(self, base_url: str, timeout: int = 120) -> None:
self._client = httpx.AsyncClient(
base_url=base_url,
timeout=timeout,
)
self._circuit_breaker = CircuitBreaker(
failure_threshold=5,
recovery_timeout=30.0,
expected_exception=(OllamaRetryableError,),
)
async def chat(
self,
model: str,
messages: list[dict[str, Any]],
json_mode: bool = False,
) -> dict[str, Any]:
try:
async for attempt in AsyncRetrying(**_RETRY_CONFIG):
with attempt:
result = await self._circuit_breaker.call(
self._chat_raw,
model,
messages,
json_mode,
)
ai_requests_total.labels(model=model, status="success").inc()
return result
except Exception:
ai_requests_total.labels(model=model, status="error").inc()
raise
async def _chat_raw(
self,
model: str,
messages: list[dict[str, Any]],
json_mode: bool = False,
) -> dict[str, Any]:
payload: dict[str, Any] = {
"model": model,
"messages": messages,
"stream": False,
}
if json_mode:
payload["format"] = "json"
logger.info(
"ollama_chat_request",
model=model,
json_mode=json_mode,
message_count=len(messages),
)
try:
response = await self._client.post("/api/chat", json=payload)
response.raise_for_status()
data = response.json()
except httpx.ConnectError as exc:
raise OllamaRetryableError(f"Connection error: {exc}") from exc
except httpx.TimeoutException as exc:
raise OllamaRetryableError(f"Timeout: {exc}") from exc
except httpx.HTTPStatusError as exc:
raise _classify_httpx_error(exc) from exc
except json.JSONDecodeError as exc:
raise OllamaFatalError(f"Invalid JSON response: {exc}") from exc
logger.info("ollama_chat_response", model=model)
return data
async def embed(
self,
model: str,
texts: list[str],
) -> list[list[float]]:
try:
async for attempt in AsyncRetrying(**_RETRY_CONFIG):
with attempt:
result = await self._circuit_breaker.call(
self._embed_raw,
model,
texts,
)
ai_requests_total.labels(model=model, status="success").inc()
return result
except Exception:
ai_requests_total.labels(model=model, status="error").inc()
raise
async def _embed_raw(
self,
model: str,
texts: list[str],
) -> list[list[float]]:
payload = {
"model": model,
"input": texts,
}
logger.info(
"ollama_embed_request",
model=model,
text_count=len(texts),
)
try:
response = await self._client.post("/api/embed", json=payload)
response.raise_for_status()
data = response.json()
except httpx.ConnectError as exc:
raise OllamaRetryableError(f"Connection error: {exc}") from exc
except httpx.TimeoutException as exc:
raise OllamaRetryableError(f"Timeout: {exc}") from exc
except httpx.HTTPStatusError as exc:
raise _classify_httpx_error(exc) from exc
except json.JSONDecodeError as exc:
raise OllamaFatalError(f"Invalid JSON response: {exc}") from exc
embeddings = data.get("embeddings")
if not embeddings:
raise OllamaFatalError(
f"Missing embeddings in response: {data.keys()}"
)
logger.info(
"ollama_embed_response",
model=model,
embedding_count=len(embeddings),
)
return embeddings
async def chat_with_images(
self,
model: str,
messages: list[dict[str, Any]],
images_base64: list[str],
) -> dict[str, Any]:
try:
async for attempt in AsyncRetrying(**_RETRY_CONFIG):
with attempt:
result = await self._circuit_breaker.call(
self._chat_with_images_raw,
model,
messages,
images_base64,
)
ai_requests_total.labels(model=model, status="success").inc()
return result
except Exception:
ai_requests_total.labels(model=model, status="error").inc()
raise
async def _chat_with_images_raw(
self,
model: str,
messages: list[dict[str, Any]],
images_base64: list[str],
) -> dict[str, Any]:
if messages and images_base64:
messages[-1]["images"] = images_base64
logger.info(
"ollama_vision_request",
model=model,
image_count=len(images_base64),
)
try:
response = await self._client.post("/api/chat", json={
"model": model,
"messages": messages,
"stream": False,
})
response.raise_for_status()
data = response.json()
except httpx.ConnectError as exc:
raise OllamaRetryableError(f"Connection error: {exc}") from exc
except httpx.TimeoutException as exc:
raise OllamaRetryableError(f"Timeout: {exc}") from exc
except httpx.HTTPStatusError as exc:
raise _classify_httpx_error(exc) from exc
except json.JSONDecodeError as exc:
raise OllamaFatalError(f"Invalid JSON response: {exc}") from exc
logger.info("ollama_vision_response", model=model)
return data
async def close(self) -> None:
await self._client.aclose()
@staticmethod
def image_to_base64(
image_path: str,
resize: bool = True,
max_size: int = 1024,
quality: int = 85,
) -> str:
with Image.open(image_path) as img:
if resize and (img.width > max_size or img.height > max_size):
img.thumbnail((max_size, max_size))
buffer = BytesIO()
with img.convert("RGB") as rgb_img:
rgb_img.save(buffer, format="JPEG", quality=quality)
return base64.b64encode(buffer.getvalue()).decode("utf-8")
with Path(image_path).open("rb") as f:
return base64.b64encode(f.read()).decode("utf-8")