|
| 1 | +"""DIV2K dataset: DIVerse 2K resolution high quality images as used for the challenges @ NTIRE (CVPR 2017 and CVPR 2018) and @ PIRM (ECCV 2018)""" |
| 2 | + |
| 3 | +from __future__ import absolute_import |
| 4 | +from __future__ import division |
| 5 | +from __future__ import print_function |
| 6 | + |
| 7 | +import os.path |
| 8 | +import re |
| 9 | + |
| 10 | +import tensorflow as tf |
| 11 | +import tensorflow_datasets.public_api as tfds |
| 12 | + |
| 13 | +_CITATION = """@InProceedings{Ignatov_2018_ECCV_Workshops, |
| 14 | +author = {Ignatov, Andrey and Timofte, Radu and others}, |
| 15 | +title = {PIRM challenge on perceptual image enhancement on smartphones: report}, |
| 16 | +booktitle = {European Conference on Computer Vision (ECCV) Workshops}, |
| 17 | +url = "http://www.vision.ee.ethz.ch/~timofter/publications/Agustsson-CVPRW-2017.pdf", |
| 18 | +month = {January}, |
| 19 | +year = {2019} |
| 20 | +} |
| 21 | +""" |
| 22 | + |
| 23 | +_DESCRIPTION = """ |
| 24 | +DIV2K dataset: DIVerse 2K resolution high quality images as used for the challenges @ NTIRE (CVPR 2017 and CVPR 2018) and @ PIRM (ECCV 2018) |
| 25 | +""" |
| 26 | + |
| 27 | +_DL_URL = "https://data.vision.ee.ethz.ch/cvl/DIV2K/" |
| 28 | + |
| 29 | +_DL_URLS = { |
| 30 | + "train_hr": _DL_URL + "DIV2K_train_HR.zip", |
| 31 | + "valid_hr": _DL_URL + "DIV2K_valid_HR.zip", |
| 32 | + "train_bicubic_x2": _DL_URL + "DIV2K_train_LR_bicubic_X2.zip", |
| 33 | + "train_unknown_x2": _DL_URL + "DIV2K_train_LR_unknown_X2.zip", |
| 34 | + "valid_bicubic_x2": _DL_URL + "DIV2K_valid_LR_bicubic_X2.zip", |
| 35 | + "valid_unknown_x2": _DL_URL + "DIV2K_valid_LR_unknown_X2.zip", |
| 36 | + "train_bicubic_x3": _DL_URL + "DIV2K_train_LR_bicubic_X3.zip", |
| 37 | + "train_unknown_x3": _DL_URL + "DIV2K_train_LR_unknown_X3.zip", |
| 38 | + "valid_bicubic_x3": _DL_URL + "DIV2K_valid_LR_bicubic_X3.zip", |
| 39 | + "valid_unknown_x3": _DL_URL + "DIV2K_valid_LR_unknown_X3.zip", |
| 40 | + "train_bicubic_x4": _DL_URL + "DIV2K_train_LR_bicubic_X4.zip", |
| 41 | + "train_unknown_x4": _DL_URL + "DIV2K_train_LR_unknown_X4.zip", |
| 42 | + "valid_bicubic_x4": _DL_URL + "DIV2K_valid_LR_bicubic_X4.zip", |
| 43 | + "valid_unknown_x4": _DL_URL + "DIV2K_valid_LR_unknown_X4.zip", |
| 44 | + "train_bicubic_x8": _DL_URL + "DIV2K_train_LR_x8.zip", |
| 45 | + "valid_bicubic_x8": _DL_URL + "DIV2K_valid_LR_x8.zip", |
| 46 | + "train_realistic_mild_x4": _DL_URL + "DIV2K_train_LR_mild.zip", |
| 47 | + "valid_realistic_mild_x4": _DL_URL + "DIV2K_valid_LR_mild.zip", |
| 48 | + "train_realistic_difficult_x4": _DL_URL + "DIV2K_train_LR_difficult.zip", |
| 49 | + "valid_realistic_difficult_x4": _DL_URL + "DIV2K_valid_LR_difficult.zip", |
| 50 | + "train_realistic_wild_x4": _DL_URL + "DIV2K_train_LR_wild.zip", |
| 51 | + "valid_realistic_wild_x4": _DL_URL + "DIV2K_valid_LR_wild.zip", |
| 52 | +} |
| 53 | + |
| 54 | +_DATA_OPTIONS = ["bicubic_x2", "bicubic_x3", "bicubic_x4", "bicubic_x8", |
| 55 | + "unknown_x2", "unknown_x3", "unknown_x4", |
| 56 | + "realistic_mild_x4", "realistic_difficult_x4", |
| 57 | + "realistic_wild_x4"] |
| 58 | + |
| 59 | +class Div2kConfig(tfds.core.BuilderConfig): |
| 60 | + """BuilderConfig for Div2k.""" |
| 61 | + |
| 62 | + def __init__(self, data, **kwargs): |
| 63 | + """Constructs a Div2kConfig.""" |
| 64 | + if data not in _DATA_OPTIONS: |
| 65 | + raise ValueError("data must be one of %s" % _DATA_OPTIONS) |
| 66 | + |
| 67 | + name = kwargs.get("name") |
| 68 | + if name is None: |
| 69 | + name = data |
| 70 | + kwargs["name"] = name |
| 71 | + |
| 72 | + description = kwargs.get("description") |
| 73 | + if description is None: |
| 74 | + description = "Uses %s data." % data |
| 75 | + kwargs["description"] = description |
| 76 | + |
| 77 | + super(Div2kConfig, self).__init__(**kwargs) |
| 78 | + self.data = data |
| 79 | + |
| 80 | + def download_urls(self): |
| 81 | + """Returns train and validation download urls for this config.""" |
| 82 | + urls = { |
| 83 | + "train_lr_url": _DL_URLS["train_"+self.data], |
| 84 | + "valid_lr_url": _DL_URLS["valid_"+self.data], |
| 85 | + "train_hr_url": _DL_URLS["train_hr"], |
| 86 | + "valid_hr_url": _DL_URLS["valid_hr"], |
| 87 | + } |
| 88 | + return urls |
| 89 | + |
| 90 | +def _make_builder_configs(): |
| 91 | + configs = [] |
| 92 | + for data in _DATA_OPTIONS: |
| 93 | + configs.append(Div2kConfig( |
| 94 | + version=tfds.core.Version("1.0.0"), |
| 95 | + data=data)) |
| 96 | + return configs |
| 97 | + |
| 98 | +class Div2k(tfds.core.GeneratorBasedBuilder): |
| 99 | + """DIV2K dataset: DIVerse 2K resolution high quality images""" |
| 100 | + |
| 101 | + BUILDER_CONFIGS = _make_builder_configs() |
| 102 | + |
| 103 | + def _info(self): |
| 104 | + return tfds.core.DatasetInfo( |
| 105 | + builder=self, |
| 106 | + description=_DESCRIPTION, |
| 107 | + features=tfds.features.FeaturesDict({ |
| 108 | + "lr": tfds.features.Image(), |
| 109 | + "hr": tfds.features.Image(), |
| 110 | + }), |
| 111 | + citation=_CITATION, |
| 112 | + ) |
| 113 | + |
| 114 | + def _split_generators(self, dl_manager): |
| 115 | + """Returns SplitGenerators.""" |
| 116 | + |
| 117 | + extracted_paths = dl_manager.download_and_extract( |
| 118 | + self.builder_config.download_urls()) |
| 119 | + |
| 120 | + return [ |
| 121 | + tfds.core.SplitGenerator( |
| 122 | + name=tfds.Split.TRAIN, |
| 123 | + gen_kwargs={ |
| 124 | + "lr_path": extracted_paths["train_lr_url"], |
| 125 | + "hr_path": extracted_paths["train_hr_url"], |
| 126 | + } |
| 127 | + ), |
| 128 | + tfds.core.SplitGenerator( |
| 129 | + name=tfds.Split.VALIDATION, |
| 130 | + gen_kwargs={ |
| 131 | + "lr_path": extracted_paths["valid_lr_url"], |
| 132 | + "hr_path": extracted_paths["valid_hr_url"], |
| 133 | + } |
| 134 | + ), |
| 135 | + ] |
| 136 | + |
| 137 | + def _generate_examples(self, lr_path, hr_path): |
| 138 | + """Yields examples.""" |
| 139 | + if not tf.io.gfile.listdir(hr_path)[0].endswith(".png"): |
| 140 | + hr_path = os.path.join(hr_path, tf.io.gfile.listdir(hr_path)[0]) |
| 141 | + |
| 142 | + for root, dirs, files in tf.io.gfile.walk(lr_path): |
| 143 | + if len(files) == 0: |
| 144 | + continue |
| 145 | + for file in files: |
| 146 | + yield root + file, { |
| 147 | + "lr": os.path.join(root, file), |
| 148 | + "hr": os.path.join(hr_path, re.search(r'\d{4}', |
| 149 | + str(file)).group(0) + ".png") |
| 150 | + } |
0 commit comments