Skip to content

Commit dabc0cb

Browse files
committed
parsing the correct labels
1 parent c2dfce8 commit dabc0cb

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

tensorflow_datasets/image/deep_weeds.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
import tensorflow as tf
2525
import tensorflow_datasets.public_api as tfds
2626

27+
import numpy as np
2728

2829
_URL = "https://nextcloud.qriscloud.org.au/index.php/s/a3KxPawpqkiorST/download"
30+
_URL_LBL = "https://raw.githubusercontent.com/AlexOlsen/DeepWeeds/master/labels/labels.csv"
2931

3032
_DESCRIPTION = (
3133
"""The DeepWeeds dataset consists of 17,509 images capturing eight different weed species native to Australia """
@@ -86,7 +88,7 @@ def _info(self):
8688
"label": tfds.features.ClassLabel(names=_NAMES),
8789
}),
8890
supervised_keys=("image", "label"),
89-
homepage="https://github.com/AlexOlsen/DeepWeeds",
91+
urls=[_URL, _URL_LBL],
9092
citation=_CITATION,
9193
)
9294

@@ -98,19 +100,38 @@ def _split_generators(self, dl_manager):
98100
url=_URL,
99101
extract_method=tfds.download.ExtractMethod.ZIP))
100102

103+
104+
path_lbl = dl_manager.download_and_extract(
105+
tfds.download.Resource(
106+
url=_URL_LBL,
107+
extract_method=None))
108+
109+
110+
# there are different label set for train and test
111+
# for now we return the full dataset as 'train' set.
101112
return [
102113
tfds.core.SplitGenerator(
103114
name="train",
104115
gen_kwargs={
105116
"data_dir_path": path,
117+
"label_dir_path": path_lbl,
106118
},
107119
),
108120
]
109121

110-
def _generate_examples(self, data_dir_path):
122+
def _generate_examples(self, data_dir_path, label_dir_path):
111123
"""Generate images and labels for splits."""
112-
124+
# parse the csv-label data
125+
csv = np.loadtxt(label_dir_path,
126+
dtype={'names': ('Filename', 'Label', 'Species'), 'formats': ('S21', 'i4', 'S1')},
127+
skiprows=1,
128+
delimiter=',')
129+
130+
label_dict = {}
131+
for entry in csv:
132+
label_dict[entry[0].decode('UTF-8')] = int(entry[1])
133+
113134
for file_name in tf.io.gfile.listdir(data_dir_path):
114135
image = os.path.join(data_dir_path, file_name)
115-
label = _NAMES[int(file_name.split("-")[2].split(".")[0])]
136+
label = _NAMES[label_dict[file_name]]
116137
yield file_name, {"image": image, "label": label}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
https://nextcloud.qriscloud.org.au/index.php/s/a3KxPawpqkiorST/download 935276050 d616b33efe097909bdac9623abc5705b59ad11f6796a26a956c1aa5de652f1ba
2+
https://raw.githubusercontent.com/AlexOlsen/DeepWeeds/master/labels/labels.csv 598898 6fb95b89fd9d384f94e185a5cab6c5da7c987649399c90618ecc60acfb0112eb

0 commit comments

Comments
 (0)