Skip to content

Commit 3b51724

Browse files
pierrot0copybara-github
authored andcommitted
Starcraft dataset: create an S3 version (non default).
PiperOrigin-RevId: 281695849
1 parent 01eb029 commit 3b51724

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

tensorflow_datasets/video/starcraft.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ def __init__(self, map_name, resolution, size_in_gb, **kwargs):
5656
super(StarcraftVideoConfig, self).__init__(
5757
version=tfds.core.Version(
5858
"0.1.2", experiments={tfds.core.Experiment.S3: False}),
59+
supported_versions=[tfds.core.Version(
60+
"1.0.0", "New split API (https://tensorflow.org/datasets/splits)")],
5961
**kwargs)
6062
self.map_name = map_name
6163
self.resolution = resolution
@@ -205,17 +207,19 @@ def _parse_single_video(self, example_proto):
205207
def _generate_examples(self, files):
206208
logging.info("Reading data from %s.", ",".join(files))
207209
with tf.Graph().as_default():
208-
ds = tf.data.TFRecordDataset(files)
210+
ds = tf.data.TFRecordDataset(sorted(files))
209211
ds = ds.map(
210212
self._parse_single_video,
211213
num_parallel_calls=tf.data.experimental.AUTOTUNE)
212214
iterator = tf.compat.v1.data.make_one_shot_iterator(ds).get_next()
213215
with tf.compat.v1.Session() as sess:
214216
sess.run(tf.compat.v1.global_variables_initializer())
215217
try:
218+
i = 0
216219
while True:
217220
video = sess.run(iterator)
218-
yield {"rgb_screen": video}
221+
yield i, {"rgb_screen": video}
222+
i += 1
219223

220224
except tf.errors.OutOfRangeError:
221225
# End of file.

tensorflow_datasets/video/starcraft_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,5 +60,9 @@ class StarcraftVideoDataset128Test(testing.DatasetBuilderTestCase):
6060
}
6161

6262

63+
class StarcraftVideoDataset128S3Test(StarcraftVideoDataset128Test):
64+
VERSION = "experimental_latest"
65+
66+
6367
if __name__ == "__main__":
6468
testing.test_main()

0 commit comments

Comments
 (0)