@@ -84,7 +84,7 @@ def __init__(self, corruption_type=None, severity=1, **kwargs):
84
84
experiments = {tfds .core .Experiment .S3 : False })
85
85
_SUPPORTED_VERSIONS = [
86
86
# Will be made canonical in near future.
87
- tfds .core .Version ('3.0.0 ' ),
87
+ tfds .core .Version ('3.0.1 ' ),
88
88
]
89
89
# Version history:
90
90
# 3.0.0: Fix colorization (all RGB) and format (all jpeg); use TAR_STREAM.
@@ -176,13 +176,21 @@ def _generate_examples_validation(self, archive, labels):
176
176
logging .warning ('Overwriting cv2 RNG seed.' )
177
177
tfds .core .lazy_imports .cv2 .setRNGSeed (357 )
178
178
179
- for example in super (Imagenet2012Corrupted ,
180
- self )._generate_examples_validation (archive , labels ):
179
+ gen_fn = super (Imagenet2012Corrupted , self )._generate_examples_validation
180
+ for example in gen_fn (archive , labels ):
181
+
182
+ if self .version .implements (tfds .core .Experiment .S3 ):
183
+ key , example = example # Unpack S3 key
184
+
181
185
with tf .Graph ().as_default ():
182
186
tf_img = tf .image .decode_jpeg (example ['image' ].read (), channels = 3 )
183
187
image_np = tfds .as_numpy (tf_img )
184
188
example ['image' ] = self ._get_corrupted_example (image_np )
185
- yield example
189
+
190
+ if self .version .implements (tfds .core .Experiment .S3 ):
191
+ yield key , example
192
+ else :
193
+ yield example
186
194
# Reset the seeds back to their original values.
187
195
np .random .set_state (numpy_st0 )
188
196
0 commit comments