Skip to content

Commit 58b7913

Browse files
authored
Merge branch 'main' into GRUv1
2 parents 229dc7b + 7982c87 commit 58b7913

File tree

3 files changed

+99
-36
lines changed

3 files changed

+99
-36
lines changed

hls4ml/converters/pytorch_to_hls.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def decorator(function):
9595
'avg_pool1d': 'AvgPool1d',
9696
'avg_pool2d': 'AvgPool2d',
9797
'flatten': 'Flatten',
98+
'view': 'View',
9899
}
99100

100101

test/pytest/generate_ci_yaml.py

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import glob
21
import itertools
32
import os
3+
from pathlib import Path
44

55
import yaml
66

@@ -9,6 +9,7 @@
99
in the pytests directory to parallelise the CI jobs.
1010
'''
1111

12+
1213
template = """
1314
pytest.{}:
1415
extends: .pytest
@@ -19,6 +20,14 @@
1920

2021
n_test_files_per_yml = int(os.environ.get('N_TESTS_PER_YAML', 4))
2122

23+
BLACKLIST = {'test_reduction'}
24+
25+
26+
def path_to_name(test_path):
27+
path = Path(test_path)
28+
name = path.stem.replace('test_', '')
29+
return name
30+
2231

2332
def batched(iterable, chunk_size):
2433
iterator = iter(iterable)
@@ -32,41 +41,32 @@ def uses_example_model(test_filename):
3241
return 'example-models' in content
3342

3443

35-
yml = None
36-
tests = glob.glob('test_*.py')
37-
for test_batch in batched(tests, n_test_files_per_yml):
38-
name = '+'.join([test.replace('test_', '').replace('.py', '') for test in test_batch])
39-
test_files = ' '.join(list(test_batch))
40-
uses_example_models = int(any([uses_example_model(test) for test in test_batch]))
41-
42-
new_yml = yaml.safe_load(template.format(name, test_files, uses_example_models))
43-
if yml is None:
44-
yml = new_yml
45-
else:
46-
yml.update(new_yml)
47-
48-
# hls4ml Optimization API
49-
tests = glob.glob('test_optimization/test_*.py')
50-
for test in tests:
51-
name = test.replace('test_optimization/', '').replace('test_', '').replace('.py', '')
52-
new_yml = yaml.safe_load(template.format(name, f'test_optimization/test_{name}.py', int(uses_example_model(test))))
53-
if yml is None:
54-
yml = new_yml
55-
else:
56-
yml.update(new_yml)
57-
58-
tests = glob.glob('test_optimization/test_keras/test_*.py')
59-
for test in tests:
60-
# For now, skip Keras Surgeon [conflicting versions]
61-
if 'test_reduction' not in test:
62-
name = test.replace('test_optimization/test_keras/', '').replace('test_', '').replace('.py', '')
63-
new_yml = yaml.safe_load(
64-
template.format(name, f'test_optimization/test_keras/test_{name}.py', int(uses_example_model(test)))
65-
)
44+
def generate_test_yaml(test_root='.'):
45+
test_root = Path(test_root)
46+
test_paths = [path for path in test_root.glob('**/test_*.py') if path.stem not in BLACKLIST]
47+
for path in test_paths:
48+
print(path.name)
49+
need_example_models = [uses_example_model(path) for path in test_paths]
50+
51+
idxs = list(range(len(need_example_models)))
52+
idxs = sorted(idxs, key=lambda i: f'{need_example_models[i]}_{path_to_name(test_paths[i])}')
53+
54+
yml = None
55+
for batch_idxs in batched(idxs, n_test_files_per_yml):
56+
batch_paths: list[Path] = [test_paths[i] for i in batch_idxs]
57+
names = [path_to_name(path) for path in batch_paths]
58+
name = '+'.join(names)
59+
test_files = ' '.join([str(path.relative_to(test_root)) for path in batch_paths])
60+
batch_need_example_model = int(any([need_example_models[i] for i in batch_idxs]))
61+
diff_yml = yaml.safe_load(template.format(name, test_files, batch_need_example_model))
6662
if yml is None:
67-
yml = new_yml
63+
yml = diff_yml
6864
else:
69-
yml.update(new_yml)
65+
yml.update(diff_yml)
66+
return yml
67+
7068

71-
yamlfile = open('pytests.yml', 'w')
72-
yaml.safe_dump(yml, yamlfile)
69+
if __name__ == '__main__':
70+
yml = generate_test_yaml(Path(__file__).parent)
71+
with open('pytests.yml', 'w') as yamlfile:
72+
yaml.safe_dump(yml, yamlfile)

test/pytest/test_pytorch_api.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -808,3 +808,65 @@ def forward(self, x):
808808
hls_prediction = hls_model.predict(hls_input).flatten()
809809

810810
np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0, atol=5e-2)
811+
812+
813+
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
814+
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
815+
def test_view(backend, io_type):
816+
817+
class TestModel(nn.Module):
818+
def __init__(self, n_in, n_out, size_in):
819+
super().__init__()
820+
self.view_mult = n_out * size_in
821+
822+
self.conv1 = nn.Conv1d(
823+
n_in,
824+
n_out,
825+
kernel_size=3,
826+
padding=1,
827+
bias=False,
828+
)
829+
830+
def forward(self, x):
831+
z = self.conv1(x)
832+
z = z.view(-1, self.view_mult)
833+
return z
834+
835+
n_in = 2
836+
n_out = 4
837+
size_in = 128
838+
n_batch = 100
839+
840+
model = TestModel(n_in, n_out, size_in)
841+
model = model.to(memory_format=torch.channels_last)
842+
model.eval()
843+
844+
X_input = np.random.rand(n_batch, n_in, size_in)
845+
pytorch_prediction = model(torch.Tensor(X_input)).detach().numpy()
846+
847+
# X_input is channels last
848+
X_input = np.ascontiguousarray(X_input.transpose(0, 2, 1))
849+
config = config_from_pytorch_model(model, inputs_channel_last=True, transpose_outputs=False)
850+
851+
output_dir = str(test_root_path / f'hls4mlprj_pytorch_view_{backend}_{io_type}')
852+
hls_model = convert_from_pytorch_model(
853+
model,
854+
(None, n_in, size_in),
855+
hls_config=config,
856+
output_dir=output_dir,
857+
backend=backend,
858+
io_type=io_type,
859+
)
860+
861+
hls_model.compile()
862+
863+
# reshape hls prediction to channels last, then transpose, then reshape
864+
# to match .view
865+
hls_prediction = np.reshape(
866+
np.transpose(np.reshape(hls_model.predict(X_input), (n_batch, size_in, n_out)), (0, 2, 1)),
867+
(n_batch, size_in * n_out),
868+
)
869+
870+
rtol = 0
871+
atol = 5.0e-2
872+
np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=rtol, atol=atol)

0 commit comments

Comments
 (0)