Skip to content

Commit 2c771cc

Browse files
Enable sparse support in TorchConnector and other minor updates (#571)
* minor updates * update imports * updates * working version * some test optimizations * unpack * some tests * more tests * fix black, lint, mypy * add reno * fix copyright, spell, black * fix copyright, fix initial_weights * no sparse support on 3.7 * more skips * more skips * add quotes * Update qiskit_machine_learning/kernels/quantum_kernel.py Co-authored-by: Steve Wood <40241007+woodsp-ibm@users.noreply.github.com> * rollback spelling changes * update reno * code review * update tests * fix copyright * update reno * Update releasenotes/notes/add-sparse-torch-connector-a3b9e3d50b405a01.yaml Co-authored-by: Steve Wood <40241007+woodsp-ibm@users.noreply.github.com> * update reno more --------- Co-authored-by: Steve Wood <40241007+woodsp-ibm@users.noreply.github.com>
1 parent 9596dc4 commit 2c771cc

File tree

5 files changed

+435
-161
lines changed

5 files changed

+435
-161
lines changed

.pylintdict

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,10 @@ discretize
6666
discretized
6767
discriminative
6868
distro
69+
dok
6970
dt
7071
eigenstates
72+
einsum
7173
endian
7274
entangler
7375
estimatorqnn

qiskit_machine_learning/connectors/torch_connector.py

Lines changed: 131 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# This code is part of Qiskit.
22
#
3-
# (C) Copyright IBM 2021, 2022.
3+
# (C) Copyright IBM 2021, 2023.
44
#
55
# This code is licensed under the Apache License, Version 2.0. You may
66
# obtain a copy of this license in the LICENSE.txt file in the root directory
@@ -11,18 +11,24 @@
1111
# that they have been altered from the originals.
1212

1313
"""A connector to use Qiskit (Quantum) Neural Networks as PyTorch modules."""
14+
from __future__ import annotations
15+
16+
import sys
17+
from typing import Tuple, Any, cast
1418

15-
from typing import Tuple, Any, Optional, cast, Union
1619
import numpy as np
1720

1821
import qiskit_machine_learning.optionals as _optionals
19-
from ..neural_networks import NeuralNetwork
2022
from ..exceptions import QiskitMachineLearningError
23+
from ..neural_networks import NeuralNetwork
2124

2225
if _optionals.HAS_TORCH:
23-
from torch import Tensor, sparse_coo_tensor, einsum
26+
import torch
27+
28+
# imports for inheritance and type hints
29+
from torch import Tensor
2430
from torch.autograd import Function
25-
from torch.nn import Module, Parameter as TorchParam
31+
from torch.nn import Module
2632
else:
2733

2834
class Function: # type: ignore
@@ -75,6 +81,7 @@ def forward( # type: ignore
7581
7682
Raises:
7783
QiskitMachineLearningError: Invalid input data.
84+
RuntimeError: if connector is configured as sparse and the network is not sparse.
7885
"""
7986

8087
# validate input shape
@@ -94,15 +101,30 @@ def forward( # type: ignore
94101
result = neural_network.forward(
95102
input_data.detach().cpu().numpy(), weights.detach().cpu().numpy()
96103
)
97-
if neural_network.sparse and sparse:
98-
_optionals.HAS_SPARSE.require_now("COO")
99-
# pylint: disable=import-error
100-
from sparse import SparseArray, COO
104+
if ctx.sparse:
105+
if neural_network.sparse:
106+
_optionals.HAS_SPARSE.require_now("SparseArray")
107+
# pylint: disable=import-error
108+
from sparse import SparseArray, COO
101109

102-
result = cast(COO, cast(SparseArray, result).asformat("coo"))
103-
result_tensor = sparse_coo_tensor(result.coords, result.data)
110+
# todo: replace output type from DOK to COO?
111+
result = cast(COO, cast(SparseArray, result).asformat("coo"))
112+
result_tensor = torch.sparse_coo_tensor(result.coords, result.data)
113+
else:
114+
raise RuntimeError(
115+
"TorchConnector configured as sparse, the network must be sparse as well"
116+
)
104117
else:
105-
result_tensor = Tensor(result)
118+
# connector is dense
119+
if neural_network.sparse:
120+
# convert to dense
121+
_optionals.HAS_SPARSE.require_now("SparseArray")
122+
from sparse import SparseArray
123+
124+
# cast is required by mypy
125+
result = cast(SparseArray, result).todense()
126+
result_tensor = torch.from_numpy(result)
127+
result_tensor = result_tensor.to(input_data.dtype)
106128

107129
# if the input was not a batch, then remove the batch-dimension from the result,
108130
# since the neural network will always treat input as a batch and cast to a
@@ -124,6 +146,8 @@ def backward(ctx: Any, grad_output: Tensor) -> Tuple: # type: ignore
124146
grad_output: previous gradient
125147
Raises:
126148
QiskitMachineLearningError: Invalid input data.
149+
RuntimeError: if connector is configured as sparse and the network is not sparse.
150+
127151
Returns:
128152
gradients for the first two arguments and None for the others
129153
"""
@@ -132,10 +156,6 @@ def backward(ctx: Any, grad_output: Tensor) -> Tuple: # type: ignore
132156
input_data, weights = ctx.saved_tensors
133157
neural_network = ctx.neural_network
134158

135-
# if sparse output is requested return None, since PyTorch does not support it yet.
136-
if neural_network.sparse and ctx.sparse:
137-
return None, None, None, None
138-
139159
# validate input shape
140160
if input_data.shape[-1] != neural_network.num_inputs:
141161
raise QiskitMachineLearningError(
@@ -152,46 +172,84 @@ def backward(ctx: Any, grad_output: Tensor) -> Tuple: # type: ignore
152172
input_data.detach().cpu().numpy(), weights.detach().cpu().numpy()
153173
)
154174
if input_grad is not None:
155-
if neural_network.sparse:
156-
input_grad = sparse_coo_tensor(input_grad.coords, input_grad.data)
157-
158-
# cast to dense here, since PyTorch does not support sparse output yet.
159-
# this should only happen if the network returns sparse output but the
160-
# connector is configured to return dense output.
161-
input_grad = input_grad.to_dense() # this should be eventually removed
162-
input_grad = input_grad.to(grad_output.dtype)
175+
if ctx.sparse:
176+
if neural_network.sparse:
177+
_optionals.HAS_SPARSE.require_now("Sparse")
178+
import sparse
179+
from sparse import COO
180+
181+
grad_output = grad_output.detach().cpu()
182+
grad_coo = COO(grad_output.indices(), grad_output.values())
183+
184+
# Takes gradients from previous layer in backward pass (i.e. later layer in
185+
# forward pass) j for each observation i in the batch. Multiplies this with
186+
# the gradient from this point on backwards with respect to each input k.
187+
# Sums over all j to get total gradient of output w.r.t. each input k and
188+
# batch index i. This operation should preserve the batch dimension to be
189+
# able to do back-prop in a batched manner.
190+
# Pytorch does not support sparse einsum, so we rely on Sparse.
191+
# pylint: disable=no-member
192+
input_grad = sparse.einsum("ij,ijk->ik", grad_coo, input_grad)
193+
194+
# return sparse gradients
195+
input_grad = torch.sparse_coo_tensor(input_grad.coords, input_grad.data)
196+
else:
197+
# this exception should never happen
198+
raise RuntimeError(
199+
"TorchConnector configured as sparse, "
200+
"the network must be sparse as well"
201+
)
163202
else:
164-
input_grad = Tensor(input_grad).to(grad_output.dtype)
165-
166-
# Takes gradients from previous layer in backward pass (i.e. later layer in forward
167-
# pass) j for each observation i in the batch. Multiplies this with the gradient
168-
# from this point on backwards with respect to each input k. Sums over all j
169-
# to get total gradient of output w.r.t. each input k and batch index i.
170-
# This operation should preserve the batch dimension to be able to do back-prop in
171-
# a batched manner.
172-
input_grad = einsum("ij,ijk->ik", grad_output.detach().cpu(), input_grad)
203+
# connector is dense
204+
if neural_network.sparse:
205+
# convert to dense
206+
input_grad = input_grad.todense()
207+
input_grad = torch.from_numpy(input_grad)
208+
input_grad = input_grad.to(grad_output.dtype)
209+
# same as above
210+
input_grad = torch.einsum("ij,ijk->ik", grad_output.detach().cpu(), input_grad)
173211

174212
# place the resulting tensor to the device where they were stored
175213
input_grad = input_grad.to(input_data.device)
176214

177215
if weights_grad is not None:
178-
if neural_network.sparse:
179-
weights_grad = sparse_coo_tensor(weights_grad.coords, weights_grad.data)
180-
181-
# cast to dense here, since PyTorch does not support sparse output yet.
182-
# this should only happen if the network returns sparse output but the
183-
# connector is configured to return dense output.
184-
weights_grad = weights_grad.to_dense() # this should be eventually removed
185-
weights_grad = weights_grad.to(grad_output.dtype)
216+
if ctx.sparse:
217+
if neural_network.sparse:
218+
import sparse
219+
from sparse import COO
220+
221+
grad_output = grad_output.detach().cpu()
222+
grad_coo = COO(grad_output.indices(), grad_output.values())
223+
224+
# Takes gradients from previous layer in backward pass (i.e. later layer in
225+
# forward pass) j for each observation i in the batch. Multiplies this with
226+
# the gradient from this point on backwards with respect to each
227+
# parameter k. Sums over all i and j to get total gradient of output
228+
# w.r.t. each parameter k. The weights' dimension is independent of the
229+
# batch size.
230+
# pylint: disable=no-member
231+
weights_grad = sparse.einsum("ij,ijk->k", grad_coo, weights_grad)
232+
233+
# return sparse gradients
234+
weights_grad = torch.sparse_coo_tensor(
235+
weights_grad.coords, weights_grad.data
236+
)
237+
else:
238+
# this exception should never happen
239+
raise RuntimeError(
240+
"TorchConnector configured as sparse, "
241+
"the network must be sparse as well"
242+
)
186243
else:
187-
weights_grad = Tensor(weights_grad).to(grad_output.dtype)
188-
189-
# Takes gradients from previous layer in backward pass (i.e. later layer in forward
190-
# pass) j for each observation i in the batch. Multiplies this with the gradient
191-
# from this point on backwards with respect to each parameter k. Sums over all i and
192-
# j to get total gradient of output w.r.t. each parameter k.
193-
# The weights' dimension is independent of the batch size.
194-
weights_grad = einsum("ij,ijk->k", grad_output.detach().cpu(), weights_grad)
244+
if neural_network.sparse:
245+
# convert to dense
246+
weights_grad = weights_grad.todense()
247+
weights_grad = torch.from_numpy(weights_grad)
248+
weights_grad = weights_grad.to(grad_output.dtype)
249+
# same as above
250+
weights_grad = torch.einsum(
251+
"ij,ijk->k", grad_output.detach().cpu(), weights_grad
252+
)
195253

196254
# place the resulting tensor to the device where they were stored
197255
weights_grad = weights_grad.to(weights.device)
@@ -202,8 +260,8 @@ def backward(ctx: Any, grad_output: Tensor) -> Tuple: # type: ignore
202260
def __init__(
203261
self,
204262
neural_network: NeuralNetwork,
205-
initial_weights: Optional[Union[np.ndarray, Tensor]] = None,
206-
sparse: Optional[bool] = None,
263+
initial_weights: np.ndarray | Tensor | None = None,
264+
sparse: bool | None = None,
207265
):
208266
"""
209267
Args:
@@ -216,15 +274,29 @@ def __init__(
216274
sparse: Whether this connector should return sparse output or not. If sparse is set
217275
to None, then the setting from the given neural network is used. Note that sparse
218276
output is only returned if the underlying neural network also returns sparse output,
219-
otherwise it will be dense independent of the setting. Also note that PyTorch
220-
currently does not support sparse back propagation, i.e., if sparse is set to True,
221-
the backward pass of this module will return None.
277+
otherwise an error will be raised. Sparse support works on python
278+
3.8 or higher.
279+
280+
Raises:
281+
QiskitMachineLearningError: If the connector is configured as sparse and the underlying
282+
network is not sparse. Or if python version is 3.7.
222283
"""
223284
super().__init__()
224285
self._neural_network = neural_network
286+
if sparse is None:
287+
sparse = self._neural_network.sparse
288+
if sparse and sys.version_info < (3, 8):
289+
raise QiskitMachineLearningError("Sparse is supported on python 3.8+")
290+
225291
self._sparse = sparse
226292

227-
weight_param = TorchParam(Tensor(neural_network.num_weights))
293+
if self._sparse and not self._neural_network.sparse:
294+
# connector is sparse while the underlying neural network is not
295+
raise QiskitMachineLearningError(
296+
"TorchConnector configured as sparse, the network must be sparse as well"
297+
)
298+
299+
weight_param = torch.nn.Parameter(torch.zeros(neural_network.num_weights))
228300
# Register param. in graph following PyTorch naming convention
229301
self.register_parameter("weight", weight_param)
230302
# If `weight_param` is assigned to `self._weights` after registration,
@@ -237,7 +309,7 @@ def __init__(
237309
if initial_weights is None:
238310
self._weights.data.uniform_(-1, 1)
239311
else:
240-
self._weights.data = Tensor(initial_weights)
312+
self._weights.data = torch.tensor(initial_weights, dtype=torch.float)
241313

242314
@property
243315
def neural_network(self) -> NeuralNetwork:
@@ -250,11 +322,11 @@ def weight(self) -> Tensor:
250322
return self._weights
251323

252324
@property
253-
def sparse(self) -> Optional[bool]:
325+
def sparse(self) -> bool | None:
254326
"""Returns whether this connector returns sparse output or not."""
255327
return self._sparse
256328

257-
def forward(self, input_data: Optional[Tensor] = None) -> Tensor:
329+
def forward(self, input_data: Tensor | None = None) -> Tensor:
258330
"""Forward pass.
259331
260332
Args:
@@ -263,7 +335,7 @@ def forward(self, input_data: Optional[Tensor] = None) -> Tensor:
263335
Returns:
264336
Result of forward pass of this model.
265337
"""
266-
input_ = input_data if input_data is not None else Tensor([])
338+
input_ = input_data if input_data is not None else torch.zeros(0)
267339
return TorchConnector._TorchNNFunction.apply(
268340
input_, self._weights, self._neural_network, self._sparse
269341
)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
---
2+
features:
3+
- |
4+
The PyTorch connector :class:`~qiskit_machine_learning.connector.TorchConnector` now fully
5+
supports sparse output in both forward and backward passes. To enable sparse support, first of
6+
all, the underlying quantum neural network must be sparse. In this case, if the `sparse`
7+
property of the connector itself is not set, then the connector inherits sparsity from the
8+
networks. If the connector is set to be sparse, but the network is not, an exception will be
9+
raised. Also you may set the connector to be dense if the network is sparse.
10+
11+
This snippet illustrates how to create a sparse instance of the connector.
12+
13+
.. code-block:: python
14+
15+
import torch
16+
from qiskit import QuantumCircuit
17+
from qiskit.circuit.library import ZFeatureMap, RealAmplitudes
18+
19+
from qiskit_machine_learning.connectors import TorchConnector
20+
from qiskit_machine_learning.neural_networks import SamplerQNN
21+
22+
num_qubits = 2
23+
fmap = ZFeatureMap(num_qubits, reps=1)
24+
ansatz = RealAmplitudes(num_qubits, reps=1)
25+
qc = QuantumCircuit(num_qubits)
26+
qc.compose(fmap, inplace=True)
27+
qc.compose(ansatz, inplace=True)
28+
29+
qnn = SamplerQNN(
30+
circuit=qc,
31+
input_params=fmap.parameters,
32+
weight_params=ansatz.parameters,
33+
sparse=True,
34+
)
35+
36+
connector = TorchConnector(qnn)
37+
38+
output = connector(torch.tensor([[1., 2.]]))
39+
print(output)
40+
41+
loss = torch.sparse.sum(output)
42+
loss.backward()
43+
44+
grad = connector.weight.grad
45+
print(grad)
46+
47+
In hybrid setup, where a PyTorch-based neural network has classical and quantum layers, sparse
48+
operations should not be mixed with dense ones, otherwise exceptions may be thrown by PyTorch.
49+
50+
Sparse support works on python 3.8+.

test/connectors/test_torch.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# This code is part of Qiskit.
22
#
3-
# (C) Copyright IBM 2022.
3+
# (C) Copyright IBM 2022, 2023.
44
#
55
# This code is licensed under the Apache License, Version 2.0. You may
66
# obtain a copy of this license in the LICENSE.txt file in the root directory
@@ -98,3 +98,8 @@ def assertLogs(self, logger=None, level=None):
9898
def assertListEqual(self, list1, list2, msg=None):
9999
"""Assert list equal."""
100100
raise Exception("Abstract method")
101+
102+
@abstractmethod
103+
def assertRaises(self, expected_exception):
104+
"""Assert raises an exception."""
105+
raise Exception("Abstract method")

0 commit comments

Comments
 (0)