Skip to content

Commit a9ff3a7

Browse files
authored
Implement kaiser function in keras.ops (#21303)
* Add kaiser window for ops * Add test cases * update excluded_concrete_tests
1 parent 3bedb9a commit a9ff3a7

File tree

12 files changed

+99
-0
lines changed

12 files changed

+99
-0
lines changed

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@
195195
from keras.src.ops.numpy import isfinite as isfinite
196196
from keras.src.ops.numpy import isinf as isinf
197197
from keras.src.ops.numpy import isnan as isnan
198+
from keras.src.ops.numpy import kaiser as kaiser
198199
from keras.src.ops.numpy import left_shift as left_shift
199200
from keras.src.ops.numpy import less as less
200201
from keras.src.ops.numpy import less_equal as less_equal

keras/api/_tf_keras/keras/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
from keras.src.ops.numpy import isfinite as isfinite
8686
from keras.src.ops.numpy import isinf as isinf
8787
from keras.src.ops.numpy import isnan as isnan
88+
from keras.src.ops.numpy import kaiser as kaiser
8889
from keras.src.ops.numpy import left_shift as left_shift
8990
from keras.src.ops.numpy import less as less
9091
from keras.src.ops.numpy import less_equal as less_equal

keras/api/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@
195195
from keras.src.ops.numpy import isfinite as isfinite
196196
from keras.src.ops.numpy import isinf as isinf
197197
from keras.src.ops.numpy import isnan as isnan
198+
from keras.src.ops.numpy import kaiser as kaiser
198199
from keras.src.ops.numpy import left_shift as left_shift
199200
from keras.src.ops.numpy import less as less
200201
from keras.src.ops.numpy import less_equal as less_equal

keras/api/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
from keras.src.ops.numpy import isfinite as isfinite
8686
from keras.src.ops.numpy import isinf as isinf
8787
from keras.src.ops.numpy import isnan as isnan
88+
from keras.src.ops.numpy import kaiser as kaiser
8889
from keras.src.ops.numpy import left_shift as left_shift
8990
from keras.src.ops.numpy import less as less
9091
from keras.src.ops.numpy import less_equal as less_equal

keras/src/backend/jax/numpy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ def hamming(x):
4747
return jnp.hamming(x)
4848

4949

50+
def kaiser(x, beta):
51+
x = convert_to_tensor(x)
52+
return jnp.kaiser(x, beta)
53+
54+
5055
def bincount(x, weights=None, minlength=0, sparse=False):
5156
# Note: bincount is never tracable / jittable because the output shape
5257
# depends on the values in x.

keras/src/backend/numpy/numpy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,11 @@ def hamming(x):
315315
return np.hamming(x).astype(config.floatx())
316316

317317

318+
def kaiser(x, beta):
319+
x = convert_to_tensor(x)
320+
return np.kaiser(x, beta).astype(config.floatx())
321+
322+
318323
def bincount(x, weights=None, minlength=0, sparse=False):
319324
if sparse:
320325
raise ValueError("Unsupported value `sparse=True` with numpy backend")

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ NumpyDtypeTest::test_array
1111
NumpyDtypeTest::test_bartlett
1212
NumpyDtypeTest::test_blackman
1313
NumpyDtypeTest::test_hamming
14+
NumpyDtypeTest::test_kaiser
1415
NumpyDtypeTest::test_bitwise
1516
NumpyDtypeTest::test_ceil
1617
NumpyDtypeTest::test_concatenate
@@ -75,6 +76,7 @@ NumpyOneInputOpsCorrectnessTest::test_array
7576
NumpyOneInputOpsCorrectnessTest::test_bartlett
7677
NumpyOneInputOpsCorrectnessTest::test_blackman
7778
NumpyOneInputOpsCorrectnessTest::test_hamming
79+
NumpyOneInputOpsCorrectnessTest::test_kaiser
7880
NumpyOneInputOpsCorrectnessTest::test_bitwise_invert
7981
NumpyOneInputOpsCorrectnessTest::test_conj
8082
NumpyOneInputOpsCorrectnessTest::test_correlate
@@ -148,4 +150,5 @@ NumpyOneInputOpsDynamicShapeTest::test_angle
148150
NumpyOneInputOpsDynamicShapeTest::test_bartlett
149151
NumpyOneInputOpsDynamicShapeTest::test_blackman
150152
NumpyOneInputOpsDynamicShapeTest::test_hamming
153+
NumpyOneInputOpsDynamicShapeTest::test_kaiser
151154
NumpyOneInputOpsStaticShapeTest::test_angle

keras/src/backend/openvino/numpy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,10 @@ def hamming(x):
481481
)
482482

483483

484+
def kaiser(x, beta):
485+
raise NotImplementedError("`kaiser` is not supported with openvino backend")
486+
487+
484488
def bincount(x, weights=None, minlength=0, sparse=False):
485489
if x is None:
486490
raise ValueError("input x is None")

keras/src/backend/tensorflow/numpy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,11 @@ def hamming(x):
151151
return tf.signal.hamming_window(x, periodic=False)
152152

153153

154+
def kaiser(x, beta):
155+
x = convert_to_tensor(x, dtype=tf.int32)
156+
return tf.signal.kaiser_window(x, beta=beta)
157+
158+
154159
def bincount(x, weights=None, minlength=0, sparse=False):
155160
x = convert_to_tensor(x)
156161
dtypes_to_resolve = [x.dtype]

keras/src/backend/torch/numpy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,11 @@ def hamming(x):
440440
return torch.signal.windows.hamming(x)
441441

442442

443+
def kaiser(x, beta):
444+
x = convert_to_tensor(x)
445+
return torch.signal.windows.kaiser(x, beta=beta)
446+
447+
443448
def bincount(x, weights=None, minlength=0, sparse=False):
444449
if sparse:
445450
raise ValueError("Unsupported value `sparse=True` with torch backend")

keras/src/ops/numpy.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,6 +1270,44 @@ def hamming(x):
12701270
return backend.numpy.hamming(x)
12711271

12721272

1273+
class Kaiser(Operation):
1274+
def __init__(self, beta):
1275+
super().__init__()
1276+
self.beta = beta
1277+
1278+
def call(self, x):
1279+
return backend.numpy.kaiser(x, self.beta)
1280+
1281+
def compute_output_spec(self, x):
1282+
return KerasTensor(x.shape, dtype=backend.floatx())
1283+
1284+
1285+
@keras_export(["keras.ops.kaiser", "keras.ops.numpy.kaiser"])
1286+
def kaiser(x, beta):
1287+
"""Kaiser window function.
1288+
1289+
The Kaiser window is defined as:
1290+
`w[n] = I0(beta * sqrt(1 - (2n / (N - 1) - 1)^2)) / I0(beta)`
1291+
where I0 is the modified zeroth-order Bessel function of the first kind.
1292+
1293+
Args:
1294+
x: Scalar or 1D Tensor. The window length.
1295+
beta: Float. Shape parameter for the Kaiser window.
1296+
1297+
Returns:
1298+
A 1D tensor containing the Kaiser window values.
1299+
1300+
Example:
1301+
>>> x = keras.ops.convert_to_tensor(5)
1302+
>>> keras.ops.kaiser(x, beta=14.0)
1303+
array([7.7268669e-06, 1.6493219e-01, 1.0000000e+00, 1.6493219e-01,
1304+
7.7268669e-06], dtype=float32)
1305+
"""
1306+
if any_symbolic_tensors((x,)):
1307+
return Kaiser(beta).symbolic_call(x)
1308+
return backend.numpy.kaiser(x, beta)
1309+
1310+
12731311
class Bincount(Operation):
12741312
def __init__(self, weights=None, minlength=0, sparse=False):
12751313
super().__init__()

keras/src/ops/numpy_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,6 +1232,11 @@ def test_hamming(self):
12321232
x = np.random.randint(1, 100 + 1)
12331233
self.assertEqual(knp.hamming(x).shape[0], x)
12341234

1235+
def test_kaiser(self):
1236+
x = np.random.randint(1, 100 + 1)
1237+
beta = float(np.random.randint(10, 20 + 1))
1238+
self.assertEqual(knp.kaiser(x, beta).shape[0], x)
1239+
12351240
def test_bitwise_invert(self):
12361241
x = KerasTensor((None, 3))
12371242
self.assertEqual(knp.bitwise_invert(x).shape, (None, 3))
@@ -3620,6 +3625,13 @@ def test_hamming(self):
36203625

36213626
self.assertAllClose(knp.Hamming()(x), np.hamming(x))
36223627

3628+
def test_kaiser(self):
3629+
x = np.random.randint(1, 100 + 1)
3630+
beta = float(np.random.randint(10, 20 + 1))
3631+
self.assertAllClose(knp.kaiser(x, beta), np.kaiser(x, beta))
3632+
3633+
self.assertAllClose(knp.Kaiser(beta)(x), np.kaiser(x, beta))
3634+
36233635
@parameterized.named_parameters(
36243636
named_product(sparse_input=(False, True), sparse_arg=(False, True))
36253637
)
@@ -5627,6 +5639,24 @@ def test_hamming(self, dtype):
56275639
expected_dtype,
56285640
)
56295641

5642+
@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
5643+
def test_kaiser(self, dtype):
5644+
import jax.numpy as jnp
5645+
5646+
x = knp.ones((), dtype=dtype)
5647+
beta = knp.ones((), dtype=dtype)
5648+
x_jax = jnp.ones((), dtype=dtype)
5649+
beta_jax = jnp.ones((), dtype=dtype)
5650+
expected_dtype = standardize_dtype(jnp.kaiser(x_jax, beta_jax).dtype)
5651+
5652+
self.assertEqual(
5653+
standardize_dtype(knp.kaiser(x, beta).dtype), expected_dtype
5654+
)
5655+
self.assertEqual(
5656+
standardize_dtype(knp.Kaiser(beta).symbolic_call(x).dtype),
5657+
expected_dtype,
5658+
)
5659+
56305660
@parameterized.named_parameters(named_product(dtype=INT_DTYPES))
56315661
def test_bincount(self, dtype):
56325662
import jax.numpy as jnp

0 commit comments

Comments
 (0)