@@ -124,9 +124,6 @@ def _overlap_sequences(x, sequence_stride):
124
124
125
125
x = _overlap_sequences (x , sequence_stride )
126
126
127
- if backend .backend () in {"numpy" , "jax" }:
128
- x = np .nan_to_num (x )
129
-
130
127
start = 0 if center is False else fft_length // 2
131
128
if length is not None :
132
129
end = start + length
@@ -862,6 +859,9 @@ def test_istft(
862
859
truncated_len = int (output .shape [- 1 ] * 0.05 )
863
860
output = output [..., truncated_len :- truncated_len ]
864
861
ref = ref [..., truncated_len :- truncated_len ]
862
+ # Nans are handled differently in different backends, so zero them out.
863
+ output = np .nan_to_num (backend .convert_to_numpy (output ), nan = 0.0 )
864
+ ref = np .nan_to_num (ref , nan = 0.0 )
865
865
self .assertAllClose (output , ref , atol = 1e-5 , rtol = 1e-5 )
866
866
867
867
# Test N-D case.
@@ -891,6 +891,9 @@ def test_istft(
891
891
truncated_len = int (output .shape [- 1 ] * 0.05 )
892
892
output = output [..., truncated_len :- truncated_len ]
893
893
ref = ref [..., truncated_len :- truncated_len ]
894
+ # Nans are handled differently in different backends, so zero them out.
895
+ output = np .nan_to_num (backend .convert_to_numpy (output ), nan = 0.0 )
896
+ ref = np .nan_to_num (ref , nan = 0.0 )
894
897
self .assertAllClose (output , ref , atol = 1e-5 , rtol = 1e-5 )
895
898
896
899
def test_rsqrt (self ):
0 commit comments