Skip to content

Commit 6f284ec

Browse files
committed
support kv3 parsed batchnorm
1 parent bf6f5a0 commit 6f284ec

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

hls4ml/model/layers.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,16 +1018,21 @@ def initialize(self):
10181018
dims = inp.dim_names
10191019
self.add_output_variable(shape, dims)
10201020

1021-
gamma = self.get_attr('gamma_data')
1022-
beta = self.get_attr('beta_data')
1023-
mean = self.get_attr('mean_data')
1024-
var = self.get_attr('variance_data')
1025-
1026-
scale = gamma / np.sqrt(var + self.get_attr('epsilon'))
1027-
bias = beta - scale * mean
1021+
if self.get_attr('scale_data') is None:
1022+
gamma = self.get_attr('gamma_data')
1023+
var = self.get_attr('variance_data')
1024+
scale = gamma / np.sqrt(var + self.get_attr('epsilon'))
1025+
self.add_weights_variable(name='scale', var_name='s{index}', data=scale)
1026+
else:
1027+
self.add_weights_variable(name='scale', var_name='s{index}')
10281028

1029-
self.add_weights_variable(name='scale', var_name='s{index}', data=scale)
1030-
self.add_weights_variable(name='bias', var_name='b{index}', data=bias)
1029+
if self.get_attr('bias_data') is None:
1030+
beta = self.get_attr('beta_data')
1031+
mean = self.get_attr('mean_data')
1032+
bias = beta - scale * mean
1033+
self.add_weights_variable(name='bias', var_name='b{index}', data=bias)
1034+
else:
1035+
self.add_weights_variable(name='bias', var_name='b{index}')
10311036

10321037

10331038
# TODO: discuss whether this should be renamed to soemthing more descriptive, and whether the class hierarchy makes sense

0 commit comments

Comments
 (0)