Newer
Older
vmk-360-data_collector / src / vmk_data_collector / services / ollama_client.py
import base64
from pathlib import Path
from typing import Any

import httpx
import structlog

logger = structlog.get_logger()


class OllamaClient:
    def __init__(self, base_url: str, timeout: int = 120) -> None:
        self._client = httpx.AsyncClient(
            base_url=base_url,
            timeout=timeout,
        )

    async def chat(
        self,
        model: str,
        messages: list[dict[str, Any]],
        json_mode: bool = False,
    ) -> dict[str, Any]:
        payload: dict[str, Any] = {
            "model": model,
            "messages": messages,
            "stream": False,
        }
        if json_mode:
            payload["format"] = "json"

        logger.info(
            "ollama_chat_request",
            model=model,
            json_mode=json_mode,
            message_count=len(messages),
        )
        response = await self._client.post("/api/chat", json=payload)
        response.raise_for_status()
        data = response.json()
        logger.info("ollama_chat_response", model=model)
        return data

    async def chat_with_images(
        self,
        model: str,
        messages: list[dict[str, Any]],
        images_base64: list[str],
    ) -> dict[str, Any]:
        if messages and images_base64:
            messages[-1]["images"] = images_base64

        logger.info(
            "ollama_vision_request",
            model=model,
            image_count=len(images_base64),
        )
        response = await self._client.post("/api/chat", json={
            "model": model,
            "messages": messages,
            "stream": False,
        })
        response.raise_for_status()
        data = response.json()
        logger.info("ollama_vision_response", model=model)
        return data

    async def close(self) -> None:
        await self._client.aclose()

    @staticmethod
    def image_to_base64(image_path: str) -> str:
        with Path(image_path).open("rb") as f:
            return base64.b64encode(f.read()).decode("utf-8")