Skip to content

Commit 7982c87

Browse files
JanFSchultevloncar
andauthored
Add View to layer name map for pytorch parser (#1039)
* add view to layer name map in pytoch converter * trigger pre-commit * add test for view in pytorch * Use unique output directory for pytorch 'view' tests --------- Co-authored-by: Vladimir Loncar <vloncar@users.noreply.github.com>
1 parent 5c0c4e6 commit 7982c87

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

hls4ml/converters/pytorch_to_hls.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def decorator(function):
9595
'avg_pool1d': 'AvgPool1d',
9696
'avg_pool2d': 'AvgPool2d',
9797
'flatten': 'Flatten',
98+
'view': 'View',
9899
}
99100

100101

test/pytest/test_pytorch_api.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -810,3 +810,65 @@ def forward(self, x):
810810
hls_prediction = hls_model.predict(hls_input).flatten()
811811

812812
np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0, atol=5e-2)
813+
814+
815+
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
816+
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
817+
def test_view(backend, io_type):
818+
819+
class TestModel(nn.Module):
820+
def __init__(self, n_in, n_out, size_in):
821+
super().__init__()
822+
self.view_mult = n_out * size_in
823+
824+
self.conv1 = nn.Conv1d(
825+
n_in,
826+
n_out,
827+
kernel_size=3,
828+
padding=1,
829+
bias=False,
830+
)
831+
832+
def forward(self, x):
833+
z = self.conv1(x)
834+
z = z.view(-1, self.view_mult)
835+
return z
836+
837+
n_in = 2
838+
n_out = 4
839+
size_in = 128
840+
n_batch = 100
841+
842+
model = TestModel(n_in, n_out, size_in)
843+
model = model.to(memory_format=torch.channels_last)
844+
model.eval()
845+
846+
X_input = np.random.rand(n_batch, n_in, size_in)
847+
pytorch_prediction = model(torch.Tensor(X_input)).detach().numpy()
848+
849+
# X_input is channels last
850+
X_input = np.ascontiguousarray(X_input.transpose(0, 2, 1))
851+
config = config_from_pytorch_model(model, inputs_channel_last=True, transpose_outputs=False)
852+
853+
output_dir = str(test_root_path / f'hls4mlprj_pytorch_view_{backend}_{io_type}')
854+
hls_model = convert_from_pytorch_model(
855+
model,
856+
(None, n_in, size_in),
857+
hls_config=config,
858+
output_dir=output_dir,
859+
backend=backend,
860+
io_type=io_type,
861+
)
862+
863+
hls_model.compile()
864+
865+
# reshape hls prediction to channels last, then transpose, then reshape
866+
# to match .view
867+
hls_prediction = np.reshape(
868+
np.transpose(np.reshape(hls_model.predict(X_input), (n_batch, size_in, n_out)), (0, 2, 1)),
869+
(n_batch, size_in * n_out),
870+
)
871+
872+
rtol = 0
873+
atol = 5.0e-2
874+
np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=rtol, atol=atol)

0 commit comments

Comments
 (0)