#!/usr/bin/env python3
import argparse
import math
from pathlib import Path

import cv2
import torch
from PIL import Image, ImageOps
from transformers import CLIPModel, CLIPProcessor
from nudenet import NudeDetector
from ultralytics import YOLO


# ============================================================
# PATHS
# ============================================================

BASE_DIR = Path(__file__).resolve().parent
FACE_MODEL_PATH = BASE_DIR / "face_detection_yunet_2023mar.onnx"


# ============================================================
# CONFIG (DSGVO-KONFORM)
# ============================================================

NUDE_THRESHOLD = 0.6
FACE_CONF_THRESHOLD = 0.5
PERSON_CONF_THRESHOLD = 0.4

# DSGVO-relevant
PERSON_MIN_AREA_RATIO = 0.02      # 2 % Bildfläche
FACE_INSIDE_RATIO = 0.6           # 60 % des Gesichts innerhalb Person

CLIP_PRINT_HINT_THRESHOLD = 0.35
CLIP_REAL_ENV_THRESHOLD = 0.20

NUDE_CLASSES = {
    "FEMALE_BREAST_EXPOSED",
    "FEMALE_GENITALIA_EXPOSED",
    "MALE_GENITALIA_EXPOSED",
    "BUTTOCKS_EXPOSED",
    "ANUS_EXPOSED",
}


# ============================================================
# CLIP PROMPTS
# ============================================================

REAL_LABELS = [
    "a photo of a real scene in a physical environment",
    "a photo of a real indoor room",
    "a photo of a real outdoor scene",
    "a photo of real objects in a room",
    "a real photograph taken by a person in the real world",
]

RECAPTURE_LABELS = [
    "a screenshot",
    "a photo that is entirely a computer screen",
    "a photo that is entirely a phone screen",
    "a photo of a printed photo",
    "a photo of printed paper",
    "a photo of a photo",
]


# ============================================================
# UTILS
# ============================================================

def softmax(xs):
    m = max(xs)
    exps = [math.exp(x - m) for x in xs]
    s = sum(exps)
    return [e / s for e in exps]


def box_area_ratio(box, img):
    x1, y1, x2, y2 = box
    area = max(0, (x2 - x1) * (y2 - y1))
    img_area = img.shape[0] * img.shape[1]
    return area / img_area if img_area > 0 else 0


def face_inside_person(face, person, min_ratio=FACE_INSIDE_RATIO):
    fx1, fy1, fx2, fy2 = face
    px1, py1, px2, py2 = person

    ix1, iy1 = max(fx1, px1), max(fy1, py1)
    ix2, iy2 = min(fx2, px2), min(fy2, py2)

    iw, ih = max(0, ix2 - ix1), max(0, iy2 - iy1)
    inter = iw * ih

    face_area = (fx2 - fx1) * (fy2 - fy1)
    if face_area <= 0:
        return False

    return (inter / face_area) >= min_ratio


# ============================================================
# CLIP
# ============================================================

def run_clip(image_path, model, processor, device):
    labels = REAL_LABELS + RECAPTURE_LABELS

    image = ImageOps.exif_transpose(Image.open(image_path)).convert("RGB")
    inputs = processor(
        text=labels,
        images=image,
        return_tensors="pt",
        padding=True
    ).to(device)

    with torch.no_grad():
        logits = model(**inputs).logits_per_image[0].cpu().tolist()

    probs = softmax(logits)
    ranked = sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)

    real_score = sum(p for l, p in ranked if l in REAL_LABELS)
    recapture_score = sum(p for l, p in ranked if l in RECAPTURE_LABELS)
    best_real = max(p for l, p in ranked if l in REAL_LABELS)
    best_recapture = max(p for l, p in ranked if l in RECAPTURE_LABELS)

    clip_print = 0.0
    real_env = 0.0

    for label, prob in ranked:
        if label in ("a photo of printed paper", "a photo of a printed photo"):
            clip_print = max(clip_print, prob)
        if label in (
            "a photo of a real scene in a physical environment",
            "a photo of a real outdoor scene",
            "a photo of a real indoor room",
        ):
            real_env = max(real_env, prob)

    return {
        "ranked": ranked,
        "real_score": real_score,
        "recapture_score": recapture_score,
        "best_real": best_real,
        "best_recapture": best_recapture,
        "clip_print": clip_print,
        "real_env": real_env,
    }


def decide_real_vs_recapture(ctx):
    if ctx["best_real"] >= 0.20:
        return "REAL"
    if ctx["best_recapture"] >= 0.40:
        return "RECAPTURE"
    if ctx["real_score"] >= ctx["recapture_score"] * 1.15:
        return "REAL"
    return "RECAPTURE"


# ============================================================
# DETECTORS
# ============================================================

def detect_faces(img):
    h, w = img.shape[:2]
    net = cv2.FaceDetectorYN.create(
        model=str(FACE_MODEL_PATH),
        config="",
        input_size=(w, h),
        score_threshold=FACE_CONF_THRESHOLD,
        nms_threshold=0.3,
        top_k=5000
    )
    net.setInputSize((w, h))
    _, faces = net.detect(img)

    results = []
    if faces is not None:
        for f in faces:
            if float(f[4]) >= FACE_CONF_THRESHOLD:
                x, y, bw, bh = map(int, f[:4])
                results.append((x, y, x + bw, y + bh))
    return results


def detect_persons(img, yolo):
    res = yolo(img, verbose=False)[0]
    persons = []
    for box, cls, conf in zip(res.boxes.xyxy, res.boxes.cls, res.boxes.conf):
        if int(cls) == 0 and conf >= PERSON_CONF_THRESHOLD:
            persons.append(tuple(map(int, box.tolist())))
    return persons


def filter_large_persons(persons, img):
    return [
        p for p in persons
        if box_area_ratio(p, img) >= PERSON_MIN_AREA_RATIO
    ]


def person_confirmed(faces, persons):
    for f in faces:
        for p in persons:
            if face_inside_person(f, p):
                return True
    return False


# ============================================================
# FINAL DECISION
# ============================================================

def should_block(image_path):
    img = cv2.imread(image_path)
    if img is None:
        raise RuntimeError("Bild konnte nicht geladen werden")

    # 1) NUDITY (absolute Priorität)
    nude_hits = NudeDetector().detect(image_path) or []
    for d in nude_hits:
        if d.get("class") in NUDE_CLASSES and d.get("score", 0) >= NUDE_THRESHOLD:
            return True, "NUDITY_DETECTED"

    # 2) CLIP
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = CLIPModel.from_pretrained(
        "openai/clip-vit-large-patch14",
        local_files_only=True
    ).to(device)

    processor = CLIPProcessor.from_pretrained(
        "openai/clip-vit-large-patch14",
        local_files_only=True,
        use_fast=False
    )

    ctx = run_clip(image_path, model, processor, device)
    verdict = decide_real_vs_recapture(ctx)

    # DEBUG
    print("\nTop predictions:")
    for l, p in ctx["ranked"][:5]:
        print(f"  {p:6.3f}  {l}")

    print("\nScores:")
    print(f"  real_score      = {ctx['real_score']:.3f}")
    print(f"  recapture_score = {ctx['recapture_score']:.3f}")
    print(f"  best_real       = {ctx['best_real']:.3f}")
    print(f"  best_recapture  = {ctx['best_recapture']:.3f}")
    print(f"\nVerdict: {'ECHTES FOTO' if verdict == 'REAL' else 'RECAPTURE'}\n")

    if verdict != "REAL":
        return False, "ALLOWED_NON_REAL"

    # 3) PERSON CONFIRMATION
    faces = detect_faces(img)

    yolo = YOLO("yolov8n.pt")
    persons = detect_persons(img, yolo)
    persons = filter_large_persons(persons, img)

    confirmed = person_confirmed(faces, persons)

    is_print_object = (
        ctx["clip_print"] >= CLIP_PRINT_HINT_THRESHOLD
        and ctx["real_env"] < CLIP_REAL_ENV_THRESHOLD
    )

    print("Faces:", faces)
    print("Persons (filtered):", persons)
    print("Confirmed Person:", confirmed)
    print("Print Object:", is_print_object)

    if confirmed and not is_print_object:
        return True, "REAL_PERSON_DETECTED"

    return False, "ALLOWED_REAL_NO_PERSON"


# ============================================================
# CLI
# ============================================================

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("image")
    args = ap.parse_args()

    blocked, reason = should_block(args.image)
    print("⛔ BLOCKIERT" if blocked else "✅ ERLAUBT", ":", reason)


if __name__ == "__main__":
    main()
