Skip to content

Commit 6fd736c

Browse files
authored
Implement corrcoef function in keras.ops (#21372)
* Add corrcoef for ops * Update method for complex case * Add init.py for corrcoef * Update code for test case * Update excluded_concrete_tests.txt for openvino * Update axis for corrcoef on tf * update docstrings
1 parent fa31fa7 commit 6fd736c

File tree

12 files changed

+128
-0
lines changed

12 files changed

+128
-0
lines changed

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@
155155
from keras.src.ops.numpy import conj as conj
156156
from keras.src.ops.numpy import conjugate as conjugate
157157
from keras.src.ops.numpy import copy as copy
158+
from keras.src.ops.numpy import corrcoef as corrcoef
158159
from keras.src.ops.numpy import correlate as correlate
159160
from keras.src.ops.numpy import cos as cos
160161
from keras.src.ops.numpy import cosh as cosh

keras/api/_tf_keras/keras/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from keras.src.ops.numpy import conj as conj
4545
from keras.src.ops.numpy import conjugate as conjugate
4646
from keras.src.ops.numpy import copy as copy
47+
from keras.src.ops.numpy import corrcoef as corrcoef
4748
from keras.src.ops.numpy import correlate as correlate
4849
from keras.src.ops.numpy import cos as cos
4950
from keras.src.ops.numpy import cosh as cosh

keras/api/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@
155155
from keras.src.ops.numpy import conj as conj
156156
from keras.src.ops.numpy import conjugate as conjugate
157157
from keras.src.ops.numpy import copy as copy
158+
from keras.src.ops.numpy import corrcoef as corrcoef
158159
from keras.src.ops.numpy import correlate as correlate
159160
from keras.src.ops.numpy import cos as cos
160161
from keras.src.ops.numpy import cosh as cosh

keras/api/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from keras.src.ops.numpy import conj as conj
4545
from keras.src.ops.numpy import conjugate as conjugate
4646
from keras.src.ops.numpy import copy as copy
47+
from keras.src.ops.numpy import corrcoef as corrcoef
4748
from keras.src.ops.numpy import correlate as correlate
4849
from keras.src.ops.numpy import cos as cos
4950
from keras.src.ops.numpy import cosh as cosh

keras/src/backend/jax/numpy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,6 +1338,11 @@ def logical_xor(x1, x2):
13381338
return jnp.logical_xor(x1, x2)
13391339

13401340

1341+
def corrcoef(x):
1342+
x = convert_to_tensor(x)
1343+
return jnp.corrcoef(x)
1344+
1345+
13411346
def correlate(x1, x2, mode="valid"):
13421347
x1 = convert_to_tensor(x1)
13431348
x2 = convert_to_tensor(x2)

keras/src/backend/numpy/numpy.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,6 +1256,19 @@ def logical_xor(x1, x2):
12561256
return np.logical_xor(x1, x2)
12571257

12581258

1259+
def corrcoef(x):
1260+
if x.dtype in ["int64", "float64"]:
1261+
dtype = "float64"
1262+
elif x.dtype in ["bfloat16", "float16"]:
1263+
dtype = x.dtype
1264+
else:
1265+
dtype = config.floatx()
1266+
1267+
x = convert_to_tensor(x)
1268+
1269+
return np.corrcoef(x).astype(dtype)
1270+
1271+
12591272
def correlate(x1, x2, mode="valid"):
12601273
dtype = dtypes.result_type(
12611274
getattr(x1, "dtype", type(x1)),

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ NumpyDtypeTest::test_kaiser
1616
NumpyDtypeTest::test_bitwise
1717
NumpyDtypeTest::test_ceil
1818
NumpyDtypeTest::test_concatenate
19+
NumpyDtypeTest::test_corrcoef
1920
NumpyDtypeTest::test_correlate
2021
NumpyDtypeTest::test_cross
2122
NumpyDtypeTest::test_cumprod
@@ -80,6 +81,7 @@ NumpyOneInputOpsCorrectnessTest::test_hanning
8081
NumpyOneInputOpsCorrectnessTest::test_kaiser
8182
NumpyOneInputOpsCorrectnessTest::test_bitwise_invert
8283
NumpyOneInputOpsCorrectnessTest::test_conj
84+
NumpyOneInputOpsCorrectnessTest::test_corrcoef
8385
NumpyOneInputOpsCorrectnessTest::test_correlate
8486
NumpyOneInputOpsCorrectnessTest::test_cumprod
8587
NumpyOneInputOpsCorrectnessTest::test_diag
@@ -149,6 +151,7 @@ NumpyTwoInputOpsCorrectnessTest::test_vdot
149151
NumpyOneInputOpsDynamicShapeTest::test_angle
150152
NumpyOneInputOpsDynamicShapeTest::test_bartlett
151153
NumpyOneInputOpsDynamicShapeTest::test_blackman
154+
NumpyOneInputOpsDynamicShapeTest::test_corrcoef
152155
NumpyOneInputOpsDynamicShapeTest::test_hamming
153156
NumpyOneInputOpsDynamicShapeTest::test_hanning
154157
NumpyOneInputOpsDynamicShapeTest::test_kaiser

keras/src/backend/openvino/numpy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1666,6 +1666,12 @@ def logical_xor(x1, x2):
16661666
return OpenVINOKerasTensor(ov_opset.logical_xor(x1, x2).output(0))
16671667

16681668

1669+
def corrcoef(x):
1670+
raise NotImplementedError(
1671+
"`corrcoef` is not supported with openvino backend"
1672+
)
1673+
1674+
16691675
def correlate(x1, x2, mode="valid"):
16701676
raise NotImplementedError(
16711677
"`correlate` is not supported with openvino backend"

keras/src/backend/tensorflow/numpy.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2786,6 +2786,38 @@ def logical_xor(x1, x2):
27862786
return tf.math.logical_xor(x1, x2)
27872787

27882788

2789+
def corrcoef(x):
2790+
dtype = x.dtype
2791+
if dtype in ["bool", "int8", "int16", "int32", "uint8", "uint16", "uint32"]:
2792+
dtype = config.floatx()
2793+
x = convert_to_tensor(x, dtype)
2794+
2795+
if tf.rank(x) == 0:
2796+
return tf.constant(float("nan"), dtype=config.floatx())
2797+
2798+
mean = tf.reduce_mean(x, axis=-1, keepdims=True)
2799+
x_centered = x - mean
2800+
2801+
num_samples = tf.cast(tf.shape(x)[-1], x.dtype)
2802+
cov_matrix = tf.matmul(x_centered, x_centered, adjoint_b=True) / (
2803+
num_samples - 1
2804+
)
2805+
2806+
diag = tf.linalg.diag_part(cov_matrix)
2807+
stddev = tf.sqrt(tf.math.real(diag))
2808+
2809+
outer_std = tf.tensordot(stddev, stddev, axes=0)
2810+
outer_std = tf.cast(outer_std, cov_matrix.dtype)
2811+
correlation = cov_matrix / outer_std
2812+
2813+
correlation_clipped = tf.clip_by_value(tf.math.real(correlation), -1.0, 1.0)
2814+
if correlation.dtype.is_complex:
2815+
imag_clipped = tf.clip_by_value(tf.math.imag(correlation), -1.0, 1.0)
2816+
return tf.complex(correlation_clipped, imag_clipped)
2817+
else:
2818+
return correlation_clipped
2819+
2820+
27892821
def correlate(x1, x2, mode="valid"):
27902822
x1 = convert_to_tensor(x1)
27912823
x2 = convert_to_tensor(x2)

keras/src/backend/torch/numpy.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1742,6 +1742,17 @@ def logical_xor(x1, x2):
17421742
return torch.logical_xor(x1, x2)
17431743

17441744

1745+
def corrcoef(x):
1746+
x = convert_to_tensor(x)
1747+
1748+
if standardize_dtype(x.dtype) == "bool":
1749+
x = cast(x, config.floatx())
1750+
elif standardize_dtype(x.dtype) == "int64":
1751+
x = cast(x, "float64")
1752+
1753+
return torch.corrcoef(x)
1754+
1755+
17451756
def correlate(x1, x2, mode="valid"):
17461757
x1 = convert_to_tensor(x1)
17471758
x2 = convert_to_tensor(x2)

0 commit comments

Comments
 (0)