Skip to content
Open
4 changes: 3 additions & 1 deletion .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

### New features

* Adds the `beamsplitter_stable` function to the `fock_gradients` module [(#405)](https://github.com/XanaduAI/thewalrus/pull/405).

### Breaking changes

### Improvements
Expand All @@ -16,7 +18,7 @@

This release contains contributions from (in alphabetical order):

L.G. Helt
L.G. Helt, F. Miatto

---

Expand Down
24 changes: 22 additions & 2 deletions thewalrus/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,28 @@ def takagi(A, svd_order=True, rtol=1e-16):
return l[::-1], U[:, ::-1]
return l, U

u, d, v = np.linalg.svd(A)
U = u @ sqrtm((v @ np.conjugate(u)).T)
u, d, vh = np.linalg.svd(A)
z = vh @ u.conj()
# Use Schur decomposition for unitary matrix (which is normal)
# For normal matrices, Schur form is diagonal with eigenvalues on diagonal
T, Q = schur(z, output="complex")
z_eigvals = np.diag(T)
# Get sorted z angles in [0, 2π)
z_angles = np.sort(np.unique(np.mod(np.angle(z_eigvals), 2 * np.pi)))
# Get midpoint of largest arc
z_diffs = np.diff(z_angles, append=z_angles[0] + 2 * np.pi)
idx = np.argmax(z_diffs)
mid = z_angles[idx] + 0.5 * z_diffs[idx]
# Get shift angle in (-pi, pi]
shift_angle = np.mod(-mid, 2 * np.pi) - np.pi
# Rotate eigenvalues to shift midpoint of largest arc to ±pi
z_eigvals_shifted = z_eigvals * np.exp(1j * shift_angle)
# Compute sqrt of eigenvalues directly (avoiding sqrtm)
sqrt_z_eigvals_shifted = np.sqrt(z_eigvals_shifted)
# Reconstruct: sqrtm(z_shifted) = Q @ diag(sqrt(eigvals)) @ Q.H
sqrt_z_shifted = Q @ np.diag(sqrt_z_eigvals_shifted) @ Q.conj().T
# Undo rotation from ±pi
U = u @ sqrt_z_shifted @ np.diag(np.exp(-0.5j * shift_angle * np.ones(n)))
if svd_order is False:
return d[::-1], U[:, ::-1]
return d, U
Expand Down
125 changes: 96 additions & 29 deletions thewalrus/fock_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
displacement
squeezing
beamsplitter
beamsplitter_stable
two_mode_squeezing
mzgate
grad_displacement
Expand All @@ -39,6 +40,7 @@
Code details
------------
"""

import numpy as np

from numba import jit
Expand Down Expand Up @@ -301,55 +303,120 @@ def grad_two_mode_squeezing(T, r, theta): # pragma: no cover
return grad_r, grad_theta


@jit(nopython=True)
def beamsplitter(theta, phi, cutoff, dtype=np.complex128): # pragma: no cover
def beamsplitter(theta, phi, cutoff, dtype=np.complex128):
r"""Calculates the Fock representation of the beamsplitter.

Args:
theta (float): transmissivity angle of the beamsplitter. The transmissivity is :math:`t=\cos(\theta)`
phi (float): reflection phase of the beamsplitter
cutoff (int): Fock ladder cutoff
cutoff (int or tuple): Fock ladder cutoff. If int, uses (cutoff, cutoff, cutoff, cutoff) shape.
If tuple, uses the provided shape directly.
dtype (data type): Specifies the data type used for the calculation

Returns:
array[float]: The Fock representation of the gate
"""
sqrt = np.sqrt(np.arange(cutoff, dtype=dtype))
# Determine shape based on whether cutoff is an int or tuple
if isinstance(cutoff, tuple):
shape = cutoff
elif isinstance(cutoff, int):
shape = (cutoff, cutoff, cutoff, cutoff)
else:
raise ValueError(f"Invalid cutoff type: {type(cutoff)}")

return _beamsplitter_stable(theta, phi, shape, dtype)


SQRT = np.sqrt(np.arange(1000))
_SQRT = np.sqrt(np.arange(1000))
_SQRT[0] = 1.0 # to avoid division by zero
INV_SQRT = 1 / _SQRT


@jit(nopython=True)
def _beamsplitter_stable(
theta, phi, shape, dtype=np.complex128
): # pragma: no cover # pylint: disable=too-many-branches
r"""
Stable implementation of the Fock representation of the beamsplitter that
averages contributions from all available pivots for ecah amplitude.
It is numerically stable up to arbitrary cutoffs (or you will likely
run out of memory before incurring in numerical issues).
The shape order is (out_0, out_1, in_0, in_1), assuming it acts on modes 0 and 1.

Args:
theta (float): beamsplitter angle
phi (float): beamsplitter phase
shape (tuple[int, int, int, int]): shape of the Fock representation
dtype (data type): Specifies the data type used for the calculation

Returns:
array (ComplexTensor): The Fock representation of the gate
"""
ct = np.cos(theta)
st = np.sin(theta) * np.exp(1j * phi)
R = np.array(
[
[0, 0, ct, -np.conj(st)],
[0, 0, st, ct],
[ct, st, 0, 0],
[-np.conj(st), ct, 0, 0],
]
)
stc = np.conj(st)

Z = np.zeros((cutoff, cutoff, cutoff, cutoff), dtype=dtype)
Z[0, 0, 0, 0] = 1.0
M, N, P, Q = shape
G = np.zeros(shape, dtype=dtype)
G[0, 0, 0, 0] = 1.0 + 0.0j

# rank 3
for m in range(cutoff):
for n in range(cutoff - m):
for m in range(M):
for n in range(min(N, P - m)):
p = m + n
if 0 < p < cutoff:
Z[m, n, p, 0] = (
R[0, 2] * sqrt[m] / sqrt[p] * Z[m - 1, n, p - 1, 0]
+ R[1, 2] * sqrt[n] / sqrt[p] * Z[m, n - 1, p - 1, 0]
val = 0
pivots = 0
if m > 0: # pivot at (m-1, n, p, 0)
val += ct * SQRT[p] * INV_SQRT[m] * G[m - 1, n, p - 1, 0]
pivots += 1
if n > 0: # pivot at (m, n-1, p, 0)
val += st * SQRT[p] * INV_SQRT[n] * G[m, n - 1, p - 1, 0]
pivots += 1
if p > 0: # pivot at (m, n, p-1, 0)
val += (
ct * SQRT[m] * INV_SQRT[p] * G[m - 1, n, p - 1, 0]
+ st * SQRT[n] * INV_SQRT[p] * G[m, n - 1, p - 1, 0]
)
pivots += 1
if m > 0 or n > 0 or p > 0:
G[m, n, p, 0] = val / pivots

# rank 4
for m in range(cutoff):
for n in range(cutoff):
for p in range(cutoff):
for m in range(M):
for n in range(N):
for p in range(max(0, m + n - Q), min(P, m + n)):
q = m + n - p
if 0 < q < cutoff:
Z[m, n, p, q] = (
R[0, 3] * sqrt[m] / sqrt[q] * Z[m - 1, n, p, q - 1]
+ R[1, 3] * sqrt[n] / sqrt[q] * Z[m, n - 1, p, q - 1]
)
return Z
if 0 < q < Q:
val = 0
pivots = 0
if m > 0:
val += (
ct * SQRT[p] * INV_SQRT[m] * G[m - 1, n, p - 1, q]
- stc * SQRT[q] * INV_SQRT[m] * G[m - 1, n, p, q - 1]
)
pivots += 1
if n > 0:
val += (
st * SQRT[p] * INV_SQRT[n] * G[m, n - 1, p - 1, q]
+ ct * SQRT[q] * INV_SQRT[n] * G[m, n - 1, p, q - 1]
)
pivots += 1
if p > 0:
val += (
ct * SQRT[m] * INV_SQRT[p] * G[m - 1, n, p - 1, q]
+ st * SQRT[n] * INV_SQRT[p] * G[m, n - 1, p - 1, q]
)
pivots += 1
if q > 0:
val += (
-stc * SQRT[m] * INV_SQRT[q] * G[m - 1, n, p, q - 1]
+ ct * SQRT[n] * INV_SQRT[q] * G[m, n - 1, p, q - 1]
)
pivots += 1
if m > 0 or n > 0 or p > 0 or q > 0:
G[m, n, p, q] = val / pivots
return G


@jit(nopython=True)
Expand Down
14 changes: 7 additions & 7 deletions thewalrus/tests/test_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,16 +256,16 @@ def test_transform(self, passive, create_transform, tol):


@pytest.mark.parametrize("n", [5, 10, 50])
@pytest.mark.parametrize("datatype", [np.complex128, np.float64])
@pytest.mark.parametrize("imag_part", [0, 1e-10, 1])
@pytest.mark.parametrize("svd_order", [True, False])
def test_takagi(n, datatype, svd_order):
"""Checks the correctness of the Takagi decomposition function"""
if datatype is np.complex128:
A = np.random.rand(n, n) + 1j * np.random.rand(n, n)
if datatype is np.float64:
A = np.random.rand(n, n)
def test_takagi(n, imag_part, svd_order):
"""Checks the correctness of the Takagi decomposition function for generic random matrices"""
A = np.random.rand(n, n)
if imag_part > 0:
A = A + 1j * imag_part * np.random.rand(n, n)
A += A.T
r, U = takagi(A, svd_order=svd_order)
assert np.allclose(np.eye(n, n), U @ U.T.conj())
assert np.allclose(A, U @ np.diag(r) @ U.T)
assert np.all(r >= 0)
if svd_order is True:
Expand Down
19 changes: 16 additions & 3 deletions thewalrus/tests/test_fock_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,9 @@ def test_S2_selection_rules(tol):

def test_beamsplitter_values(tol):
r"""Test that the representation of an interferometer in the single
excitation manifold is precisely the unitary matrix that represents it
mode in space. This test in particular checks that the BS gate is
consistent with strawberryfields
excitation manifold is precisely the unitary matrix that represents it.
This test in particular checks that the BS gate is consistent
with strawberryfields
"""
nmodes = 2
vec_list = np.identity(nmodes, dtype=int).tolist()
Expand All @@ -306,6 +306,19 @@ def test_beamsplitter_values(tol):
assert np.allclose(U, U_rec, atol=tol, rtol=0)


def test_beamsplitter_stability():
r"""Tests the stability of the beamsplitter operation"""
theta = np.random.rand()
phi = np.random.rand()
cutoff = 70
stable = beamsplitter(theta, phi, cutoff)
assert np.isclose(np.max(np.abs(stable)), 1.0, atol=1e-5, rtol=0)
assert stable.dtype == np.complex128
stable = beamsplitter(theta, phi, cutoff, dtype=np.complex64)
assert np.isclose(np.max(np.abs(stable)), 1.0, atol=1e-5, rtol=0)
assert stable.dtype == np.complex64


def test_mzgate_values(tol):
r"""Test that the representation of an interferometer in the single
excitation manifold is precisely the unitary matrix that represents it
Expand Down