import argparse
import math
from pathlib import Path

import torch
from PIL import Image
from transformers import CLIPModel, CLIPProcessor


# =========================
# PROMPT-GRUPPEN (FIX)
# =========================

REAL_CONTEXT = [
    "a photo of a real scene in a physical environment",
    "a photo of a real object on a table in a room",
]

BENIGN_OBJECTS = [
    "a photo of a computer screen",
    "a photo of a phone screen",
]

DECEPTIVE = [
    "a screenshot",
    "a photo of a photo",
    "a photo of a printed photo",
    "a photo of a printed paper",
]

# Feste Reihenfolge (wichtig!)
LABELS = REAL_CONTEXT + BENIGN_OBJECTS + DECEPTIVE


# =========================
# HILFSFUNKTION
# =========================

def softmax(xs):
    m = max(xs)
    exps = [math.exp(x - m) for x in xs]
    s = sum(exps)
    return [e / s for e in exps]


# =========================
# MAIN
# =========================

def main():
    ap = argparse.ArgumentParser("CLIP Recapture Detection – Final Policy")
    ap.add_argument("--image", "-i", required=True, help="Pfad zum Bild")
    ap.add_argument("--model", default="openai/clip-vit-large-patch14")
    ap.add_argument("--topk", type=int, default=5)
    args = ap.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"

    model = CLIPModel.from_pretrained(args.model).to(device)
    processor = CLIPProcessor.from_pretrained(args.model)

    img_path = Path(args.image)
    image = Image.open(img_path).convert("RGB")

    inputs = processor(
        text=LABELS,
        images=image,
        return_tensors="pt",
        padding=True
    ).to(device)

    with torch.no_grad():
        logits = model(**inputs).logits_per_image[0].cpu().tolist()

    probs = softmax(logits)
    ranked = sorted(zip(LABELS, probs), key=lambda x: x[1], reverse=True)

    print(f"\nImage: {img_path}\n")
    print("Top predictions:")
    for label, p in ranked[: args.topk]:
        print(f"  {p:6.3f}  {label}")

    # =========================
    # SCORE-BERECHNUNG
    # =========================

    deceptive_score = sum(p for l, p in ranked if l in DECEPTIVE)

    print("\nScores:")
    print(f"  deceptive_score = {deceptive_score:.3f}")

    # =========================
    # FINALE POLICY
    # =========================

    if deceptive_score >= 0.50:
        verdict = "BLOCK (täuschende Darstellung)"
    elif deceptive_score <= 0.10:
        verdict = "ACCEPT (echtes Foto)"
    else:
        verdict = "UNSICHER (Grenzfall)"

    print(f"\nVerdict: {verdict}\n")


if __name__ == "__main__":
    main()
