import argparse
import math
from pathlib import Path

import torch
from PIL import Image, ImageOps
from transformers import CLIPModel, CLIPProcessor


# ============================================================
# PROMPTS
# ============================================================

REAL_LABELS_EN = [
    "a photo of a real scene in a physical environment",
    "a photo of a real indoor room",
    "a photo of a real outdoor scene",
    "a photo of real objects in a room",
    "a real photograph taken by a person in the real world",
]

RECAPTURE_LABELS_EN = [
    "a screenshot",
    "a photo that is entirely a computer screen",
    "a photo that is entirely a phone screen",
    "a photo of a printed photo",
    "a photo of printed paper",
    "a photo of a photo",
]

REAL_LABELS_DE = [
    "ein echtes foto in einer realen umgebung",
    "ein echtes foto in einem raum",
    "ein echtes foto draußen",
    "ein foto von echten objekten in einem raum",
    "ein echtes foto, das von einer person aufgenommen wurde",
]

RECAPTURE_LABELS_DE = [
    "ein screenshot",
    "ein foto, das vollständig einen computerbildschirm zeigt",
    "ein foto, das vollständig einen handybildschirm zeigt",
    "ein foto eines ausgedruckten fotos",
    "ein foto von bedrucktem papier",
    "ein foto von einem foto",
]


# ============================================================
# UTILS
# ============================================================

def softmax(xs):
    m = max(xs)
    exps = [math.exp(x - m) for x in xs]
    s = sum(exps)
    return [e / s for e in exps]


# ============================================================
# MAIN
# ============================================================

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--image", "-i", required=True)
    ap.add_argument("--lang", choices=["en", "de"], default="en")
    ap.add_argument("--model", default="openai/clip-vit-large-patch14")
    args = ap.parse_args()

    if args.lang == "en":
        real_labels = REAL_LABELS_EN
        recapture_labels = RECAPTURE_LABELS_EN
    else:
        real_labels = REAL_LABELS_DE
        recapture_labels = RECAPTURE_LABELS_DE

    labels = real_labels + recapture_labels

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = CLIPModel.from_pretrained(
    args.model,
    local_files_only=True
    ).to(device)

    processor = CLIPProcessor.from_pretrained(
        args.model,
        local_files_only=True,
        use_fast=False
    )


    image = ImageOps.exif_transpose(Image.open(args.image)).convert("RGB")

    inputs = processor(
        text=labels,
        images=image,
        return_tensors="pt",
        padding=True
    ).to(device)

    with torch.no_grad():
        logits = model(**inputs).logits_per_image[0].cpu().tolist()

    probs = softmax(logits)

    ranked = sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)

    real_score = sum(p for l, p in ranked if l in real_labels)
    recapture_score = sum(p for l, p in ranked if l in recapture_labels)

    best_real = max((p for l, p in ranked if l in real_labels), default=0.0)
    best_recapture = max((p for l, p in ranked if l in recapture_labels), default=0.0)

    print("\nTop predictions:")
    for l, p in ranked[:5]:
        print(f"  {p:6.3f}  {l}")

    print("\nScores:")
    print(f"  real_score      = {real_score:.3f}")
    print(f"  recapture_score = {recapture_score:.3f}")
    print(f"  best_real       = {best_real:.3f}")
    print(f"  best_recapture  = {best_recapture:.3f}")

    # ========================================================
    # ENTSCHEIDUNG (DIE EINZIGE REGEL, DIE ZÄHLT)
    # ========================================================

    # FINALE ENTSCHEIDUNG (keine UNKLAR-Fälle mehr)

    # 1) Reale Umgebung klar erkennbar → ECHT
    if best_real >= 0.20:
        verdict = "ECHTES FOTO (reale Umgebung vorhanden)"

    # 2) Klarer Recapture → RECATURE
    elif best_recapture >= 0.40:
        verdict = "RECATURE (keine reale Umgebung, nur Reproduktion)"

    # 3) Alles andere → RECATURE (Grenzfall = eher fake)
    else:
        verdict = "RECATURE (Grenzfall, eher fake)"


    print(f"\nVerdict: {verdict}\n")


if __name__ == "__main__":
    main()
