"""Application service layer."""

import secrets as py_secrets
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any

from sqlalchemy import func, or_, select
from sqlalchemy.orm import Session, selectinload

from gnexus_creds import crypto
from gnexus_creds.config import Settings, get_settings
from gnexus_creds.errors import AppError
from gnexus_creds.models import (
    ApiToken,
    AuditEvent,
    RateLimitBucket,
    Secret,
    SecretTag,
    SecretVersion,
    SessionRecord,
    User,
    UserEncryptionKey,
    utcnow,
)
from gnexus_creds.schemas import (
    ApiTokenCreate,
    Scope,
    SecretCreate,
    SecretFieldIn,
    SecretFieldOut,
    SecretRead,
    SecretReveal,
    SecretStatus,
    SecretUpdate,
    SecretVersionRead,
)


@dataclass
class Actor:
    user: User
    channel: str = "rest"
    api_token: ApiToken | None = None
    ip_address: str | None = None
    user_agent: str | None = None

    def require(self, scope: Scope) -> None:
        if self.channel == "ui":
            return
        if self.api_token is None or scope.value not in self.api_token.scopes:
            raise AppError("forbidden", "Missing required scope.", status_code=403)


def _master_key(settings: Settings | None = None) -> bytes:
    return crypto.derive_master_key((settings or get_settings()).master_key)


def ensure_user_key(db: Session, user: User) -> UserEncryptionKey:
    existing = db.scalar(select(UserEncryptionKey).where(UserEncryptionKey.user_id == user.id))
    if existing:
        return existing
    raw_key = crypto.new_raw_key()
    nonce, encrypted_key = crypto.encrypt_bytes(_master_key(), raw_key)
    row = UserEncryptionKey(
        user_id=user.id,
        key_id=f"uk_{uuid.uuid4().hex[:16]}",
        encrypted_key=encrypted_key,
        nonce=nonce,
        algorithm=crypto.ALGORITHM,
    )
    db.add(row)
    db.flush()
    return row


def get_user_key(db: Session, user: User) -> tuple[str, bytes]:
    row = ensure_user_key(db, user)
    raw = crypto.decrypt_bytes(_master_key(), row.nonce, row.encrypted_key)
    return row.key_id, raw


def upsert_user_from_auth(
    db: Session,
    *,
    auth_subject: str,
    email: str,
    display_name: str | None,
    profile: dict,
    status: str = "enabled",
    system_role: str = "user",
    locale: str | None = None,
) -> User:
    user = db.scalar(select(User).where(User.auth_subject == auth_subject))
    if user is None:
        user = User(
            auth_subject=auth_subject,
            email=email,
            display_name=display_name,
            profile=profile,
            status=status,
            system_role=system_role if system_role in {"user", "admin"} else "user",
            locale=locale,
            last_seen_at=utcnow(),
        )
        db.add(user)
        db.flush()
        ensure_user_key(db, user)
    else:
        user.email = email
        user.display_name = display_name
        user.profile = profile
        user.status = status
        user.system_role = system_role if system_role in {"user", "admin"} else "user"
        user.locale = locale
        user.last_seen_at = utcnow()
    return user


def audit(
    db: Session,
    actor: Actor | None,
    *,
    action: str,
    user_id: uuid.UUID | None = None,
    secret_id: uuid.UUID | None = None,
    metadata: dict | None = None,
) -> AuditEvent:
    row = AuditEvent(
        user_id=user_id or (actor.user.id if actor else None),
        actor_user_id=actor.user.id if actor else None,
        api_token_id=actor.api_token.id if actor and actor.api_token else None,
        secret_id=secret_id,
        channel=actor.channel if actor else "rest",
        action=action,
        ip_address=actor.ip_address if actor else None,
        user_agent=actor.user_agent if actor else None,
        audit_metadata=metadata or {},
    )
    db.add(row)
    return row


def is_expired(start: datetime, *, now: datetime, delta: timedelta | None = None) -> bool:
    if start.tzinfo is None and now.tzinfo is not None:
        now = now.replace(tzinfo=None)
    return start + (delta or timedelta(seconds=0)) <= now


def check_rate_limit(
    db: Session,
    *,
    key: str,
    max_requests: int | None = None,
    window_seconds: int | None = None,
) -> None:
    settings = get_settings()
    max_requests = max_requests or settings.rate_limit_max_sensitive_requests
    window_seconds = window_seconds or settings.rate_limit_window_seconds
    now = utcnow()
    row = db.get(RateLimitBucket, key)
    if row is None:
        db.add(RateLimitBucket(key=key, count=1, window_start=now))
        return
    if is_expired(row.window_start, delta=timedelta(seconds=window_seconds), now=now):
        row.count = 1
        row.window_start = now
        return
    row.count += 1
    if row.count > max_requests:
        raise AppError("rate_limited", "Too many requests.", status_code=429)


def check_actor_rate_limit(db: Session, actor: Actor, action: str) -> None:
    token_part = actor.api_token.public_id if actor.api_token else "session"
    check_rate_limit(db, key=f"{action}:{actor.user.id}:{token_part}")


def log_failed_access_once(
    db: Session,
    *,
    key: str,
    channel: str,
    ip_address: str | None,
    user_agent: str | None,
) -> None:
    now = utcnow()
    bucket_key = f"failed_access:{key}"
    row = db.get(RateLimitBucket, bucket_key)
    if row is None or is_expired(row.window_start, delta=timedelta(minutes=5), now=now):
        if row is None:
            db.add(RateLimitBucket(key=bucket_key, count=1, window_start=now))
        else:
            row.count = 1
            row.window_start = now
        db.add(
            AuditEvent(
                user_id=None,
                actor_user_id=None,
                api_token_id=None,
                secret_id=None,
                channel=channel,
                action="access.failed",
                ip_address=ip_address,
                user_agent=user_agent,
                audit_metadata={"key": key},
            )
        )
    else:
        row.count += 1


def _store_fields(db: Session, user: User, fields: list[SecretFieldIn]) -> list[dict]:
    key_id, key = get_user_key(db, user)
    stored = []
    for index, field in enumerate(sorted(fields, key=lambda item: item.position)):
        payload: dict[str, Any] = {
            "name": field.name,
            "encrypted": field.encrypted,
            "masked": field.masked,
            "position": field.position if field.position is not None else index,
        }
        payload["value"] = (
            crypto.encrypt_text(key, key_id, field.value) if field.encrypted else field.value
        )
        stored.append(payload)
    return stored


def _field_search_text(fields: list[SecretFieldIn]) -> str:
    chunks: list[str] = []
    for field in fields:
        chunks.append(field.name)
        if not field.encrypted:
            chunks.append(field.value)
    return " ".join(chunks).lower()


def _public_fields(
    fields: list[dict], *, reveal: bool, db: Session | None = None, user: User | None = None
):
    key: bytes | None = None
    if reveal and db is not None and user is not None:
        _, key = get_user_key(db, user)
    result = []
    for field in sorted(fields, key=lambda item: item.get("position", 0)):
        value = None
        if reveal:
            if field.get("encrypted"):
                value = crypto.decrypt_text(key, field["value"]) if key else None
            else:
                value = field.get("value")
        elif not field.get("encrypted"):
            value = field.get("value")
        result.append(
            SecretFieldOut(
                name=field["name"],
                value=value,
                encrypted=bool(field.get("encrypted")),
                masked=bool(field.get("masked")),
                position=int(field.get("position", 0)),
            )
        )
    return result


def _tags(secret: Secret) -> list[str]:
    return [tag.name for tag in secret.tags]


def _current_version(secret: Secret) -> SecretVersion:
    if not secret.versions:
        raise AppError("secret_has_no_versions", "Secret has no versions.", status_code=500)
    return secret.versions[-1]


def serialize_secret(
    secret: Secret, *, reveal: bool = False, db: Session | None = None, user: User | None = None
):
    version = _current_version(secret)
    cls = SecretReveal if reveal else SecretRead
    base = {
        "id": secret.id,
        "title": secret.title,
        "purpose": secret.purpose,
        "category": secret.category,
        "source": secret.source,
        "notes": secret.notes,
        "tags": _tags(secret),
        "status": secret.status,
        "archived": secret.archived,
        "allow_ui": secret.allow_ui,
        "allow_rest_api": secret.allow_rest_api,
        "allow_mcp": secret.allow_mcp,
        "created_at": secret.created_at,
        "updated_at": secret.updated_at,
        "fields": _public_fields(version.fields, reveal=reveal, db=db, user=user),
    }
    if reveal:
        base["version_id"] = version.id
        base["version_number"] = version.version_number
    return cls(**base)


def create_secret(db: Session, actor: Actor, payload: SecretCreate) -> SecretRead:
    actor.require(Scope.write)
    secret = Secret(
        user_id=actor.user.id,
        title=payload.title,
        purpose=payload.purpose,
        category=payload.category,
        source=payload.source,
        notes=payload.notes,
        status=payload.status.value,
        archived=payload.archived,
        allow_ui=payload.allow_ui,
        allow_rest_api=payload.allow_rest_api,
        allow_mcp=payload.allow_mcp,
    )
    db.add(secret)
    db.flush()
    for tag in payload.tags:
        db.add(SecretTag(secret_id=secret.id, user_id=actor.user.id, name=tag))
    db.add(
        SecretVersion(
            secret_id=secret.id,
            version_number=1,
            fields=_store_fields(db, actor.user, payload.fields),
            search_text=_field_search_text(payload.fields),
        )
    )
    audit(db, actor, action="secret.created", secret_id=secret.id, metadata={"title": secret.title})
    db.flush()
    db.refresh(secret, attribute_names=["versions", "tags"])
    return serialize_secret(secret)


def _load_secret(db: Session, user_id: uuid.UUID, secret_id: uuid.UUID) -> Secret:
    secret = db.scalar(
        select(Secret)
        .where(Secret.id == secret_id, Secret.user_id == user_id)
        .options(selectinload(Secret.versions), selectinload(Secret.tags))
    )
    if secret is None:
        raise AppError("secret_not_found", "Secret not found.", status_code=404)
    return secret


def list_secrets(
    db: Session,
    actor: Actor,
    *,
    q: str | None = None,
    category: str | None = None,
    status: SecretStatus | None = None,
    include_archived: bool = False,
    offset: int = 0,
    limit: int = 50,
    mcp: bool = False,
) -> tuple[list[SecretRead], int]:
    actor.require(Scope.read)
    stmt = select(Secret).where(Secret.user_id == actor.user.id)
    if not include_archived:
        stmt = stmt.where(Secret.archived.is_(False))
    if mcp:
        stmt = stmt.where(Secret.archived.is_(False), Secret.allow_mcp.is_(True))
    if category:
        stmt = stmt.where(Secret.category == category)
    if status:
        stmt = stmt.where(Secret.status == status.value)
    if q:
        like = f"%{q.lower()}%"
        stmt = (
            stmt.join(SecretVersion)
            .outerjoin(SecretTag)
            .where(
                or_(
                    func.lower(Secret.title).like(like),
                    func.lower(Secret.purpose).like(like),
                    func.lower(Secret.category).like(like),
                    func.lower(Secret.source).like(like),
                    func.lower(Secret.notes).like(like),
                    func.lower(SecretTag.name).like(like),
                    SecretVersion.search_text.ilike(like),
                )
            )
        )
    count_stmt = select(func.count()).select_from(stmt.subquery())
    total = db.scalar(count_stmt) or 0
    rows = db.scalars(
        stmt.options(selectinload(Secret.versions), selectinload(Secret.tags))
        .order_by(Secret.updated_at.desc())
        .offset(offset)
        .limit(limit)
    ).unique()
    return [serialize_secret(row) for row in rows], total


def get_secret(db: Session, actor: Actor, secret_id: uuid.UUID, *, mcp: bool = False) -> SecretRead:
    actor.require(Scope.read)
    secret = _load_secret(db, actor.user.id, secret_id)
    if mcp and (secret.archived or not secret.allow_mcp):
        raise AppError("secret_not_found", "Secret not found.", status_code=404)
    return serialize_secret(secret)


def reveal_secret(
    db: Session,
    actor: Actor,
    secret_id: uuid.UUID,
    *,
    version_id: uuid.UUID | None = None,
    mcp: bool = False,
) -> SecretReveal:
    actor.require(Scope.reveal)
    check_actor_rate_limit(db, actor, "reveal")
    secret = _load_secret(db, actor.user.id, secret_id)
    if mcp and (secret.archived or not secret.allow_mcp):
        raise AppError("secret_not_found", "Secret not found.", status_code=404)
    version = _current_version(secret)
    if version_id:
        version = next((item for item in secret.versions if item.id == version_id), None)
        if version is None:
            raise AppError("version_not_found", "Secret version not found.", status_code=404)
    audit(
        db,
        actor,
        action="secret.revealed",
        secret_id=secret.id,
        metadata={"title": secret.title, "version_number": version.version_number},
    )
    base = serialize_secret(secret, reveal=True, db=db, user=actor.user).model_dump()
    base["version_id"] = version.id
    base["version_number"] = version.version_number
    base["fields"] = _public_fields(version.fields, reveal=True, db=db, user=actor.user)
    return SecretReveal(**base)


def update_secret(
    db: Session, actor: Actor, secret_id: uuid.UUID, payload: SecretUpdate
) -> SecretRead:
    actor.require(Scope.write)
    secret = _load_secret(db, actor.user.id, secret_id)
    changed_metadata: dict[str, Any] = {}
    for field in [
        "title",
        "purpose",
        "category",
        "source",
        "notes",
        "archived",
        "allow_ui",
        "allow_rest_api",
        "allow_mcp",
    ]:
        value = getattr(payload, field)
        if value is not None and getattr(secret, field) != value:
            changed_metadata[field] = {"old": getattr(secret, field), "new": value}
            setattr(secret, field, value)
    if payload.status is not None and secret.status != payload.status.value:
        changed_metadata["status"] = {"old": secret.status, "new": payload.status.value}
        secret.status = payload.status.value
    if payload.tags is not None:
        normalized = SecretCreate(title="x", tags=payload.tags).tags
        if normalized != _tags(secret):
            secret.tags.clear()
            db.flush()
            for tag in normalized:
                db.add(SecretTag(secret_id=secret.id, user_id=actor.user.id, name=tag))
            changed_metadata["tags"] = {"old": _tags(secret), "new": normalized}
    if payload.fields is not None:
        next_version = _current_version(secret).version_number + 1
        db.add(
            SecretVersion(
                secret_id=secret.id,
                version_number=next_version,
                fields=_store_fields(db, actor.user, payload.fields),
                search_text=_field_search_text(payload.fields),
            )
        )
        audit(
            db,
            actor,
            action="secret.version_created",
            secret_id=secret.id,
            metadata={"version": next_version},
        )
    if changed_metadata:
        audit(
            db,
            actor,
            action="secret.metadata_updated",
            secret_id=secret.id,
            metadata={"diff": changed_metadata},
        )
    secret.updated_at = utcnow()
    db.flush()
    db.refresh(secret, attribute_names=["versions", "tags"])
    return serialize_secret(secret)


def delete_secret(db: Session, actor: Actor, secret_id: uuid.UUID) -> None:
    actor.require(Scope.write)
    secret = _load_secret(db, actor.user.id, secret_id)
    snapshot = {"title": secret.title, "category": secret.category, "tags": _tags(secret)}
    audit(db, actor, action="secret.deleted", secret_id=secret.id, metadata={"snapshot": snapshot})
    db.delete(secret)


def list_versions(db: Session, actor: Actor, secret_id: uuid.UUID) -> list[SecretVersionRead]:
    actor.require(Scope.read)
    secret = _load_secret(db, actor.user.id, secret_id)
    return [
        SecretVersionRead(
            id=version.id,
            version_number=version.version_number,
            created_at=version.created_at,
            fields=_public_fields(version.fields, reveal=False),
        )
        for version in secret.versions
    ]


def create_api_token(db: Session, actor: Actor, payload: ApiTokenCreate):
    actor.require(Scope.admin)
    check_actor_rate_limit(db, actor, "api_token_create")
    public_id = py_secrets.token_hex(8)
    secret = crypto.token_secret()
    full_token = f"gcr_{public_id}_{secret}"
    row = ApiToken(
        user_id=actor.user.id,
        public_id=public_id,
        name=payload.name,
        token_hash=crypto.token_hash(full_token),
        scopes=[scope.value for scope in payload.scopes],
    )
    db.add(row)
    audit(
        db, actor, action="api_token.created", metadata={"name": payload.name, "scopes": row.scopes}
    )
    db.flush()
    return row, full_token


def authenticate_api_token(db: Session, token: str) -> Actor | None:
    row = db.scalar(
        select(ApiToken).where(
            ApiToken.token_hash == crypto.token_hash(token), ApiToken.revoked_at.is_(None)
        )
    )
    if row is None:
        return None
    user = db.get(User, row.user_id)
    if user is None or user.status == "disabled":
        return None
    row.last_used_at = utcnow()
    return Actor(user=user, channel="rest", api_token=row)


def create_session(db: Session, user: User, ttl_seconds: int) -> SessionRecord:
    sid = py_secrets.token_urlsafe(48)
    row = SessionRecord(
        id=sid,
        user_id=user.id,
        data={},
        expires_at=utcnow() + timedelta(seconds=ttl_seconds),
    )
    db.add(row)
    return row


def get_session_user(db: Session, session_id: str | None) -> User | None:
    if not session_id:
        return None
    row = db.get(SessionRecord, session_id)
    if row is None or is_expired(row.expires_at, now=utcnow()):
        return None
    user = db.get(User, row.user_id)
    if user is None or user.status == "disabled":
        return None
    return user


def delete_session(db: Session, session_id: str | None) -> None:
    if session_id:
        row = db.get(SessionRecord, session_id)
        if row:
            db.delete(row)
