diff --git a/_static/demonstration_assets/bp_catalyst/css.svg b/_static/demonstration_assets/bp_catalyst/css.svg new file mode 100644 index 0000000000..2d26c8b23c --- /dev/null +++ b/_static/demonstration_assets/bp_catalyst/css.svg @@ -0,0 +1,4 @@ + + + +
Decoder
Error


Decoder

\ No newline at end of file diff --git a/_static/demonstration_assets/bp_catalyst/qec.svg b/_static/demonstration_assets/bp_catalyst/qec.svg new file mode 100644 index 0000000000..a26de1e73d --- /dev/null +++ b/_static/demonstration_assets/bp_catalyst/qec.svg @@ -0,0 +1,4 @@ + + + +
Decoder
Error


\ No newline at end of file diff --git a/demonstrations/tutorial_bp_catalyst.metadata.json b/demonstrations/tutorial_bp_catalyst.metadata.json new file mode 100644 index 0000000000..f0a7b747ed --- /dev/null +++ b/demonstrations/tutorial_bp_catalyst.metadata.json @@ -0,0 +1,102 @@ +{ + "title": "Decoding Quantum Errors on the Steane code with Belief Propagation & Catalyst", + "authors": [ + { + "username": "tomginsberg" + } + ], + "dateOfPublication": "2025-05-14T00:00:00+00:00", + "dateOfLastModification": "2025-05-14T00:00:00+00:00", + "categories": ["Quantum Computing"], + "tags": ["quantum error correction", "css codes", "belief propagation"], + "previewImages": [ + { + "type": "thumbnail", + "uri": "/_static/demonstration_assets/neutral_atoms/thumbnail_tutorial_neutral_atoms.png" + }, + { + "type": "large_thumbnail", + "uri": "/_static/demo_thumbnails/large_demo_thumbnails/thumbnail_large_neutral_atoms.png" + } + + ], + "seoDescription": "Learn how to decode CSS codes using belief propagation with JAX, Pennylane and Catalyst", + "doi": "", + "references": [ + { + "id": "Pesah2023", + "type": "webpage", + "title": "The stabilizer trilogy I — Stabilizer codes", + "authors": "Arthur Pesah", + "year": "2023", + "url": "https://arthurpesah.me/blog/2023-01-31-stabilizer-formalism-1/" + }, + { + "id": "Wiberg2001", + "type": "phdthesis", + "title": "Codes and Decoding on General Graphs", + "authors": "Niclas Wiberg", + "year": "2001", + "url": "https://www.essrl.wustl.edu/~jao/itrg/wiberg.pdf" + }, + { + "id": "Panteleev2019", + "type": "preprint", + "title": "Degenerate Quantum LDPC Codes With Good Finite Length Performance", + "authors": "Pavel Panteleev", + "year": "2019", + "url": "https://arxiv.org/abs/1904.02703v3" + }, + { + "id": "Hillmann2024", + "type": "preprint", + "title": "Localized statistics decoding: A parallel decoding algorithm for quantum low-density parity-check codes", + "authors": "Timo Hillmann", + "year": "2024", + "url": "https://arxiv.org/abs/2406.18655v1" + }, + { + "id": "Wolanski2024", + "type": "preprint", + "title": "Ambiguity Clustering: an accurate and efficient decoder for qLDPC codes", + "authors": "Stasiu Wolanski", + "year": "2024", + "url": "https://arxiv.org/abs/2406.14527v2" + }, + { + "id": "Loeliger2004", + "type": "article", + "title": "An introduction to factor graphs", + "authors": "Hans-Andrea Loeliger", + "year": "2004", + "url": "https://www.isiweb.ee.ethz.ch/papers/arch/aloe-2004-spmagffg.pdf" + }, + { + "id": "Barber2012", + "type": "book", + "title": "Bayesian Reasoning and Machine Learning", + "authors": "David Barber", + "year": "2012", + "url": "http://web4.cs.ucl.ac.uk/staff/D.Barber/textbook/180325.pdf#page=107.50" + } + ], + "basedOnPapers": [], + "referencedByPapers": [], + "relatedContent": [ + { + "type": "demonstration", + "id": "tutorial_magic_state_distillation", + "weight": 1 + }, + { + "type": "demonstration", + "id": "tutorial_mcm_introduction", + "weight": 1 + }, + { + "type": "demonstration", + "id": "tutorial_qjit_compile_grovers_algorithm_with_catalyst", + "weight": 1 + } + ] +} \ No newline at end of file diff --git a/demonstrations/tutorial_bp_catalyst.py b/demonstrations/tutorial_bp_catalyst.py new file mode 100644 index 0000000000..794677a3d8 --- /dev/null +++ b/demonstrations/tutorial_bp_catalyst.py @@ -0,0 +1,1095 @@ +r"""Decoding Quantum Errors on the Steane code with Belief Propagation and Catalyst +============================================================================= + +*Learn how to build, simulate, and decode the Steane code using JAX and Catalyst, blending quantum +circuits with fast classical decoders in a seamless workflow.* + +-------------- + +Introduction +------------ + +This tutorial walks you through a simplified error correction cycle using the Steane $[[7,1,3]]$ +code. You’ll encode a logical qubit, introduce noise, extract syndromes, and apply decoding using +two different strategies: a simple lookup table and a belief propagation (BP) decoder. Both decoders +are implemented in JAX and JIT-compiled by Catalyst, allowing everything to run inside a single +``@qml.qjit`` circuit. + +Why is this exciting? Quantum error correction (QEC) is essential for building reliable quantum +computers, but it requires more than just quantum operations. Fast classical feedback is needed as well. +Catalyst addresses this by fusing the classical and quantum workflows, giving us one unified, +hardware-agnostic kernel that runs on CPUs, GPUs, and beyond. + +What We’ll Build +---------------- + +By the end of this tutorial, you’ll have: + +- **Encoded** a logical $|0⟩$ using the Steane code. +- **Simulated noise** using configurable bit-flip and phase-flip channels. +- **Extracted syndromes** via ancilla-assisted stabilizer measurements. +- **Decoded errors** using: + + - a **Lookup Table (LUT)** decoder + - a **Belief Propagation (BP)** decoder + +- **Benchmarked performance** across different physical error rates. +""" + +###################################################################### +# Understanding Quantum Error Correction +# -------------------------------------- +# +# At its core, QEC protects quantum information through redundant encoding. Stabilizer codes are a +# foundational tool for this purpose. Each stabilizer is a multi-qubit Pauli operator (composed of I, +# X, Y, Z gates) that defines parity constraints the code space must satisfy. +# +# Symplectic Representation +# ~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# To formalize stabilizers and errors, we use **symplectic vectors**. For :math:`n` physical qubits, +# any :math:`n`-qubit Pauli operator can be represented as a binary vector of length :math:`2n`: +# +# .. math:: +# +# +# (v | u) = (v_1, v_2, ..., v_n | u_1, u_2, ..., u_n), +# +# where: +# +# - :math:`v_i = 1` if the Pauli has an X on qubit :math:`i` (0 otherwise), +# - :math:`u_i = 1` if the Pauli has a Z on qubit :math:`i` (0 otherwise), +# - and :math:`v_i = u_i = 1` if the Pauli has a Y on qubit :math:`i`, since :math:`X \cdot Z \propto Y`. +# +# For example, the operator :math:`XZIY` on 4 qubits corresponds to :math:`(1,0,0, 1 | 0,1,0, 1)`. +# The stabilizer group is then generated by a set of such symplectic vectors, forming a **stabilizer +# matrix** of size :math:`m \times 2n` (with :math:`m` generators). +# +# The **commutation condition** between two Pauli operators :math:`(v|u)` and :math:`(v'|u')` is +# captured by their **symplectic inner product**: +# +# .. math:: +# +# +# v \cdot u' + u \cdot v' \pmod{2}. +# +# For a valid stabilizer code, all generators must commute, which implies that the symplectic inner product between any +# two rows of the stabilizer matrix must be zero. +# +# Measurement and Syndrome +# ~~~~~~~~~~~~~~~~~~~~~~~~ +# +# When an error :math:`e` occurs, represented as a symplectic vector :math:`(v_e | u_e)`, the syndrome +# is calculated by taking the symplectic inner product of :math:`e` with each stabilizer generator. +# This produces a syndrome bit :math:`s_i` for each generator: +# +# .. math:: +# +# +# s_i = v_g^{(i)} \cdot u_e + u_g^{(i)} \cdot v_e \pmod{2}, +# +# where :math:`(v_g^{(i)} | u_g^{(i)})` is the :math:`i`-th generator. +# +# Below, we show the basic quantum circuit for extracting a syndrome value. The ancilla qubit is +# initialized in the :math:`|+\rangle` state, and controlled operations are applied based on the +# stabilizer generators. Finally, the ancilla is measured in the :math:`X`-basis to obtain the +# syndrome value. Check out Arthur Pesah’s excellent blog post series [#Pesah]_ on Stabilizer codes +# for a deeper introduction. +# + +from typing import Callable, Optional, Dict, Union, Sequence + +import pennylane as qml + +syndromes = ["XXZIZIX", "XXIIZZI"] +dev = qml.device("lightning.qubit", wires := max(map(len, syndromes)) + 1) + + +@qml.qnode(device=dev) +def ancilla_assisted_syndrome_extraction(syndromes: list[str]): + ancilla = wires - 1 + for i, syndrome in enumerate(syndromes): + qml.Hadamard(ancilla) + for i, s in enumerate(syndrome): + if s == "X": + qml.CNOT(wires=[ancilla, i]) + elif s == "Z": + qml.CZ(wires=[ancilla, i]) + qml.Hadamard(ancilla) + qml.measure(ancilla) + qml.Barrier() + ancilla += 1 + + +print(qml.draw(ancilla_assisted_syndrome_extraction, show_all_wires=True)(syndromes)) + +###################################################################### +# Decoding: The Classical Half of QEC +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Once you have syndrome bits from your stabilizer measurements, you need to figure out what error +# likely occurred-this is the job of the decoder. Formally, given the syndrome, you’re solving for +# the most probable error, usually called the maximum likelihood estimate (MLE) for the error. +# +# However, exact MLE decoding depends on the precise information of your noise model and is generally +# computationally intractable (NP-Hard) because :math:`n` one-bit syndrome measurements can take on +# :math:`2^n` unique values. In practice, we rely on approximate methods tuned to assumptions about the +# noise model. +# +# .. figure:: ../_static/demonstration_assets/bp_catalyst/css.svg +# :align: center +# :width: 70% +# :alt: A complete quantum error correction (QEC) cycle. A logical state (:math: `|\psi\rangle_L`), experiences an error before stabilizers (:math: `\mathcal{S}`) are measured. The resulting syndrome is decoded classically to produce a correction (:math: `\mathcal{R}`), which is applied to restore the logical state +# +# CSS Codes: Simplifying the Structure +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# CSS (Calderbank–Shor–Steane) codes are a special class of stabilizer code where the generators are split +# into **X-type** and **Z-type** operators. Their symplectic vectors look like this: +# +# - **X-type generator**: :math:`(v | 0)` (only Xs) +# - **Z-type generator**: :math:`(0 | u)` (only Zs) +# +# This allows us to represent the stabilizers with two :math:`m \times n` **matrices**: +# +# - :math:`H_X` for X-type generators +# - :math:`H_Z` for Z-type generators +# +# The commutation condition to ensure that all generators are simultaneously observable is: +# +# .. math:: +# +# +# H_X H_Z^T = 0 \pmod{2}, +# +# which ensures that all X and Z stabilizers commute pairwise. When measuring syndromes: +# +# - X-type stabilizers detect **Z errors** via :math:`s_X = H_X e_Z^T \pmod{2}`. +# - Z-type stabilizers detect **X errors** via :math:`s_Z = H_Z e_X^T \pmod{2}`. +# +# This separation makes decoding modular, allowing you to handle X and Z errors independently. When we +# introduce the Steane code later, you’ll see these matrices explicitly and how they simplify syndrome +# calculation and decoding. See a similar diagram below for the CSS code cycle structure. +# +# .. figure ../_static/demonstration_assets/bp_catalyst/css.svg +# :align: center +# :width: 70% +# :alt: The Error Correction Cycle on a CSS Code +# + +###################################################################### +# The Steane Code +# --------------- +# +# The Steane code is one of the simplest quantum error correcting codes, a CSS code built from two +# classical Hamming codes. It encodes one logical qubit into seven physical qubits and can correct +# any single-qubit error. Traditionally, the error correcting ability of a code is referred as the +# distance or :math:`d` and the number of errors a code can correct is +# :math:`\lfloor (d-1)/2 \rfloor`. Since the Steane code can correct a single error, it is said +# to have distance :math:`3`. This code uses six stabilizer generators: +# +# .. math:: +# +# +# H_X = \begin{bmatrix} +# 0 & 0 & 0 & 1 & 1 & 1 & 1 \\ +# 0 & 1 & 1 & 0 & 0 & 1 & 1 \\ +# 1 & 0 & 1 & 0 & 1 & 0 & 1 +# \end{bmatrix}, \quad +# H_Z = H_X. +# +# We’ll start by implementing two decoding strategies: +# +# - **Lookup Table (LUT)**: Pre-compute minimal corrections for every syndrome (possible for small +# codes like this one). +# - **Belief Propagation (BP)**: An iterative message-passing algorithm that operates on the code’s Tanner graph +# (a bipartite graph representing the relationships between qubits and stabilizers). It approximates +# the marginal probabilities of errors on each qubit, offering greater scalability for larger, sparser codes. +# +# We’ll begin with the LUT decoder due to its simplicity and then explore BP, which is more flexible for +# larger or sparser codes. +# + +###################################################################### +# Lookup‑table (LUT) decoding +# --------------------------- +# +# For the Steane code, with :math:`3` :math:`X` and :math:`3` :math:`Z` stabilizer +# generators, there are :math:`2^3=8` possible syndromes for both :math:`X` and :math:`Z`. We can +# create a small table that maps each three‑bit syndrome to a weight‑1 error. +# + +import jax.numpy as jnp +from itertools import combinations +from jax.typing import ArrayLike +import jax +from tabulate import tabulate + + +def lookup_decoder(matrix: ArrayLike, max_weight: int = 1): + m, n = matrix.shape + lut = jnp.zeros((1 << m, n), dtype=jnp.int8) + + # fill table with the lowest‑weight correction for each syndrome + # we do this by iterating over all possible weight one errors and computing their corresponding syndromes + for w in range(1, max_weight + 1): + # iterate over all possible weight-w errors + for qs in combinations(range(n), w): + err = jnp.zeros(n, dtype=jnp.int8).at[jnp.array(qs)].set(1) # error mask + syn = (matrix @ err) % 2 # syndrome for this error + idx = jnp.dot(syn, 1 << jnp.arange(m, dtype=jnp.int8)) # syndrome bits to base 10 index + lut = lut.at[idx].set(err) + + @jax.jit + def _decode(syndrome: ArrayLike): + # convert the syndrome to base 10 and look it up in the table + idx = jnp.dot(syndrome, 1 << jnp.arange(m)) + return lut[idx] + + return _decode + + +H_steane= jnp.array( + [[0, 0, 0, 1, 1, 1, 1], [0, 1, 1, 0, 0, 1, 1], [1, 0, 1, 0, 1, 0, 1]], dtype=int +) +lut_steane= lookup_decoder(H_steane) + +# we see that the steane code has a nice property where counting up in binary shifts the error to the right +table_data = [] +for i in range(8): + decoded = lut_steane(jnp.array([int(x) for x in f"{i:03b}"])) + table_data.append([f"{i:03b}", "".join(map(str, decoded))]) + +print(tabulate(table_data, headers=["Syndrome", "LUT Error"])) + +###################################################################### +# While this approach is optimal for small codes, it rapidly becomes infeasible for larger examples. +# For instance, the distance-:math:`30` rotated surface code, which encodes only :math:`1` logical qubits, has :math:`450` +# stabilizers for both :math:`X` and :math:`Z`. A full lookup table decoder for just one check type +# would require approximately :math:`2.9\times 10^{35}` entries. +# + +###################################################################### +# Belief-Propagation (BP) Decoder +# ------------------------------- +# +# Belief propagation is an iterative message-passing algorithm used to decode errors by working on the +# **Tanner graph** [#Wiberg]_ of the code. This graph has two types of nodes: +# +# - **Variable nodes** represent the physical qubits, which may or may not have experienced an error. +# These correspond to the bits of the error vector :math:`e = (e_1, e_2, \dots, e_n)`. +# - **Check nodes** represent stabilizers, which enforce parity constraints on subsets of qubits. Each +# check node corresponds to a row of the parity-check matrix :math:`H`. +# +# There is an edge between a check node :math:`c` and a variable node :math:`v` if and only if +# :math:`H_{cv} = 1`, meaning that qubit :math:`v` participates in stabilizer :math:`c`. +# +# The goal is to estimate the probability that each qubit has been flipped (i.e., that +# :math:`e_v = 1`), given the observed syndrome bits :math:`s_c`. BP updates there beliefs iteratively by +# exchanging messages between variable and check nodes. +# +# The Sum-Product Algorithm +# ~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# The BP decoder is based on the **sum-product algorithm**, which computes marginal probabilities over +# the graph. Here’s the procedure: +# +# 1. **Initialization** +# +# Each variable node :math:`v` sends an initial message to its neighboring checks that reflects the +# *intrinsic* belief about whether an error has occurred. This is the log-likelihood ratio (LLR) +# based on the physical error rate :math:`p`: +# +# .. math:: +# +# +# L_0 = \log\frac{1 - p}{p} +# +# This expresses the prior belief: if :math:`p` is small (e.g., 0.01), then :math:`L_0` is +# positive, favoring no error; if :math:`p` is close to 0.5, :math:`L_0` is near zero (no strong +# prior). In general :math:`p` is a parameter of the algorithm that can be tuned to your specific +# noise source. +# +# 2. **Variable-to-Check Update** +# +# Each variable node updates its message to a neighboring check :math:`c` by combining its +# intrinsic belief with the incoming messages from other connected checks: +# +# .. math:: +# +# +# m_{v \to c} = L_0 + \sum_{c' \in N(v) \setminus c} m_{c' \to v} +# +# Here: +# +# - :math:`m_{v \to c}` is the message from variable :math:`v` to check :math:`c`. +# - :math:`N(v)` is the set of checks connected to variable :math:`v`. +# - :math:`m_{c' \to v}` are messages received from neighboring checks other than :math:`c`. +# +# 3. **Check-to-Variable Update** +# +# Each check node updates its message to a neighboring variable :math:`v` based on the syndrome bit +# :math:`s_c` and the incoming messages from the other variables connected to it: +# +# .. math:: +# +# +# m_{c \to v} = (-1)^{s_c} \; 2 \, \operatorname{arctanh} \biggl( \prod_{v' \in N(c) \setminus v} \tanh\frac{m_{v' \to c}}{2} \biggr) +# +# Here: +# +# - :math:`m_{c \to v}` is the message from check :math:`c` to variable :math:`v`. +# - :math:`s_c` is the syndrome bit for check :math:`c` (0 if the stabilizer is satisfied, 1 if +# violated). +# - :math:`N(c)` is the set of variables connected to check :math:`c`. +# - The :math:`\tanh` and :math:`\operatorname{arctanh}` functions implement the **sum-product +# rule** for combining binary parity checks derived from classical probability theory. +# +# **What’s going on?** This formula indicates that if the product of incoming :math:`\tanh` terms is +# close to +1 or -1, it means there is a strong belief about whether the parity is satisfied or +# violated. The :math:`\operatorname{arctanh}` converts that back into an LLR-style message. The +# :math:`(-1)^{s_c}` factor flips the sign if the syndrome is 1, signaling that a parity error was +# detected. +# +# 4. **Iteration** +# +# Steps 2 and 3 are repeated for a fixed number of iterations (e.g., 10–20) or until the messages +# converge (i.e., stop changing significantly). Traditional theory and heuristics in error +# correction say to repeat :math:`BP` roughly on the order of :math:`O(n)`. +# +# 5. **Decision Rule** +# +# After the iterations, each variable node computes its **posterior LLR** by summing its intrinsic +# belief and all incoming messages: +# +# .. math:: +# +# +# L_v = L_0 + \sum_{c \in N(v)} m_{c \to v} +# +# The decoder then makes a hard decision: +# +# - If :math:`L_v < 0`, it guesses :math:`e_v = 1` (error detected). +# - If :math:`L_v > 0`, it guesses :math:`e_v = 0` (no error). +# +# Why This Works +# ~~~~~~~~~~~~~~ +# +# Belief propagation is exact on **tree-like graphs**, where no cycles exist. However, even on Tanner +# graphs, which are never tree-like, it provides a good approximation to the maximum-likelihood decoder +# by using only local, iterative computations. Nevertheless, its performance can degrade when the Tanner +# graph contains many short cycles—a common characteristic of many popular quantum codes, which can lead to poor +# convergence. In practice, further extensions like BP-OSD [#Panteleev]_, BP-LSD [#Hillmann]_ or +# Ambiguity Clustering [#Wolanski]_ are used to fix these issues. +# +# See the following summary article [#Loeliger]_ as well as Chapter 5 in Bayesian Reasoning and +# Machine Learning [#Barber]_ for a deeper dive into message passing algorithms on graphs. +# +# BP in JAX +# ~~~~~~~~~~~~~ +# +# Below, we implement a BP decoder using Jax broken down into it’s core components. +# + +###################################################################### +# Before we can pass messages, we need to establish the connectivity between nodes. The ``_build_graph`` function scans the +# parity‑check matrix once and records, for every variable node, which checks touch it and vice‑versa. +# We convert the neighbour lists to tuples so they become immutable, hashable static data. JAX can +# then embed their values as compile‑time constants in the XLA program and reliably reuse the compiled +# kernel multiple times. A cool thing about JAX/XLA is that when using simple static parameters like the ones +# below, the individual integers it contains are baked into the XLA program as compile‑time constants, +# so we can truly compile a high performance decoder for our specific parity check matrix. +# + + +def _build_graph( + pcm: ArrayLike, +) -> tuple[tuple[tuple[int, ...], ...], tuple[tuple[int, ...], ...]]: + """ + Pre‑compute variable‑node and check‑node neighbors. + + Returns + ------- + var_neighbors : tuple[tuple[int, ...], ...] # length = n + check_neighbors : tuple[tuple[int, ...], ...] # length = m + """ + m, n = pcm.shape + vars_, checks_ = [[] for _ in range(n)], [[] for _ in range(m)] + + for c in range(m): + for v in range(n): + if pcm[c, v]: + vars_[v].append(c) + checks_[c].append(v) + + return tuple(map(tuple, vars_)), tuple(map(tuple, checks_)) + + +###################################################################### +# A nice way to visulaize this Tanner graph is using the ``networkx`` package. Below is an example on +# the Steane code. +# + +import matplotlib.pyplot as plt +import networkx as nx + +vars, checks = _build_graph(H_steane) +G = nx.Graph() +num_vars = len(vars) +num_checks = len(checks) + +# build the nx graph object from our vars and checks +for v in range(num_vars): + G.add_node(f"v{v}", bipartite=0) +for c in range(num_checks): + G.add_node(f"c{c}", bipartite=1) +for c in range(num_checks): + for v in checks[c]: + G.add_edge(f"c{c}", f"v{v}") + +pos = nx.bipartite_layout(G, nodes=[f"v{i}" for i in range(num_vars)]) + +plt.figure(figsize=(10, 7)) +nx.draw(G, pos, with_labels=True, node_color="skyblue", node_size=500, font_weight="bold") +plt.title("Bipartite Graph for H_steane", fontsize=16) +plt.show() + +###################################################################### +# The ``_c2v_update`` helper function performs one full sweep of check‑to‑variable updates (step 3 of the +# sum‑product algorithm). It takes the previous messages, the syndrome, the neighbor tables, and two +# scalars (``L_int`` for the intrinsic log‑likelihood ratio and ``eps`` for numerical safety). It +# loops only over existing edges, multiplies the relevant :math:`\operatorname{tanh}` terms, clips +# the product, applies :math:`\operatorname{arctanh}`, and writes the new message into the next +# matrix. +# + + +def _c2v_update( + m_c2v_prev: ArrayLike, + syndrome: ArrayLike, + var_nei: tuple[tuple[int, ...], ...], + check_nei: tuple[tuple[int, ...], ...], + L_int: float, + eps: float, +) -> ArrayLike: + """ + Compute the next round of check‑to‑variable messages. + """ + m, n = m_c2v_prev.shape + m_c2v_next = jnp.zeros_like(m_c2v_prev) + + # Loop over checks (outer) then their vars (inner) + for c in range(m): + Vc = check_nei[c] + if len(Vc) < 2: + continue # degree‑1 checks carry no new info + + for v in Vc: + prod = 1.0 + # product over all v' ≠ v in this check + for v_p in Vc: + if v_p == v: + continue + incoming = L_int + for c_p in var_nei[v_p]: + if c_p != c: + incoming += m_c2v_prev[c_p, v_p] + prod *= jnp.tanh(incoming / 2.0) + + prod = jnp.clip(prod, -1.0 + eps, 1.0 - eps) + msg = ((-1) ** syndrome[c]) * 2.0 * jnp.arctanh(prod) + m_c2v_next = m_c2v_next.at[c, v].set(msg) + + return m_c2v_next + + +###################################################################### +# Once the main loop finishes, we still need a hard decision. The function ``_posterior_llrs`` folds every final +# check‑to‑variable message for bit ``v`` into its intrinsic LLR, yielding the posterior belief for +# that bit. A negative value means “error likely,” a positive value means “no error.” +# + +def _posterior_llrs( + m_c2v_final: ArrayLike, var_nei: tuple[tuple[int, ...], ...], L_int: float +) -> ArrayLike: + """ + Combine intrinsic LLR with all incoming messages. + """ + n = m_c2v_final.shape[1] + llr = jnp.full(n, L_int) + for v in range(n): + for c in var_nei[v]: + llr = llr.at[v].add(m_c2v_final[c, v]) + return llr + + +###################################################################### +# ``build_bp_decoder`` serves as the main entry point for compiling our decoder. It takes the parity‑check +# matrix and channel error rate, builds the parity graph, pre‑computes the intrinsic LLR, and returns +# a JIT‑compiled function ``_decode``. +# +# Inside ``_decode``, the following steps are executed: +# +# 1. All messages are zero-initialized. +# 2. ``_c2v_update`` is called inside a ``jax.lax.fori_loop`` for ``max_iter`` rounds. +# 3. Final messages are converted to posterior LLRs with ``_posterior_llrs``. +# 4. A binary error vector is output by thresholding the LLRs at zero. +# +# Because the whole ``_decode`` body is wrapped in ``@jax.jit``, the first call compiles everything +# into an XLA kernel; subsequent calls run at full device speed. +# + + +def build_bp_decoder( + parity_check_matrix: ArrayLike, + error_rate: float, + max_iter: int = 10, + epsilon: float = 1e-9, +) -> Callable[[ArrayLike], ArrayLike]: + """ + Return a JIT‑compiled BP decoder for the given code and channel. + + Parameters + ---------- + parity_check_matrix : array‑like (m, n) + error_rate : float # BSC crossover probability p + max_iter : int + epsilon : float # numerical safety margin + """ + pcm = jnp.asarray(parity_check_matrix, dtype=jnp.int32) + m, n = pcm.shape + L_int = jnp.log((1.0 - error_rate) / error_rate) + + var_nei, check_nei = _build_graph(pcm) + + @jax.jit + def _decode(syndrome: ArrayLike) -> ArrayLike: + syndrome = jnp.asarray(syndrome, dtype=jnp.int32) + + # Initialise all messages to zero + m_c2v = jnp.zeros((m, n), dtype=jnp.float32) + + # BP loop + def body(_, msgs): + return _c2v_update(msgs, syndrome, var_nei, check_nei, L_int, epsilon) + + m_c2v = jax.lax.fori_loop(0, max_iter, body, m_c2v) + + # Hard decision from posterior LLRs + llr = _posterior_llrs(m_c2v, var_nei, L_int) + return (llr < 0).astype(jnp.int32) + + # optionally we can force our decoder to compile right away by calling it on a test input + _decode(jnp.zeros(m, dtype=jnp.int32)) + + return _decode + + +###################################################################### +# Let’s test the performance of the BP decoder on the Steane code compared to the LUT decoder. +# + +bp_steane = build_bp_decoder(H_steane, error_rate=0.05, max_iter=7) + +n_bits = H_steane.shape[0] +correct = 0 +total_syndromes = 2**n_bits + +table_data = [] +headers = ["Syndrome", "BP Estimated Error", "LUT Exact Error", "Match"] + +for i in range(total_syndromes): + syndrome_binary_string = f"{i:0{n_bits}b}" + s_array = jnp.array([int(x) for x in syndrome_binary_string]) + + # Get error patterns from BP decoder and LUT + bp_pattern = bp_steane(s_array) + lut_pattern = lut_steane(s_array) + match = jnp.all(bp_pattern == lut_pattern) + + bp_pattern_str = "".join(map(str, bp_pattern.tolist())) + lut_pattern_str = "".join(map(str, lut_pattern.tolist())) + + table_data.append([syndrome_binary_string, bp_pattern_str, lut_pattern_str, str(match)]) + + # Increment correct count if patterns match + if match: + correct += 1 + +print(tabulate(table_data, headers=headers)) + +# Calculate and print the BP accuracy +accuracy = (correct / total_syndromes) * 100 if total_syndromes > 0 else 0 +print(f"\nBP Accuracy: {accuracy:.2f}%") + +###################################################################### +# Before diving into the code, let’s test our belief‑propagation (BP) decoder on a bigger example: the +# `n‑bit repetition code `__. This code stores each +# logical bit by repeating it :math:`n` times (e.g. :math:`0 \mapsto 00\ldots0` and +# :math:`1 \mapsto 11\ldots1`). Its parity‑check matrix consists of :math:`n-1` rows, each enforcing +# that two neighbouring bits are equal. Below, we measure how often the BP decoder corrects random +# errors on a 50‑bit repetition code and compare its success rate to an optimal maximum‑likelihood +# (ML) decoder, which simply picks the lower‑weight error pattern consistent with the observed +# syndrome. +# + + +def rep_code(n: int) -> ArrayLike: + """ + Build the (n − 1) × n parity‑check matrix H for the [n, 1] repetition code. + + Each row enforces equality between two neighboring bits: + H[i] has 1s in positions i and i+1, zeros elsewhere. + """ + # First row: parity check on bits 0 and 1 → [1, 1, 0, 0, …, 0] + first_row = jnp.zeros(n, dtype=jnp.int8).at[jnp.array([0, 1])].set(1) + rows = [first_row] + + # Remaining rows: slide the two‑bit “window” to the right + for _ in range(n - 2): + rows.append(jnp.roll(rows[-1], 1)) # shift previous row by 1 position + + return jnp.stack(rows) # shape = (n‑1, n) + + +@jax.jit +def ml_rep_decoder(syndrome: ArrayLike) -> ArrayLike: + """ + Minimum‑weight decoder for the repetition code. + + Parameters + ---------- + syndrome : ArrayLike, shape (n‑1,) + The syndrome s = H e (mod 2). + + Returns + ------- + error : ArrayLike, shape (n,) + A lowest‑weight error vector consistent with `syndrome`. + """ + # Candidate 1: assume e[0] = 0, then recover the rest via cumulative XOR. + # e[k+1] = e[k] ⊕ s[k] ⇒ e = [0, cumsum(s) mod 2] + e0 = jnp.concatenate((jnp.array([0], dtype=jnp.int32), jnp.mod(jnp.cumsum(syndrome), 2))) + + # Candidate 2: flip every bit (equivalent to choosing e[0] = 1). + e1 = (e0 + 1) & 1 # fast “add‑one then mod 2” + + # Compare Hamming weights. + w0, w1 = jnp.sum(e0), jnp.sum(e1) + + # Return the lighter candidate (ties resolved in favour of e0). + return jax.lax.cond(w0 <= w1, lambda _: e0, lambda _: e1, operand=None) + + +###################################################################### +# We run a short experiment on a :math:`50` bit repetition code. We sample 10,000 random syndromes +# vectors and compute the accuracy of our BP decoder compared to our baseline ``ml_rep_decoder`` +# + +H_rep = rep_code(n := 50) +bp_rep = build_bp_decoder(parity_check_matrix=H_rep, error_rate=0.1, max_iter=n) + +# sample random syndromes +N = 10_000 +key = jax.random.PRNGKey(0) +syndromes = jax.random.randint(key, shape=(N, n - 1), minval=0, maxval=2) + +# use jax to map the decoder over the syndromes +# since our decoders are jit compiled jax functions they can be used with jax.vmap +success_rate = jnp.mean( + jnp.all(jax.vmap(ml_rep_decoder)(syndromes) == jax.vmap(bp_rep)(syndromes), axis=1) +) + +print(f"Decoding success rate: {success_rate * 100:.2f}%") + +###################################################################### +# Catalyst hybrid kernel +# ------------------------ +# +# Now that we understand a good chunk of theory behind CSS codes, the Steane code and decoding +# algorithms, let’s put this into action with Catalyst! +# +# Catalyst lets us build hybrid quantum-classical workflows, compiling both quantum operations and +# classical decoding logic into a single, efficient kernel. We’ll start with a quantum-classical +# circuit to prepare the logical zero state :math:`|0\rangle_L` for our Steane code. This method is +# also general for initializing logical zero states for any CSS codes. +# +# Start with a :math:`+1` eigenstate (or stabilizer state) of all the :math:`Z` type stabilizers. +# The :math:`|0\ldots 0\rangle` is always stabilized by any :math:`Z` type Pauli operator, making it a +# suitable choice. +# +# Then, for each X-type generator: +# - Prepare an ancilla qubit in the :math:`|+\rangle` state. +# - Measure X-type stabilizers using CNOT operations onto an ancilla. +# - Measure in the :math:`X` basis. +# +# Next: +# - Use measurement outcomes (syndromes) to determine necessary corrections using our decoder. +# - Apply Z-type corrections based on decoding results. +# +# This procedure uses projective measurements to force the data qubits to be in the :math:`+1` +# eigenstate of our :math:`X` type generators. Since the state was already a :math:`+1` eigenstate +# of our :math:`Z` type generators, and by virtue of the CSS code all :math:`X` and :math:`Z` +# generators simultaneously commute, we are left with a state in the :math:`+1` eigenspace of all the +# generators. +# + +import pennylane as qml +from jax import random +import catalyst + +r, n = H_steane.shape +n_wires = n + r + +dev = qml.device("lightning.qubit", wires=n_wires) + + +def measure_x_stabilizers(H: ArrayLike): + """ + Measure all X type stabilizers based on the parity check matrix X then apply Z type corrections from our decoder + :param H: Parity check X matrix + """ + r, n = H.shape + + # Encode logical |0> + # (Hadamard on ancillas, controlled X stabilizers) + for a in range(r): + qml.H(wires=n + a) + for a, row in enumerate(H): + for q, x in enumerate(row): + if x: + qml.CNOT(wires=[n + a, q]) + for a in range(r): + qml.H(wires=n + a) + + # Measure + reset ancillas (X stabilizers) + sx = jnp.stack([catalyst.measure(n + a) for a in range(r)]) + for a, bit in enumerate(sx): + if bit: + qml.PauliX(wires=n + a) # reset ancilla + + # Z‑correction + # Since the BP and LUT decoder + # we're both perfect on the Steane code + # well use the LUT for simplicity + rec_z = lut_steane(sx) + for q, bit in enumerate(rec_z): + if bit: + qml.PauliZ(wires=q) + + +@qml.qjit(autograph=True) +@qml.qnode(dev) +def encode_zero_steane(): + measure_x_stabilizers(H_steane) + return qml.state() + + +###################################################################### +# A simple utility function to display the state +# + +from pprint import pprint + +def state_vector_to_dict( + sv: ArrayLike, + wires: Optional[Sequence[int]], + eps: float = 1e-8, + probability: bool = False, + display: bool = True, +) -> Dict[str, Union[float, complex]]: + """ + Convert a state vector into {bitstring: amplitude | probability}. + """ + n_qubits = int(jnp.log2(len(sv))) + + + out: Dict[str, Union[float, complex]] = {} + + for idx, amp in enumerate(sv): + mag = jnp.abs(amp) ** 2 if probability else jnp.abs(amp) + if mag <= eps: + continue + + bitstring = f"{idx:0{n_qubits}b}" + key = "".join(b for i, b in enumerate(bitstring) if wires is None or i in wires) + + if probability: + out[key] = out.get(key, 0.0) + float(mag) + else: + out[key] = amp.item() + + if display: + pprint(out) + + return out + + +###################################################################### +# We run the ``encode_zero`` function and see that we recover the `correct logical zero state for the +# Steane code `__: +# +# .. math:: \begin{aligned}|\overline{0}\rangle= & \frac{1}{\sqrt{8}}(|0000000\rangle+|1010101\rangle+|0110011\rangle+|1100110\rangle \\ & +|0001111\rangle+|1011010\rangle+|0111100\rangle+|1101001\rangle)\end{aligned} +# + +sv_clean = encode_zero_steane() +state_vector_to_dict(sv_clean, display=True, wires=range(n)) + +###################################################################### +# Simulating Errors and Full Correction +# ------------------------------------- +# +# We’re now ready to wrap everything together: +# +# - Prepare the zero state. +# - Simulate noise using a depolarizing channel. +# - Perform one complete round of stabilizer measurements and corrections. +# + + +def noise_channel(n: int, p_err: float, key: random.PRNGKey): + """ + Apply a single‑qubit Pauli noise channel independently to each of `n` qubits. + + For every qubit the channel does: + 0 → I with probability 1 - p_err + 1 → X with probability p_err / 3 + 2 → Z with probability p_err / 3 + 3 → Y with probability p_err / 3 + """ + probs = jnp.array([1.0 - p_err, p_err / 3, p_err / 3, p_err / 3]) + outcomes = random.choice(key, 4, shape=(n,), p=probs) + + for idx, outcome in enumerate(outcomes): + if outcome == 1: + qml.X(wires=idx) + elif outcome == 2: + qml.Z(wires=idx) + elif outcome == 3: + qml.Y(wires=idx) + + +# this is a helper function to get the specific error we used in a given round based on the key +def get_error(n: int, p_err: float, key: random.PRNGKey): + err = [] + probs = jnp.array([1.0 - p_err, p_err / 3, p_err / 3, p_err / 3]) + outcomes = random.choice(key, 4, shape=(n,), p=probs) + + for idx, outcome in enumerate(outcomes): + if outcome == 1: + err.append(qml.X(wires=idx)) + elif outcome == 2: + err.append(qml.Z(wires=idx)) + elif outcome == 3: + err.append(qml.Y(wires=idx)) + return qml.ops.prod(*err) + + +###################################################################### +# Similar to ``measure_x_stabilizers``, however, we now apply CNOT from data to an ancilla prepared +# in the :math:`|0\rangle` state and perform a :math:`Z`-basis measurement. +# + + +def measure_z_stabilizers(H): + r, n = H.shape + for a, row in enumerate(H): + for q, x in enumerate(row): + if x: + qml.CNOT(wires=[q, n + a]) + + sz = jnp.stack([catalyst.measure(n + a) for a in range(r)]) + for a, bit in enumerate(sz): + if bit: + qml.PauliX(wires=n + a) + + rec_x = lut_steane(sz) + for q, bit in enumerate(rec_x): + if bit: + qml.PauliX(wires=q) + + +###################################################################### +# Now, let's run the ``qec_round`` using state preparation, followed by one round of noise injection and one round of :math:`X` +# and :math:`Z` correction. We’ll print the error that occurred in our noisy channel and demonstrate that the output state closely resembles the noiseless state we observed previously. +# + + +@qml.qjit(autograph=True) +@qml.qnode(dev, interface="jax") +def qec_round(H: ArrayLike, p_err=1e-3, key=random.PRNGKey(0)): + """One round of Steane code QEC with LUT decoding.""" + + measure_x_stabilizers(H) # prepare 0 state + noise_channel(n, p_err, key) # inject IID pauli noise + measure_x_stabilizers(H) # correct X errors + measure_z_stabilizers(H) # correct Z errors + + return qml.state() + + +p_err = 0.1 +key = random.PRNGKey(10) +print(f"Running Steane Code QEC Round with error: {get_error(n, p_err=p_err, key=key)}") +state_vector_to_dict(qec_round(H_steane, p_err=p_err, key=key), display=True, wires=range(n)) + +###################################################################### +# If we increase the likelihood of errors, we are more likely to end up with an error pattern that +# can’t be corrected. +# + +p_err = 0.3 +key = random.PRNGKey(8) +print(f"Running Steane Code QEC Round with error: {get_error(n, p_err=p_err, key=key)}") +state_vector_to_dict(qec_round(H_steane, p_err=p_err, key=key), display=True, wires=range(n)) + +###################################################################### +# Benchmarking logical vs. physical error rates +# --------------------------------------------- +# +# In the final section of this demo, we will compute the average performance of our Steane code error +# correction circuit for a range of possible error rates. We’ll define logical error rates by +# comparing the state vector from our noisy simulation with the clean state vector ``sv_clean`` of +# the Steane code logical zero state. +# + +import pandas as pd +import seaborn as sns +import matplotlib.pyplot as plt +from tqdm import tqdm + +clean_idx = jnp.where(sv_clean)[0] + + +def logical_error(sv): + st = sv[clean_idx] + return 1 - jnp.all(jnp.isclose(st / st[0], 1)) + + +###################################################################### +# Simulate 1000 noisy shots for several noise levels, we’ll use ``jax.vmap`` to efficiently map our +# catalyst kernel over a set of random keys. +# + + +@jax.jit +def single_trial_error(key, p_err, H_steane): + """Performs one QEC round and checks for a logical error.""" + round_output = qec_round(H_steane, p_err, key) + err = logical_error(round_output) + return err + + +batch_trial_errors = jax.vmap(single_trial_error, in_axes=(0, None, None)) + +N = 1000 +p_rng = 2 ** jnp.arange(-5, -1.75, 0.25, dtype=jnp.float32) +res = [] +master_key = random.PRNGKey(0) + +for p in tqdm(p_rng): + keys_for_batch, master_key = random.split(master_key) + all_keys = random.split(keys_for_batch, N) + + errors_batch = batch_trial_errors(all_keys, p, H_steane) + + p_value = p.item() + for idx, err in enumerate(errors_batch): + res.append({"p": p_value, "seed": idx, "err": err.item()}) + +df = pd.DataFrame(res) + +###################################################################### +# Plot the results using seaborn +# + +p_rng_min = p_rng[0] +p_rng_max = p_rng[-1] + +sns.set_theme(style="whitegrid", context="talk") + +plt.figure(figsize=(10, 7)) +sns.lineplot( + data=df, + x="p", + y="err", + marker="o", + markersize=8, + linewidth=2.5, + label="Simulated Logical Error Rate", +) + +plt.plot( + [p_rng_min, p_rng_max], + [p_rng_min, p_rng_max], + linestyle="--", + color="gray", + linewidth=1.5, + label="$p_{physical} = p_{logical}$", # Label for legend +) +plt.xlabel("Physical Error Rate ($p$)", fontsize=16) +plt.ylabel("Logical Error Rate ($P_L$)", fontsize=16) +plt.xscale("log", base=2) +plt.yscale("log", base=2) +plt.title("Logical vs. Physical Error Rate", fontsize=18, pad=20) +plt.legend(fontsize=14) +plt.grid(True, which="both", ls="--", c="lightgray", alpha=0.7) # 'both' for major and minor ticks +sns.despine() +plt.tight_layout() +plt.show() + +###################################################################### +# Conclusion and Limitations +# -------------------------- +# +# In this tutorial, we successfully built, simulated, and decoded a simple quantum error correction +# cycle using the Steane code. We demonstrated encoding a logical qubit, introduced errors through +# noise simulation, and performed error correction using stabilizer measurements combined with +# classical decoding. Performance was benchmarked by measuring the logical versus physical error +# rates. +# +# However, our approach relied on a significant simplifying assumption known as the code capacity model, +# where errors are assumed to occur at only one stage of the circuit, with otherwise perfect encoding and +# syndrome extraction. A more realistic approach—called circuit-level noise—accounts for +# potential errors at every gate and measurement within the circuit. This model significantly +# complicates decoding, as it requires mapping every possible error location not only in space but +# also across multiple syndrome measurement rounds, forming a complex space-time hypergraph. Decoding +# then involves interpreting error events over both spatial and temporal dimensions. +# +# Nevertheless, the fundamental decoding principles explored here, particularly the Belief Propagation +# algorithm, remain highly relevant. BP is flexible enough to operate effectively on more comprehensive +# circuit-level decoding hypergraphs. +# + +###################################################################### +# References +# ---------- +# +# .. [#Pesah] Pesah, Arthur. “The stabilizer trilogy I — Stabilizer codes.” Arthur Pesah, 31 +# Jan. 2023, https://arthurpesah.me/blog/2023-01-31-stabilizer-formalism-1/. +# +# .. [#Wiberg] Wiberg, Niclas. (2001). Codes and Decoding on General Graphs. +# https://www.essrl.wustl.edu/~jao/itrg/wiberg.pdf +# +# .. [#Panteleev] Panteleev, Pavel. “Degenerate Quantum LDPC Codes With Good Finite Length +# Performance.” arXiv.org, 04 Apr. 2019, https://arxiv.org/abs/1904.02703v3. +# +# .. [#Hillmann] Hillmann, Timo. “Localized statistics decoding: A parallel decoding algorithm for +# quantum low-density parity-check codes.” arXiv.org, 26 Jun. 2024, +# https://arxiv.org/abs/2406.18655v1. +# +# .. [#Wolanski] Wolanski, Stasiu. “Ambiguity Clustering: an accurate and efficient decoder for qLDPC +# codes.” arXiv.org, 20 Jun. 2024, https://arxiv.org/abs/2406.14527v2. +# +# .. [#Loeliger] Loeliger, Hans-Andrea. “An introduction to factor graphs” in IEEE Signal Processing +# Magazine, vol. 21, no. 1, pp. 28-41, Jan. 2004, +# https://www.isiweb.ee.ethz.ch/papers/arch/aloe-2004-spmagffg.pdf. +# +# .. [#Barber] Barber, David. “Bayesian Reasoning and Machine Learning”. Cambridge University Press, +# USA. 2012, http://web4.cs.ucl.ac.uk/staff/D.Barber/textbook/180325.pdf#page=107.50 +# + +###################################################################### +# About the author +# ---------------- +#