Skip to content

Commit f85e044

Browse files
authored
Add support for multiple dynamic dimensions to Flatten layer on TF. (#21399)
With TensorFlow, when both the batch size and some other dimensions are dynamic, we need to use the dynamic batch size that comes from `ops.shape` instead of None / -1, otherwise there are two Nones / -1s passed to `ops.reshape`. Fixes #21380
1 parent e4bca84 commit f85e044

File tree

2 files changed

+23
-8
lines changed

2 files changed

+23
-8
lines changed

keras/src/layers/reshaping/flatten.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,22 @@ def __init__(self, data_format=None, **kwargs):
4040
self._channels_first = self.data_format == "channels_first"
4141

4242
def call(self, inputs):
43-
input_shape = inputs.shape
43+
input_shape = ops.shape(inputs)
4444
rank = len(input_shape)
4545

4646
if self._channels_first and rank > 1:
4747
# Switch to channels-last format.
4848
inputs = ops.transpose(inputs, axes=(0, *range(2, rank), 1))
4949

50-
output_shape = tuple(
51-
dim if dim is not None else -1
52-
for dim in self.compute_output_shape(input_shape)
53-
)
54-
return ops.reshape(inputs, output_shape)
50+
non_batch_dims = input_shape[1:]
51+
if len(non_batch_dims) == 0:
52+
flattened_dim = 1
53+
elif any(not isinstance(d, int) for d in non_batch_dims):
54+
flattened_dim = -1
55+
else:
56+
flattened_dim = math.prod(non_batch_dims)
57+
58+
return ops.reshape(inputs, (input_shape[0], flattened_dim))
5559

5660
def compute_output_shape(self, input_shape):
5761
non_batch_dims = input_shape[1:]

keras/src/layers/reshaping/flatten_test.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
import pytest
33
from absl.testing import parameterized
44

5+
from conftest import skip_if_backend
56
from keras.src import backend
67
from keras.src import layers
8+
from keras.src import models
79
from keras.src import ops
810
from keras.src import testing
911

@@ -112,12 +114,21 @@ def test_flatten_with_scalar_channels(self):
112114
expected_output=expected_output,
113115
)
114116

115-
def test_flatten_with_dynamic_batch_size(self):
117+
def test_flatten_symbolic_with_dynamic_batch_size(self):
116118
input_layer = layers.Input(batch_shape=(None, 2, 3))
117119
flattened = layers.Flatten()(input_layer)
118120
self.assertEqual(flattened.shape, (None, 2 * 3))
119121

120-
def test_flatten_with_dynamic_dimension(self):
122+
def test_flatten_symbolic_with_dynamic_dimension(self):
121123
input_layer = layers.Input(batch_shape=(5, 2, None))
122124
flattened = layers.Flatten()(input_layer)
123125
self.assertEqual(flattened.shape, (5, None))
126+
127+
@skip_if_backend("openvino", "Dynamic dimensions not supported by OpenVino")
128+
def test_flatten_with_dynamic_batch_size_and_dynamic_dimenstions(self):
129+
def generator():
130+
yield (np.ones((3, 5, 7), dtype="float32"),)
131+
yield (np.ones((2, 7, 5), dtype="float32"),)
132+
133+
model = models.Sequential([layers.Flatten()])
134+
model.predict(generator())

0 commit comments

Comments
 (0)