47
47
_DATA_URL = "https://s3.eu-central-1.amazonaws.com/avg-kitti"
48
48
_IMAGES_FNAME = "data_object_image_2.zip"
49
49
_LABELS_FNAME = "data_object_label_2.zip"
50
+ _DEVKIT_FNAME = "devkit_object.zip"
50
51
_OBJECT_LABELS = [
51
52
"Car" ,
52
53
"Van" ,
57
58
"Tram" ,
58
59
"Misc" ,
59
60
]
61
+ # The percentage of trainset videos to put into validation and test sets.
62
+ # The released test images do not have labels.
63
+ _VALIDATION_SPLIT_PERCENT_VIDEOS = 10
64
+ _TEST_SPLIT_PERCENT_VIDEOS = 10
60
65
61
66
# Raw Kitti representation of a bounding box. Coordinates are in pixels,
62
67
# measured from the top-left hand corner.
67
72
class Kitti (tfds .core .GeneratorBasedBuilder ):
68
73
"""Kitti dataset."""
69
74
70
- VERSION = tfds .core .Version ("1 .0.0" )
75
+ VERSION = tfds .core .Version ("3 .0.0" )
71
76
SUPPORTED_VERSIONS = [
72
77
tfds .core .Version ("2.0.0" ),
73
78
]
74
79
# Version history:
75
80
# 2.0.0: S3 with new hashing function (different shuffle).
81
+ # 3.0.0: Train/val/test splits based on random video IDs created.
76
82
77
83
def _info (self ):
78
84
# Annotation descriptions are in the object development kit.
@@ -102,26 +108,48 @@ def _split_generators(self, dl_manager):
102
108
filenames = {
103
109
"images" : os .path .join (_DATA_URL , _IMAGES_FNAME ),
104
110
"annotations" : os .path .join (_DATA_URL , _LABELS_FNAME ),
111
+ "devkit" : os .path .join (_DATA_URL , _DEVKIT_FNAME ),
105
112
}
106
113
files = dl_manager .download (filenames )
114
+ train_images , validation_images , test_images = _build_splits (
115
+ dl_manager .iter_archive (files ["devkit" ]))
116
+
107
117
return [
108
118
tfds .core .SplitGenerator (
109
119
name = tfds .Split .TRAIN ,
110
120
gen_kwargs = {
111
121
"images" : dl_manager .iter_archive (files ["images" ]),
112
122
"annotations" : dl_manager .iter_archive (files ["annotations" ]),
113
123
"subdir" : "training" ,
124
+ "image_ids" : train_images ,
125
+ }),
126
+ tfds .core .SplitGenerator (
127
+ name = tfds .Split .VALIDATION ,
128
+ gen_kwargs = {
129
+ "images" : dl_manager .iter_archive (files ["images" ]),
130
+ "annotations" : dl_manager .iter_archive (files ["annotations" ]),
131
+ "subdir" : "training" ,
132
+ "image_ids" : validation_images ,
133
+ }),
134
+ tfds .core .SplitGenerator (
135
+ name = tfds .Split .TEST ,
136
+ gen_kwargs = {
137
+ "images" : dl_manager .iter_archive (files ["images" ]),
138
+ "annotations" : dl_manager .iter_archive (files ["annotations" ]),
139
+ "subdir" : "training" ,
140
+ "image_ids" : test_images ,
114
141
}),
115
142
]
116
143
117
- def _generate_examples (self , images , annotations , subdir ):
144
+ def _generate_examples (self , images , annotations , subdir , image_ids ):
118
145
"""Yields images and annotations.
119
146
120
147
Args:
121
148
images: object that iterates over the archive of images.
122
149
annotations: object that iterates over the archive of annotations.
123
150
subdir: subdirectory from which to extract images and annotations, e.g.
124
151
training or testing.
152
+ image_ids: file ids for images in this split.
125
153
126
154
Yields:
127
155
A tuple containing the example's key, and the example.
@@ -145,8 +173,10 @@ def _generate_examples(self, images, annotations, subdir):
145
173
continue
146
174
if prefix .split ("/" )[0 ] != subdir :
147
175
continue
148
-
149
- annotations = all_annotations [int (prefix [- 6 :])]
176
+ image_id = int (prefix [- 6 :])
177
+ if image_id not in image_ids :
178
+ continue
179
+ annotations = all_annotations [image_id ]
150
180
img = cv2 .imdecode (np .fromstring (fobj .read (), dtype = np .uint8 ),
151
181
cv2 .IMREAD_COLOR )
152
182
img = cv2 .cvtColor (img , cv2 .COLOR_BGR2RGB )
@@ -211,3 +241,71 @@ def _parse_kitti_annotations(annotations_csv):
211
241
"rotation_y" : float (rotation_y ),
212
242
})
213
243
return annotations
244
+
245
+
246
+ def _build_splits (devkit ):
247
+ """Splits the train data into train/val/test by video.
248
+
249
+ Ensures that images from the same video do not traverse the splits.
250
+
251
+ Args:
252
+ devkit: object that iterates over the devkit archive.
253
+
254
+ Returns:
255
+ train_images: File ids for the training set images.
256
+ validation_images: File ids for the validation set images.
257
+ test_images: File ids for the test set images.
258
+ """
259
+ mapping_line_ids = None
260
+ mapping_lines = None
261
+ for fpath , fobj in devkit :
262
+ if fpath == "mapping/train_rand.txt" :
263
+ # Converts 1-based line index to 0-based line index.
264
+ mapping_line_ids = [
265
+ int (x .strip ()) - 1 for x in fobj .read ().decode ("utf-8" ).split ("," )
266
+ ]
267
+ if fpath == "mapping/train_mapping.txt" :
268
+ mapping_lines = fobj .readlines ()
269
+ mapping_lines = [x .decode ("utf-8" ) for x in mapping_lines ]
270
+
271
+ assert mapping_line_ids
272
+ assert mapping_lines
273
+
274
+ video_to_image = collections .defaultdict (list )
275
+ for image_id , mapping_lineid in enumerate (mapping_line_ids ):
276
+ line = mapping_lines [mapping_lineid ]
277
+ video_id = line .split (" " )[1 ]
278
+ video_to_image [video_id ].append (image_id )
279
+
280
+ # Sets numpy random state.
281
+ numpy_original_state = np .random .get_state ()
282
+ np .random .seed (seed = 123 )
283
+
284
+ # Max 1 for testing.
285
+ num_test_videos = max (1 ,
286
+ _TEST_SPLIT_PERCENT_VIDEOS * len (video_to_image ) // 100 )
287
+ num_validation_videos = max (
288
+ 1 ,
289
+ _VALIDATION_SPLIT_PERCENT_VIDEOS * len (video_to_image ) // 100 )
290
+ test_videos = set (
291
+ np .random .choice (
292
+ list (video_to_image .keys ()), num_test_videos , replace = False ))
293
+ validation_videos = set (
294
+ np .random .choice (
295
+ list (set (video_to_image .keys ()) - set (test_videos )),
296
+ num_validation_videos ,
297
+ replace = False ))
298
+ test_images = []
299
+ validation_images = []
300
+ train_images = []
301
+ for k , v in video_to_image .items ():
302
+ if k in test_videos :
303
+ test_images .extend (v )
304
+ elif k in validation_videos :
305
+ validation_images .extend (v )
306
+ else :
307
+ train_images .extend (v )
308
+
309
+ # Resets numpy random state.
310
+ np .random .set_state (numpy_original_state )
311
+ return train_images , validation_images , test_images
0 commit comments