import os
import sys
import cv2
import numpy as np
from scipy.fft import fft2, fftshift
from scipy.stats import entropy

# Noiseprint-PyTorch Portierung
sys.path.append(os.path.join(os.path.dirname(__file__), "noiseprint_pytorch"))
from Noiseprint import getNoiseprint  # erwartet Bildpfad

# -----------------------------
# Bild helpers
# -----------------------------
def load_image(path):
    """Lädt ein Graustufenbild für Moiré-Analyse"""
    img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        raise ValueError(f"Bild konnte nicht geladen werden: {path}")
    img = cv2.resize(img, (512, 512))
    return img.astype(np.float32) / 255.0

# -----------------------------
# Noiseprint extrahieren
# -----------------------------
def extract_noiseprint(image_path):
    """
    Übergibt den Bildpfad direkt an getNoiseprint.
    Liefert ein numpy-array.
    """
    noise = getNoiseprint(image_path)
    return noise

# -----------------------------
# Forensische Features
# -----------------------------
def moire_score(img):
    f = fftshift(np.abs(fft2(img)))
    f /= np.max(f) + 1e-8
    return np.mean(f[f > 0.8])

def noise_energy(noise):
    return np.var(noise)

def noise_entropy(noise):
    hist, _ = np.histogram(noise, bins=256, density=True)
    return entropy(hist + 1e-12)

# -----------------------------
# Recapture Detection
# -----------------------------
def detect_recapture(image_path):
    # Graustufenbild für Moiré
    img = load_image(image_path)
    # Noiseprint für Rauschmerkmale
    noise = extract_noiseprint(image_path)

    # Feature Scores
    s_moire   = moire_score(img)
    s_energy  = noise_energy(noise)
    s_entropy = noise_entropy(noise)

    # Flags
    flags = 0
    if s_moire > 0.12:
        flags += 1
    if s_energy < 0.0005:
        flags += 1
    if s_entropy < 4.0:
        flags += 1

    result = "RECAPTURED" if flags >= 2 else "REAL"

    return {
        "result": result,
        "moire_score": round(float(s_moire), 4),
        "noise_energy": round(float(s_energy), 6),
        "noise_entropy": round(float(s_entropy), 3),
        "flags": flags
    }

# -----------------------------
# CLI Entry
# -----------------------------
if __name__ == "__main__":
    if len(sys.argv) != 2:
        print("Usage: python recapture_detector.py <image.jpg>")
        sys.exit(1)

    image_path = sys.argv[1]
    if not os.path.isfile(image_path):
        print(f"Bild nicht gefunden: {image_path}")
        sys.exit(1)

    res = detect_recapture(image_path)
    print(res)
