Skip to content

Commit b539552

Browse files
authored
Fix bitwise left_shift and right_shift result dtype... (#21034)
when second argument is a constant int. Previously, a `convert_to_tensor` was applied to the second argument, making it an `int32` or `int64`. The result dtype would take into account this dtype, which could upgrade the dtype of the result. The expectation is that if the second argument is a constant, the result dtype is the same as the first argument. This is already supported correctly by all underlying backend implementations.
1 parent 04ebf72 commit b539552

File tree

6 files changed

+44
-24
lines changed

6 files changed

+44
-24
lines changed

keras/src/backend/jax/numpy.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,8 @@ def bitwise_xor(x, y):
439439

440440
def bitwise_left_shift(x, y):
441441
x = convert_to_tensor(x)
442-
y = convert_to_tensor(y)
442+
if not isinstance(y, int):
443+
y = convert_to_tensor(y)
443444
return jnp.left_shift(x, y)
444445

445446

@@ -449,7 +450,8 @@ def left_shift(x, y):
449450

450451
def bitwise_right_shift(x, y):
451452
x = convert_to_tensor(x)
452-
y = convert_to_tensor(y)
453+
if not isinstance(y, int):
454+
y = convert_to_tensor(y)
453455
return jnp.right_shift(x, y)
454456

455457

keras/src/backend/numpy/numpy.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,8 @@ def bitwise_xor(x, y):
355355

356356
def bitwise_left_shift(x, y):
357357
x = convert_to_tensor(x)
358-
y = convert_to_tensor(y)
358+
if not isinstance(y, int):
359+
y = convert_to_tensor(y)
359360
return np.left_shift(x, y)
360361

361362

@@ -365,7 +366,8 @@ def left_shift(x, y):
365366

366367
def bitwise_right_shift(x, y):
367368
x = convert_to_tensor(x)
368-
y = convert_to_tensor(y)
369+
if not isinstance(y, int):
370+
y = convert_to_tensor(y)
369371
return np.right_shift(x, y)
370372

371373

keras/src/backend/tensorflow/numpy.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,10 +1007,11 @@ def bitwise_xor(x, y):
10071007

10081008
def bitwise_left_shift(x, y):
10091009
x = convert_to_tensor(x)
1010-
y = convert_to_tensor(y)
1011-
dtype = dtypes.result_type(x.dtype, y.dtype)
1012-
x = tf.cast(x, dtype)
1013-
y = tf.cast(y, dtype)
1010+
if not isinstance(y, int):
1011+
y = convert_to_tensor(y)
1012+
dtype = dtypes.result_type(x.dtype, y.dtype)
1013+
x = tf.cast(x, dtype)
1014+
y = tf.cast(y, dtype)
10141015
return tf.bitwise.left_shift(x, y)
10151016

10161017

@@ -1020,10 +1021,11 @@ def left_shift(x, y):
10201021

10211022
def bitwise_right_shift(x, y):
10221023
x = convert_to_tensor(x)
1023-
y = convert_to_tensor(y)
1024-
dtype = dtypes.result_type(x.dtype, y.dtype)
1025-
x = tf.cast(x, dtype)
1026-
y = tf.cast(y, dtype)
1024+
if not isinstance(y, int):
1025+
y = convert_to_tensor(y)
1026+
dtype = dtypes.result_type(x.dtype, y.dtype)
1027+
x = tf.cast(x, dtype)
1028+
y = tf.cast(y, dtype)
10271029
return tf.bitwise.right_shift(x, y)
10281030

10291031

keras/src/backend/torch/numpy.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,8 @@ def bitwise_xor(x, y):
479479

480480
def bitwise_left_shift(x, y):
481481
x = convert_to_tensor(x)
482-
y = convert_to_tensor(y)
482+
if not isinstance(y, int):
483+
y = convert_to_tensor(y)
483484
return torch.bitwise_left_shift(x, y)
484485

485486

@@ -489,7 +490,8 @@ def left_shift(x, y):
489490

490491
def bitwise_right_shift(x, y):
491492
x = convert_to_tensor(x)
492-
y = convert_to_tensor(y)
493+
if not isinstance(y, int):
494+
y = convert_to_tensor(y)
493495
return torch.bitwise_right_shift(x, y)
494496

495497

keras/src/ops/numpy.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,7 +1417,10 @@ def call(self, x, y):
14171417
return backend.numpy.bitwise_left_shift(x, y)
14181418

14191419
def compute_output_spec(self, x, y):
1420-
dtype = dtypes.result_type(x.dtype, y.dtype)
1420+
if isinstance(y, int):
1421+
dtype = x.dtype
1422+
else:
1423+
dtype = dtypes.result_type(x.dtype, y.dtype)
14211424
return KerasTensor(x.shape, dtype=dtype)
14221425

14231426

@@ -1451,7 +1454,10 @@ def call(self, x, y):
14511454
return backend.numpy.left_shift(x, y)
14521455

14531456
def compute_output_spec(self, x, y):
1454-
dtype = dtypes.result_type(x.dtype, y.dtype)
1457+
if isinstance(y, int):
1458+
dtype = x.dtype
1459+
else:
1460+
dtype = dtypes.result_type(x.dtype, y.dtype)
14551461
return KerasTensor(x.shape, dtype=dtype)
14561462

14571463

@@ -1483,7 +1489,10 @@ def call(self, x, y):
14831489
return backend.numpy.bitwise_right_shift(x, y)
14841490

14851491
def compute_output_spec(self, x, y):
1486-
dtype = dtypes.result_type(x.dtype, y.dtype)
1492+
if isinstance(y, int):
1493+
dtype = x.dtype
1494+
else:
1495+
dtype = dtypes.result_type(x.dtype, y.dtype)
14871496
return KerasTensor(x.shape, dtype=dtype)
14881497

14891498

@@ -1517,7 +1526,10 @@ def call(self, x, y):
15171526
return backend.numpy.right_shift(x, y)
15181527

15191528
def compute_output_spec(self, x, y):
1520-
dtype = dtypes.result_type(x.dtype, y.dtype)
1529+
if isinstance(y, int):
1530+
dtype = x.dtype
1531+
else:
1532+
dtype = dtypes.result_type(x.dtype, y.dtype)
15211533
return KerasTensor(x.shape, dtype=dtype)
15221534

15231535

keras/src/ops/numpy_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6229,16 +6229,16 @@ def test_bitwise_xor(self, dtypes):
62296229
self.assertDType(knp.BitwiseXor().symbolic_call(x1, x2), expected_dtype)
62306230

62316231
@parameterized.named_parameters(
6232-
named_product(dtypes=itertools.combinations(INT_DTYPES, 2))
6232+
named_product(dtypes=itertools.product(INT_DTYPES, INT_DTYPES + [None]))
62336233
)
62346234
def test_bitwise_left_shift(self, dtypes):
62356235
import jax.numpy as jnp
62366236

62376237
dtype1, dtype2 = dtypes
62386238
x1 = knp.ones((1,), dtype=dtype1)
6239-
x2 = knp.ones((1,), dtype=dtype2)
6239+
x2 = knp.ones((1,), dtype=dtype2) if dtype2 else 1
62406240
x1_jax = jnp.ones((1,), dtype=dtype1)
6241-
x2_jax = jnp.ones((1,), dtype=dtype2)
6241+
x2_jax = jnp.ones((1,), dtype=dtype2) if dtype2 else 1
62426242
expected_dtype = standardize_dtype(jnp.left_shift(x1_jax, x2_jax).dtype)
62436243

62446244
self.assertDType(knp.bitwise_left_shift(x1, x2), expected_dtype)
@@ -6249,16 +6249,16 @@ def test_bitwise_left_shift(self, dtypes):
62496249
# left_shift is same as bitwise_left_shift
62506250

62516251
@parameterized.named_parameters(
6252-
named_product(dtypes=itertools.combinations(INT_DTYPES, 2))
6252+
named_product(dtypes=itertools.product(INT_DTYPES, INT_DTYPES + [None]))
62536253
)
62546254
def test_bitwise_right_shift(self, dtypes):
62556255
import jax.numpy as jnp
62566256

62576257
dtype1, dtype2 = dtypes
62586258
x1 = knp.ones((1,), dtype=dtype1)
6259-
x2 = knp.ones((1,), dtype=dtype2)
6259+
x2 = knp.ones((1,), dtype=dtype2) if dtype2 else 1
62606260
x1_jax = jnp.ones((1,), dtype=dtype1)
6261-
x2_jax = jnp.ones((1,), dtype=dtype2)
6261+
x2_jax = jnp.ones((1,), dtype=dtype2) if dtype2 else 1
62626262
expected_dtype = standardize_dtype(
62636263
jnp.right_shift(x1_jax, x2_jax).dtype
62646264
)

0 commit comments

Comments
 (0)