from __future__ import annotations
import numpy as np
import pandas as pd
from dataclasses import dataclass
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch
from torch.utils.data import Dataset, DataLoader

@dataclass
class DatasetBundle:
    X_train: np.ndarray
    y_train: np.ndarray
    X_val: np.ndarray
    y_val: np.ndarray
    scaler: StandardScaler

class NumpyDataset(Dataset):
    def __init__(self, X: np.ndarray, y: np.ndarray):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self) -> int:
        return self.X.shape[0]

    def __getitem__(self, idx: int):
        return self.X[idx], self.y[idx]

def make_synthetic_ids(
    n: int = 4000,
    n_features: int = 6,
    seed: int = 7,
    attack_ratio: float = 0.35,
) -> tuple[np.ndarray, np.ndarray]:
    """Create a small, explainable synthetic IDS-like dataset.

    Features are generic telemetry-style signals:
    - burstiness, entropy proxy, port rarity proxy, auth failures, process anomaly, lateral-move indicator
    Labels: 0=benign, 1=malicious
    """
    rng = np.random.default_rng(seed)
    X = rng.normal(0.0, 1.0, size=(n, n_features))

    # Create a malicious cluster by shifting a subset of dimensions
    y = (rng.random(n) < attack_ratio).astype(int)
    shift = np.zeros(n_features)
    # emphasize a few features typical of suspicious behavior
    suspicious_dims = [0, 1, 3, 5] if n_features >= 6 else list(range(min(3, n_features)))
    shift[suspicious_dims] = rng.uniform(0.8, 1.8, size=len(suspicious_dims))

    X[y == 1] = X[y == 1] + shift + rng.normal(0.0, 0.25, size=(y.sum(), n_features))
    # benign slightly tighter distribution
    X[y == 0] = X[y == 0] + rng.normal(0.0, 0.15, size=((y == 0).sum(), n_features))

    return X, y

def build_datasets(
    n: int = 4000,
    n_features: int = 6,
    seed: int = 7,
    test_size: float = 0.25,
) -> DatasetBundle:
    X, y = make_synthetic_ids(n=n, n_features=n_features, seed=seed)
    X_train, X_val, y_train, y_val = train_test_split(
        X, y, test_size=test_size, random_state=seed, stratify=y
    )
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_val = scaler.transform(X_val)
    return DatasetBundle(X_train, y_train, X_val, y_val, scaler)

def loaders(bundle: DatasetBundle, batch_size: int = 64):
    train_ds = NumpyDataset(bundle.X_train, bundle.y_train)
    val_ds = NumpyDataset(bundle.X_val, bundle.y_val)
    return (
        DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=False),
        DataLoader(val_ds, batch_size=batch_size, shuffle=False, drop_last=False),
    )

def export_csv(bundle: DatasetBundle, path: str):
    df_train = pd.DataFrame(bundle.X_train, columns=[f"f{i}" for i in range(bundle.X_train.shape[1])])
    df_train["label"] = bundle.y_train
    df_val = pd.DataFrame(bundle.X_val, columns=[f"f{i}" for i in range(bundle.X_val.shape[1])])
    df_val["label"] = bundle.y_val
    df_train.to_csv(path.replace(".csv", "_train.csv"), index=False)
    df_val.to_csv(path.replace(".csv", "_val.csv"), index=False)
