@@ -79,7 +79,9 @@ def max(self):
79
79
def allowed (self , value ):
80
80
"""Check whether given value is allowed for this DataType.
81
81
82
- * value (float32): value to be checked"""
82
+ * value (float32 | np.ndarray): value to be checked
83
+
84
+ Returns a boolean numpy array of the same shape as `value`"""
83
85
pass
84
86
85
87
@abstractmethod
@@ -199,19 +201,19 @@ def allowed(self, value):
199
201
bin_val = np .float32 (value ).view (np .uint32 )
200
202
exp = (bin_val & 0b01111111100000000000000000000000 ) >> fp32_mantissa_bitwidth
201
203
mant = bin_val & 0b00000000011111111111111111111111
202
- exp_biased = exp - fp32_exponent_bias # bias the extracted raw exponent (assume not denormal)
203
- mant_normalized = mant + int ((2 ** fp32_mantissa_bitwidth ) * (exp != 0 )) # append implicit 1
204
+ exp_biased = np . array ( exp ). astype ( int ) - fp32_exponent_bias # bias the extracted raw exponent (assume not denormal)
205
+ mant_normalized = mant + np . array ((2 ** fp32_mantissa_bitwidth ) * (exp != 0 )). astype ( int ) # append implicit 1
204
206
# for this value to be representable as this ArbPrecFloatType:
205
207
# the value must be within the representable range
206
- range_ok = (value <= self .max ()) and ( value >= self .min ())
208
+ range_ok = np . logical_and (value <= self .max (), value >= self .min ())
207
209
# the mantissa must be within representable range:
208
210
# no set bits in the mantissa beyond the allowed number of bits (assume value is not denormal in fp32)
209
211
# compute bits of precision lost to tapered precision if denormal, clamp to: 0 <= dnm_shift <= nrm_mantissa_bitwidth
210
- dnm_shift = int ( min ( max (0 , min_exponent - exp_biased ), nrm_mantissa_bitwidth ))
212
+ dnm_shift = np . array ( np . minimum ( np . maximum (0 , min_exponent - exp_biased ), nrm_mantissa_bitwidth )). astype ( int )
211
213
available_bits = nrm_mantissa_bitwidth - dnm_shift # number of bits of precision available
212
- mantissa_mask = "0" * available_bits + "1" * (fp32_nrm_mantissa_bitwidth - available_bits )
213
- mantissa_ok = (mant_normalized & int ( mantissa_mask , base = 2 ) ) == 0
214
- return bool (mantissa_ok and range_ok )
214
+ mantissa_mask = ( 1 << (fp32_nrm_mantissa_bitwidth - available_bits )) - 1
215
+ mantissa_ok = (mant_normalized & mantissa_mask ) == 0
216
+ return np . logical_and (mantissa_ok , range_ok )
215
217
216
218
def is_integer (self ):
217
219
return False
@@ -286,7 +288,9 @@ def max(self):
286
288
return signed_max if self ._signed else unsigned_max
287
289
288
290
def allowed (self , value ):
289
- return (self .min () <= value ) and (value <= self .max ()) and float (value ).is_integer ()
291
+ value_is_integer = (np .round (value ) == value )
292
+ value_is_bounded = np .logical_and (self .min () <= value , value <= self .max ())
293
+ return np .logical_and (value_is_integer , value_is_bounded )
290
294
291
295
def get_num_possible_values (self ):
292
296
return abs (self .min ()) + abs (self .max ()) + 1
@@ -334,7 +338,7 @@ def max(self):
334
338
return + 1
335
339
336
340
def allowed (self , value ):
337
- return value in [- 1 , + 1 ]
341
+ return np . isin ( value , [- 1 , + 1 ])
338
342
339
343
def get_num_possible_values (self ):
340
344
return 2
@@ -366,7 +370,7 @@ def max(self):
366
370
return + 1
367
371
368
372
def allowed (self , value ):
369
- return value in [- 1 , 0 , + 1 ]
373
+ return np . isin ( value , [- 1 , 0 , + 1 ])
370
374
371
375
def get_num_possible_values (self ):
372
376
return 3
0 commit comments