Skip to content

Commit f2c7982

Browse files
Merge pull request #1290 from WilliamHYZhang:flic_dataset
PiperOrigin-RevId: 290918143
2 parents 851eddb + fb8afb9 commit f2c7982

File tree

11 files changed

+330
-0
lines changed

11 files changed

+330
-0
lines changed

tensorflow_datasets/image/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from tensorflow_datasets.image.dtd import Dtd
5151
from tensorflow_datasets.image.duke_ultrasound import DukeUltrasound
5252
from tensorflow_datasets.image.eurosat import Eurosat
53+
from tensorflow_datasets.image.flic import Flic
5354
from tensorflow_datasets.image.flowers import TFFlowers
5455
from tensorflow_datasets.image.food101 import Food101
5556
from tensorflow_datasets.image.horses_or_humans import HorsesOrHumans

tensorflow_datasets/image/flic.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# coding=utf-8
2+
# Copyright 2019 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Frames Labeled In Cinema (FLIC)."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import os
23+
24+
import tensorflow.compat.v2 as tf
25+
import tensorflow_datasets.public_api as tfds
26+
27+
_CITATION = """@inproceedings{modec13,
28+
title={MODEC: Multimodal Decomposable Models for Human Pose Estimation},
29+
author={Sapp, Benjamin and Taskar, Ben},
30+
booktitle={In Proc. CVPR},
31+
year={2013},
32+
}
33+
"""
34+
35+
_DESCRIPTION = """
36+
From the paper: We collected a 5003 image dataset automatically from popular
37+
Hollywood movies. The images were obtained by running a state-of-the-art person
38+
detector on every tenth frame of 30 movies. People detected with high confidence
39+
(roughly 20K candidates) were then sent to the crowdsourcing marketplace Amazon
40+
Mechanical Turk to obtain groundtruthlabeling. Each image was annotated by five
41+
Turkers for $0.01 each to label 10 upperbody joints. The median-of-five labeling
42+
was taken in each image to be robust to outlier annotation. Finally, images were
43+
rejected manually by us if the person was occluded or severely non-frontal. We
44+
set aside 20% (1016 images) of the data for testing.
45+
"""
46+
47+
_DATA_OPTIONS = ["small", "full"]
48+
49+
_HOMEPAGE_URL = "https://bensapp.github.io/flic-dataset.html"
50+
51+
_URL_SUBSET = "https://drive.google.com/uc?id=0B4K3PZp8xXDJN0Fpb0piVjQ3Y3M&export=download"
52+
_URL_SUPERSET = "https://drive.google.com/uc?id=0B4K3PZp8xXDJd2VwblhhOVBfMDg&export=download"
53+
54+
55+
def _normalize_bbox(raw_bbox, img_path):
56+
"""Normalize torsobox bbox values."""
57+
with tf.io.gfile.GFile(img_path, "rb") as fp:
58+
img = tfds.core.lazy_imports.PIL_Image.open(fp)
59+
width, height = img.size
60+
61+
return tfds.features.BBox(
62+
ymin=raw_bbox[1] / height,
63+
ymax=raw_bbox[3] / height,
64+
xmin=raw_bbox[0] / width,
65+
xmax=raw_bbox[2] / width,
66+
)
67+
68+
69+
class FlicConfig(tfds.core.BuilderConfig):
70+
"""BuilderConfig for FLIC."""
71+
72+
@tfds.core.disallow_positional_args
73+
def __init__(self, data, **kwargs):
74+
"""Constructs a FlicConfig."""
75+
if data not in _DATA_OPTIONS:
76+
raise ValueError("data must be one of %s" % _DATA_OPTIONS)
77+
78+
descriptions = {
79+
"small": "5003 examples used in CVPR13 MODEC paper.",
80+
"full":
81+
"20928 examples, a superset of FLIC consisting of more difficult "
82+
"examples."
83+
}
84+
description = kwargs.get("description", "Uses %s" % descriptions[data])
85+
kwargs["description"] = description
86+
87+
super(FlicConfig, self).__init__(**kwargs)
88+
self.data = data
89+
self.url = _URL_SUBSET if data == "small" else _URL_SUPERSET
90+
self.dir = "FLIC" if data == "small" else "FLIC-full"
91+
92+
93+
def _make_builder_configs():
94+
configs = []
95+
for data in _DATA_OPTIONS:
96+
configs.append(
97+
FlicConfig(name=data, version=tfds.core.Version("2.0.0"), data=data))
98+
return configs
99+
100+
101+
class Flic(tfds.core.GeneratorBasedBuilder):
102+
"""Frames Labeled In Cinema (FLIC)."""
103+
104+
BUILDER_CONFIGS = _make_builder_configs()
105+
106+
def _info(self):
107+
return tfds.core.DatasetInfo(
108+
builder=self,
109+
description=_DESCRIPTION,
110+
features=tfds.features.FeaturesDict({
111+
"image":
112+
tfds.features.Image(
113+
shape=(480, 720, 3), encoding_format="jpeg"),
114+
"poselet_hit_idx":
115+
tfds.features.Sequence(tf.uint16),
116+
"moviename":
117+
tfds.features.Text(),
118+
"xcoords":
119+
tfds.features.Sequence(tf.float64),
120+
"ycoords":
121+
tfds.features.Sequence(tf.float64),
122+
"currframe":
123+
tfds.features.Tensor(shape=(), dtype=tf.float64),
124+
"torsobox":
125+
tfds.features.BBoxFeature(),
126+
}),
127+
homepage=_HOMEPAGE_URL,
128+
citation=_CITATION,
129+
)
130+
131+
def _split_generators(self, dl_manager):
132+
"""Returns SplitGenerators."""
133+
extract_path = dl_manager.download_and_extract(self.builder_config.url)
134+
135+
mat_path = os.path.join(extract_path, self.builder_config.dir,
136+
"examples.mat")
137+
with tf.io.gfile.GFile(mat_path, "rb") as f:
138+
data = tfds.core.lazy_imports.scipy.io.loadmat(
139+
f, struct_as_record=True, squeeze_me=True, mat_dtype=True)
140+
141+
return [
142+
tfds.core.SplitGenerator(
143+
name=tfds.Split.TRAIN,
144+
gen_kwargs={
145+
"extract_path": extract_path,
146+
"data": data,
147+
"selection_column": 7, # indicates train split selection
148+
},
149+
),
150+
tfds.core.SplitGenerator(
151+
name=tfds.Split.TEST,
152+
gen_kwargs={
153+
"extract_path": extract_path,
154+
"data": data,
155+
"selection_column": 8, # indicates test split selection
156+
},
157+
),
158+
]
159+
160+
def _generate_examples(self, extract_path, data, selection_column):
161+
"""Yields examples."""
162+
for u_id, example in enumerate(data["examples"]):
163+
if example[selection_column]:
164+
img_path = os.path.join(extract_path, self.builder_config.dir, "images",
165+
example[3])
166+
yield u_id, {
167+
"image": img_path,
168+
"poselet_hit_idx": example[0],
169+
"moviename": example[1],
170+
"xcoords": example[2][0],
171+
"ycoords": example[2][1],
172+
"currframe": example[5],
173+
"torsobox": _normalize_bbox(example[6], img_path),
174+
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# coding=utf-8
2+
# Copyright 2019 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Test for FLIC dataset."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
from tensorflow_datasets import testing
23+
from tensorflow_datasets.image import flic
24+
25+
26+
class FlicTestSmall(testing.DatasetBuilderTestCase):
27+
DATASET_CLASS = flic.Flic
28+
BUILDER_CONFIG_NAMES_TO_TEST = ["small"]
29+
SPLITS = {
30+
"train": 1,
31+
"test": 1,
32+
}
33+
34+
35+
class FlicTestFull(testing.DatasetBuilderTestCase):
36+
DATASET_CLASS = flic.Flic
37+
BUILDER_CONFIG_NAMES_TO_TEST = ["full"]
38+
SPLITS = {
39+
"train": 1,
40+
"test": 1,
41+
}
42+
43+
44+
if __name__ == "__main__":
45+
testing.test_main()

tensorflow_datasets/testing/flic.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# coding=utf-8
2+
# Copyright 2019 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Generates FLIC like files with random data for testing."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import os
23+
24+
from absl import app
25+
from absl import flags
26+
27+
import numpy as np
28+
import scipy.io
29+
import tensorflow as tf
30+
31+
from tensorflow_datasets.core.utils import py_utils
32+
from tensorflow_datasets.testing import fake_data_utils
33+
34+
flags.DEFINE_string("tfds_dir", py_utils.tfds_dir(),
35+
"Path to tensorflow_datasets directory")
36+
37+
FLAGS = flags.FLAGS
38+
39+
40+
def _output_dir(data):
41+
"""Returns output directory."""
42+
dname = "FLIC" if data == "small" else "FLIC-full"
43+
return os.path.join(FLAGS.tfds_dir, "testing", "test_data", "fake_examples",
44+
"flic", dname)
45+
46+
47+
def _generate_image(data, fdir, fname):
48+
dirname = os.path.join(_output_dir(data), fdir)
49+
if not os.path.exists(dirname):
50+
os.makedirs(dirname)
51+
tf.io.gfile.copy(
52+
fake_data_utils.get_random_jpeg(480, 720),
53+
os.path.join(dirname, fname),
54+
overwrite=True)
55+
56+
57+
def _generate_mat(data, train_fname, test_fname):
58+
"""Generate MAT file for given data type (small or full)."""
59+
dirname = os.path.join(_output_dir(data), "examples.mat")
60+
data = {
61+
"examples":
62+
np.array([
63+
np.array([
64+
np.array([1, 2, 3], dtype=np.uint16),
65+
"example_movie",
66+
np.array(
67+
[np.array([1.0, 2.0, 3.0]),
68+
np.array([1.0, 2.0, 3.0])]),
69+
train_fname,
70+
np.array([1.0, 2.0, 3.0]),
71+
1.0,
72+
np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32),
73+
True,
74+
False,
75+
]),
76+
np.array([
77+
np.array([1, 2, 3], dtype=np.uint16),
78+
"example_movie",
79+
np.array(
80+
[np.array([1.0, 2.0, 3.0]),
81+
np.array([1.0, 2.0, 3.0])]),
82+
test_fname,
83+
np.array([1.0, 2.0, 3.0]),
84+
1.0,
85+
np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32),
86+
False,
87+
True,
88+
]),
89+
]),
90+
}
91+
92+
scipy.io.savemat(dirname, data)
93+
94+
95+
def main(unused_argv):
96+
_generate_image("small", "images", "example_movie00000001.jpg")
97+
_generate_image("small", "images", "example_movie00000002.jpg")
98+
_generate_mat("small", "example_movie00000001.jpg",
99+
"example_movie00000002.jpg")
100+
101+
_generate_image("full", "images", "example_movie00000003.jpg")
102+
_generate_image("full", "images", "example_movie00000004.jpg")
103+
_generate_mat("full", "example_movie00000003.jpg",
104+
"example_movie00000004.jpg")
105+
106+
107+
if __name__ == "__main__":
108+
app.run(main)
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)