Skip to content

[DataType] vectorize checking if value is allowed for a datatype #196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions src/qonnx/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def max(self):
def allowed(self, value):
"""Check whether given value is allowed for this DataType.

* value (float32): value to be checked"""
* value (float32 | np.ndarray): value to be checked

Returns a boolean numpy array of the same shape as `value`"""
pass

@abstractmethod
Expand Down Expand Up @@ -199,19 +201,19 @@ def allowed(self, value):
bin_val = np.float32(value).view(np.uint32)
exp = (bin_val & 0b01111111100000000000000000000000) >> fp32_mantissa_bitwidth
mant = bin_val & 0b00000000011111111111111111111111
exp_biased = exp - fp32_exponent_bias # bias the extracted raw exponent (assume not denormal)
mant_normalized = mant + int((2**fp32_mantissa_bitwidth) * (exp != 0)) # append implicit 1
exp_biased = np.array(exp).astype(int) - fp32_exponent_bias # bias the extracted raw exponent (assume not denormal)
mant_normalized = mant + np.array((2**fp32_mantissa_bitwidth) * (exp != 0)).astype(int) # append implicit 1
# for this value to be representable as this ArbPrecFloatType:
# the value must be within the representable range
range_ok = (value <= self.max()) and (value >= self.min())
range_ok = np.logical_and(value <= self.max(), value >= self.min())
# the mantissa must be within representable range:
# no set bits in the mantissa beyond the allowed number of bits (assume value is not denormal in fp32)
# compute bits of precision lost to tapered precision if denormal, clamp to: 0 <= dnm_shift <= nrm_mantissa_bitwidth
dnm_shift = int(min(max(0, min_exponent - exp_biased), nrm_mantissa_bitwidth))
dnm_shift = np.array(np.minimum(np.maximum(0, min_exponent - exp_biased), nrm_mantissa_bitwidth)).astype(int)
available_bits = nrm_mantissa_bitwidth - dnm_shift # number of bits of precision available
mantissa_mask = "0" * available_bits + "1" * (fp32_nrm_mantissa_bitwidth - available_bits)
mantissa_ok = (mant_normalized & int(mantissa_mask, base=2)) == 0
return bool(mantissa_ok and range_ok)
mantissa_mask = (1 << (fp32_nrm_mantissa_bitwidth - available_bits)) - 1
mantissa_ok = (mant_normalized & mantissa_mask) == 0
return np.logical_and(mantissa_ok, range_ok)

def is_integer(self):
return False
Expand Down Expand Up @@ -286,7 +288,9 @@ def max(self):
return signed_max if self._signed else unsigned_max

def allowed(self, value):
return (self.min() <= value) and (value <= self.max()) and float(value).is_integer()
value_is_integer = (np.round(value) == value)
value_is_bounded = np.logical_and(self.min() <= value, value <= self.max())
return np.logical_and(value_is_integer, value_is_bounded)

def get_num_possible_values(self):
return abs(self.min()) + abs(self.max()) + 1
Expand Down Expand Up @@ -334,7 +338,7 @@ def max(self):
return +1

def allowed(self, value):
return value in [-1, +1]
return np.isin(value, [-1, +1])

def get_num_possible_values(self):
return 2
Expand Down Expand Up @@ -366,7 +370,7 @@ def max(self):
return +1

def allowed(self, value):
return value in [-1, 0, +1]
return np.isin(value, [-1, 0, +1])

def get_num_possible_values(self):
return 3
Expand Down
26 changes: 11 additions & 15 deletions src/qonnx/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,13 +288,9 @@ def sanitize_quant_values(model, node_tensors, execution_context, check_values=F
continue
current_values = execution_context[tensor_name]
updated_values = current_values
has_to_be_rounded = False
# TODO: vectorize with numpy
for value in np.nditer(current_values):
if not dtype.allowed(value):
has_to_be_rounded = True
break
if has_to_be_rounded:
is_allowed = dtype.allowed(current_values)
is_allowed = is_allowed.all() if isinstance(is_allowed, np.ndarray) else is_allowed
if not is_allowed:
updated_values = np.round(current_values)
warnings.warn(
"The values of tensor {} can't be represented "
Expand All @@ -306,15 +302,15 @@ def sanitize_quant_values(model, node_tensors, execution_context, check_values=F
if max_error <= get_execution_error_thresh():
if check_values is True:
# check again if values can now be represented with set finn datatype
# TODO: vectorize with numpy
for value in np.nditer(updated_values):
if not dtype.allowed(value):
raise Exception(
"""Values can't be represented with set
finn datatype ({}) for input {}""".format(
dtype, tensor_name
)
is_allowed = dtype.allowed(updated_values)
is_allowed = is_allowed.all() if isinstance(is_allowed, np.ndarray) else is_allowed
if not is_allowed:
raise Exception(
"""Values can't be represented with set
finn datatype ({}) for input {}""".format(
dtype, tensor_name
)
)
execution_context[tensor_name] = updated_values
else:
raise Exception(
Expand Down
Loading
Loading