Skip to content

Commit a5d0580

Browse files
TensorFlow Datasets Teamcopybara-github
authored andcommitted
Creates train/val/test split for Kitti using video ids.
PiperOrigin-RevId: 258103690
1 parent 7e5aaa0 commit a5d0580

File tree

7 files changed

+138
-6
lines changed

7 files changed

+138
-6
lines changed

tensorflow_datasets/image/kitti.py

Lines changed: 102 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
_DATA_URL = "https://s3.eu-central-1.amazonaws.com/avg-kitti"
4848
_IMAGES_FNAME = "data_object_image_2.zip"
4949
_LABELS_FNAME = "data_object_label_2.zip"
50+
_DEVKIT_FNAME = "devkit_object.zip"
5051
_OBJECT_LABELS = [
5152
"Car",
5253
"Van",
@@ -57,6 +58,10 @@
5758
"Tram",
5859
"Misc",
5960
]
61+
# The percentage of trainset videos to put into validation and test sets.
62+
# The released test images do not have labels.
63+
_VALIDATION_SPLIT_PERCENT_VIDEOS = 10
64+
_TEST_SPLIT_PERCENT_VIDEOS = 10
6065

6166
# Raw Kitti representation of a bounding box. Coordinates are in pixels,
6267
# measured from the top-left hand corner.
@@ -67,12 +72,13 @@
6772
class Kitti(tfds.core.GeneratorBasedBuilder):
6873
"""Kitti dataset."""
6974

70-
VERSION = tfds.core.Version("1.0.0")
75+
VERSION = tfds.core.Version("3.0.0")
7176
SUPPORTED_VERSIONS = [
7277
tfds.core.Version("2.0.0"),
7378
]
7479
# Version history:
7580
# 2.0.0: S3 with new hashing function (different shuffle).
81+
# 3.0.0: Train/val/test splits based on random video IDs created.
7682

7783
def _info(self):
7884
# Annotation descriptions are in the object development kit.
@@ -102,26 +108,48 @@ def _split_generators(self, dl_manager):
102108
filenames = {
103109
"images": os.path.join(_DATA_URL, _IMAGES_FNAME),
104110
"annotations": os.path.join(_DATA_URL, _LABELS_FNAME),
111+
"devkit": os.path.join(_DATA_URL, _DEVKIT_FNAME),
105112
}
106113
files = dl_manager.download(filenames)
114+
train_images, validation_images, test_images = _build_splits(
115+
dl_manager.iter_archive(files["devkit"]))
116+
107117
return [
108118
tfds.core.SplitGenerator(
109119
name=tfds.Split.TRAIN,
110120
gen_kwargs={
111121
"images": dl_manager.iter_archive(files["images"]),
112122
"annotations": dl_manager.iter_archive(files["annotations"]),
113123
"subdir": "training",
124+
"image_ids": train_images,
125+
}),
126+
tfds.core.SplitGenerator(
127+
name=tfds.Split.VALIDATION,
128+
gen_kwargs={
129+
"images": dl_manager.iter_archive(files["images"]),
130+
"annotations": dl_manager.iter_archive(files["annotations"]),
131+
"subdir": "training",
132+
"image_ids": validation_images,
133+
}),
134+
tfds.core.SplitGenerator(
135+
name=tfds.Split.TEST,
136+
gen_kwargs={
137+
"images": dl_manager.iter_archive(files["images"]),
138+
"annotations": dl_manager.iter_archive(files["annotations"]),
139+
"subdir": "training",
140+
"image_ids": test_images,
114141
}),
115142
]
116143

117-
def _generate_examples(self, images, annotations, subdir):
144+
def _generate_examples(self, images, annotations, subdir, image_ids):
118145
"""Yields images and annotations.
119146
120147
Args:
121148
images: object that iterates over the archive of images.
122149
annotations: object that iterates over the archive of annotations.
123150
subdir: subdirectory from which to extract images and annotations, e.g.
124151
training or testing.
152+
image_ids: file ids for images in this split.
125153
126154
Yields:
127155
A tuple containing the example's key, and the example.
@@ -145,8 +173,10 @@ def _generate_examples(self, images, annotations, subdir):
145173
continue
146174
if prefix.split("/")[0] != subdir:
147175
continue
148-
149-
annotations = all_annotations[int(prefix[-6:])]
176+
image_id = int(prefix[-6:])
177+
if image_id not in image_ids:
178+
continue
179+
annotations = all_annotations[image_id]
150180
img = cv2.imdecode(np.fromstring(fobj.read(), dtype=np.uint8),
151181
cv2.IMREAD_COLOR)
152182
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
@@ -211,3 +241,71 @@ def _parse_kitti_annotations(annotations_csv):
211241
"rotation_y": float(rotation_y),
212242
})
213243
return annotations
244+
245+
246+
def _build_splits(devkit):
247+
"""Splits the train data into train/val/test by video.
248+
249+
Ensures that images from the same video do not traverse the splits.
250+
251+
Args:
252+
devkit: object that iterates over the devkit archive.
253+
254+
Returns:
255+
train_images: File ids for the training set images.
256+
validation_images: File ids for the validation set images.
257+
test_images: File ids for the test set images.
258+
"""
259+
mapping_line_ids = None
260+
mapping_lines = None
261+
for fpath, fobj in devkit:
262+
if fpath == "mapping/train_rand.txt":
263+
# Converts 1-based line index to 0-based line index.
264+
mapping_line_ids = [
265+
int(x.strip()) - 1 for x in fobj.read().decode("utf-8").split(",")
266+
]
267+
if fpath == "mapping/train_mapping.txt":
268+
mapping_lines = fobj.readlines()
269+
mapping_lines = [x.decode("utf-8") for x in mapping_lines]
270+
271+
assert mapping_line_ids
272+
assert mapping_lines
273+
274+
video_to_image = collections.defaultdict(list)
275+
for image_id, mapping_lineid in enumerate(mapping_line_ids):
276+
line = mapping_lines[mapping_lineid]
277+
video_id = line.split(" ")[1]
278+
video_to_image[video_id].append(image_id)
279+
280+
# Sets numpy random state.
281+
numpy_original_state = np.random.get_state()
282+
np.random.seed(seed=123)
283+
284+
# Max 1 for testing.
285+
num_test_videos = max(1,
286+
_TEST_SPLIT_PERCENT_VIDEOS * len(video_to_image) // 100)
287+
num_validation_videos = max(
288+
1,
289+
_VALIDATION_SPLIT_PERCENT_VIDEOS * len(video_to_image) // 100)
290+
test_videos = set(
291+
np.random.choice(
292+
list(video_to_image.keys()), num_test_videos, replace=False))
293+
validation_videos = set(
294+
np.random.choice(
295+
list(set(video_to_image.keys()) - set(test_videos)),
296+
num_validation_videos,
297+
replace=False))
298+
test_images = []
299+
validation_images = []
300+
train_images = []
301+
for k, v in video_to_image.items():
302+
if k in test_videos:
303+
test_images.extend(v)
304+
elif k in validation_videos:
305+
validation_images.extend(v)
306+
else:
307+
train_images.extend(v)
308+
309+
# Resets numpy random state.
310+
np.random.set_state(numpy_original_state)
311+
return train_images, validation_images, test_images

tensorflow_datasets/image/kitti_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,14 @@
2626
class KittiTest(testing.DatasetBuilderTestCase):
2727
DATASET_CLASS = kitti.Kitti
2828
SPLITS = {
29-
"train": 5,
29+
"train": 6,
30+
"validation": 2,
31+
"test": 2,
3032
}
3133
DL_EXTRACT_RESULT = {
3234
"images": "data_object_image_2.zip",
3335
"annotations": "data_object_label_2.zip",
36+
"devkit": "devkit_object.zip",
3437
}
3538

3639

tensorflow_datasets/testing/kitti.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@
4040
"Path to tensorflow_datasets directory")
4141

4242
FLAGS = flags.FLAGS
43-
NUM_IMAGES = 5
43+
NUM_IMAGES = 10
44+
NUM_VIDEOS = 5
4445
HEIGHT = 375
4546
WIDTH = 1242
4647
OBJECTS = [
@@ -139,6 +140,28 @@ def _get_label_file(annotation):
139140
return fobj.name
140141

141142

143+
def _get_mapping_files():
144+
"""Returns dummy image to video mapping files."""
145+
# Random indices file.
146+
train_rand = np.random.permutation(range(1, NUM_IMAGES + 1)) # 1-based index
147+
fobj_rand = tempfile.NamedTemporaryFile(
148+
delete=False, mode="wb", suffix=".txt")
149+
fobj_rand.write(",".join([str(x) for x in train_rand]))
150+
fobj_rand.close()
151+
152+
# Mapping file.
153+
fobj_map = tempfile.NamedTemporaryFile(delete=False, mode="wb", suffix=".txt")
154+
assert NUM_IMAGES > NUM_VIDEOS
155+
assert NUM_IMAGES % NUM_VIDEOS == 0
156+
vid_ids = list(range(NUM_VIDEOS)) * (NUM_IMAGES // NUM_VIDEOS)
157+
for vid in vid_ids:
158+
row = "2011_09_26 2011_09_26_drive_00{:02d}_sync 0000000123".format(vid)
159+
fobj_map.write(row + "\n")
160+
fobj_map.close()
161+
162+
return fobj_rand.name, fobj_map.name
163+
164+
142165
def _create_zip_files():
143166
"""Saves png and label using name index."""
144167
if not os.path.exists(_output_dir()):
@@ -161,6 +184,13 @@ def _create_zip_files():
161184
label,
162185
os.path.join("training", "label_2", "label_{:06d}.txt".format(i)))
163186

187+
devkit_out_path = os.path.join(_output_dir(), "devkit_object.zip")
188+
with zipfile.ZipFile(devkit_out_path, "w") as devkit_zip:
189+
train_rand, train_mapping = _get_mapping_files()
190+
devkit_zip.write(train_rand, os.path.join("mapping", "train_rand.txt"))
191+
devkit_zip.write(train_mapping, os.path.join("mapping",
192+
"train_mapping.txt"))
193+
164194

165195
def main(argv):
166196
if len(argv) > 1:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
https://s3.eu-central-1.amazonaws.com/avg-kitti/data_object_image_2.zip 12569945557 351c5a2aa0cd9238b50174a3a62b846bc5855da256b82a196431d60ff8d43617
22
https://s3.eu-central-1.amazonaws.com/avg-kitti/data_object_label_2.zip 5601213 4efc76220d867e1c31bb980bbf8cbc02599f02a9cb4350effa98dbb04aaed880
3+
https://s3.eu-central-1.amazonaws.com/avg-kitti/devkit_object.zip 63778 cfde67e531832618ea0fd4844f91e34a45068025e2bef79e278cb812dc2537d0

0 commit comments

Comments
 (0)