from __future__ import annotations
import os
import numpy as np
import torch
import torch.nn.functional as F
from .attacks import fgsm_attack, pgd_attack
from .metrics import accuracy, softmax_probs, stability_variance

def set_seed(seed: int):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def train_one_epoch(model, loader, opt, angle_sigma: float, output_sigma: float, device: str):
    model.train()
    total_loss = 0.0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        opt.zero_grad()
        logits = model(x, angle_sigma=angle_sigma, output_sigma=output_sigma)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        opt.step()
        total_loss += loss.item() * x.size(0)
    return total_loss / len(loader.dataset)

@torch.no_grad()
def eval_clean(model, loader, angle_sigma: float, output_sigma: float, device: str):
    model.eval()
    accs = []
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x, angle_sigma=angle_sigma, output_sigma=output_sigma)
        accs.append(accuracy(logits, y))
    return float(np.mean(accs))

def eval_adversarial(model, loader, attack: str, eps: float, pgd_alpha: float, pgd_steps: int,
                    angle_sigma: float, output_sigma: float, device: str):
    model.eval()
    accs = []
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        if attack == "fgsm":
            x_adv = fgsm_attack(model, x, y, eps=eps, angle_sigma=angle_sigma, output_sigma=output_sigma)
        elif attack == "pgd":
            x_adv = pgd_attack(model, x, y, eps=eps, alpha=pgd_alpha, steps=pgd_steps,
                               angle_sigma=angle_sigma, output_sigma=output_sigma)
        else:
            x_adv = x
        with torch.no_grad():
            logits = model(x_adv, angle_sigma=angle_sigma, output_sigma=output_sigma)
            accs.append(accuracy(logits, y))
    return float(np.mean(accs))

@torch.no_grad()
def eval_stability(model, loader, repeats: int, angle_sigma: float, output_sigma: float, device: str):
    """Stability under repeated quantum-aware noise.
    We measure the variance of predicted malicious probability under repeated noisy forward passes.
    """
    model.eval()
    probs_runs = []
    for _ in range(repeats):
        probs_all = []
        for x, _ in loader:
            x = x.to(device)
            logits = model(x, angle_sigma=angle_sigma, output_sigma=output_sigma)
            probs_all.append(softmax_probs(logits))
        probs_runs.append(np.concatenate(probs_all, axis=0))
    return stability_variance(probs_runs)
