Skip to content

Commit 32aa040

Browse files
obachemcopybara-github
authored andcommitted
Adds 'shapes3d' data set.
PiperOrigin-RevId: 240356128
1 parent 51c55d9 commit 32aa040

File tree

6 files changed

+293
-0
lines changed

6 files changed

+293
-0
lines changed

tensorflow_datasets/image/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,6 @@
4141
from tensorflow_datasets.image.open_images import OpenImagesV4
4242
from tensorflow_datasets.image.quickdraw import QuickdrawBitmap
4343
from tensorflow_datasets.image.rock_paper_scissors import RockPaperScissors
44+
from tensorflow_datasets.image.shapes3d import Shapes3d
4445
from tensorflow_datasets.image.svhn import SvhnCropped
4546
from tensorflow_datasets.image.voc import Voc2007

tensorflow_datasets/image/shapes3d.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# coding=utf-8
2+
# Copyright 2019 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Shapes3D dataset."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import tempfile
23+
24+
import h5py
25+
import numpy as np
26+
from six import moves
27+
import tensorflow as tf
28+
29+
import tensorflow_datasets.public_api as tfds
30+
31+
_CITATION = """\
32+
@misc{3dshapes18,
33+
title={3D Shapes Dataset},
34+
author={Burgess, Chris and Kim, Hyunjik},
35+
howpublished={https://github.com/deepmind/3dshapes-dataset/},
36+
year={2018}
37+
}
38+
"""
39+
40+
_URL = ("https://storage.googleapis.com/3d-shapes/3dshapes.h5")
41+
42+
_DESCRIPTION = """\
43+
3dshapes is a dataset of 3D shapes procedurally generated from 6 ground truth
44+
independent latent factors. These factors are *floor colour*, *wall colour*, *object colour*,
45+
*scale*, *shape* and *orientation*.
46+
47+
All possible combinations of these latents are present exactly once, generating N = 480000 total images.
48+
49+
### Latent factor values
50+
51+
* floor hue: 10 values linearly spaced in [0, 1]
52+
* wall hue: 10 values linearly spaced in [0, 1]
53+
* object hue: 10 values linearly spaced in [0, 1]
54+
* scale: 8 values linearly spaced in [0, 1]
55+
* shape: 4 values in [0, 1, 2, 3]
56+
* orientation: 15 values linearly spaced in [-30, 30]
57+
58+
We varied one latent at a time (starting from orientation, then shape, etc), and sequentially stored the images in fixed order in the `images` array. The corresponding values of the factors are stored in the same order in the `labels` array.
59+
"""
60+
61+
62+
class Shapes3d(tfds.core.GeneratorBasedBuilder):
63+
"""Shapes3d data set."""
64+
65+
VERSION = tfds.core.Version("0.1.0")
66+
67+
def _info(self):
68+
return tfds.core.DatasetInfo(
69+
builder=self,
70+
description=_DESCRIPTION,
71+
features=tfds.features.FeaturesDict({
72+
"image":
73+
tfds.features.Image(shape=(64, 64, 3)),
74+
"label_floor_hue":
75+
tfds.features.ClassLabel(num_classes=10),
76+
"label_wall_hue":
77+
tfds.features.ClassLabel(num_classes=10),
78+
"label_object_hue":
79+
tfds.features.ClassLabel(num_classes=10),
80+
"label_scale":
81+
tfds.features.ClassLabel(num_classes=8),
82+
"label_shape":
83+
tfds.features.ClassLabel(num_classes=4),
84+
"label_orientation":
85+
tfds.features.ClassLabel(num_classes=15),
86+
"value_floor_hue":
87+
tfds.features.Tensor(shape=[], dtype=tf.float32),
88+
"value_wall_hue":
89+
tfds.features.Tensor(shape=[], dtype=tf.float32),
90+
"value_object_hue":
91+
tfds.features.Tensor(shape=[], dtype=tf.float32),
92+
"value_scale":
93+
tfds.features.Tensor(shape=[], dtype=tf.float32),
94+
"value_shape":
95+
tfds.features.Tensor(shape=[], dtype=tf.float32),
96+
"value_orientation":
97+
tfds.features.Tensor(shape=[], dtype=tf.float32),
98+
}),
99+
urls=["https://github.com/deepmind/3d-shapes"],
100+
citation=_CITATION,
101+
)
102+
103+
def _split_generators(self, dl_manager):
104+
filepath = dl_manager.download(_URL)
105+
106+
# There is no predefined train/val/test split for this dataset.
107+
return [
108+
tfds.core.SplitGenerator(
109+
name=tfds.Split.TRAIN,
110+
num_shards=1,
111+
gen_kwargs=dict(filepath=filepath)),
112+
]
113+
114+
def _generate_examples(self, filepath):
115+
"""Generate examples for the Shapes3d dataset.
116+
117+
Args:
118+
filepath: path to the Shapes3d hdf5 file.
119+
120+
Yields:
121+
Dictionaries with images and the different labels.
122+
"""
123+
# Simultaneously iterating through the different data sets in the hdf5
124+
# file will be slow with a single file. Instead, we first load everything
125+
# into memory before yielding the samples.
126+
image_array, values_array = _load_data(filepath)
127+
128+
# We need to calculate the class labels from the float values in the file.
129+
labels_array = np.zeros_like(values_array, dtype=np.int64)
130+
for i in range(values_array.shape[1]):
131+
labels_array[:, i] = _discretize(values_array[:, i])
132+
133+
for image, labels, values in moves.zip(image_array, labels_array,
134+
values_array):
135+
yield {
136+
"image": image,
137+
"label_floor_hue": labels[0],
138+
"label_wall_hue": labels[1],
139+
"label_object_hue": labels[2],
140+
"label_scale": labels[3],
141+
"label_shape": labels[4],
142+
"label_orientation": labels[5],
143+
"value_floor_hue": values[0],
144+
"value_wall_hue": values[1],
145+
"value_object_hue": values[2],
146+
"value_scale": values[3],
147+
"value_shape": values[4],
148+
"value_orientation": values[5],
149+
}
150+
151+
152+
def _load_data(filepath):
153+
"""Loads the images and latent values into Numpy arrays."""
154+
with h5py.File(filepath, "r") as h5dataset:
155+
image_array = np.array(h5dataset["images"])
156+
# The 'label' data set in the hdf5 file actually contains the float values
157+
# and not the class labels.
158+
values_array = np.array(h5dataset["labels"])
159+
return image_array, values_array
160+
161+
162+
163+
164+
def _discretize(a):
165+
"""Discretizes array values to class labels."""
166+
arr = np.asarray(a)
167+
index = np.argsort(arr)
168+
inverse_index = np.zeros(arr.size, dtype=np.intp)
169+
inverse_index[index] = np.arange(arr.size, dtype=np.intp)
170+
arr = arr[index]
171+
obs = np.r_[True, arr[1:] != arr[:-1]]
172+
return obs.cumsum()[inverse_index] - 1
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# coding=utf-8
2+
# Copyright 2019 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from tensorflow_datasets.image import shapes3d
17+
import tensorflow_datasets.testing as tfds_test
18+
19+
20+
class Shapes3dTest(tfds_test.DatasetBuilderTestCase):
21+
DATASET_CLASS = shapes3d.Shapes3d
22+
SPLITS = {"train": 5}
23+
DL_EXTRACT_RESULT = "3dshapes.h5"
24+
25+
26+
if __name__ == "__main__":
27+
tfds_test.test_main()
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# coding=utf-8
2+
# Copyright 2019 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
r"""Generate Shapes3d-like files, smaller and with random data.
17+
18+
"""
19+
20+
from __future__ import absolute_import
21+
from __future__ import division
22+
from __future__ import print_function
23+
24+
import os
25+
26+
from absl import app
27+
from absl import flags
28+
import h5py
29+
import numpy as np
30+
31+
from tensorflow_datasets.core.utils import py_utils
32+
from tensorflow_datasets.testing import test_utils
33+
34+
NUM_IMAGES = 5
35+
FACTOR_VALUES = [[0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
36+
[0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
37+
[0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
38+
[
39+
0.75, 0.82142857, 0.89285714, 0.96428571, 1.03571429,
40+
1.10714286, 1.17857143, 1.25
41+
], [0., 1., 2., 3.],
42+
[
43+
-30., -25.71428571, -21.42857143, -17.14285714,
44+
-12.85714286, -8.57142857, -4.28571429, 0., 4.28571429,
45+
8.57142857, 12.85714286, 17.14285714, 21.42857143,
46+
25.71428571, 30.
47+
]]
48+
OUTPUT_NAME = "3dshapes.h5"
49+
50+
flags.DEFINE_string("tfds_dir", py_utils.tfds_dir(),
51+
"Path to tensorflow_datasets directory")
52+
FLAGS = flags.FLAGS
53+
54+
55+
def _create_fake_samples():
56+
"""Creates a fake set of samples.
57+
58+
Returns:
59+
Tuple with fake images and fake latent values.
60+
"""
61+
rs = np.random.RandomState(0)
62+
images = rs.randint(256, size=(NUM_IMAGES, 64, 64, 3)).astype("uint8")
63+
values = []
64+
for factor_values in FACTOR_VALUES:
65+
values.append(rs.choice(factor_values, size=(NUM_IMAGES)))
66+
67+
return images, np.transpose(values)
68+
69+
70+
def _generate():
71+
"""Generates a fake data set and writes it to the fake_examples directory."""
72+
output_dir = os.path.join(FLAGS.tfds_dir, "testing", "test_data",
73+
"fake_examples", "shapes3d")
74+
test_utils.remake_dir(output_dir)
75+
76+
images, values = _create_fake_samples()
77+
78+
with h5py.File(os.path.join(output_dir, OUTPUT_NAME), "w") as f:
79+
img_dataset = f.create_dataset("images", images.shape, "|u1")
80+
img_dataset.write_direct(images)
81+
values_dataset = f.create_dataset("labels", values.shape, "<f8")
82+
values_dataset.write_direct(np.ascontiguousarray(values))
83+
84+
85+
def main(argv):
86+
if len(argv) > 1:
87+
raise app.UsageError("Too many command-line arguments.")
88+
_generate()
89+
90+
91+
if __name__ == "__main__":
92+
app.run(main)
Binary file not shown.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
https://storage.googleapis.com/3d-shapes/3dshapes.h5 267573662 0a0f6ed98baff276a50f3a081a7434d788da63cb135a98189b2a5b5769be1785

0 commit comments

Comments
 (0)