|
24 | 24 | import tensorflow as tf
|
25 | 25 | import tensorflow_datasets.public_api as tfds
|
26 | 26 |
|
27 |
| -import numpy as np |
| 27 | +import csv |
28 | 28 |
|
29 | 29 | _URL = "https://nextcloud.qriscloud.org.au/index.php/s/a3KxPawpqkiorST/download"
|
30 | 30 | _URL_LBL = "https://raw.githubusercontent.com/AlexOlsen/DeepWeeds/master/labels/labels.csv"
|
@@ -85,54 +85,49 @@ def _info(self):
|
85 | 85 | description=(_DESCRIPTION),
|
86 | 86 | features=tfds.features.FeaturesDict({
|
87 | 87 | "image": tfds.features.Image(shape=_IMAGE_SHAPE),
|
88 |
| - "label": tfds.features.ClassLabel(names=_NAMES), |
| 88 | + "label": tfds.features.ClassLabel(num_classes=9), |
89 | 89 | }),
|
90 | 90 | supervised_keys=("image", "label"),
|
91 | 91 | homepage="https://github.com/AlexOlsen/DeepWeeds",
|
92 |
| - urls=[_URL, _URL_LBL], |
93 | 92 | citation=_CITATION,
|
94 | 93 | )
|
95 | 94 |
|
96 | 95 | def _split_generators(self, dl_manager):
|
97 | 96 | """Define Splits."""
|
98 | 97 | # The file is in ZIP format, but URL doesn't mention it.
|
99 |
| - path = dl_manager.download_and_extract( |
100 |
| - tfds.download.Resource( |
101 |
| - url=_URL, |
102 |
| - extract_method=tfds.download.ExtractMethod.ZIP)) |
103 |
| - |
104 |
| - |
105 |
| - path_lbl = dl_manager.download_and_extract( |
106 |
| - tfds.download.Resource( |
107 |
| - url=_URL_LBL, |
108 |
| - extract_method=None)) |
| 98 | + paths = dl_manager.download_and_extract({ |
| 99 | + "image": tfds.download.Resource( |
| 100 | + url=_URL, |
| 101 | + extract_method=tfds.download.ExtractMethod.ZIP), |
| 102 | + "label": _URL_LBL}) |
109 | 103 |
|
110 |
| - |
111 |
| - # there are different label set for train and test |
112 |
| - # for now we return the full dataset as 'train' set. |
113 | 104 | return [
|
114 | 105 | tfds.core.SplitGenerator(
|
115 | 106 | name="train",
|
116 | 107 | gen_kwargs={
|
117 |
| - "data_dir_path": path, |
118 |
| - "label_dir_path": path_lbl, |
| 108 | + "data_dir_path": paths["image"], |
| 109 | + "label_path": paths["label"], |
119 | 110 | },
|
120 | 111 | ),
|
121 | 112 | ]
|
122 | 113 |
|
123 |
| - def _generate_examples(self, data_dir_path, label_dir_path): |
| 114 | + def _generate_examples(self, data_dir_path, label_path): |
124 | 115 | """Generate images and labels for splits."""
|
125 |
| - # parse the csv-label data |
126 |
| - csv = np.loadtxt(label_dir_path, |
127 |
| - dtype={'names': ('Filename', 'Label', 'Species'), 'formats': ('S21', 'i4', 'S1')}, |
128 |
| - skiprows=1, |
129 |
| - delimiter=',') |
130 |
| - |
131 |
| - label_dict = {} |
132 |
| - for entry in csv: |
133 |
| - label_dict[entry[0].decode('UTF-8')] = int(entry[1]) |
134 | 116 |
|
| 117 | + with tf.io.gfile.GFile(label_path) as f: |
| 118 | + reader = csv.DictReader(f) |
| 119 | + |
| 120 | + # Extract the mapping int -> str and save the label name string to the feature |
| 121 | + label_id_to_name = { |
| 122 | + row['Label']: row['Species'] for row in reader |
| 123 | + } |
| 124 | + self.info.features['label'].names = [v for k, v in sorted(label_id_to_name.items())] |
| 125 | + |
| 126 | + filename_to_label = { |
| 127 | + row['Filename']: row['Species'] for row in reader |
| 128 | + } |
135 | 129 | for file_name in tf.io.gfile.listdir(data_dir_path):
|
136 |
| - image = os.path.join(data_dir_path, file_name) |
137 |
| - label = _NAMES[label_dict[file_name]] |
138 |
| - yield file_name, {"image": image, "label": label} |
| 130 | + yield file_name, { |
| 131 | + "image": os.path.join(data_dir_path, file_name), |
| 132 | + "label": filename_to_label[file_name] |
| 133 | + } |
0 commit comments