Skip to content

Commit 6cdf842

Browse files
authored
Add support for einsum operation to pytorch parser (requires 1116) (#1273)
* add einsum support to pytorch parser * use _validate_einsum_expr to extract output shape
1 parent 03953ad commit 6cdf842

File tree

2 files changed

+103
-0
lines changed

2 files changed

+103
-0
lines changed

hls4ml/converters/pytorch/core.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22

33
from hls4ml.converters.pytorch_to_hls import pytorch_handler
4+
from hls4ml.utils.einsum_utils import _validate_einsum_expr
45

56

67
@pytorch_handler('Constant')
@@ -157,3 +158,29 @@ def parse_batchnorm_layer(operation, layer_name, input_names, input_shapes, node
157158
layer['n_filt'] = input_shapes[0][1] # Always channel first for Pytorch
158159

159160
return layer, [shape for shape in input_shapes[0]]
161+
162+
163+
@pytorch_handler('einsum')
164+
def parse_einsum_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
165+
assert 'einsum' in operation
166+
167+
layer = {}
168+
169+
if len(input_names) != 2:
170+
raise Exception('Only einsum operations with two inputs are supported')
171+
layer['class_name'] = 'Einsum'
172+
layer['name'] = layer_name
173+
layer['inputs'] = input_names
174+
175+
# Need to set batch size to a real value instead of 'None'. Using '1' as dummy value
176+
import copy
177+
178+
input_shapes_tmp = copy.deepcopy(input_shapes)
179+
input_shapes_tmp[0][0] = 1
180+
input_shapes_tmp[1][0] = 1
181+
layer['inp0_shape'] = tuple(input_shapes_tmp[0])
182+
layer['inp1_shape'] = tuple(input_shapes_tmp[1])
183+
184+
layer['equation'], layer['out_shape'] = _validate_einsum_expr(node.args[0], layer['inp0_shape'], layer['inp1_shape'])
185+
186+
return layer, [shape for shape in input_shapes[0]]

test/pytest/test_pytorch_api.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -877,3 +877,79 @@ def forward(self, x):
877877
rtol = 0
878878
atol = 5.0e-2
879879
np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=rtol, atol=atol)
880+
881+
882+
class EinsumOuterProduct(nn.Module):
883+
def __init__(self):
884+
super().__init__()
885+
886+
def forward(self, x, y):
887+
return torch.einsum('bi,bj->bij', x, y)
888+
889+
890+
class EinsumBatchMatMul(nn.Module):
891+
def __init__(self):
892+
super().__init__()
893+
894+
def forward(self, x, y):
895+
return torch.einsum('bij,bjk->bik', x, y)
896+
897+
898+
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis'])
899+
@pytest.mark.parametrize('io_type', ['io_parallel'])
900+
def test_einsum_outer_product(backend, io_type):
901+
902+
model = EinsumOuterProduct()
903+
model.eval()
904+
905+
X_input = np.random.rand(3, 4)
906+
Y_input = np.random.rand(3, 5)
907+
908+
pytorch_prediction = model(torch.Tensor(X_input), torch.Tensor(Y_input)).detach().numpy()
909+
910+
config = config_from_pytorch_model(
911+
model,
912+
[(None, 4), (None, 5)],
913+
default_precision='ap_fixed<16,6>',
914+
channels_last_conversion="internal",
915+
transpose_outputs=False,
916+
)
917+
output_dir = str(test_root_path / f'hls4mlprj_pytorch_einsum_outer_product_{backend}_{io_type}')
918+
919+
hls_model = convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type)
920+
921+
hls_model.compile()
922+
923+
hls_prediction = np.reshape(hls_model.predict([X_input, Y_input]), pytorch_prediction.shape)
924+
925+
np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=1e-2, atol=0.01)
926+
927+
928+
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis'])
929+
@pytest.mark.parametrize('io_type', ['io_parallel'])
930+
def test_einsum_batch_matmul(backend, io_type):
931+
932+
model = EinsumBatchMatMul()
933+
model.eval()
934+
935+
X_input = np.random.rand(3, 2, 5)
936+
Y_input = np.random.rand(3, 5, 4)
937+
938+
pytorch_prediction = model(torch.Tensor(X_input), torch.Tensor(Y_input)).detach().numpy()
939+
940+
config = config_from_pytorch_model(
941+
model,
942+
[(None, 2, 5), (None, 5, 4)],
943+
default_precision='ap_fixed<16,6>',
944+
channels_last_conversion="internal",
945+
transpose_outputs=False,
946+
)
947+
output_dir = str(test_root_path / f'hls4mlprj_pytorch_einsum_batch_matmul_{backend}_{io_type}')
948+
949+
hls_model = convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type)
950+
951+
hls_model.compile()
952+
953+
hls_prediction = np.reshape(hls_model.predict([X_input, Y_input]), pytorch_prediction.shape)
954+
955+
np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=1e-2, atol=0.01)

0 commit comments

Comments
 (0)