"""Session model and in-memory session store."""

import uuid
from abc import ABC, abstractmethod
from datetime import datetime, timezone

from pydantic import BaseModel, Field

from navi.llm.base import Message


class Session(BaseModel):
    id: str = Field(default_factory=lambda: str(uuid.uuid4()))
    profile_id: str
    user_id: str | None = None  # owner; null for legacy single-user sessions
    messages: list[Message] = Field(default_factory=list)   # full display history (never compressed)
    context: list[Message] = Field(default_factory=list)    # LLM context (may be compressed)
    context_token_count: int = 0   # accumulated total; reset to 0 after compression
    pinned: bool = False
    name: str | None = None
    created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
    last_active: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
    planning_logs: list[dict] = Field(default_factory=list)  # raw planning phase outputs per turn


class SessionStore(ABC):
    @abstractmethod
    async def create(self, profile_id: str, user_id: str | None = None) -> Session: ...

    @abstractmethod
    async def get(self, session_id: str) -> Session | None: ...

    @abstractmethod
    async def save(self, session: Session) -> None: ...

    @abstractmethod
    async def list_all(self, user_id: str | None = None, is_admin: bool = False) -> list[Session]: ...

    @abstractmethod
    async def list_page(
        self,
        *,
        limit: int,
        offset: int,
        profile_id: str | None = None,
        user_id: str | None = None,
        is_admin: bool = False,
    ) -> list[Session]: ...

    @abstractmethod
    async def delete(self, session_id: str) -> bool: ...

    @abstractmethod
    async def set_pinned(self, session_id: str, pinned: bool) -> bool: ...

    @abstractmethod
    async def set_name(self, session_id: str, name: str) -> bool: ...


class InMemorySessionStore(SessionStore):
    def __init__(self) -> None:
        self._sessions: dict[str, Session] = {}

    async def create(self, profile_id: str, user_id: str | None = None) -> Session:
        session = Session(profile_id=profile_id, user_id=user_id)
        self._sessions[session.id] = session
        return session

    async def get(self, session_id: str) -> Session | None:
        return self._sessions.get(session_id)

    async def save(self, session: Session) -> None:
        session.last_active = datetime.now(timezone.utc)
        self._sessions[session.id] = session

    async def list_all(self, user_id: str | None = None, is_admin: bool = False) -> list[Session]:
        sessions = self._sessions.values()
        if not is_admin and user_id is not None:
            sessions = [s for s in sessions if s.user_id == user_id]
        return sorted(
            sessions,
            key=lambda s: (s.pinned, s.last_active),
            reverse=True,
        )

    async def list_page(
        self,
        *,
        limit: int,
        offset: int,
        profile_id: str | None = None,
        user_id: str | None = None,
        is_admin: bool = False,
    ) -> list[Session]:
        sessions = await self.list_all(user_id=user_id, is_admin=is_admin)
        if profile_id:
            sessions = [s for s in sessions if s.profile_id == profile_id]
        return sessions[offset:offset + limit]

    async def delete(self, session_id: str) -> bool:
        if session_id in self._sessions:
            del self._sessions[session_id]
            return True
        return False

    async def set_pinned(self, session_id: str, pinned: bool) -> bool:
        s = self._sessions.get(session_id)
        if s is None:
            return False
        s.pinned = pinned
        return True

    async def set_name(self, session_id: str, name: str) -> bool:
        s = self._sessions.get(session_id)
        if s is None:
            return False
        s.name = name
        return True
