Skip to content

Commit c28a63f

Browse files
committed
apply the suggestions from the review; no numpy; only GFile;
1 parent 08e499e commit c28a63f

File tree

1 file changed

+26
-31
lines changed

1 file changed

+26
-31
lines changed

tensorflow_datasets/image/deep_weeds.py

Lines changed: 26 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import tensorflow as tf
2525
import tensorflow_datasets.public_api as tfds
2626

27-
import numpy as np
27+
import csv
2828

2929
_URL = "https://nextcloud.qriscloud.org.au/index.php/s/a3KxPawpqkiorST/download"
3030
_URL_LBL = "https://raw.githubusercontent.com/AlexOlsen/DeepWeeds/master/labels/labels.csv"
@@ -85,54 +85,49 @@ def _info(self):
8585
description=(_DESCRIPTION),
8686
features=tfds.features.FeaturesDict({
8787
"image": tfds.features.Image(shape=_IMAGE_SHAPE),
88-
"label": tfds.features.ClassLabel(names=_NAMES),
88+
"label": tfds.features.ClassLabel(num_classes=9),
8989
}),
9090
supervised_keys=("image", "label"),
9191
homepage="https://github.com/AlexOlsen/DeepWeeds",
92-
urls=[_URL, _URL_LBL],
9392
citation=_CITATION,
9493
)
9594

9695
def _split_generators(self, dl_manager):
9796
"""Define Splits."""
9897
# The file is in ZIP format, but URL doesn't mention it.
99-
path = dl_manager.download_and_extract(
100-
tfds.download.Resource(
101-
url=_URL,
102-
extract_method=tfds.download.ExtractMethod.ZIP))
103-
104-
105-
path_lbl = dl_manager.download_and_extract(
106-
tfds.download.Resource(
107-
url=_URL_LBL,
108-
extract_method=None))
98+
paths = dl_manager.download_and_extract({
99+
"image": tfds.download.Resource(
100+
url=_URL,
101+
extract_method=tfds.download.ExtractMethod.ZIP),
102+
"label": _URL_LBL})
109103

110-
111-
# there are different label set for train and test
112-
# for now we return the full dataset as 'train' set.
113104
return [
114105
tfds.core.SplitGenerator(
115106
name="train",
116107
gen_kwargs={
117-
"data_dir_path": path,
118-
"label_dir_path": path_lbl,
108+
"data_dir_path": paths["image"],
109+
"label_path": paths["label"],
119110
},
120111
),
121112
]
122113

123-
def _generate_examples(self, data_dir_path, label_dir_path):
114+
def _generate_examples(self, data_dir_path, label_path):
124115
"""Generate images and labels for splits."""
125-
# parse the csv-label data
126-
csv = np.loadtxt(label_dir_path,
127-
dtype={'names': ('Filename', 'Label', 'Species'), 'formats': ('S21', 'i4', 'S1')},
128-
skiprows=1,
129-
delimiter=',')
130-
131-
label_dict = {}
132-
for entry in csv:
133-
label_dict[entry[0].decode('UTF-8')] = int(entry[1])
134116

117+
with tf.io.gfile.GFile(label_path) as f:
118+
reader = csv.DictReader(f)
119+
120+
# Extract the mapping int -> str and save the label name string to the feature
121+
label_id_to_name = {
122+
row['Label']: row['Species'] for row in reader
123+
}
124+
self.info.features['label'].names = [v for k, v in sorted(label_id_to_name.items())]
125+
126+
filename_to_label = {
127+
row['Filename']: row['Species'] for row in reader
128+
}
135129
for file_name in tf.io.gfile.listdir(data_dir_path):
136-
image = os.path.join(data_dir_path, file_name)
137-
label = _NAMES[label_dict[file_name]]
138-
yield file_name, {"image": image, "label": label}
130+
yield file_name, {
131+
"image": os.path.join(data_dir_path, file_name),
132+
"label": filename_to_label[file_name]
133+
}

0 commit comments

Comments
 (0)