-
Notifications
You must be signed in to change notification settings - Fork 467
Automatic type inference for param_t
in Parametrised Activations
#1139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
11601cd
72026fb
10ec7a2
29f0831
ecf5c2c
49e5a75
baba0f3
0808580
79f7372
a4f5fa5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import math | ||
import struct | ||
from typing import Iterable | ||
|
||
import numpy as np | ||
|
@@ -561,15 +562,34 @@ def _infer_rnn_precision(self, node, types_to_infer): | |
|
||
return inferred_types | ||
|
||
def _infer_par_act_precision(self, node, types_to_infer): | ||
def _infer_const_precision(self, node, type_to_infer, attr_name): | ||
inferred_types = [] | ||
|
||
# For threshold relu, set the parameter precision to be the input precision by default; | ||
# for other parametrized activations, just allow the default precision to be used. | ||
# Can override these values in the configuration by explicitly setting them. | ||
if 'param_t' in inferred_types and self.get_attr('activation').lower() == 'thresholdedrelu': | ||
in_type = node.get_input_variable().type.precision | ||
node.attributes['param_t'].type = in_type | ||
inferred_types.append('param_t') | ||
def get_man_exp(f): | ||
f = np.abs(f) | ||
s = struct.pack('>f', f) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use calculations: based on the value, you can easily determine how many bits you need. Going to structs is hard to follow. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that in general
do you have an algorithm in mind for that you can point me out? |
||
l_float = struct.unpack('>l', s)[0] | ||
bits = f'{l_float:032b}' | ||
m = bits[-23:] | ||
e = bits[-23 - 8 : -23] | ||
return m, e | ||
|
||
param = node.get_attr(attr_name) | ||
m, e = get_man_exp(param) | ||
I_pos = int(e, 2) - 127 + 1 # -127 is the bias of the exponent | ||
try: | ||
W_bits = m.rindex('1') + 2 # + 1 for accounting the index starting from 0, +1 for the leading 1 of the exponent | ||
except Exception: | ||
W_bits = 1 # the value is a power of 2, 1 bit is needed, I_pos will offset the bit in the proper place | ||
if param < 0 and W_bits > 1: # for po2 values the increment is not needed | ||
I_pos += 1 | ||
W_bits += 1 | ||
node.attributes[type_to_infer].precision = FixedPrecisionType(W_bits, I_pos, True if param < 0 else False) | ||
inferred_types.append(type_to_infer) | ||
return inferred_types | ||
|
||
def _infer_par_act_precision(self, node, types_to_infer): | ||
inferred_types = [] | ||
if 'param_t' in types_to_infer: | ||
inferred_types.extend(self._infer_const_precision(node, 'param_t', 'activ_param')) | ||
return inferred_types |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Struct is much too low level for what we are doing here. We have a python float. We should use it, not look at the bits.