Skip to content

Commit 744b8be

Browse files
authored
Fix symbolic call of logsumexp with int axis. (#21428)
Using `keras.ops.math.logsumexp` with an int for `axis` in a functional model would throw an error.
1 parent df58ec9 commit 744b8be

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

keras/src/ops/math_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,10 @@ def test_in_top_k(self):
179179

180180
def test_logsumexp(self):
181181
x = KerasTensor((None, 2, 3), dtype="float32")
182-
result = kmath.logsumexp(x)
183-
self.assertEqual(result.shape, ())
182+
self.assertEqual(kmath.logsumexp(x).shape, ())
183+
self.assertEqual(kmath.logsumexp(x, axis=1).shape, (None, 3))
184+
self.assertEqual(kmath.logsumexp(x, axis=(1, 2)).shape, (None,))
185+
self.assertEqual(kmath.logsumexp(x, keepdims=True).shape, (1, 1, 1))
184186

185187
def test_extract_sequences(self):
186188
# Defined dimension

keras/src/ops/operation_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,8 @@ def reduce_shape(shape, axis=None, keepdims=False):
375375
return tuple([1 for _ in shape])
376376
else:
377377
return tuple([])
378+
elif isinstance(axis, int):
379+
axis = (axis,)
378380

379381
if keepdims:
380382
for ax in axis:

0 commit comments

Comments
 (0)