Skip to content

Commit 99f7eb7

Browse files
authored
Merge branch 'main' into sepconv_io_parallel
2 parents 252958f + 7855db2 commit 99f7eb7

File tree

8 files changed

+15
-6
lines changed

8 files changed

+15
-6
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ exclude: (^hls4ml\/templates\/(vivado|quartus)\/(ap_types|ac_types)\/|^test/pyte
22

33
repos:
44
- repo: https://github.com/psf/black
5-
rev: 24.4.0
5+
rev: 24.4.2
66
hooks:
77
- id: black
88
language_version: python3

docs/advanced/extension.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ For concreteness, let's say our custom layer ``KReverse`` is implemented in Kera
3535
def call(self, inputs):
3636
return tf.reverse(inputs, axis=[-1])
3737
38+
def get_config(self):
39+
return super().get_config()
40+
41+
Make sure you define a ``get_config()`` method for your custom layer as this is needed for correct parsing.
3842
We can define the equivalent layer in hls4ml ``HReverse``, which inherits from ``hls4ml.model.layers.Layer``.
3943

4044
.. code-block:: Python

hls4ml/writer/catapult_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -884,7 +884,7 @@ def write_yml(self, model):
884884
"""
885885

886886
def keras_model_representer(dumper, keras_model):
887-
model_path = model.config.get_output_dir() + '/keras_model.h5'
887+
model_path = model.config.get_output_dir() + '/keras_model.keras'
888888
keras_model.save(model_path)
889889
return dumper.represent_scalar('!keras_model', model_path)
890890

hls4ml/writer/quartus_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1322,7 +1322,7 @@ def write_yml(self, model):
13221322
"""
13231323

13241324
def keras_model_representer(dumper, keras_model):
1325-
model_path = model.config.get_output_dir() + '/keras_model.h5'
1325+
model_path = model.config.get_output_dir() + '/keras_model.keras'
13261326
keras_model.save(model_path)
13271327
return dumper.represent_scalar('!keras_model', model_path)
13281328

hls4ml/writer/vivado_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,7 @@ def write_yml(self, model):
686686
"""
687687

688688
def keras_model_representer(dumper, keras_model):
689-
model_path = model.config.get_output_dir() + '/keras_model.h5'
689+
model_path = model.config.get_output_dir() + '/keras_model.keras'
690690
keras_model.save(model_path)
691691
return dumper.represent_scalar('!keras_model', model_path)
692692

test/pytest/ci-template.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
.pytest:
22
stage: test
3-
image: gitlab-registry.cern.ch/fastmachinelearning/hls4ml-testing:0.4.base
3+
image: gitlab-registry.cern.ch/fastmachinelearning/hls4ml-testing:0.5.5.base
44
tags:
55
- k8s-default
66
before_script:
77
- source ~/.bashrc
8+
- git config --global --add safe.directory /builds/fastmachinelearning/hls4ml
89
- git submodule update --init --recursive hls4ml/templates/catapult/
910
- if [ $EXAMPLEMODEL == 1 ]; then git submodule update --init example-models; fi
1011
- conda activate hls4ml-testing

test/pytest/test_extensions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ def __init__(self):
1919
def call(self, inputs):
2020
return tf.reverse(inputs, axis=[-1])
2121

22+
def get_config(self):
23+
# Breaks serialization and parsing in hls4ml if not defined
24+
return super().get_config()
25+
2226

2327
# hls4ml layer implementation
2428
class HReverse(hls4ml.model.layers.Layer):

test/pytest/test_weight_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,5 @@ def test_weight_writer(k, i, f):
2929
print(w_paths[0])
3030
assert len(w_paths) == 1
3131
w_loaded = np.loadtxt(w_paths[0], delimiter=',').reshape(1, 1)
32-
print(f'{w[0,0]:.14}', f'{w_loaded[0,0]:.14}')
32+
print(f'{w[0, 0]:.14}', f'{w_loaded[0, 0]:.14}')
3333
assert np.all(w == w_loaded)

0 commit comments

Comments
 (0)