Skip to content

Commit 38979b3

Browse files
update div2k main, test data
1 parent 650d6d8 commit 38979b3

File tree

46 files changed

+83
-18
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+83
-18
lines changed

tensorflow_datasets/image/div2k.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -64,20 +64,16 @@ def __init__(self, data, **kwargs):
6464
if data not in _DATA_OPTIONS:
6565
raise ValueError("data must be one of %s" % _DATA_OPTIONS)
6666

67-
name = kwargs.get("name")
68-
if name is None:
69-
name = data
67+
name = kwargs.get("name", data)
7068
kwargs["name"] = name
7169

72-
description = kwargs.get("description")
73-
if description is None:
74-
description = "Uses %s data." % data
70+
description = kwargs.get("description", "Uses %s data." % data)
7571
kwargs["description"] = description
7672

7773
super(Div2kConfig, self).__init__(**kwargs)
7874
self.data = data
7975

80-
def download_urls(self):
76+
def download_urls():
8177
"""Returns train and validation download urls for this config."""
8278
urls = {
8379
"train_lr_url": _DL_URLS["train_"+self.data],
@@ -91,14 +87,15 @@ def _make_builder_configs():
9187
configs = []
9288
for data in _DATA_OPTIONS:
9389
configs.append(Div2kConfig(
94-
version=tfds.core.Version("1.0.0"),
90+
version=tfds.core.Version("2.0.0"),
9591
data=data))
9692
return configs
9793

9894
class Div2k(tfds.core.GeneratorBasedBuilder):
9995
"""DIV2K dataset: DIVerse 2K resolution high quality images"""
10096

10197
BUILDER_CONFIGS = _make_builder_configs()
98+
VERSION = tfds.core.Version("2.0.0")
10299

103100
def _info(self):
104101
return tfds.core.DatasetInfo(
@@ -108,14 +105,15 @@ def _info(self):
108105
"lr": tfds.features.Image(),
109106
"hr": tfds.features.Image(),
110107
}),
108+
#homepage=_DL_URL,
111109
citation=_CITATION,
112110
)
113111

114112
def _split_generators(self, dl_manager):
115113
"""Returns SplitGenerators."""
116114

117115
extracted_paths = dl_manager.download_and_extract(
118-
self.builder_config.download_urls())
116+
self.builder_config.download_urls)
119117

120118
return [
121119
tfds.core.SplitGenerator(
@@ -139,12 +137,12 @@ def _generate_examples(self, lr_path, hr_path):
139137
if not tf.io.gfile.listdir(hr_path)[0].endswith(".png"):
140138
hr_path = os.path.join(hr_path, tf.io.gfile.listdir(hr_path)[0])
141139

142-
for root, dirs, files in tf.io.gfile.walk(lr_path):
143-
if len(files) == 0:
144-
continue
145-
for file in files:
146-
yield root + file, {
147-
"lr": os.path.join(root, file),
148-
"hr": os.path.join(hr_path, re.search(r'\d{4}',
149-
str(file)).group(0) + ".png")
150-
}
140+
for root, _, files in tf.io.gfile.walk(lr_path):
141+
if len(files):
142+
for file in files:
143+
yield root + file, {
144+
"lr": os.path.join(root, file),
145+
"hr": os.path.join(hr_path,
146+
re.search(r'\d{4}',
147+
str(file)).group(0) + ".png")
148+
}

tensorflow_datasets/testing/div2k.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Generates DIV2K like files with random data for testing."""
2+
3+
from __future__ import absolute_import
4+
from __future__ import division
5+
from __future__ import print_function
6+
7+
import os
8+
import zipfile
9+
10+
from absl import app
11+
from absl import flags
12+
13+
import tensorflow as tf
14+
15+
from tensorflow_datasets.core.utils import py_utils
16+
from tensorflow_datasets.testing import fake_data_utils
17+
18+
flags.DEFINE_string("tfds_dir", py_utils.tfds_dir(),
19+
"Path to tensorflow_datasets directory")
20+
21+
FLAGS = flags.FLAGS
22+
23+
DATA = {
24+
"DIV2K_train_HR": "0001.png",
25+
"DIV2K_train_LR_bicubic_X2": "0001x2.png",
26+
"DIV2K_train_LR_bicubic_X3": "0001x3.png",
27+
"DIV2K_train_LR_bicubic_X4": "0001x4.png",
28+
"DIV2K_train_LR_difficult": "0001x4d.png",
29+
"DIV2K_train_LR_mild": "0001x4m.png",
30+
"DIV2K_train_LR_unknown_X2": "0001x2.png",
31+
"DIV2K_train_LR_unknown_X3": "0001x3.png",
32+
"DIV2K_train_LR_unknown_X4": "0001x4.png",
33+
"DIV2K_train_LR_wild": "0001x4w.png",
34+
"DIV2K_train_LR_x8": "0001x8.png",
35+
"DIV2K_valid_HR": "0002.png",
36+
"DIV2K_valid_LR_bicubic_X2": "0002x2.png",
37+
"DIV2K_valid_LR_bicubic_X3": "0002x3.png",
38+
"DIV2K_valid_LR_bicubic_X4": "0002x4.png",
39+
"DIV2K_valid_LR_difficult": "0002x4d.png",
40+
"DIV2K_valid_LR_mild": "0002x4m.png",
41+
"DIV2K_valid_LR_unknown_X2": "0002x2.png",
42+
"DIV2K_valid_LR_unknown_X3": "0002x3.png",
43+
"DIV2K_valid_LR_unknown_X4": "0002x4.png",
44+
"DIV2K_valid_LR_wild": "0002x4w.png",
45+
"DIV2K_valid_LR_x8": "0002x8.png",
46+
}
47+
48+
def _output_dir():
49+
"""Returns output directory."""
50+
return os.path.join(FLAGS.tfds_dir, "testing", "test_data", "fake_examples",
51+
"div2k")
52+
53+
def _generate_image(fdir, fname):
54+
dirname = os.path.join(_output_dir(), fdir)
55+
if not os.path.exists(dirname):
56+
os.makedirs(dirname)
57+
tf.io.gfile.copy(
58+
fake_data_utils.get_random_png(1, 1),
59+
os.path.join(dirname, fname),
60+
overwrite=True)
61+
62+
def main(argv):
63+
for fdir, fname in DATA.items():
64+
_generate_image(fdir, fname)
65+
66+
if __name__ == "__main__":
67+
app.run(main)

0 commit comments

Comments
 (0)