Skip to content

Commit ed55394

Browse files
authored
Merge pull request #1006 from fastmachinelearning/pre-commit-and-keras
Fix pre-commit warning and change '.h5' to '.keras' for written output
2 parents 46ec22b + ce33496 commit ed55394

File tree

4 files changed

+4
-4
lines changed

4 files changed

+4
-4
lines changed

hls4ml/writer/catapult_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -884,7 +884,7 @@ def write_yml(self, model):
884884
"""
885885

886886
def keras_model_representer(dumper, keras_model):
887-
model_path = model.config.get_output_dir() + '/keras_model.h5'
887+
model_path = model.config.get_output_dir() + '/keras_model.keras'
888888
keras_model.save(model_path)
889889
return dumper.represent_scalar('!keras_model', model_path)
890890

hls4ml/writer/quartus_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1322,7 +1322,7 @@ def write_yml(self, model):
13221322
"""
13231323

13241324
def keras_model_representer(dumper, keras_model):
1325-
model_path = model.config.get_output_dir() + '/keras_model.h5'
1325+
model_path = model.config.get_output_dir() + '/keras_model.keras'
13261326
keras_model.save(model_path)
13271327
return dumper.represent_scalar('!keras_model', model_path)
13281328

hls4ml/writer/vivado_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,7 @@ def write_yml(self, model):
686686
"""
687687

688688
def keras_model_representer(dumper, keras_model):
689-
model_path = model.config.get_output_dir() + '/keras_model.h5'
689+
model_path = model.config.get_output_dir() + '/keras_model.keras'
690690
keras_model.save(model_path)
691691
return dumper.represent_scalar('!keras_model', model_path)
692692

test/pytest/test_weight_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,5 @@ def test_weight_writer(k, i, f):
2929
print(w_paths[0])
3030
assert len(w_paths) == 1
3131
w_loaded = np.loadtxt(w_paths[0], delimiter=',').reshape(1, 1)
32-
print(f'{w[0,0]:.14}', f'{w_loaded[0,0]:.14}')
32+
print(f'{w[0, 0]:.14}', f'{w_loaded[0, 0]:.14}')
3333
assert np.all(w == w_loaded)

0 commit comments

Comments
 (0)