Skip to content

Commit 3a11132

Browse files
authored
Fix nan istft flakiness in different backends in math_test.py (#21419)
1 parent 6ff1cdb commit 3a11132

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

keras/src/ops/math_test.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,6 @@ def _overlap_sequences(x, sequence_stride):
124124

125125
x = _overlap_sequences(x, sequence_stride)
126126

127-
if backend.backend() in {"numpy", "jax"}:
128-
x = np.nan_to_num(x)
129-
130127
start = 0 if center is False else fft_length // 2
131128
if length is not None:
132129
end = start + length
@@ -862,6 +859,9 @@ def test_istft(
862859
truncated_len = int(output.shape[-1] * 0.05)
863860
output = output[..., truncated_len:-truncated_len]
864861
ref = ref[..., truncated_len:-truncated_len]
862+
# Nans are handled differently in different backends, so zero them out.
863+
output = np.nan_to_num(backend.convert_to_numpy(output), nan=0.0)
864+
ref = np.nan_to_num(ref, nan=0.0)
865865
self.assertAllClose(output, ref, atol=1e-5, rtol=1e-5)
866866

867867
# Test N-D case.
@@ -891,6 +891,9 @@ def test_istft(
891891
truncated_len = int(output.shape[-1] * 0.05)
892892
output = output[..., truncated_len:-truncated_len]
893893
ref = ref[..., truncated_len:-truncated_len]
894+
# Nans are handled differently in different backends, so zero them out.
895+
output = np.nan_to_num(backend.convert_to_numpy(output), nan=0.0)
896+
ref = np.nan_to_num(ref, nan=0.0)
894897
self.assertAllClose(output, ref, atol=1e-5, rtol=1e-5)
895898

896899
def test_rsqrt(self):

0 commit comments

Comments
 (0)