Skip to content

Commit f3b4d15

Browse files
TensorFlow Datasets Teamcopybara-github
authored andcommitted
Added segmentation mask to Oxford-IIIT pets dataset.
Extended oxford_iiit_pet.py and other related files PiperOrigin-RevId: 257082403
1 parent 1482e7c commit f3b4d15

File tree

13 files changed

+34
-8
lines changed

13 files changed

+34
-8
lines changed

tensorflow_datasets/image/oxford_iiit_pet.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,13 @@ class OxfordIIITPet(tfds.core.GeneratorBasedBuilder):
4848

4949
VERSION = tfds.core.Version("1.1.0")
5050
SUPPORTED_VERSIONS = [
51+
tfds.core.Version("2.1.0", experiments={tfds.core.Experiment.S3: True}),
5152
tfds.core.Version("2.0.0", experiments={tfds.core.Experiment.S3: True}),
53+
tfds.core.Version("1.2.0"),
5254
tfds.core.Version("1.1.0"),
5355
]
5456
# Version history:
57+
# 2.1.0, 1.2.0: addition of the segmentation_mask feature.
5558
# 2.0.0: S3 (new shuffling, sharding and slicing mechanism).
5659

5760
def _info(self):
@@ -62,6 +65,7 @@ def _info(self):
6265
"image": tfds.features.Image(),
6366
"label": tfds.features.ClassLabel(num_classes=37),
6467
"file_name": tfds.features.Text(),
68+
"segmentation_mask": tfds.features.Image(shape=(None, None, 1))
6569
}),
6670
supervised_keys=("image", "label"),
6771
urls=["http://www.robots.ox.ac.uk/~vgg/data/pets/"],
@@ -91,6 +95,7 @@ def _split_generators(self, dl_manager):
9195
num_shards=_NUM_SHARDS,
9296
gen_kwargs={
9397
"images_dir_path": images_path_dir,
98+
"annotations_dir_path": annotations_path_dir,
9499
"images_list_file": os.path.join(annotations_path_dir,
95100
"trainval.txt"),
96101
},
@@ -100,25 +105,31 @@ def _split_generators(self, dl_manager):
100105
num_shards=_NUM_SHARDS,
101106
gen_kwargs={
102107
"images_dir_path": images_path_dir,
108+
"annotations_dir_path": annotations_path_dir,
103109
"images_list_file": os.path.join(annotations_path_dir,
104110
"test.txt")
105111
},
106112
)
107113

108114
return [train_split, test_split]
109115

110-
def _generate_examples(self, images_dir_path, images_list_file):
116+
def _generate_examples(self, images_dir_path, annotations_dir_path,
117+
images_list_file):
111118
with tf.io.gfile.GFile(images_list_file, "r") as images_list:
112119
for line in images_list:
113120
image_name, label, _, _ = line.strip().split(" ")
114121

122+
trimaps_dir_path = os.path.join(annotations_dir_path, "trimaps")
123+
124+
trimap_name = image_name + ".png"
115125
image_name += ".jpg"
116126
label = int(label) - 1
117127

118128
record = {
119129
"image": os.path.join(images_dir_path, image_name),
120130
"label": int(label),
121-
"file_name": image_name
131+
"file_name": image_name,
132+
"segmentation_mask": os.path.join(trimaps_dir_path, trimap_name)
122133
}
123134
if self.version.implements(tfds.core.Experiment.S3):
124135
yield image_name, record

tensorflow_datasets/testing/fake_data_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,17 @@
3232
CHANNELS_NB = 3
3333

3434

35-
def get_random_picture(height=None, width=None):
35+
def get_random_picture(height=None, width=None, channels=CHANNELS_NB):
3636
"""Returns random picture as np.ndarray (int)."""
3737
height = height or random.randrange(MIN_HEIGHT_WIDTH, MAX_HEIGHT_WIDTH)
3838
width = width or random.randrange(MIN_HEIGHT_WIDTH, MAX_HEIGHT_WIDTH)
3939
return np.random.randint(
40-
256, size=(height, width, CHANNELS_NB), dtype=np.uint8)
40+
256, size=(height, width, channels), dtype=np.uint8)
4141

4242

43-
def get_random_jpeg(height=None, width=None):
43+
def get_random_jpeg(height=None, width=None, channels=CHANNELS_NB):
4444
"""Returns path to JPEG picture."""
45-
image = get_random_picture(height, width)
45+
image = get_random_picture(height, width, channels)
4646
jpeg = tf.image.encode_jpeg(image)
4747
with utils.nogpu_session() as sess:
4848
res = sess.run(jpeg)
@@ -52,9 +52,9 @@ def get_random_jpeg(height=None, width=None):
5252
return fobj.name
5353

5454

55-
def get_random_png(height=None, width=None):
55+
def get_random_png(height=None, width=None, channels=CHANNELS_NB):
5656
"""Returns path to PNG picture."""
57-
image = get_random_picture(height, width)
57+
image = get_random_picture(height, width, channels)
5858
png = tf.image.encode_png(image)
5959
with utils.nogpu_session() as sess:
6060
res = sess.run(png)

tensorflow_datasets/testing/oxford_iiit_pet.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,16 @@ def _generate_data():
6060

6161
# Generate annotations
6262
annotations_dir = os.path.join(_output_dir(), 'annotations')
63+
6364
if not tf.io.gfile.exists(annotations_dir):
6465
tf.io.gfile.makedirs(annotations_dir)
66+
67+
# Generate trimaps
68+
trimaps_dir = os.path.join(annotations_dir, 'trimaps')
69+
70+
if not tf.io.gfile.exists(trimaps_dir):
71+
tf.io.gfile.makedirs(trimaps_dir)
72+
6573
global_count = 0
6674
for filename, num_examples in [('trainval.txt', _TRAIN_IMAGES_NUMBER),
6775
('test.txt', _TEST_IMAGES_NUMBER)]:
@@ -73,6 +81,13 @@ def _generate_data():
7381
tf.io.gfile.copy(fobj.name, os.path.join(annotations_dir, filename),
7482
overwrite=True)
7583

84+
# Create trimaps
85+
for i in range(_TRAIN_IMAGES_NUMBER + _TEST_IMAGES_NUMBER):
86+
trimap_name = 'image{:03d}.png'.format(i)
87+
tf.io.gfile.copy(fake_data_utils.get_random_png(channels=1),
88+
os.path.join(trimaps_dir, trimap_name),
89+
overwrite=True)
90+
7691

7792
def main(argv):
7893
if len(argv) > 1:

0 commit comments

Comments
 (0)