Skip to content

Commit 1654c1c

Browse files
authored
Expose alpha and theta type for parametrized activations (#1069)
* update parametrized activations for Xilinx * update quartus and catapult * fix pre-commit * fix non-parametrized version of elu * update comment on parametriced activation precision
1 parent bd69272 commit 1654c1c

File tree

15 files changed

+173
-70
lines changed

15 files changed

+173
-70
lines changed

example-models

hls4ml/backends/catapult/passes/core_templates.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,15 @@ def format(self, node):
115115
typedef {table_t.name} table_t;
116116
}};\n"""
117117

118+
param_activ_config_template = """struct {type}_config{index} : nnet::activ_config {{
119+
static const unsigned n_in = {n_in};
120+
static const unsigned table_size = {table_size};
121+
static const unsigned io_type = nnet::{iotype};
122+
static const unsigned reuse_factor = {reuse};
123+
typedef {table_t.name} table_t;
124+
typedef {param_t.name} param_t;
125+
}};\n"""
126+
118127
hard_activ_config_template = """struct {type}_config{index} {{
119128
static const unsigned n_in = {n_in};
120129
static const {slope_t.name} slope;
@@ -140,14 +149,16 @@ def format(self, node):
140149
}};\n"""
141150

142151
activ_function_template = 'nnet::{activation}<{input_t}, {output_t}, {config}>({input}, {output});'
143-
param_activ_function_template = 'nnet::{activation}<{input_t}, {output_t}, {config}>({input}, {param}, {output});'
152+
param_activ_function_template = (
153+
'nnet::{activation}<{input_t}, {param_t.name}, {output_t}, {config}>({input}, {param}, {output});'
154+
)
144155

145156
activ_include_list = ['nnet_utils/nnet_activation.h', 'nnet_utils/nnet_activation_stream.h']
146157

147158

148159
class ActivationConfigTemplate(LayerConfigTemplate):
149160
def __init__(self):
150-
super().__init__((Activation, ParametrizedActivation, PReLU))
161+
super().__init__(Activation)
151162
self.template = activ_config_template
152163

153164
def format(self, node):
@@ -157,6 +168,18 @@ def format(self, node):
157168
return self.template.format(**params)
158169

159170

171+
class ParamActivationConfigTemplate(LayerConfigTemplate):
172+
def __init__(self):
173+
super().__init__((ParametrizedActivation, PReLU))
174+
self.template = param_activ_config_template
175+
176+
def format(self, node):
177+
params = self._default_config_params(node)
178+
params['type'] = node.get_attr('activation')
179+
180+
return self.template.format(**params)
181+
182+
160183
class HardActivationConfigTemplate(LayerConfigTemplate):
161184
def __init__(self):
162185
super().__init__(HardActivation)
@@ -210,7 +233,7 @@ def __init__(self):
210233
def format(self, node):
211234
params = self._default_function_params(node)
212235
params['activation'] = node.get_attr('activation').lower()
213-
params['param'] = node.get_weights('alpha').name
236+
params['param'] = node.get_weights('param').name
214237
params['config'] = '{}_config{}'.format(node.get_attr('activation'), node.index)
215238

216239
return self.template.format(**params)

hls4ml/backends/quartus/passes/core_templates.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,15 @@ def format(self, node):
125125
typedef {table_t.name} table_t;
126126
}};\n"""
127127

128+
param_activ_config_template = """struct {type}_config{index} : nnet::activ_config {{
129+
static const unsigned n_in = {n_in};
130+
static const unsigned table_size = {table_size};
131+
static const unsigned io_type = nnet::{iotype};
132+
static const unsigned reuse_factor = {reuse};
133+
typedef {table_t.name} table_t;
134+
typedef {param_t.name} param_t;
135+
}};\n"""
136+
128137
hard_activ_config_template = """struct {type}_config{index} {{
129138
static const unsigned n_in = {n_in};
130139
static const {slope_t.name} slope;
@@ -146,14 +155,16 @@ def format(self, node):
146155
}};\n"""
147156

148157
activ_function_template = 'nnet::{activation}<{input_t}, {output_t}, {config}>({input}, {output});'
149-
param_activ_function_template = 'nnet::{activation}<{input_t}, {output_t}, {config}>({input}, {param}, {output});'
158+
param_activ_function_template = (
159+
'nnet::{activation}<{input_t}, {param_t.name}, {output_t}, {config}>({input}, {param}, {output});'
160+
)
150161

151162
activ_include_list = ['nnet_utils/nnet_activation.h', 'nnet_utils/nnet_activation_stream.h']
152163

153164

154165
class ActivationConfigTemplate(LayerConfigTemplate):
155166
def __init__(self):
156-
super().__init__((Activation, ParametrizedActivation, PReLU, UnaryLUT))
167+
super().__init__((Activation, UnaryLUT))
157168
self.template = activ_config_template
158169

159170
def format(self, node):
@@ -163,6 +174,18 @@ def format(self, node):
163174
return self.template.format(**params)
164175

165176

177+
class ParamActivationConfigTemplate(LayerConfigTemplate):
178+
def __init__(self):
179+
super().__init__((ParametrizedActivation, PReLU))
180+
self.template = param_activ_config_template
181+
182+
def format(self, node):
183+
params = self._default_config_params(node)
184+
params['type'] = node.get_attr('activation')
185+
186+
return self.template.format(**params)
187+
188+
166189
class HardActivationConfigTemplate(LayerConfigTemplate):
167190
def __init__(self):
168191
super().__init__(HardActivation)
@@ -216,7 +239,7 @@ def __init__(self):
216239
def format(self, node):
217240
params = self._default_function_params(node)
218241
params['activation'] = node.get_attr('activation').lower()
219-
params['param'] = node.get_weights('alpha').name
242+
params['param'] = node.get_weights('param').name
220243
params['config'] = '{}_config{}'.format(node.get_attr('activation'), node.index)
221244

222245
return self.template.format(**params)

hls4ml/backends/vivado/passes/core_templates.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,15 @@ def format(self, node):
116116
typedef {table_t.name} table_t;
117117
}};\n"""
118118

119+
param_activ_config_template = """struct {type}_config{index} : nnet::activ_config {{
120+
static const unsigned n_in = {n_in};
121+
static const unsigned table_size = {table_size};
122+
static const unsigned io_type = nnet::{iotype};
123+
static const unsigned reuse_factor = {reuse};
124+
typedef {table_t.name} table_t;
125+
typedef {param_t.name} param_t;
126+
}};\n"""
127+
119128
hard_activ_config_template = """struct {type}_config{index} {{
120129
static const unsigned n_in = {n_in};
121130
static const {slope_t.name} slope;
@@ -138,14 +147,16 @@ def format(self, node):
138147
}};\n"""
139148

140149
activ_function_template = 'nnet::{activation}<{input_t}, {output_t}, {config}>({input}, {output});'
141-
param_activ_function_template = 'nnet::{activation}<{input_t}, {output_t}, {config}>({input}, {param}, {output});'
150+
param_activ_function_template = (
151+
'nnet::{activation}<{input_t}, {param_t.name}, {output_t}, {config}>({input}, {param}, {output});'
152+
)
142153

143154
activ_include_list = ['nnet_utils/nnet_activation.h', 'nnet_utils/nnet_activation_stream.h']
144155

145156

146157
class ActivationConfigTemplate(LayerConfigTemplate):
147158
def __init__(self):
148-
super().__init__((Activation, ParametrizedActivation, PReLU, UnaryLUT))
159+
super().__init__((Activation, UnaryLUT))
149160
self.template = activ_config_template
150161

151162
def format(self, node):
@@ -155,6 +166,18 @@ def format(self, node):
155166
return self.template.format(**params)
156167

157168

169+
class ParamActivationConfigTemplate(LayerConfigTemplate):
170+
def __init__(self):
171+
super().__init__((ParametrizedActivation, PReLU))
172+
self.template = param_activ_config_template
173+
174+
def format(self, node):
175+
params = self._default_config_params(node)
176+
params['type'] = node.get_attr('activation')
177+
178+
return self.template.format(**params)
179+
180+
158181
class HardActivationConfigTemplate(LayerConfigTemplate):
159182
def __init__(self):
160183
super().__init__(HardActivation)
@@ -208,7 +231,7 @@ def __init__(self):
208231
def format(self, node):
209232
params = self._default_function_params(node)
210233
params['activation'] = node.get_attr('activation').lower()
211-
params['param'] = node.get_weights('alpha').name
234+
params['param'] = node.get_weights('param').name
212235
params['config'] = '{}_config{}'.format(node.get_attr('activation'), node.index)
213236

214237
return self.template.format(**params)

hls4ml/converters/keras/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def parse_activation_layer(keras_layer, input_names, input_shapes, data_reader):
7171
elif layer['class_name'] == 'ReLU':
7272
layer['class_name'] = 'Activation'
7373
elif layer['class_name'] == 'PReLU':
74-
layer['alpha_data'] = get_weights_data(data_reader, layer['name'], 'alpha')
74+
layer['param_data'] = get_weights_data(data_reader, layer['name'], 'alpha')
7575

7676
if layer['class_name'] == 'Activation' and layer['activation'] == 'softmax':
7777
layer['class_name'] = 'Softmax'

hls4ml/converters/pytorch/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod
5555
if layer['class_name'] == 'ELU':
5656
layer['activ_param'] = class_object.alpha
5757
if layer['class_name'] == 'PReLU':
58-
layer['alpha_data'] = class_object.weight.data.numpy()
58+
layer['param_data'] = class_object.weight.data.numpy()
5959
if layer['class_name'] == 'Threshold':
6060
layer['activ_param'] = class_object.threshold
6161
layer['class_name'] = 'ThresholdedReLU'

hls4ml/model/layers.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,17 @@ def initialize(self):
845845

846846

847847
class ParametrizedActivation(Activation):
848+
_expected_attributes = [
849+
Attribute('n_in'),
850+
Attribute('activation', value_type=str),
851+
TypeAttribute('param'),
852+
]
853+
854+
def initialize(self):
855+
super().initialize()
856+
param_t = NamedType(*reversed(self.model.config.get_precision(self, 'param')))
857+
self.set_attr('param_t', param_t)
858+
848859
def _get_act_function_name(self):
849860
act = self.get_attr('activation').lower()
850861
if act == 'leakyrelu':
@@ -882,9 +893,16 @@ def initialize(self):
882893

883894

884895
class PReLU(Activation):
896+
_expected_attributes = [
897+
Attribute('n_in'),
898+
Attribute('activation', value_type=str),
899+
WeightAttribute('param'),
900+
TypeAttribute('param'),
901+
]
902+
885903
def initialize(self):
886904
super().initialize()
887-
self.add_weights_variable(name='alpha', var_name='a{index}')
905+
self.add_weights_variable(name='param', var_name='a{index}')
888906

889907

890908
class Softmax(Activation):

hls4ml/model/optimizer/passes/infer_precision.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ def _infer_precision(self, node, types_to_infer):
8484
if node_class in ['SimpleRNN', 'LSTM', 'GRU']:
8585
return self._infer_rnn_precision(node, types_to_infer)
8686

87+
if node_class in ['ParametrizedActivation']:
88+
return self._infer_par_act_precision(node, types_to_infer)
89+
8790
# What about quantized activation layer? Setting it to 'auto' manually will break it here. We should prevent
8891
# this in config_from_* functions
8992

@@ -557,3 +560,16 @@ def _infer_rnn_precision(self, node, types_to_infer):
557560
inferred_types.append(f'{weightvar}_t')
558561

559562
return inferred_types
563+
564+
def _infer_par_act_precision(self, node, types_to_infer):
565+
inferred_types = []
566+
567+
# For threshold relu, set the parameter precision to be the input precision by default;
568+
# for other parametrized activations, just allow the default precision to be used.
569+
# Can override these values in the configuration by explicitly setting them.
570+
if 'param_t' in inferred_types and self.get_attr('activation').lower() == 'thresholdedrelu':
571+
in_type = node.get_input_variable().type.precision
572+
node.attributes['param_t'].type = in_type
573+
inferred_types.append('param_t')
574+
575+
return inferred_types

hls4ml/templates/catapult/nnet_utils/nnet_activation.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -686,8 +686,8 @@ void hard_tanh(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) {
686686
// *************************************************
687687
// Leaky RELU Activation
688688
// *************************************************
689-
template <class data_T, class res_T, typename CONFIG_T>
690-
void leaky_relu(data_T data[CONFIG_T::n_in], data_T alpha, res_T res[CONFIG_T::n_in]) {
689+
template <class data_T, class param_T, class res_T, typename CONFIG_T>
690+
void leaky_relu(data_T data[CONFIG_T::n_in], param_T alpha, res_T res[CONFIG_T::n_in]) {
691691
//#pragma HLS PIPELINE
692692

693693
data_T datareg;
@@ -703,8 +703,8 @@ void leaky_relu(data_T data[CONFIG_T::n_in], data_T alpha, res_T res[CONFIG_T::n
703703
// *************************************************
704704
// Thresholded RELU Activation
705705
// *************************************************
706-
template <class data_T, class res_T, typename CONFIG_T>
707-
void thresholded_relu(data_T data[CONFIG_T::n_in], data_T theta, res_T res[CONFIG_T::n_in]) {
706+
template <class data_T, class param_T, class res_T, typename CONFIG_T>
707+
void thresholded_relu(data_T data[CONFIG_T::n_in], param_T theta, res_T res[CONFIG_T::n_in]) {
708708
//#pragma HLS PIPELINE
709709

710710
data_T datareg;
@@ -917,8 +917,8 @@ template <typename CONFIG_T, int N_TABLE> void init_elu_table(typename CONFIG_T:
917917

918918
#ifndef USE_AC_MATH
919919

920-
template <class data_T, class res_T, typename CONFIG_T>
921-
void elu(data_T data[CONFIG_T::n_in], const res_T alpha, res_T res[CONFIG_T::n_in]) {
920+
template <class data_T, class param_T, class res_T, typename CONFIG_T>
921+
void elu(data_T data[CONFIG_T::n_in], const param_T alpha, res_T res[CONFIG_T::n_in]) {
922922
// Initialize the lookup table
923923
#ifdef __HLS_SYN__
924924
bool initialized = false;
@@ -953,8 +953,8 @@ void elu(data_T data[CONFIG_T::n_in], const res_T alpha, res_T res[CONFIG_T::n_i
953953

954954
#else
955955

956-
template <class data_T, class res_T, typename CONFIG_T>
957-
void elu(data_T data[CONFIG_T::n_in], const res_T alpha, res_T res[CONFIG_T::n_in]) {
956+
template <class data_T, class param_T, class res_T, typename CONFIG_T>
957+
void elu(data_T data[CONFIG_T::n_in], const param_T alpha, res_T res[CONFIG_T::n_in]) {
958958
for (int ii = 0; ii < CONFIG_T::n_in; ii++) {
959959
ac_math::ac_elu_pwl(data[ii], res[ii], alpha);
960960
}
@@ -1045,8 +1045,8 @@ template <class data_T, class res_T, typename CONFIG_T> void selu(data_T data[CO
10451045
// *************************************************
10461046
// PReLU Activation
10471047
// *************************************************
1048-
template <class data_T, class res_T, typename CONFIG_T>
1049-
void prelu(data_T data[CONFIG_T::n_in], data_T alpha[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) {
1048+
template <class data_T, class param_T, class res_T, typename CONFIG_T>
1049+
void prelu(data_T data[CONFIG_T::n_in], param_T alpha[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) {
10501050
//#pragma HLS PIPELINE
10511051

10521052
data_T datareg;

hls4ml/templates/catapult/nnet_utils/nnet_activation_stream.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -545,8 +545,8 @@ template <class data_T, class res_T, typename CONFIG_T> void hard_tanh(ac_channe
545545
// *************************************************
546546
// Leaky RELU Activation
547547
// *************************************************
548-
template <class data_T, class res_T, typename CONFIG_T>
549-
void leaky_relu(ac_channel<data_T> &data, typename data_T::value_type alpha, ac_channel<res_T> &res) {
548+
template <class data_T, class param_T, class res_T, typename CONFIG_T>
549+
void leaky_relu(ac_channel<data_T> &data, param_T alpha, ac_channel<res_T> &res) {
550550
LeakyReLUActLoop:
551551
for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) {
552552
//#pragma HLS PIPELINE
@@ -571,8 +571,8 @@ void leaky_relu(ac_channel<data_T> &data, typename data_T::value_type alpha, ac_
571571
// Thresholded RELU Activation
572572
// *************************************************
573573

574-
template <class data_T, class res_T, typename CONFIG_T>
575-
void thresholded_relu(ac_channel<data_T> &data, typename data_T::value_type theta, ac_channel<res_T> &res) {
574+
template <class data_T, class param_T, class res_T, typename CONFIG_T>
575+
void thresholded_relu(ac_channel<data_T> &data, param_T theta, ac_channel<res_T> &res) {
576576
ThresholdedReLUActLoop:
577577
for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) {
578578
//#pragma HLS PIPELINE
@@ -720,8 +720,8 @@ template <class data_T, class res_T, typename CONFIG_T> void softsign(ac_channel
720720

721721
#ifndef USE_AC_MATH
722722

723-
template <class data_T, class res_T, typename CONFIG_T>
724-
void elu(ac_channel<data_T> &data, typename data_T::value_type alpha, ac_channel<res_T> &res) {
723+
template <class data_T, class param_T, class res_T, typename CONFIG_T>
724+
void elu(ac_channel<data_T> &data, param_T alpha, ac_channel<res_T> &res) {
725725
// Initialize the lookup table
726726
#ifdef __HLS_SYN__
727727
bool initialized = false;
@@ -763,8 +763,8 @@ void elu(ac_channel<data_T> &data, typename data_T::value_type alpha, ac_channel
763763
}
764764

765765
#else
766-
template <class data_T, class res_T, typename CONFIG_T>
767-
void elu(ac_channel<data_T> &data, typename data_T::value_type alpha, ac_channel<res_T> &res) {
766+
template <class data_T, class param_T, class res_T, typename CONFIG_T>
767+
void elu(ac_channel<data_T> &data, param_T alpha, ac_channel<res_T> &res) {
768768
EluActLoop:
769769
for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) {
770770
data_T in_data = data.read();
@@ -845,8 +845,8 @@ template <class data_T, class res_T, typename CONFIG_T> void selu(ac_channel<dat
845845
// *************************************************
846846
// PReLU Activation
847847
// *************************************************
848-
template <class data_T, class res_T, typename CONFIG_T>
849-
void prelu(ac_channel<data_T> &data, typename data_T::value_type alpha[CONFIG_T::n_in], ac_channel<res_T> &res) {
848+
template <class data_T, class param_T, class res_T, typename CONFIG_T>
849+
void prelu(ac_channel<data_T> &data, const param_T alpha[CONFIG_T::n_in], ac_channel<res_T> &res) {
850850
PReLUActLoop:
851851
for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) {
852852
//#pragma HLS PIPELINE

0 commit comments

Comments
 (0)