Skip to content

Commit 2c3ca88

Browse files
Added Downsmpled_ImageNet .py files
1 parent 8175693 commit 2c3ca88

File tree

3 files changed

+180
-0
lines changed

3 files changed

+180
-0
lines changed

tensorflow_datasets/image/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from tensorflow_datasets.image.colorectal_histology import ColorectalHistologyLarge
3131
from tensorflow_datasets.image.cycle_gan import CycleGAN
3232
from tensorflow_datasets.image.diabetic_retinopathy_detection import DiabeticRetinopathyDetection
33+
from tensorflow_datasets.image.downsampled_imagenet import DownsampledImagenet
3334
from tensorflow_datasets.image.dsprites import Dsprites
3435
from tensorflow_datasets.image.dtd import Dtd
3536
from tensorflow_datasets.image.flowers import TFFlowers
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# coding=utf-8
2+
# Copyright 2019 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Downsampled Imagenet dataset."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import os
23+
24+
import tensorflow as tf
25+
26+
from tensorflow_datasets.core import api_utils
27+
import tensorflow_datasets.public_api as tfds
28+
29+
# From https://arxiv.org/abs/1601.06759
30+
_CITATION = """\
31+
@article{DBLP:journals/corr/OordKK16,
32+
author = {A{\"{a}}ron van den Oord and
33+
Nal Kalchbrenner and
34+
Koray Kavukcuoglu},
35+
title = {Pixel Recurrent Neural Networks},
36+
journal = {CoRR},
37+
volume = {abs/1601.06759},
38+
year = {2016},
39+
url = {http://arxiv.org/abs/1601.06759},
40+
archivePrefix = {arXiv},
41+
eprint = {1601.06759},
42+
timestamp = {Mon, 13 Aug 2018 16:46:29 +0200},
43+
biburl = {https://dblp.org/rec/bib/journals/corr/OordKK16},
44+
bibsource = {dblp computer science bibliography, https://dblp.org}
45+
}
46+
"""
47+
48+
_DESCRIPTION = """\
49+
Dataset with images of 2 resolutions (see config name for information on the resolution).
50+
It is used for density estimation and generative modeling experiments.
51+
"""
52+
53+
_DL_URL = "http://image-net.org/small/"
54+
55+
_DATA_OPTIONS = ["32x32", "64x64"]
56+
57+
_DL_URLS_TRAIN = {name: _DL_URL + "train_" + name + ".tar" for name in _DATA_OPTIONS}
58+
_DL_URLS_VALIDATION = {name: _DL_URL + "valid_" + name + ".tar" for name in _DATA_OPTIONS}
59+
60+
61+
class DownsampledImagenetConfig(tfds.core.BuilderConfig):
62+
"""BuilderConfig for Downsampled Imagenet."""
63+
64+
@api_utils.disallow_positional_args
65+
def __init__(self, data=None, **kwargs):
66+
"""Constructs a DownsampledImagenetConfig.
67+
68+
Args:
69+
data: `str`, one of `_DATA_OPTIONS`.
70+
**kwargs: keyword arguments forwarded to super.
71+
"""
72+
if data not in _DATA_OPTIONS:
73+
raise ValueError("data must be one of %s" % _DATA_OPTIONS)
74+
75+
super(DownsampledImagenetConfig, self).__init__(**kwargs)
76+
self.data=data
77+
78+
79+
class DownsampledImagenet(tfds.core.GeneratorBasedBuilder):
80+
"""Downsampled Imagenet dataset."""
81+
82+
BUILDER_CONFIGS = [
83+
DownsampledImagenetConfig( # pylint: disable=g-complex-comprehension
84+
name=config_name,
85+
description=("A dataset consisting of Train and Validation images of " + config_name + " resolution."),
86+
version="0.1.0",
87+
data=config_name,
88+
) for config_name in _DATA_OPTIONS
89+
]
90+
91+
def _info(self):
92+
return tfds.core.DatasetInfo(
93+
builder=self,
94+
description=_DESCRIPTION,
95+
features=tfds.features.FeaturesDict({
96+
"image": tfds.features.Image(),
97+
}),
98+
supervised_keys=None,
99+
urls=[
100+
"http://image-net.org/small/download.php"
101+
],
102+
)
103+
104+
def _split_generators(self, dl_manager):
105+
"""Returns SplitGenerators."""
106+
107+
train_url = _DL_URLS_TRAIN[self.builder_config.name]
108+
valid_url = _DL_URLS_VALIDATION[self.builder_config.name]
109+
110+
extracted_paths = dl_manager.download_and_extract({
111+
"train_images": train_url,
112+
"valid_images": valid_url,
113+
})
114+
115+
return [
116+
tfds.core.SplitGenerator(
117+
name=tfds.Split.TRAIN,
118+
num_shards=10,
119+
gen_kwargs={
120+
"path": os.path.join(extracted_paths["train_images"], "train_"+self.builder_config.name),
121+
}),
122+
tfds.core.SplitGenerator(
123+
name=tfds.Split.VALIDATION,
124+
num_shards=1,
125+
gen_kwargs={
126+
"path": os.path.join(extracted_paths["valid_images"], "valid_"+self.builder_config.name),
127+
}),
128+
]
129+
130+
def _generate_examples(self, path):
131+
images = tf.io.gfile.listdir(path)
132+
133+
for image in images:
134+
yield {
135+
"image": os.path.join(path, image),
136+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# coding=utf-8
2+
# Copyright 2019 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for downsampled_imagenet dataset module."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
from tensorflow_datasets import testing
23+
from tensorflow_datasets.image import downsampled_imagenet
24+
25+
import tensorflow_datasets as tfds
26+
27+
28+
class DownsampledImagenetTest(testing.DatasetBuilderTestCase):
29+
DATASET_CLASS = downsampled_imagenet.DownsampledImagenet
30+
BUILDER_CONFIG_NAMES_TO_TEST = ["32x32", "64x64"]
31+
32+
SPLITS = {
33+
tfds.Split.TRAIN: 5,
34+
tfds.Split.VALIDATION: 2,
35+
}
36+
37+
DL_EXTRACT_RESULT = {
38+
"train_images": "train_images",
39+
"valid_images": "valid_images",
40+
}
41+
42+
if __name__ == "__main__":
43+
testing.test_main()

0 commit comments

Comments
 (0)