@@ -1018,16 +1018,21 @@ def initialize(self):
1018
1018
dims = inp .dim_names
1019
1019
self .add_output_variable (shape , dims )
1020
1020
1021
- gamma = self .get_attr ('gamma_data' )
1022
- beta = self .get_attr ('beta_data ' )
1023
- mean = self .get_attr ('mean_data ' )
1024
- var = self .get_attr ('variance_data' )
1025
-
1026
- scale = gamma / np . sqrt ( var + self . get_attr ( 'epsilon' ))
1027
- bias = beta - scale * mean
1021
+ if self .get_attr ('scale_data' ) is None :
1022
+ gamma = self .get_attr ('gamma_data ' )
1023
+ var = self .get_attr ('variance_data ' )
1024
+ scale = gamma / np . sqrt ( var + self .get_attr ('epsilon' ) )
1025
+ self . add_weights_variable ( name = 'scale' , var_name = 's{index}' , data = scale )
1026
+ else :
1027
+ self . add_weights_variable ( name = ' scale' , var_name = 's{index}' )
1028
1028
1029
- self .add_weights_variable (name = 'scale' , var_name = 's{index}' , data = scale )
1030
- self .add_weights_variable (name = 'bias' , var_name = 'b{index}' , data = bias )
1029
+ if self .get_attr ('bias_data' ) is None :
1030
+ beta = self .get_attr ('beta_data' )
1031
+ mean = self .get_attr ('mean_data' )
1032
+ bias = beta - scale * mean
1033
+ self .add_weights_variable (name = 'bias' , var_name = 'b{index}' , data = bias )
1034
+ else :
1035
+ self .add_weights_variable (name = 'bias' , var_name = 'b{index}' )
1031
1036
1032
1037
1033
1038
# TODO: discuss whether this should be renamed to soemthing more descriptive, and whether the class hierarchy makes sense
0 commit comments