from __future__ import annotations
import os
import matplotlib.pyplot as plt

def plot_performance(outdir: str, train_losses: list[float], clean_acc: float, adv_acc: float):
    os.makedirs(outdir, exist_ok=True)
    plt.figure()
    plt.plot(train_losses)
    plt.xlabel("Epoch")
    plt.ylabel("Training Loss")
    plt.title("Training Curve")
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "training_curve.png"))
    plt.close()

    plt.figure()
    plt.bar(["Clean", "Adversarial"], [clean_acc, adv_acc])
    plt.ylabel("Accuracy")
    plt.title("Performance Under Attack")
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "performance.png"))
    plt.close()

def plot_robustness(outdir: str, clean_acc: float, adv_acc: float, stability_var: float):
    os.makedirs(outdir, exist_ok=True)
    plt.figure()
    plt.bar(["Robustness Gap", "Stability Variance"], [clean_acc - adv_acc, stability_var])
    plt.title("Robustness Metrics")
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "robustness.png"))
    plt.close()
