Skip to content

Commit 6b74cb0

Browse files
authored
Implement hanning function in keras.ops (#21318)
* Add hanning window for ops * Add test cases * update excluded_concrete_tests.txt
1 parent 6810267 commit 6b74cb0

File tree

11 files changed

+84
-0
lines changed

11 files changed

+84
-0
lines changed

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@
186186
from keras.src.ops.numpy import greater as greater
187187
from keras.src.ops.numpy import greater_equal as greater_equal
188188
from keras.src.ops.numpy import hamming as hamming
189+
from keras.src.ops.numpy import hanning as hanning
189190
from keras.src.ops.numpy import histogram as histogram
190191
from keras.src.ops.numpy import hstack as hstack
191192
from keras.src.ops.numpy import identity as identity

keras/api/_tf_keras/keras/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
from keras.src.ops.numpy import greater as greater
7777
from keras.src.ops.numpy import greater_equal as greater_equal
7878
from keras.src.ops.numpy import hamming as hamming
79+
from keras.src.ops.numpy import hanning as hanning
7980
from keras.src.ops.numpy import histogram as histogram
8081
from keras.src.ops.numpy import hstack as hstack
8182
from keras.src.ops.numpy import identity as identity

keras/api/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@
186186
from keras.src.ops.numpy import greater as greater
187187
from keras.src.ops.numpy import greater_equal as greater_equal
188188
from keras.src.ops.numpy import hamming as hamming
189+
from keras.src.ops.numpy import hanning as hanning
189190
from keras.src.ops.numpy import histogram as histogram
190191
from keras.src.ops.numpy import hstack as hstack
191192
from keras.src.ops.numpy import identity as identity

keras/api/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
from keras.src.ops.numpy import greater as greater
7777
from keras.src.ops.numpy import greater_equal as greater_equal
7878
from keras.src.ops.numpy import hamming as hamming
79+
from keras.src.ops.numpy import hanning as hanning
7980
from keras.src.ops.numpy import histogram as histogram
8081
from keras.src.ops.numpy import hstack as hstack
8182
from keras.src.ops.numpy import identity as identity

keras/src/backend/jax/numpy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ def hamming(x):
4747
return jnp.hamming(x)
4848

4949

50+
def hanning(x):
51+
x = convert_to_tensor(x)
52+
return jnp.hanning(x)
53+
54+
5055
def kaiser(x, beta):
5156
x = convert_to_tensor(x)
5257
return jnp.kaiser(x, beta)

keras/src/backend/numpy/numpy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,11 @@ def hamming(x):
315315
return np.hamming(x).astype(config.floatx())
316316

317317

318+
def hanning(x):
319+
x = convert_to_tensor(x)
320+
return np.hanning(x).astype(config.floatx())
321+
322+
318323
def kaiser(x, beta):
319324
x = convert_to_tensor(x)
320325
return np.kaiser(x, beta).astype(config.floatx())

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ NumpyDtypeTest::test_array
1111
NumpyDtypeTest::test_bartlett
1212
NumpyDtypeTest::test_blackman
1313
NumpyDtypeTest::test_hamming
14+
NumpyDtypeTest::test_hanning
1415
NumpyDtypeTest::test_kaiser
1516
NumpyDtypeTest::test_bitwise
1617
NumpyDtypeTest::test_ceil
@@ -76,6 +77,7 @@ NumpyOneInputOpsCorrectnessTest::test_array
7677
NumpyOneInputOpsCorrectnessTest::test_bartlett
7778
NumpyOneInputOpsCorrectnessTest::test_blackman
7879
NumpyOneInputOpsCorrectnessTest::test_hamming
80+
NumpyOneInputOpsCorrectnessTest::test_hanning
7981
NumpyOneInputOpsCorrectnessTest::test_kaiser
8082
NumpyOneInputOpsCorrectnessTest::test_bitwise_invert
8183
NumpyOneInputOpsCorrectnessTest::test_conj
@@ -150,5 +152,6 @@ NumpyOneInputOpsDynamicShapeTest::test_angle
150152
NumpyOneInputOpsDynamicShapeTest::test_bartlett
151153
NumpyOneInputOpsDynamicShapeTest::test_blackman
152154
NumpyOneInputOpsDynamicShapeTest::test_hamming
155+
NumpyOneInputOpsDynamicShapeTest::test_hanning
153156
NumpyOneInputOpsDynamicShapeTest::test_kaiser
154157
NumpyOneInputOpsStaticShapeTest::test_angle

keras/src/backend/tensorflow/numpy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,11 @@ def hamming(x):
151151
return tf.signal.hamming_window(x, periodic=False)
152152

153153

154+
def hanning(x):
155+
x = convert_to_tensor(x, dtype=tf.int32)
156+
return tf.signal.hann_window(x, periodic=False)
157+
158+
154159
def kaiser(x, beta):
155160
x = convert_to_tensor(x, dtype=tf.int32)
156161
return tf.signal.kaiser_window(x, beta=beta)

keras/src/backend/torch/numpy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,11 @@ def hamming(x):
440440
return torch.signal.windows.hamming(x)
441441

442442

443+
def hanning(x):
444+
x = convert_to_tensor(x)
445+
return torch.signal.windows.hann(x)
446+
447+
443448
def kaiser(x, beta):
444449
x = convert_to_tensor(x)
445450
return torch.signal.windows.kaiser(x, beta=beta)

keras/src/ops/numpy.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,6 +1270,37 @@ def hamming(x):
12701270
return backend.numpy.hamming(x)
12711271

12721272

1273+
class Hanning(Operation):
1274+
def call(self, x):
1275+
return backend.numpy.hanning(x)
1276+
1277+
def compute_output_spec(self, x):
1278+
return KerasTensor(x.shape, dtype=backend.floatx())
1279+
1280+
1281+
@keras_export(["keras.ops.hanning", "keras.ops.numpy.hanning"])
1282+
def hanning(x):
1283+
"""Hanning window function.
1284+
1285+
The Hanning window is defined as:
1286+
`w[n] = 0.5 - 0.5 * cos(2 * pi * n / (N - 1))` for `0 <= n <= N - 1`.
1287+
1288+
Args:
1289+
x: Scalar or 1D Tensor. The window length.
1290+
1291+
Returns:
1292+
A 1D tensor containing the Hanning window values.
1293+
1294+
Example:
1295+
>>> x = keras.ops.convert_to_tensor(5)
1296+
>>> keras.ops.hanning(x)
1297+
array([0. , 0.5, 1. , 0.5, 0. ], dtype=float32)
1298+
"""
1299+
if any_symbolic_tensors((x,)):
1300+
return Hanning().symbolic_call(x)
1301+
return backend.numpy.hanning(x)
1302+
1303+
12731304
class Kaiser(Operation):
12741305
def __init__(self, beta):
12751306
super().__init__()

keras/src/ops/numpy_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,6 +1232,10 @@ def test_hamming(self):
12321232
x = np.random.randint(1, 100 + 1)
12331233
self.assertEqual(knp.hamming(x).shape[0], x)
12341234

1235+
def test_hanning(self):
1236+
x = np.random.randint(1, 100 + 1)
1237+
self.assertEqual(knp.hanning(x).shape[0], x)
1238+
12351239
def test_kaiser(self):
12361240
x = np.random.randint(1, 100 + 1)
12371241
beta = float(np.random.randint(10, 20 + 1))
@@ -3625,6 +3629,12 @@ def test_hamming(self):
36253629

36263630
self.assertAllClose(knp.Hamming()(x), np.hamming(x))
36273631

3632+
def test_hanning(self):
3633+
x = np.random.randint(1, 100 + 1)
3634+
self.assertAllClose(knp.hanning(x), np.hanning(x))
3635+
3636+
self.assertAllClose(knp.Hanning()(x), np.hanning(x))
3637+
36283638
def test_kaiser(self):
36293639
x = np.random.randint(1, 100 + 1)
36303640
beta = float(np.random.randint(10, 20 + 1))
@@ -5639,6 +5649,22 @@ def test_hamming(self, dtype):
56395649
expected_dtype,
56405650
)
56415651

5652+
@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
5653+
def test_hanning(self, dtype):
5654+
import jax.numpy as jnp
5655+
5656+
x = knp.ones((), dtype=dtype)
5657+
x_jax = jnp.ones((), dtype=dtype)
5658+
expected_dtype = standardize_dtype(jnp.hanning(x_jax).dtype)
5659+
5660+
self.assertEqual(
5661+
standardize_dtype(knp.hanning(x).dtype), expected_dtype
5662+
)
5663+
self.assertEqual(
5664+
standardize_dtype(knp.Hanning().symbolic_call(x).dtype),
5665+
expected_dtype,
5666+
)
5667+
56425668
@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
56435669
def test_kaiser(self, dtype):
56445670
import jax.numpy as jnp

0 commit comments

Comments
 (0)