Skip to content

Commit 532a84b

Browse files
fix(transformers/ut): fix compute_diffs Division by zero check (#1351)
* compute diffs Division by zero check * Update tests/modeling_test_utils.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 41b8d04 commit 532a84b

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/modeling_test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def compute_diffs(pt_outputs: Union[torch.Tensor, np.ndarray], ms_outputs: Union
379379
# dist(x, y) := ||x - y|| / ||y||, where ||·|| means Frobenius norm
380380

381381
# adaption for tensor with all zeros element
382-
eps = 1e-9 if np.all(m.astype(np.float32) == 0) and np.all(p.astype(np.float32) == 0) else 0
382+
eps = 1e-9 if np.isclose(np.linalg.norm(p), 0, atol=1e-9) else 0
383383
d = np.linalg.norm(p - m) / (np.linalg.norm(p) + eps)
384384
diffs.append(d)
385385

0 commit comments

Comments
 (0)