from __future__ import annotations
import numpy as np
from qiskit import QuantumCircuit
from qiskit_aer import AerSimulator
from qiskit_aer.noise import NoiseModel, depolarizing_error
from qiskit.quantum_info import Statevector

def build_param_circuit(x_angles: np.ndarray, weights: np.ndarray) -> QuantumCircuit:
    """Simple 2–3 qubit circuit for noisy simulation (not differentiable end-to-end).

    x_angles: [n_qubits]
    weights:  [n_layers, n_qubits, 3]
    """
    n_layers, n_qubits, _ = weights.shape
    qc = QuantumCircuit(n_qubits)
    for i in range(n_qubits):
        qc.ry(float(x_angles[i]), i)
        qc.rz(float(0.5 * x_angles[i]), i)

    for l in range(n_layers):
        for i in range(n_qubits):
            qc.rx(float(weights[l, i, 0]), i)
            qc.ry(float(weights[l, i, 1]), i)
            qc.rz(float(weights[l, i, 2]), i)
        for i in range(n_qubits - 1):
            qc.cx(i, i + 1)
        qc.cx(n_qubits - 1, 0)
    return qc

def noise_model_depolarizing(p: float = 0.01, n_qubits: int = 3) -> NoiseModel:
    nm = NoiseModel()
    err1 = depolarizing_error(p, 1)
    err2 = depolarizing_error(min(2*p, 0.25), 2)
    # apply to typical gates
    for g in ["rx", "ry", "rz"]:
        nm.add_all_qubit_quantum_error(err1, g)
    nm.add_all_qubit_quantum_error(err2, "cx")
    return nm

def expectations_noisy(
    x_angles: np.ndarray,
    weights: np.ndarray,
    p: float = 0.01,
    shots: int = 2048,
) -> np.ndarray:
    """Return Pauli-Z expectations per qubit using Aer noise simulation."""
    qc = build_param_circuit(x_angles, weights)
    n_qubits = qc.num_qubits
    nm = noise_model_depolarizing(p=p, n_qubits=n_qubits)
    sim = AerSimulator(noise_model=nm)
    # measure
    meas = qc.copy()
    meas.measure_all()
    job = sim.run(meas, shots=shots)
    counts = job.result().get_counts()

    # Compute expectations from bitstring counts
    exp = np.zeros(n_qubits, dtype=float)
    for bitstring, c in counts.items():
        # qiskit returns bitstrings little-endian; reverse for wire order
        bs = bitstring[::-1]
        for i in range(n_qubits):
            z = 1.0 if bs[i] == "0" else -1.0
            exp[i] += z * (c / shots)
    return exp
