Skip to content

Commit 2a2065e

Browse files
authored
Merge pull request #196 from SpiritSeeker/vectorize-datatype-allowed
[DataType] vectorize checking if value is allowed for a datatype
2 parents 565366a + d299620 commit 2a2065e

File tree

3 files changed

+306
-56
lines changed

3 files changed

+306
-56
lines changed

src/qonnx/core/datatype.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ def max(self):
7979
def allowed(self, value):
8080
"""Check whether given value is allowed for this DataType.
8181
82-
* value (float32): value to be checked"""
82+
* value (float32 | np.ndarray): value to be checked
83+
84+
Returns a boolean numpy array of the same shape as `value`"""
8385
pass
8486

8587
@abstractmethod
@@ -199,19 +201,19 @@ def allowed(self, value):
199201
bin_val = np.float32(value).view(np.uint32)
200202
exp = (bin_val & 0b01111111100000000000000000000000) >> fp32_mantissa_bitwidth
201203
mant = bin_val & 0b00000000011111111111111111111111
202-
exp_biased = exp - fp32_exponent_bias # bias the extracted raw exponent (assume not denormal)
203-
mant_normalized = mant + int((2**fp32_mantissa_bitwidth) * (exp != 0)) # append implicit 1
204+
exp_biased = np.array(exp).astype(int) - fp32_exponent_bias # bias the extracted raw exponent (assume not denormal)
205+
mant_normalized = mant + np.array((2**fp32_mantissa_bitwidth) * (exp != 0)).astype(int) # append implicit 1
204206
# for this value to be representable as this ArbPrecFloatType:
205207
# the value must be within the representable range
206-
range_ok = (value <= self.max()) and (value >= self.min())
208+
range_ok = np.logical_and(value <= self.max(), value >= self.min())
207209
# the mantissa must be within representable range:
208210
# no set bits in the mantissa beyond the allowed number of bits (assume value is not denormal in fp32)
209211
# compute bits of precision lost to tapered precision if denormal, clamp to: 0 <= dnm_shift <= nrm_mantissa_bitwidth
210-
dnm_shift = int(min(max(0, min_exponent - exp_biased), nrm_mantissa_bitwidth))
212+
dnm_shift = np.array(np.minimum(np.maximum(0, min_exponent - exp_biased), nrm_mantissa_bitwidth)).astype(int)
211213
available_bits = nrm_mantissa_bitwidth - dnm_shift # number of bits of precision available
212-
mantissa_mask = "0" * available_bits + "1" * (fp32_nrm_mantissa_bitwidth - available_bits)
213-
mantissa_ok = (mant_normalized & int(mantissa_mask, base=2)) == 0
214-
return bool(mantissa_ok and range_ok)
214+
mantissa_mask = (1 << (fp32_nrm_mantissa_bitwidth - available_bits)) - 1
215+
mantissa_ok = (mant_normalized & mantissa_mask) == 0
216+
return np.logical_and(mantissa_ok, range_ok)
215217

216218
def is_integer(self):
217219
return False
@@ -286,7 +288,9 @@ def max(self):
286288
return signed_max if self._signed else unsigned_max
287289

288290
def allowed(self, value):
289-
return (self.min() <= value) and (value <= self.max()) and float(value).is_integer()
291+
value_is_integer = (np.round(value) == value)
292+
value_is_bounded = np.logical_and(self.min() <= value, value <= self.max())
293+
return np.logical_and(value_is_integer, value_is_bounded)
290294

291295
def get_num_possible_values(self):
292296
return abs(self.min()) + abs(self.max()) + 1
@@ -334,7 +338,7 @@ def max(self):
334338
return +1
335339

336340
def allowed(self, value):
337-
return value in [-1, +1]
341+
return np.isin(value, [-1, +1])
338342

339343
def get_num_possible_values(self):
340344
return 2
@@ -366,7 +370,7 @@ def max(self):
366370
return +1
367371

368372
def allowed(self, value):
369-
return value in [-1, 0, +1]
373+
return np.isin(value, [-1, 0, +1])
370374

371375
def get_num_possible_values(self):
372376
return 3

src/qonnx/util/basic.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -288,13 +288,9 @@ def sanitize_quant_values(model, node_tensors, execution_context, check_values=F
288288
continue
289289
current_values = execution_context[tensor_name]
290290
updated_values = current_values
291-
has_to_be_rounded = False
292-
# TODO: vectorize with numpy
293-
for value in np.nditer(current_values):
294-
if not dtype.allowed(value):
295-
has_to_be_rounded = True
296-
break
297-
if has_to_be_rounded:
291+
is_allowed = dtype.allowed(current_values)
292+
is_allowed = is_allowed.all() if isinstance(is_allowed, np.ndarray) else is_allowed
293+
if not is_allowed:
298294
updated_values = np.round(current_values)
299295
warnings.warn(
300296
"The values of tensor {} can't be represented "
@@ -306,15 +302,15 @@ def sanitize_quant_values(model, node_tensors, execution_context, check_values=F
306302
if max_error <= get_execution_error_thresh():
307303
if check_values is True:
308304
# check again if values can now be represented with set finn datatype
309-
# TODO: vectorize with numpy
310-
for value in np.nditer(updated_values):
311-
if not dtype.allowed(value):
312-
raise Exception(
313-
"""Values can't be represented with set
314-
finn datatype ({}) for input {}""".format(
315-
dtype, tensor_name
316-
)
305+
is_allowed = dtype.allowed(updated_values)
306+
is_allowed = is_allowed.all() if isinstance(is_allowed, np.ndarray) else is_allowed
307+
if not is_allowed:
308+
raise Exception(
309+
"""Values can't be represented with set
310+
finn datatype ({}) for input {}""".format(
311+
dtype, tensor_name
317312
)
313+
)
318314
execution_context[tensor_name] = updated_values
319315
else:
320316
raise Exception(

0 commit comments

Comments
 (0)