from __future__ import annotations

import os
import argparse
import json

import numpy as np
import torch

from hqai.config import TrainConfig
from hqai.data import build_datasets, loaders
from hqai.model import HybridQuantumNeuralClassifier
from hqai.train import (
    set_seed,
    train_one_epoch,
    eval_clean,
    eval_adversarial,
    eval_stability,
)
from hqai.metrics import robustness_gap, save_metrics
from hqai.plots import plot_performance, plot_robustness


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description="Run HQAI CyberLab end-to-end: data -> train -> evaluate -> plots"
    )
    p.add_argument("--epochs", type=int, default=10, help="Training epochs")
    p.add_argument(
        "--attack",
        type=str,
        default="pgd",
        choices=["none", "fgsm", "pgd"],
        help="Adversarial attack used during evaluation",
    )
    p.add_argument("--seed", type=int, default=1337, help="Random seed")
    p.add_argument(
        "--device",
        type=str,
        default="auto",
        choices=["auto", "cpu", "cuda"],
        help="Compute device selection",
    )
    p.add_argument(
        "--outdir",
        type=str,
        default=None,
        help="Output directory (default: ./outputs at repo root)",
    )
    return p.parse_args()


def _resolve_device(device_arg: str) -> str:
    if device_arg == "auto":
        return "cuda" if torch.cuda.is_available() else "cpu"
    return device_arg


def main() -> None:
    args = parse_args()
    set_seed(args.seed)

    device = _resolve_device(args.device)

    # Build config (defaults live in TrainConfig)
    cfg = TrainConfig(epochs=args.epochs)

    # Build datasets/loaders
    df, (X_train, y_train), (X_test, y_test) = build_datasets(cfg)
    train_loader, test_loader = loaders(cfg, X_train, y_train, X_test, y_test)

    # Model
    model = HybridQuantumNeuralClassifier(cfg).to(device)

    # Training
    train_losses = []
    for epoch in range(cfg.epochs):
        loss = train_one_epoch(cfg, model, train_loader, device=device)
        train_losses.append(float(loss))

    # Clean evaluation
    clean_acc = float(eval_clean(cfg, model, test_loader, device=device))

    # Adversarial evaluation
    adv_acc = float(
        eval_adversarial(cfg, model, test_loader, attack=args.attack, device=device)
    )

    # Stability evaluation (quantum-aware perturbations/noise)
    stability_var = float(eval_stability(cfg, model, test_loader, device=device))

    # Metrics summary
    metrics = {
        "epochs": int(cfg.epochs),
        "attack": str(args.attack),
        "seed": int(args.seed),
        "device": str(device),
        "clean_accuracy": clean_acc,
        "adversarial_accuracy": adv_acc,
        "stability_variance": stability_var,
        "robustness_gap": float(robustness_gap(clean_acc, adv_acc)),
        "noise_model": cfg.noise_model,
        "angle_sigma": cfg.angle_sigma,
        "output_sigma": cfg.output_sigma,
        "notes": (
            "Clean evaluation uses no injected noise; "
            "stability uses quantum-aware noise injection."
        ),
    }

    # Outputs
    if args.outdir:
        outdir = os.path.abspath(args.outdir)
    else:
        # Repo root outputs folder
        outdir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "outputs"))

    os.makedirs(outdir, exist_ok=True)

    save_metrics(os.path.join(outdir, "metrics.json"), metrics)
    plot_performance(outdir, train_losses, clean_acc, adv_acc)
    plot_robustness(outdir, clean_acc, adv_acc, stability_var)

    print("\nSaved outputs to:", outdir)
    print(json.dumps(metrics, indent=2))


if __name__ == "__main__":
    main()
