Skip to content

Commit 47e4357

Browse files
authored
Merge pull request #70 from i-colbert/fix/matvec_accumulator_range
Fix (utils): Updating matvec accumulator range calculation
2 parents 0412069 + 7f96161 commit 47e4357

File tree

1 file changed

+6
-10
lines changed

1 file changed

+6
-10
lines changed

src/qonnx/util/basic.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -209,16 +209,12 @@ def pad_tensor_to_multiple_of(ndarray, pad_to_dims, val=0, distr_pad=False):
209209

210210

211211
def calculate_matvec_accumulator_range(matrix: np.ndarray, vec_dt: DataType):
212-
"""Calculate the minimum and maximum possible result (accumulator) values
213-
for a dot product x * A, given matrix A of dims (MW, MH), and vector (1, MW)
214-
with datatype vec_dt. Returns (acc_min, acc_max).
215-
"""
216-
max_weight = abs(matrix).sum(axis=0).max()
217-
max_input = max(abs(vec_dt.min()), abs(vec_dt.max()))
218-
max_value = max_input * max_weight
219-
# If either the weight and input datatypes are signed, then the minimum
220-
# value that their accumulated product can be is -max_value. Else, it's 0.
221-
min_value = -max_value if (matrix.min() < 0) or vec_dt.signed() else 0
212+
"""Calculate the minimum and maximum possible result (accumulator) values for a dot product x * A,
213+
given matrix A of dims (MW, MH), and vector (1, MW) with datatype vec_dt. Returns (acc_min, acc_max)."""
214+
max_vectors = np.where(matrix > 0, vec_dt.max(), vec_dt.min())
215+
min_vectors = np.where(matrix > 0, vec_dt.min(), vec_dt.max())
216+
max_value = (matrix * max_vectors).sum(axis=0).max()
217+
min_value = (matrix * min_vectors).sum(axis=0).min()
222218
return (min_value, max_value)
223219

224220

0 commit comments

Comments
 (0)