Skip to content

Commit d74d5a6

Browse files
authored
Fix parsing einsum with a single input (#1311)
1 parent a07d4f1 commit d74d5a6

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

hls4ml/converters/pytorch/core.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,10 @@ def parse_einsum_layer(operation, layer_name, input_names, input_shapes, node, c
166166

167167
layer = {}
168168

169-
if len(input_names) != 2:
169+
if len(input_names) == 1:
170+
input_names += input_names
171+
input_shapes += input_shapes
172+
elif len(input_names) > 2:
170173
raise Exception('Only einsum operations with two inputs are supported')
171174
layer['class_name'] = 'Einsum'
172175
layer['name'] = layer_name

test/pytest/test_pytorch_api.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,18 @@ def forward(self, x, y):
895895
return torch.einsum('bij,bjk->bik', x, y)
896896

897897

898+
class EinsumSingleInput(nn.Module):
899+
def __init__(self, input_dim=8):
900+
super().__init__()
901+
self.input_dim = input_dim
902+
self.linear = nn.Linear(self.input_dim, self.input_dim)
903+
904+
def forward(self, x):
905+
"""using torch einsum to get the dot product"""
906+
out = self.linear(x)
907+
return torch.einsum("ij,ij->i", out, out)
908+
909+
898910
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis'])
899911
@pytest.mark.parametrize('io_type', ['io_parallel'])
900912
def test_einsum_outer_product(backend, io_type):
@@ -953,3 +965,32 @@ def test_einsum_batch_matmul(backend, io_type):
953965
hls_prediction = np.reshape(hls_model.predict([X_input, Y_input]), pytorch_prediction.shape)
954966

955967
np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=1e-2, atol=0.01)
968+
969+
970+
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis'])
971+
@pytest.mark.parametrize('io_type', ['io_parallel'])
972+
def test_einsum_single_input(backend, io_type):
973+
974+
model = EinsumSingleInput()
975+
model.eval()
976+
977+
X_input = np.random.rand(3, 8)
978+
979+
pytorch_prediction = model(torch.Tensor(X_input)).detach().numpy()
980+
981+
config = config_from_pytorch_model(
982+
model,
983+
[(None, 8)],
984+
default_precision='ap_fixed<16,6>',
985+
channels_last_conversion="internal",
986+
transpose_outputs=False,
987+
)
988+
output_dir = str(test_root_path / f'hls4mlprj_pytorch_einsum_single_input_{backend}_{io_type}')
989+
990+
hls_model = convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type)
991+
992+
hls_model.compile()
993+
994+
hls_prediction = np.reshape(hls_model.predict(X_input), pytorch_prediction.shape)
995+
996+
np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=1e-2, atol=0.01)

0 commit comments

Comments
 (0)