@@ -209,16 +209,12 @@ def pad_tensor_to_multiple_of(ndarray, pad_to_dims, val=0, distr_pad=False):
209
209
210
210
211
211
def calculate_matvec_accumulator_range (matrix : np .ndarray , vec_dt : DataType ):
212
- """Calculate the minimum and maximum possible result (accumulator) values
213
- for a dot product x * A, given matrix A of dims (MW, MH), and vector (1, MW)
214
- with datatype vec_dt. Returns (acc_min, acc_max).
215
- """
216
- max_weight = abs (matrix ).sum (axis = 0 ).max ()
217
- max_input = max (abs (vec_dt .min ()), abs (vec_dt .max ()))
218
- max_value = max_input * max_weight
219
- # If either the weight and input datatypes are signed, then the minimum
220
- # value that their accumulated product can be is -max_value. Else, it's 0.
221
- min_value = - max_value if (matrix .min () < 0 ) or vec_dt .signed () else 0
212
+ """Calculate the minimum and maximum possible result (accumulator) values for a dot product x * A,
213
+ given matrix A of dims (MW, MH), and vector (1, MW) with datatype vec_dt. Returns (acc_min, acc_max)."""
214
+ max_vectors = np .where (matrix > 0 , vec_dt .max (), vec_dt .min ())
215
+ min_vectors = np .where (matrix > 0 , vec_dt .min (), vec_dt .max ())
216
+ max_value = (matrix * max_vectors ).sum (axis = 0 ).max ()
217
+ min_value = (matrix * min_vectors ).sum (axis = 0 ).min ()
222
218
return (min_value , max_value )
223
219
224
220
0 commit comments