#!/usr/bin/env python3
import argparse
import math
from pathlib import Path

import cv2
import torch
from PIL import Image, ImageOps
from transformers import CLIPModel, CLIPProcessor
from nudenet import NudeDetector
from ultralytics import YOLO


# ============================================================
# PATHS
# ============================================================

BASE_DIR = Path(__file__).resolve().parent
FACE_MODEL_PATH = BASE_DIR / "face_detection_yunet_2023mar.onnx"


# ============================================================
# CONFIG
# ============================================================

NUDE_THRESHOLD = 0.6
FACE_CONF_THRESHOLD = 0.5
PERSON_CONF_THRESHOLD = 0.4

CLIP_PERSON_HINT_THRESHOLD = 0.30
CLIP_PRINT_HINT_THRESHOLD = 0.30
CLIP_REAL_ENV_THRESHOLD = 0.20   # <<< NEU

NUDE_CLASSES = {
    "FEMALE_BREAST_EXPOSED",
    "FEMALE_GENITALIA_EXPOSED",
    "MALE_GENITALIA_EXPOSED",
    "BUTTOCKS_EXPOSED",
    "ANUS_EXPOSED",
}


# ============================================================
# CLIP PROMPTS (PURE)
# ============================================================

REAL_LABELS = [
    "a photo of a real scene in a physical environment",
    "a photo of a real indoor room",
    "a photo of a real outdoor scene",
    "a photo of real objects in a room",
    "a real photograph taken by a person in the real world",
]

RECAPTURE_LABELS = [
    "a screenshot",
    "a photo that is entirely a computer screen",
    "a photo that is entirely a phone screen",
    "a photo of a printed photo",
    "a photo of printed paper",
    "a photo of a photo",
]


# ============================================================
# UTILS
# ============================================================

def softmax(xs):
    m = max(xs)
    exps = [math.exp(x - m) for x in xs]
    s = sum(exps)
    return [e / s for e in exps]


def box_iou(a, b):
    ax1, ay1, ax2, ay2 = a
    bx1, by1, bx2, by2 = b

    ix1, iy1 = max(ax1, bx1), max(ay1, by1)
    ix2, iy2 = min(ax2, bx2), min(ay2, by2)

    iw, ih = max(0, ix2 - ix1), max(0, iy2 - iy1)
    inter = iw * ih

    area_a = (ax2 - ax1) * (ay2 - ay1)
    area_b = (bx2 - bx1) * (by2 - by1)

    union = area_a + area_b - inter
    return inter / union if union > 0 else 0


# ============================================================
# CLIP
# ============================================================

def run_clip(image_path, model, processor, device):
    labels = REAL_LABELS + RECAPTURE_LABELS

    image = ImageOps.exif_transpose(Image.open(image_path)).convert("RGB")
    inputs = processor(
        text=labels,
        images=image,
        return_tensors="pt",
        padding=True
    ).to(device)

    with torch.no_grad():
        logits = model(**inputs).logits_per_image[0].cpu().tolist()

    probs = softmax(logits)
    ranked = sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)

    real_score = sum(p for l, p in ranked if l in REAL_LABELS)
    recapture_score = sum(p for l, p in ranked if l in RECAPTURE_LABELS)
    best_real = max(p for l, p in ranked if l in REAL_LABELS)
    best_recapture = max(p for l, p in ranked if l in RECAPTURE_LABELS)

    clip_person = 0.0
    clip_print = 0.0
    real_env = 0.0

    for label, prob in ranked:
        if label == "a real photograph taken by a person in the real world":
            clip_person = prob
        if label in ("a photo of printed paper", "a photo of a printed photo"):
            clip_print = max(clip_print, prob)
        if label in (
            "a photo of a real scene in a physical environment",
            "a photo of a real outdoor scene",
            "a photo of a real indoor room",
        ):
            real_env = max(real_env, prob)

    return {
        "ranked": ranked,
        "real_score": real_score,
        "recapture_score": recapture_score,
        "best_real": best_real,
        "best_recapture": best_recapture,
        "clip_person": clip_person,
        "clip_print": clip_print,
        "real_env": real_env,
    }


def decide_real_vs_recapture(ctx):
    if ctx["best_real"] >= 0.20:
        return "REAL"
    if ctx["best_recapture"] >= 0.40:
        return "RECATURE"
    if ctx["real_score"] >= ctx["recapture_score"] * 1.15:
        return "REAL"
    return "RECATURE"


# ============================================================
# DETECTORS
# ============================================================

def detect_faces(img):
    h, w = img.shape[:2]
    net = cv2.FaceDetectorYN.create(
        model=str(FACE_MODEL_PATH),
        config="",
        input_size=(w, h),
        score_threshold=FACE_CONF_THRESHOLD,
        nms_threshold=0.3,
        top_k=5000
    )
    net.setInputSize((w, h))
    _, faces = net.detect(img)

    results = []
    if faces is not None:
        for f in faces:
            score = float(f[4])
            if score >= FACE_CONF_THRESHOLD:
                x, y, bw, bh = map(int, f[:4])
                results.append((x, y, x + bw, y + bh))
    return results


def detect_persons(img, yolo):
    res = yolo(img, verbose=False)[0]
    persons = []
    for box, cls, conf in zip(res.boxes.xyxy, res.boxes.cls, res.boxes.conf):
        if int(cls) == 0 and conf >= PERSON_CONF_THRESHOLD:
            persons.append(tuple(map(int, box.tolist())))
    return persons


# ============================================================
# FINAL DECISION
# ============================================================

def should_block(image_path):
    img = cv2.imread(image_path)
    if img is None:
        raise RuntimeError("Bild konnte nicht geladen werden")

    # 1) Nude hat immer Priorität
    nude_hits = NudeDetector().detect(image_path) or []
    for d in nude_hits:
        if d.get("class") in NUDE_CLASSES and d.get("score", 0) >= NUDE_THRESHOLD:
            return True, "NUDITY_DETECTED"

    # 2) CLIP
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = CLIPModel.from_pretrained(
        "openai/clip-vit-large-patch14",
        local_files_only=True
    ).to(device)

    processor = CLIPProcessor.from_pretrained(
        "openai/clip-vit-large-patch14",
        local_files_only=True,
        use_fast=False
    )

    ctx = run_clip(image_path, model, processor, device)
    verdict = decide_real_vs_recapture(ctx)

    # ---- CLIP DEBUG OUTPUT ----
    print("\nTop predictions:")
    for l, p in ctx["ranked"][:5]:
        print(f"  {p:6.3f}  {l}")

    print("\nScores:")
    print(f"  real_score      = {ctx['real_score']:.3f}")
    print(f"  recapture_score = {ctx['recapture_score']:.3f}")
    print(f"  best_real       = {ctx['best_real']:.3f}")
    print(f"  best_recapture  = {ctx['best_recapture']:.3f}")

    print(f"\nVerdict: {'ECHTES FOTO' if verdict == 'REAL' else 'RECATURE'}\n")
    # ---------------------------

    if verdict != "REAL":
        return False, "ALLOWED_NON_REAL"

    # 3) Face-Gate (YuNet)
    faces = detect_faces(img)
    if not faces:
        return False, "ALLOWED_REAL_NO_PERSON"

    # Print-Objekt nur, wenn Print hoch UND reale Szene schwach
    is_print_object = (
        ctx["clip_print"] >= CLIP_PRINT_HINT_THRESHOLD
        and ctx["real_env"] < CLIP_REAL_ENV_THRESHOLD
    )

    # 4) YOLO Person (nur wenn KEIN Print-Objekt)
    yolo = YOLO("yolov8n.pt")
    persons = detect_persons(img, yolo)

    if persons and not is_print_object:
        return True, "REAL_PERSON_DETECTED"

    # 5) CLIP-Person-Fallback (nur wenn KEIN Print-Objekt)
    if (
        ctx["clip_person"] >= CLIP_PERSON_HINT_THRESHOLD
        and not is_print_object
    ):
        return True, "REAL_PERSON_DETECTED_CLIP_HINT"

    return False, "ALLOWED_REAL_NO_PERSON"


# ============================================================
# CLI
# ============================================================

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("image")
    args = ap.parse_args()

    blocked, reason = should_block(args.image)
    print("⛔ BLOCKIERT" if blocked else "✅ ERLAUBT", ":", reason)


if __name__ == "__main__":
    main()
