Skip to content

Commit 1b5735d

Browse files
akolesnikoffcopybara-github
authored andcommitted
Add DMLab dataset to TFDS.
PiperOrigin-RevId: 279786128
1 parent 100ef56 commit 1b5735d

File tree

8 files changed

+182
-0
lines changed

8 files changed

+182
-0
lines changed

docs/release_notes.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@
1313
([guide](https://github.com/tensorflow/datasets/tree/master/docs/decode.md)).
1414
* Add `duke_ultrasound` dataset of ultrasound phantoms and invivo liver images
1515
from the [MimickNet paper](https://arxiv.org/abs/1908.05782)
16+
* Add Dmlab dataset from the
17+
[VTAB benchmark](https://arxiv.org/abs/1910.04867).

tensorflow_datasets/image/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from tensorflow_datasets.image.cycle_gan import CycleGAN
4343
from tensorflow_datasets.image.deep_weeds import DeepWeeds
4444
from tensorflow_datasets.image.diabetic_retinopathy_detection import DiabeticRetinopathyDetection
45+
from tensorflow_datasets.image.dmlab import Dmlab
4546
from tensorflow_datasets.image.downsampled_imagenet import DownsampledImagenet
4647
from tensorflow_datasets.image.dsprites import Dsprites
4748
from tensorflow_datasets.image.dtd import Dtd

tensorflow_datasets/image/dmlab.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# coding=utf-8
2+
# Copyright 2019 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Dmlab dataset."""
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import io
22+
23+
import os
24+
from absl import logging
25+
import tensorflow as tf
26+
27+
import tensorflow_datasets.public_api as tfds
28+
29+
_URL = "https://storage.googleapis.com/akolesnikov-dmlab-tfds/dmlab.tar.gz"
30+
31+
32+
class Dmlab(tfds.core.GeneratorBasedBuilder):
33+
"""Dmlab dataset."""
34+
35+
VERSION = tfds.core.Version("1.0.0")
36+
37+
def _info(self):
38+
return tfds.core.DatasetInfo(
39+
builder=self,
40+
description=(r"""
41+
The Dmlab dataset contains frames observed by the agent acting in the
42+
DeepMind Lab environment, which are annotated by the distance between
43+
the agent and various objects present in the environment. The goal is to
44+
is to evaluate the ability of a visual model to reason about distances
45+
from the visual input in 3D environments. The Dmlab dataset consists of
46+
360x480 color images in 6 classes. The classes are
47+
{close, far, very far} x {positive reward, negative reward}
48+
respectively."""),
49+
features=tfds.features.FeaturesDict({
50+
"image": tfds.features.Image(shape=(360, 480, 3),
51+
encoding_format="jpeg"),
52+
"filename": tfds.features.Text(),
53+
"label": tfds.features.ClassLabel(num_classes=6),
54+
}),
55+
homepage="https://github.com/google-research/task_adaptation",
56+
citation=r"""@article{zhai2019visual,
57+
title={The Visual Task Adaptation Benchmark},
58+
author={Xiaohua Zhai and Joan Puigcerver and Alexander Kolesnikov and
59+
Pierre Ruyssen and Carlos Riquelme and Mario Lucic and
60+
Josip Djolonga and Andre Susano Pinto and Maxim Neumann and
61+
Alexey Dosovitskiy and Lucas Beyer and Olivier Bachem and
62+
Michael Tschannen and Marcin Michalski and Olivier Bousquet and
63+
Sylvain Gelly and Neil Houlsby},
64+
year={2019},
65+
eprint={1910.04867},
66+
archivePrefix={arXiv},
67+
primaryClass={cs.CV},
68+
url = {https://arxiv.org/abs/1910.04867}
69+
}""",
70+
supervised_keys=("image", "label")
71+
)
72+
73+
def _split_generators(self, dl_manager):
74+
path = dl_manager.download_and_extract(_URL)
75+
76+
return [
77+
tfds.core.SplitGenerator(
78+
name=tfds.Split.TRAIN,
79+
gen_kwargs={
80+
"images_dir_path": path,
81+
"split_name": "train",
82+
}),
83+
tfds.core.SplitGenerator(
84+
name=tfds.Split.VALIDATION,
85+
gen_kwargs={
86+
"images_dir_path": path,
87+
"split_name": "validation",
88+
}),
89+
tfds.core.SplitGenerator(
90+
name=tfds.Split.TEST,
91+
gen_kwargs={
92+
"images_dir_path": path,
93+
"split_name": "test",
94+
}),
95+
]
96+
97+
def _parse_single_image(self, example_proto):
98+
"""Parses single video from the input tfrecords.
99+
100+
Args:
101+
example_proto: tfExample proto with a single video.
102+
103+
Returns:
104+
dict with all frames, positions and actions.
105+
"""
106+
107+
feature_map = {
108+
"image": tf.io.FixedLenFeature(shape=[], dtype=tf.string),
109+
"filename": tf.io.FixedLenFeature(shape=[], dtype=tf.string),
110+
"label": tf.io.FixedLenFeature(shape=[], dtype=tf.int64),
111+
}
112+
113+
parse_single = tf.io.parse_single_example(example_proto, feature_map)
114+
115+
return parse_single
116+
117+
def _generate_examples(self, images_dir_path, split_name):
118+
path_glob = os.path.join(images_dir_path,
119+
"dmlab-{}.tfrecord*".format(split_name))
120+
files = tf.io.gfile.glob(path_glob)
121+
122+
logging.info("Reading data from %s.", ",".join(files))
123+
with tf.Graph().as_default():
124+
ds = tf.data.TFRecordDataset(files)
125+
ds = ds.map(
126+
self._parse_single_image,
127+
num_parallel_calls=tf.data.experimental.AUTOTUNE)
128+
iterator = tf.compat.v1.data.make_one_shot_iterator(ds).get_next()
129+
with tf.compat.v1.Session() as sess:
130+
sess.run(tf.compat.v1.global_variables_initializer())
131+
try:
132+
while True:
133+
result = sess.run(iterator)
134+
yield result["filename"], {
135+
"image": io.BytesIO(result["image"]),
136+
"filename": result["filename"],
137+
"label": result["label"],
138+
}
139+
140+
except tf.errors.OutOfRangeError:
141+
return
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# coding=utf-8
2+
# Copyright 2019 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for DMlab dataset."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
from tensorflow_datasets import testing
23+
from tensorflow_datasets.image import dmlab
24+
25+
26+
class DmlabDatasetTest(testing.DatasetBuilderTestCase):
27+
DATASET_CLASS = dmlab.Dmlab
28+
29+
SPLITS = {
30+
"train": 2,
31+
"test": 2,
32+
"validation": 2,
33+
}
34+
35+
36+
if __name__ == "__main__":
37+
testing.test_main()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
https://storage.googleapis.com/akolesnikov-dmlab-tfds/dmlab.tar.gz 3017022789 638b18fa69a5d61bbc310cd0b87ac603a39f41de0fdb07b6e77de274a24480a4

0 commit comments

Comments
 (0)