Skip to content

Commit dd6cd9f

Browse files
Conchylicultorcopybara-github
authored andcommitted
Fix imagenet corrupted with S3
PiperOrigin-RevId: 258476639
1 parent 3bc311e commit dd6cd9f

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

tensorflow_datasets/image/imagenet2012_corrupted.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(self, corruption_type=None, severity=1, **kwargs):
8484
experiments={tfds.core.Experiment.S3: False})
8585
_SUPPORTED_VERSIONS = [
8686
# Will be made canonical in near future.
87-
tfds.core.Version('3.0.0'),
87+
tfds.core.Version('3.0.1'),
8888
]
8989
# Version history:
9090
# 3.0.0: Fix colorization (all RGB) and format (all jpeg); use TAR_STREAM.
@@ -176,13 +176,21 @@ def _generate_examples_validation(self, archive, labels):
176176
logging.warning('Overwriting cv2 RNG seed.')
177177
tfds.core.lazy_imports.cv2.setRNGSeed(357)
178178

179-
for example in super(Imagenet2012Corrupted,
180-
self)._generate_examples_validation(archive, labels):
179+
gen_fn = super(Imagenet2012Corrupted, self)._generate_examples_validation
180+
for example in gen_fn(archive, labels):
181+
182+
if self.version.implements(tfds.core.Experiment.S3):
183+
key, example = example # Unpack S3 key
184+
181185
with tf.Graph().as_default():
182186
tf_img = tf.image.decode_jpeg(example['image'].read(), channels=3)
183187
image_np = tfds.as_numpy(tf_img)
184188
example['image'] = self._get_corrupted_example(image_np)
185-
yield example
189+
190+
if self.version.implements(tfds.core.Experiment.S3):
191+
yield key, example
192+
else:
193+
yield example
186194
# Reset the seeds back to their original values.
187195
np.random.set_state(numpy_st0)
188196

0 commit comments

Comments
 (0)