Skip to content

Commit 157f653

Browse files
authored
Implement cbrt function in keras.ops (#21453)
* Add cbrt implementation for ops * Add test case for cbrt * Update excluded_concrete_tests.txt
1 parent e5c45f0 commit 157f653

File tree

12 files changed

+109
-0
lines changed

12 files changed

+109
-0
lines changed

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@
149149
from keras.src.ops.numpy import bitwise_xor as bitwise_xor
150150
from keras.src.ops.numpy import blackman as blackman
151151
from keras.src.ops.numpy import broadcast_to as broadcast_to
152+
from keras.src.ops.numpy import cbrt as cbrt
152153
from keras.src.ops.numpy import ceil as ceil
153154
from keras.src.ops.numpy import clip as clip
154155
from keras.src.ops.numpy import concatenate as concatenate

keras/api/_tf_keras/keras/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from keras.src.ops.numpy import bitwise_xor as bitwise_xor
3939
from keras.src.ops.numpy import blackman as blackman
4040
from keras.src.ops.numpy import broadcast_to as broadcast_to
41+
from keras.src.ops.numpy import cbrt as cbrt
4142
from keras.src.ops.numpy import ceil as ceil
4243
from keras.src.ops.numpy import clip as clip
4344
from keras.src.ops.numpy import concatenate as concatenate

keras/api/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@
149149
from keras.src.ops.numpy import bitwise_xor as bitwise_xor
150150
from keras.src.ops.numpy import blackman as blackman
151151
from keras.src.ops.numpy import broadcast_to as broadcast_to
152+
from keras.src.ops.numpy import cbrt as cbrt
152153
from keras.src.ops.numpy import ceil as ceil
153154
from keras.src.ops.numpy import clip as clip
154155
from keras.src.ops.numpy import concatenate as concatenate

keras/api/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from keras.src.ops.numpy import bitwise_xor as bitwise_xor
3939
from keras.src.ops.numpy import blackman as blackman
4040
from keras.src.ops.numpy import broadcast_to as broadcast_to
41+
from keras.src.ops.numpy import cbrt as cbrt
4142
from keras.src.ops.numpy import ceil as ceil
4243
from keras.src.ops.numpy import clip as clip
4344
from keras.src.ops.numpy import concatenate as concatenate

keras/src/backend/jax/numpy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,11 @@ def broadcast_to(x, shape):
499499
return jnp.broadcast_to(x, shape)
500500

501501

502+
def cbrt(x):
503+
x = convert_to_tensor(x)
504+
return jnp.cbrt(x)
505+
506+
502507
@sparse.elementwise_unary(linear=False)
503508
def ceil(x):
504509
x = convert_to_tensor(x)

keras/src/backend/numpy/numpy.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,18 @@ def broadcast_to(x, shape):
414414
return np.broadcast_to(x, shape)
415415

416416

417+
def cbrt(x):
418+
x = convert_to_tensor(x)
419+
420+
dtype = standardize_dtype(x.dtype)
421+
if dtype in ["bool", "int8", "int16", "int32", "uint8", "uint16", "uint32"]:
422+
dtype = config.floatx()
423+
elif dtype == "int64":
424+
dtype = "float64"
425+
426+
return np.cbrt(x).astype(dtype)
427+
428+
417429
def ceil(x):
418430
x = convert_to_tensor(x)
419431
if standardize_dtype(x.dtype) == "int64":

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ NumpyDtypeTest::test_hamming
1414
NumpyDtypeTest::test_hanning
1515
NumpyDtypeTest::test_kaiser
1616
NumpyDtypeTest::test_bitwise
17+
NumpyDtypeTest::test_cbrt
1718
NumpyDtypeTest::test_ceil
1819
NumpyDtypeTest::test_concatenate
1920
NumpyDtypeTest::test_corrcoef
@@ -81,6 +82,7 @@ NumpyOneInputOpsCorrectnessTest::test_hamming
8182
NumpyOneInputOpsCorrectnessTest::test_hanning
8283
NumpyOneInputOpsCorrectnessTest::test_kaiser
8384
NumpyOneInputOpsCorrectnessTest::test_bitwise_invert
85+
NumpyOneInputOpsCorrectnessTest::test_cbrt
8486
NumpyOneInputOpsCorrectnessTest::test_conj
8587
NumpyOneInputOpsCorrectnessTest::test_corrcoef
8688
NumpyOneInputOpsCorrectnessTest::test_correlate
@@ -153,12 +155,14 @@ NumpyTwoInputOpsCorrectnessTest::test_vdot
153155
NumpyOneInputOpsDynamicShapeTest::test_angle
154156
NumpyOneInputOpsDynamicShapeTest::test_bartlett
155157
NumpyOneInputOpsDynamicShapeTest::test_blackman
158+
NumpyOneInputOpsDynamicShapeTest::test_cbrt
156159
NumpyOneInputOpsDynamicShapeTest::test_corrcoef
157160
NumpyOneInputOpsDynamicShapeTest::test_deg2rad
158161
NumpyOneInputOpsDynamicShapeTest::test_hamming
159162
NumpyOneInputOpsDynamicShapeTest::test_hanning
160163
NumpyOneInputOpsDynamicShapeTest::test_kaiser
161164
NumpyOneInputOpsStaticShapeTest::test_angle
165+
NumpyOneInputOpsStaticShapeTest::test_cbrt
162166
NumpyOneInputOpsStaticShapeTest::test_deg2rad
163167
CoreOpsBehaviorTests::test_associative_scan_invalid_arguments
164168
CoreOpsBehaviorTests::test_scan_invalid_arguments

keras/src/backend/openvino/numpy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,10 @@ def broadcast_to(x, shape):
545545
return OpenVINOKerasTensor(ov_opset.broadcast(x, target_shape).output(0))
546546

547547

548+
def cbrt(x):
549+
raise NotImplementedError("`cbrt` is not supported with openvino backend")
550+
551+
548552
def ceil(x):
549553
x = get_ov_output(x)
550554
return OpenVINOKerasTensor(ov_opset.ceil(x).output(0))

keras/src/backend/tensorflow/numpy.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,6 +1089,18 @@ def broadcast_to(x, shape):
10891089
return tf.broadcast_to(x, shape)
10901090

10911091

1092+
def cbrt(x):
1093+
x = convert_to_tensor(x)
1094+
1095+
dtype = standardize_dtype(x.dtype)
1096+
if dtype == "int64":
1097+
x = tf.cast(x, "float64")
1098+
elif dtype not in ["bfloat16", "float16", "float64"]:
1099+
x = tf.cast(x, config.floatx())
1100+
1101+
return tf.sign(x) * tf.pow(tf.abs(x), 1.0 / 3.0)
1102+
1103+
10921104
@sparse.elementwise_unary
10931105
def ceil(x):
10941106
x = convert_to_tensor(x)

keras/src/backend/torch/numpy.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,18 @@ def broadcast_to(x, shape):
540540
return torch.broadcast_to(x, shape)
541541

542542

543+
def cbrt(x):
544+
x = convert_to_tensor(x)
545+
546+
dtype = standardize_dtype(x.dtype)
547+
if dtype == "bool":
548+
x = cast(x, "int32")
549+
elif dtype == "int64":
550+
x = cast(x, "float64")
551+
552+
return torch.sign(x) * torch.abs(x) ** (1.0 / 3.0)
553+
554+
543555
def ceil(x):
544556
x = convert_to_tensor(x)
545557
ori_dtype = standardize_dtype(x.dtype)

0 commit comments

Comments
 (0)