Skip to content

Commit 51c55d9

Browse files
Merge pull request #212 from ChanchalKumarMaji:master
PiperOrigin-RevId: 240346070
2 parents c138fef + 09a52d4 commit 51c55d9

File tree

12 files changed

+206
-0
lines changed

12 files changed

+206
-0
lines changed

tensorflow_datasets/image/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from tensorflow_datasets.image.coco import Coco2014
2626
from tensorflow_datasets.image.colorectal_histology import ColorectalHistology
2727
from tensorflow_datasets.image.colorectal_histology import ColorectalHistologyLarge
28+
from tensorflow_datasets.image.cycle_gan import CycleGAN
2829
from tensorflow_datasets.image.diabetic_retinopathy_detection import DiabeticRetinopathyDetection
2930
from tensorflow_datasets.image.dsprites import Dsprites
3031
from tensorflow_datasets.image.flowers import TFFlowers
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# coding=utf-8
2+
# Copyright 2019 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""CycleGAN dataset."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import os
23+
24+
import tensorflow as tf
25+
26+
from tensorflow_datasets.core import api_utils
27+
import tensorflow_datasets.public_api as tfds
28+
29+
# From https://arxiv.org/abs/1703.10593
30+
_CITATION = """\
31+
@article{DBLP:journals/corr/ZhuPIE17,
32+
author = {Jun{-}Yan Zhu and
33+
Taesung Park and
34+
Phillip Isola and
35+
Alexei A. Efros},
36+
title = {Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial
37+
Networks},
38+
journal = {CoRR},
39+
volume = {abs/1703.10593},
40+
year = {2017},
41+
url = {http://arxiv.org/abs/1703.10593},
42+
archivePrefix = {arXiv},
43+
eprint = {1703.10593},
44+
timestamp = {Mon, 13 Aug 2018 16:48:06 +0200},
45+
biburl = {https://dblp.org/rec/bib/journals/corr/ZhuPIE17},
46+
bibsource = {dblp computer science bibliography, https://dblp.org}
47+
}
48+
"""
49+
50+
_DL_URL = "https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/"
51+
52+
# "ae_photos" : Not added because trainA and trainB are missing.
53+
_DATA_OPTIONS = [
54+
"apple2orange", "summer2winter_yosemite", "horse2zebra", "monet2photo",
55+
"cezanne2photo", "ukiyoe2photo", "vangogh2photo", "maps", "cityscapes",
56+
"facades", "iphone2dslr_flower"
57+
]
58+
59+
_DL_URLS = {name: _DL_URL + name + ".zip" for name in _DATA_OPTIONS}
60+
61+
62+
class CycleGANConfig(tfds.core.BuilderConfig):
63+
"""BuilderConfig for CycleGAN."""
64+
65+
@api_utils.disallow_positional_args
66+
def __init__(self, data=None, **kwargs):
67+
"""Constructs a CycleGANConfig.
68+
69+
Args:
70+
data: `str`, one of `_DATA_OPTIONS`.
71+
**kwargs: keyword arguments forwarded to super.
72+
"""
73+
if data not in _DATA_OPTIONS:
74+
raise ValueError("data must be one of %s" % _DATA_OPTIONS)
75+
76+
super(CycleGANConfig, self).__init__(**kwargs)
77+
self.data = data
78+
79+
80+
class CycleGAN(tfds.core.GeneratorBasedBuilder):
81+
"""CycleGAN dataset."""
82+
83+
BUILDER_CONFIGS = [
84+
CycleGANConfig( # pylint: disable=g-complex-comprehension
85+
name=config_name,
86+
description=("A dataset consisting of images from two classes: "
87+
"A and B for example: horses and zebras."),
88+
version="0.1.0",
89+
data=config_name,
90+
) for config_name in _DATA_OPTIONS
91+
]
92+
93+
def _info(self):
94+
return tfds.core.DatasetInfo(
95+
builder=self,
96+
description=("Dataset with images from 2 classes (see config name for "
97+
"information on the specific class)"),
98+
features=tfds.features.FeaturesDict({
99+
"image": tfds.features.Image(),
100+
"label": tfds.features.ClassLabel(names=["A", "B"]),
101+
}),
102+
supervised_keys=("image", "label"),
103+
urls=[
104+
"https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/"
105+
],
106+
)
107+
108+
def _split_generators(self, dl_manager):
109+
url = _DL_URLS[self.builder_config.name]
110+
data_dirs = dl_manager.download_and_extract(url)
111+
112+
path_to_dataset = os.path.join(data_dirs, tf.io.gfile.listdir(data_dirs)[0])
113+
114+
train_a_path = os.path.join(path_to_dataset, "trainA")
115+
train_b_path = os.path.join(path_to_dataset, "trainB")
116+
test_a_path = os.path.join(path_to_dataset, "testA")
117+
test_b_path = os.path.join(path_to_dataset, "testB")
118+
119+
return [
120+
tfds.core.SplitGenerator(
121+
name="trainA",
122+
num_shards=10,
123+
gen_kwargs={
124+
"path": train_a_path,
125+
"label": "A",
126+
}),
127+
tfds.core.SplitGenerator(
128+
name="trainB",
129+
num_shards=10,
130+
gen_kwargs={
131+
"path": train_b_path,
132+
"label": "B",
133+
}),
134+
tfds.core.SplitGenerator(
135+
name="testA",
136+
num_shards=1,
137+
gen_kwargs={
138+
"path": test_a_path,
139+
"label": "A",
140+
}),
141+
tfds.core.SplitGenerator(
142+
name="testB",
143+
num_shards=1,
144+
gen_kwargs={
145+
"path": test_b_path,
146+
"label": "B",
147+
}),
148+
]
149+
150+
def _generate_examples(self, path, label):
151+
images = tf.io.gfile.listdir(path)
152+
153+
for image in images:
154+
yield {
155+
"image": os.path.join(path, image),
156+
"label": label,
157+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# coding=utf-8
2+
# Copyright 2019 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for cycle_gan dataset module."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
from tensorflow_datasets import testing
23+
from tensorflow_datasets.image import cycle_gan
24+
25+
26+
class CycleGANTest(testing.DatasetBuilderTestCase):
27+
DATASET_CLASS = cycle_gan.CycleGAN
28+
BUILDER_CONFIG_NAMES_TO_TEST = ["horse2zebra"]
29+
SPLITS = {
30+
"trainA": 2,
31+
"testA": 2,
32+
"trainB": 2,
33+
"testB": 2,
34+
}
35+
36+
if __name__ == "__main__":
37+
testing.test_main()
Loading
Loading
Loading
Loading
Loading
Loading
Loading

0 commit comments

Comments
 (0)