Skip to content

Commit 785c9b0

Browse files
Updated random_grayscale.py compute_output_spec (#21312)
* Update nn.py * Update nn.py * Update nn.py * Update nn.py * Update nn.py Corrected indentation in doc string * Update nn.py * Update random_grayscale.py Fixed issue with passing a single image without batch dimension. * Update keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py Co-authored-by: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> * Update random_grayscale_test.py Test case for unbatched inputs * code reformat * Update random_grayscale_test.py Testcase for checking both unbatched and batched single image inputs. * changed compute_output_spec There was a bug, and it was causing cycle in graph. --------- Co-authored-by: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com>
1 parent 7d7a6bb commit 785c9b0

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from keras.src import backend
2+
from keras.src import tree
23
from keras.src.api_export import keras_export
34
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
45
BaseImagePreprocessingLayer,
@@ -96,7 +97,12 @@ def compute_output_shape(self, input_shape):
9697
return input_shape
9798

9899
def compute_output_spec(self, inputs, **kwargs):
99-
return inputs
100+
return tree.map_structure(
101+
lambda x: backend.KerasTensor(
102+
x.shape, dtype=x.dtype, sparse=x.sparse
103+
),
104+
inputs,
105+
)
100106

101107
def transform_bounding_boxes(self, bounding_boxes, **kwargs):
102108
return bounding_boxes

0 commit comments

Comments
 (0)