Skip to content

Commit 46e580c

Browse files
added div2k main and test files
1 parent bc11057 commit 46e580c

File tree

2 files changed

+311
-0
lines changed

2 files changed

+311
-0
lines changed

tensorflow_datasets/image/div2k.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
"""DIV2K dataset: DIVerse 2K resolution high quality images as used for the challenges @ NTIRE (CVPR 2017 and CVPR 2018) and @ PIRM (ECCV 2018)"""
2+
3+
from __future__ import absolute_import
4+
from __future__ import division
5+
from __future__ import print_function
6+
7+
import os.path
8+
import re
9+
10+
import tensorflow as tf
11+
import tensorflow_datasets.public_api as tfds
12+
13+
_CITATION = """@InProceedings{Ignatov_2018_ECCV_Workshops,
14+
author = {Ignatov, Andrey and Timofte, Radu and others},
15+
title = {PIRM challenge on perceptual image enhancement on smartphones: report},
16+
booktitle = {European Conference on Computer Vision (ECCV) Workshops},
17+
url = "http://www.vision.ee.ethz.ch/~timofter/publications/Agustsson-CVPRW-2017.pdf",
18+
month = {January},
19+
year = {2019}
20+
}
21+
"""
22+
23+
_DESCRIPTION = """
24+
DIV2K dataset: DIVerse 2K resolution high quality images as used for the challenges @ NTIRE (CVPR 2017 and CVPR 2018) and @ PIRM (ECCV 2018)
25+
"""
26+
27+
_DL_URL = "https://data.vision.ee.ethz.ch/cvl/DIV2K/"
28+
29+
_DL_URLS = {
30+
"train_hr": _DL_URL + "DIV2K_train_HR.zip",
31+
"valid_hr": _DL_URL + "DIV2K_valid_HR.zip",
32+
"train_bicubic_x2": _DL_URL + "DIV2K_train_LR_bicubic_X2.zip",
33+
"train_unknown_x2": _DL_URL + "DIV2K_train_LR_unknown_X2.zip",
34+
"valid_bicubic_x2": _DL_URL + "DIV2K_valid_LR_bicubic_X2.zip",
35+
"valid_unknown_x2": _DL_URL + "DIV2K_valid_LR_unknown_X2.zip",
36+
"train_bicubic_x3": _DL_URL + "DIV2K_train_LR_bicubic_X3.zip",
37+
"train_unknown_x3": _DL_URL + "DIV2K_train_LR_unknown_X3.zip",
38+
"valid_bicubic_x3": _DL_URL + "DIV2K_valid_LR_bicubic_X3.zip",
39+
"valid_unknown_x3": _DL_URL + "DIV2K_valid_LR_unknown_X3.zip",
40+
"train_bicubic_x4": _DL_URL + "DIV2K_train_LR_bicubic_X4.zip",
41+
"train_unknown_x4": _DL_URL + "DIV2K_train_LR_unknown_X4.zip",
42+
"valid_bicubic_x4": _DL_URL + "DIV2K_valid_LR_bicubic_X4.zip",
43+
"valid_unknown_x4": _DL_URL + "DIV2K_valid_LR_unknown_X4.zip",
44+
"train_bicubic_x8": _DL_URL + "DIV2K_train_LR_x8.zip",
45+
"valid_bicubic_x8": _DL_URL + "DIV2K_valid_LR_x8.zip",
46+
"train_realistic_mild_x4": _DL_URL + "DIV2K_train_LR_mild.zip",
47+
"valid_realistic_mild_x4": _DL_URL + "DIV2K_valid_LR_mild.zip",
48+
"train_realistic_difficult_x4": _DL_URL + "DIV2K_train_LR_difficult.zip",
49+
"valid_realistic_difficult_x4": _DL_URL + "DIV2K_valid_LR_difficult.zip",
50+
"train_realistic_wild_x4": _DL_URL + "DIV2K_train_LR_wild.zip",
51+
"valid_realistic_wild_x4": _DL_URL + "DIV2K_valid_LR_wild.zip",
52+
}
53+
54+
_DATA_OPTIONS = ["bicubic_x2", "bicubic_x3", "bicubic_x4", "bicubic_x8",
55+
"unknown_x2", "unknown_x3", "unknown_x4",
56+
"realistic_mild_x4", "realistic_difficult_x4",
57+
"realistic_wild_x4"]
58+
59+
class Div2kConfig(tfds.core.BuilderConfig):
60+
"""BuilderConfig for Div2k."""
61+
62+
def __init__(self, data, **kwargs):
63+
"""Constructs a Div2kConfig."""
64+
if data not in _DATA_OPTIONS:
65+
raise ValueError("data must be one of %s" % _DATA_OPTIONS)
66+
67+
name = kwargs.get("name")
68+
if name is None:
69+
name = data
70+
kwargs["name"] = name
71+
72+
description = kwargs.get("description")
73+
if description is None:
74+
description = "Uses %s data." % data
75+
kwargs["description"] = description
76+
77+
super(Div2kConfig, self).__init__(**kwargs)
78+
self.data = data
79+
80+
def download_urls(self):
81+
"""Returns train and validation download urls for this config."""
82+
urls = {
83+
"train_lr_url": _DL_URLS["train_"+self.data],
84+
"valid_lr_url": _DL_URLS["valid_"+self.data],
85+
"train_hr_url": _DL_URLS["train_hr"],
86+
"valid_hr_url": _DL_URLS["valid_hr"],
87+
}
88+
return urls
89+
90+
def _make_builder_configs():
91+
configs = []
92+
for data in _DATA_OPTIONS:
93+
configs.append(Div2kConfig(
94+
version=tfds.core.Version("1.0.0"),
95+
data=data))
96+
return configs
97+
98+
class Div2k(tfds.core.GeneratorBasedBuilder):
99+
"""DIV2K dataset: DIVerse 2K resolution high quality images"""
100+
101+
BUILDER_CONFIGS = _make_builder_configs()
102+
103+
def _info(self):
104+
return tfds.core.DatasetInfo(
105+
builder=self,
106+
description=_DESCRIPTION,
107+
features=tfds.features.FeaturesDict({
108+
"lr": tfds.features.Image(),
109+
"hr": tfds.features.Image(),
110+
}),
111+
citation=_CITATION,
112+
)
113+
114+
def _split_generators(self, dl_manager):
115+
"""Returns SplitGenerators."""
116+
117+
extracted_paths = dl_manager.download_and_extract(
118+
self.builder_config.download_urls())
119+
120+
return [
121+
tfds.core.SplitGenerator(
122+
name=tfds.Split.TRAIN,
123+
gen_kwargs={
124+
"lr_path": extracted_paths["train_lr_url"],
125+
"hr_path": extracted_paths["train_hr_url"],
126+
}
127+
),
128+
tfds.core.SplitGenerator(
129+
name=tfds.Split.VALIDATION,
130+
gen_kwargs={
131+
"lr_path": extracted_paths["valid_lr_url"],
132+
"hr_path": extracted_paths["valid_hr_url"],
133+
}
134+
),
135+
]
136+
137+
def _generate_examples(self, lr_path, hr_path):
138+
"""Yields examples."""
139+
if not tf.io.gfile.listdir(hr_path)[0].endswith(".png"):
140+
hr_path = os.path.join(hr_path, tf.io.gfile.listdir(hr_path)[0])
141+
142+
for root, dirs, files in tf.io.gfile.walk(lr_path):
143+
if len(files) == 0:
144+
continue
145+
for file in files:
146+
yield root + file, {
147+
"lr": os.path.join(root, file),
148+
"hr": os.path.join(hr_path, re.search(r'\d{4}',
149+
str(file)).group(0) + ".png")
150+
}
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
"""Test for div2k dataset."""
2+
3+
from __future__ import absolute_import
4+
from __future__ import division
5+
from __future__ import print_function
6+
7+
from tensorflow_datasets import testing
8+
from tensorflow_datasets.image import div2k
9+
10+
class Div2kTest_bicubic_x2(testing.DatasetBuilderTestCase):
11+
DATASET_CLASS = div2k.Div2k
12+
BUILDER_CONFIG_NAMES_TO_TEST = ["bicubic_x2"]
13+
SPLITS = {
14+
"train": 1,
15+
"validation": 1,
16+
}
17+
18+
DL_EXTRACT_RESULT = {
19+
"train_hr_url": "DIV2K_train_HR",
20+
"valid_hr_url": "DIV2K_valid_HR",
21+
"train_lr_url": "DIV2K_train_LR_bicubic_X2",
22+
"valid_lr_url": "DIV2K_valid_LR_bicubic_X2",
23+
}
24+
25+
class Div2kTest_bicubic_x3(testing.DatasetBuilderTestCase):
26+
DATASET_CLASS = div2k.Div2k
27+
BUILDER_CONFIG_NAMES_TO_TEST = ["bicubic_x3"]
28+
SPLITS = {
29+
"train": 1,
30+
"validation": 1,
31+
}
32+
33+
DL_EXTRACT_RESULT = {
34+
"train_hr_url": "DIV2K_train_HR",
35+
"valid_hr_url": "DIV2K_valid_HR",
36+
"train_lr_url": "DIV2K_train_LR_bicubic_X3",
37+
"valid_lr_url": "DIV2K_valid_LR_bicubic_X3",
38+
}
39+
40+
class Div2kTest_bicubic_x4(testing.DatasetBuilderTestCase):
41+
DATASET_CLASS = div2k.Div2k
42+
BUILDER_CONFIG_NAMES_TO_TEST = ["bicubic_x4"]
43+
SPLITS = {
44+
"train": 1,
45+
"validation": 1,
46+
}
47+
48+
DL_EXTRACT_RESULT = {
49+
"train_hr_url": "DIV2K_train_HR",
50+
"valid_hr_url": "DIV2K_valid_HR",
51+
"train_lr_url": "DIV2K_train_LR_bicubic_X4",
52+
"valid_lr_url": "DIV2K_valid_LR_bicubic_X4",
53+
}
54+
55+
class Div2kTest_bicubic_x4(testing.DatasetBuilderTestCase):
56+
DATASET_CLASS = div2k.Div2k
57+
BUILDER_CONFIG_NAMES_TO_TEST = ["bicubic_x8"]
58+
SPLITS = {
59+
"train": 1,
60+
"validation": 1,
61+
}
62+
63+
DL_EXTRACT_RESULT = {
64+
"train_hr_url": "DIV2K_train_HR",
65+
"valid_hr_url": "DIV2K_valid_HR",
66+
"train_lr_url": "DIV2K_train_LR_x8",
67+
"valid_lr_url": "DIV2K_valid_LR_x8",
68+
}
69+
70+
class Div2kTest_unknown_x2(testing.DatasetBuilderTestCase):
71+
DATASET_CLASS = div2k.Div2k
72+
BUILDER_CONFIG_NAMES_TO_TEST = ["unknown_x2"]
73+
SPLITS = {
74+
"train": 1,
75+
"validation": 1,
76+
}
77+
78+
DL_EXTRACT_RESULT = {
79+
"train_hr_url": "DIV2K_train_HR",
80+
"valid_hr_url": "DIV2K_valid_HR",
81+
"train_lr_url": "DIV2K_train_LR_unknown_X2",
82+
"valid_lr_url": "DIV2K_valid_LR_unknown_X2",
83+
}
84+
85+
class Div2kTest_unknown_x3(testing.DatasetBuilderTestCase):
86+
DATASET_CLASS = div2k.Div2k
87+
BUILDER_CONFIG_NAMES_TO_TEST = ["unknown_x3"]
88+
SPLITS = {
89+
"train": 1,
90+
"validation": 1,
91+
}
92+
93+
DL_EXTRACT_RESULT = {
94+
"train_hr_url": "DIV2K_train_HR",
95+
"valid_hr_url": "DIV2K_valid_HR",
96+
"train_lr_url": "DIV2K_train_LR_unknown_X3",
97+
"valid_lr_url": "DIV2K_valid_LR_unknown_X3",
98+
}
99+
100+
class Div2kTest_unknown_x4(testing.DatasetBuilderTestCase):
101+
DATASET_CLASS = div2k.Div2k
102+
BUILDER_CONFIG_NAMES_TO_TEST = ["unknown_x4"]
103+
SPLITS = {
104+
"train": 1,
105+
"validation": 1,
106+
}
107+
108+
DL_EXTRACT_RESULT = {
109+
"train_hr_url": "DIV2K_train_HR",
110+
"valid_hr_url": "DIV2K_valid_HR",
111+
"train_lr_url": "DIV2K_train_LR_unknown_X4",
112+
"valid_lr_url": "DIV2K_valid_LR_unknown_X4",
113+
}
114+
115+
class Div2kTest_realistic_mild_x4(testing.DatasetBuilderTestCase):
116+
DATASET_CLASS = div2k.Div2k
117+
BUILDER_CONFIG_NAMES_TO_TEST = ["realistic_mild_x4"]
118+
SPLITS = {
119+
"train": 1,
120+
"validation": 1,
121+
}
122+
123+
DL_EXTRACT_RESULT = {
124+
"train_hr_url": "DIV2K_train_HR",
125+
"valid_hr_url": "DIV2K_valid_HR",
126+
"train_lr_url": "DIV2K_train_LR_mild",
127+
"valid_lr_url": "DIV2K_valid_LR_mild",
128+
}
129+
130+
class Div2kTest_realistic_difficult_x4(testing.DatasetBuilderTestCase):
131+
DATASET_CLASS = div2k.Div2k
132+
BUILDER_CONFIG_NAMES_TO_TEST = ["realistic_difficult_x4"]
133+
SPLITS = {
134+
"train": 1,
135+
"validation": 1,
136+
}
137+
138+
DL_EXTRACT_RESULT = {
139+
"train_hr_url": "DIV2K_train_HR",
140+
"valid_hr_url": "DIV2K_valid_HR",
141+
"train_lr_url": "DIV2K_train_LR_difficult",
142+
"valid_lr_url": "DIV2K_valid_LR_difficult",
143+
}
144+
145+
class Div2kTest_realistic_wild_x4(testing.DatasetBuilderTestCase):
146+
DATASET_CLASS = div2k.Div2k
147+
BUILDER_CONFIG_NAMES_TO_TEST = ["realistic_wild_x4"]
148+
SPLITS = {
149+
"train": 1,
150+
"validation": 1,
151+
}
152+
153+
DL_EXTRACT_RESULT = {
154+
"train_hr_url": "DIV2K_train_HR",
155+
"valid_hr_url": "DIV2K_valid_HR",
156+
"train_lr_url": "DIV2K_train_LR_wild",
157+
"valid_lr_url": "DIV2K_valid_LR_wild",
158+
}
159+
160+
if __name__ == "__main__":
161+
testing.test_main()

0 commit comments

Comments
 (0)