from __future__ import annotations
import json
import numpy as np
import torch
import torch.nn.functional as F

def accuracy(logits: torch.Tensor, y: torch.Tensor) -> float:
    preds = torch.argmax(logits, dim=1)
    return (preds == y).float().mean().item()

def robustness_gap(clean_acc: float, adv_acc: float) -> float:
    return float(clean_acc - adv_acc)

def stability_variance(probs_list: list[np.ndarray]) -> float:
    # variance of predicted probability for class 1 across repeated noisy runs
    arr = np.stack(probs_list, axis=0)  # [k, n]
    return float(np.var(arr, axis=0).mean())

def softmax_probs(logits: torch.Tensor) -> np.ndarray:
    return F.softmax(logits, dim=1)[:, 1].detach().cpu().numpy()

def save_metrics(path: str, metrics: dict):
    with open(path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)
