Skip to content

Commit 8a6e83b

Browse files
authored
Fix Embedding.compute_output_spec with a non-KerasTensor input. (#21192)
The `ragged` attribute exists only with `KerasTensor`s. Minor fix of a unit tests that was using the same local variable for two nested loops.
1 parent 72cc27f commit 8a6e83b

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

keras/src/layers/core/embedding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,9 @@ def compute_output_shape(self, input_shape):
159159

160160
def compute_output_spec(self, inputs):
161161
output_shape = (*inputs.shape, self.output_dim)
162+
ragged = getattr(inputs, "ragged", False)
162163
return KerasTensor(
163-
output_shape, dtype=self.compute_dtype, ragged=inputs.ragged
164+
output_shape, dtype=self.compute_dtype, ragged=ragged
164165
)
165166

166167
def enable_lora(

keras/src/trainers/data_adapters/generator_data_adapter_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ def test_basic_flow(self, use_sample_weight, generator_type):
100100
self.assertEqual(by.shape, (2, 2))
101101
if use_sample_weight:
102102
self.assertIsInstance(bsw, expected_class)
103-
for i in range(by.shape[0]):
104-
sample_order.append(by[i, 0])
103+
for j in range(by.shape[0]):
104+
sample_order.append(by[j, 0])
105105
self.assertAllClose(sample_order, list(range(34)))
106106

107107
def test_with_different_shapes(self):

0 commit comments

Comments
 (0)