from __future__ import annotations
import torch
import torch.nn as nn
from .quantum_pennylane import vqc_expectations, inject_angle_noise

class HybridQuantumNeuralClassifier(nn.Module):
    """Hybrid model:
    X -> classical projection -> angles (n_qubits) -> VQC expectations -> MLP -> logits
    """
    def __init__(self, n_features: int, n_qubits: int, n_layers: int, hidden_dim: int = 32):
        super().__init__()
        self.n_features = n_features
        self.n_qubits = n_qubits
        self.n_layers = n_layers

        self.to_angles = nn.Sequential(
            nn.Linear(n_features, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_qubits),
        )

        # trainable quantum weights [n_layers, n_qubits, 3]
        self.q_weights = nn.Parameter(0.01 * torch.randn(n_layers, n_qubits, 3))

        self.head = nn.Sequential(
            nn.Linear(n_qubits, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2),
        )

    def forward(self, x: torch.Tensor, angle_sigma: float = 0.0, output_sigma: float = 0.0):
        angles = self.to_angles(x)
        angles = torch.tanh(angles) * 3.14159  # keep bounded
        angles = inject_angle_noise(angles, sigma=angle_sigma)
        q = vqc_expectations(angles, self.q_weights, n_qubits=self.n_qubits, output_sigma=output_sigma)
        logits = self.head(q)
        return logits
