Skip to content

Commit 9e11471

Browse files
address requested changes
1 parent 053c7dd commit 9e11471

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

tensorflow_datasets/image/div2k.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,14 @@ def _info(self):
9797
"lr": tfds.features.Image(),
9898
"hr": tfds.features.Image(),
9999
}),
100+
supervised_keys=("lr", "hr"),
100101
homepage=_DL_URL,
101102
citation=_CITATION,
102103
)
103104

104105
def _split_generators(self, dl_manager):
105106
"""Returns SplitGenerators."""
107+
print("EXTRACTING", self.builder_config.download_urls)
106108
extracted_paths = dl_manager.download_and_extract(
107109
self.builder_config.download_urls)
108110

@@ -125,16 +127,10 @@ def _split_generators(self, dl_manager):
125127

126128
def _generate_examples(self, lr_path, hr_path):
127129
"""Yields examples."""
128-
if not tf.io.gfile.listdir(hr_path)[0].endswith(".png"):
129-
hr_path = os.path.join(hr_path, tf.io.gfile.listdir(hr_path)[0])
130-
131130
for root, _, files in tf.io.gfile.walk(lr_path):
132-
if len(files):
133-
for file_path in files:
134-
yield root + file_path, {
135-
"lr": os.path.join(root, file_path),
136-
#extract for corresponding file with matching 4 digit id
137-
"hr": os.path.join(hr_path,
138-
re.search(r'\d{4}',
139-
str(file_path)).group(0) + ".png")
140-
}
131+
for file_path in files:
132+
yield file_path, {
133+
"lr": os.path.join(root, file_path),
134+
#Extract the image id from the filename: "0001x2.png"
135+
"hr": os.path.join(hr_path, file_path[:4]+".png")
136+
}

0 commit comments

Comments
 (0)