Skip to content

Commit d341cee

Browse files
Merge pull request #1 from tensorflow/master
Update fork
2 parents 40be1e6 + 79e8511 commit d341cee

23 files changed

+1369
-81
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
REQUIRED_PKGS = [
3636
'absl-py',
3737
'future',
38+
'numpy',
3839
'promise',
3940
'protobuf>=3.6.1',
4041
'requests',

tensorflow_datasets/audio/librispeech_test.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,40 +20,40 @@
2020
from __future__ import print_function
2121

2222
from tensorflow_datasets import testing
23+
from tensorflow_datasets.audio import librispeech
2324
import tensorflow_datasets.public_api as tfds
24-
from tensorflow_datasets.audio import librispeech
2525

2626

27-
class LibrispeechTest(testing.DatasetBuilderTestCase):
27+
class LibrispeechTest100(testing.DatasetBuilderTestCase):
2828
DATASET_CLASS = librispeech.Librispeech
2929
BUILDER_CONFIG_NAMES_TO_TEST = ["clean-100"]
3030
SPLITS = {
3131
"train": 2,
3232
"test": 1,
3333
"dev": 1,
3434
}
35-
35+
3636
DL_EXTRACT_RESULT = {
3737
tfds.Split.TRAIN: ["train-clean-100"],
3838
tfds.Split.TEST: ["test-clean"],
3939
tfds.Split.VALIDATION: ["dev-clean"],
4040
}
41-
42-
43-
class LibrispeechTest(testing.DatasetBuilderTestCase):
41+
42+
43+
class LibrispeechTest360(testing.DatasetBuilderTestCase):
4444
DATASET_CLASS = librispeech.Librispeech
4545
BUILDER_CONFIG_NAMES_TO_TEST = ["clean-360"]
4646
SPLITS = {
4747
"train": 1,
4848
"test": 1,
4949
"dev": 1,
5050
}
51-
51+
5252
DL_EXTRACT_RESULT = {
53-
tfds.Split.TRAIN: ["train-clean-100", "train-clean-360"],
53+
tfds.Split.TRAIN: ["train-clean-100", "train-clean-360"],
5454
tfds.Split.TEST: ["test-clean"],
5555
tfds.Split.VALIDATION: ["dev-clean"],
56-
}
56+
}
5757

5858

5959
if __name__ == "__main__":

tensorflow_datasets/audio/nsynth.py

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
_DESCRIPTION = """\
2727
The NSynth Dataset is an audio dataset containing ~300k musical notes, each
2828
with a unique pitch, timbre, and envelope. Each note is annotated with three
29-
additional pieces of information based on a combination of human evaluation
29+
additional pieces of information based on a combination of human evaluation
3030
and heuristic algorithms:
3131
-Source: The method of sound production for the note's instrument.
3232
-Family: The high-level family of which the note's instrument is a member.
@@ -63,8 +63,16 @@
6363
"string", "synth_lead", "vocal"]
6464
_INSTRUMENT_SOURCES = ["acoustic", "electronic", "synthetic"]
6565
_QUALITIES = [
66-
"bright", "dark", "distortion", "fast_decay", "long_release", "multiphonic",
67-
"nonlinear_env", "percussive", "reverb", "tempo-synced"]
66+
"bright",
67+
"dark",
68+
"distortion",
69+
"fast_decay",
70+
"long_release",
71+
"multiphonic",
72+
"nonlinear_env",
73+
"percussive",
74+
"reverb",
75+
"tempo-synced"]
6876

6977
_BASE_DOWNLOAD_PATH = "http://download.magenta.tensorflow.org/datasets/nsynth/nsynth-"
7078

@@ -86,11 +94,14 @@ def _info(self):
8694
builder=self,
8795
description=_DESCRIPTION,
8896
features=tfds.features.FeaturesDict({
89-
"id": tf.string,
90-
"audio": tfds.features.Tensor(
91-
shape=(_SAMPLE_LENGTH,), dtype=tf.float32),
92-
"pitch": tfds.features.ClassLabel(num_classes=128),
93-
"velocity": tfds.features.ClassLabel(num_classes=128),
97+
"id":
98+
tf.string,
99+
"audio":
100+
tfds.features.Tensor(shape=(_SAMPLE_LENGTH,), dtype=tf.float32),
101+
"pitch":
102+
tfds.features.ClassLabel(num_classes=128),
103+
"velocity":
104+
tfds.features.ClassLabel(num_classes=128),
94105
"instrument": {
95106
# We read the list of labels in _split_generators.
96107
"label": tfds.features.ClassLabel(num_classes=1006),
@@ -105,17 +116,20 @@ def _info(self):
105116

106117
def _split_generators(self, dl_manager):
107118
dl_urls = {
108-
split: _BASE_DOWNLOAD_PATH + "%s.tfrecord" % split for split in _SPLITS}
109-
dl_urls["instrument_labels"] = _BASE_DOWNLOAD_PATH + "instrument_labels.txt"
119+
split: _BASE_DOWNLOAD_PATH + "%s.tfrecord" % split for split in _SPLITS
120+
}
121+
dl_urls["instrument_labels"] = (_BASE_DOWNLOAD_PATH +
122+
"instrument_labels.txt")
110123
dl_paths = dl_manager.download_and_extract(dl_urls)
111124

112-
instrument_labels = tf.io.gfile.GFile(
113-
dl_paths["instrument_labels"], "r").read().strip().split("\n")
125+
instrument_labels = tf.io.gfile.GFile(dl_paths["instrument_labels"],
126+
"r").read().strip().split("\n")
114127
self.info.features["instrument"]["label"].names = instrument_labels
115128

116129
return [
117-
tfds.core.SplitGenerator(
118-
name=split, num_shards=_SPLIT_SHARDS[split],
130+
tfds.core.SplitGenerator( # pylint: disable=g-complex-comprehension
131+
name=split,
132+
num_shards=_SPLIT_SHARDS[split],
119133
gen_kwargs={"path": dl_paths[split]}) for split in _SPLITS
120134
]
121135

@@ -126,15 +140,24 @@ def _generate_examples(self, path):
126140
example = tf.train.Example.FromString(example_str)
127141
features = example.features.feature
128142
yield {
129-
"id": features["note_str"].bytes_list.value[0],
130-
"audio": np.array(
131-
features["audio"].float_list.value, dtype=np.float32),
132-
"pitch": features["pitch"].int64_list.value[0],
133-
"velocity": features["velocity"].int64_list.value[0],
143+
"id":
144+
features["note_str"].bytes_list.value[0],
145+
"audio":
146+
np.array(features["audio"].float_list.value, dtype=np.float32),
147+
"pitch":
148+
features["pitch"].int64_list.value[0],
149+
"velocity":
150+
features["velocity"].int64_list.value[0],
134151
"instrument": {
135-
"label": features["instrument_str"].bytes_list.value[0],
136-
"family": features["instrument_family_str"].bytes_list.value[0],
137-
"source": features["instrument_source_str"].bytes_list.value[0]
152+
"label":
153+
tf.compat.as_text(
154+
features["instrument_str"].bytes_list.value[0]),
155+
"family":
156+
tf.compat.as_text(
157+
features["instrument_family_str"].bytes_list.value[0]),
158+
"source":
159+
tf.compat.as_text(
160+
features["instrument_source_str"].bytes_list.value[0])
138161
},
139162
"qualities": {
140163
q: features["qualities"].int64_list.value[i]
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# coding=utf-8
2+
# Copyright 2019 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Nsynth Dataset Builder test."""
17+
from tensorflow_datasets.audio import nsynth
18+
import tensorflow_datasets.testing as tfds_test
19+
20+
21+
class NsynthTest(tfds_test.DatasetBuilderTestCase):
22+
"""Test Nsynth."""
23+
DATASET_CLASS = nsynth.Nsynth
24+
SPLITS = {"train": 3, "test": 3, "valid": 3}
25+
DL_EXTRACT_RESULT = {
26+
"train": "nsynth-train.tfrecord",
27+
"test": "nsynth-test.tfrecord",
28+
"valid": "nsynth-valid.tfrecord",
29+
"instrument_labels": "nsynth-instrument_labels.txt"
30+
}
31+
32+
33+
if __name__ == "__main__":
34+
tfds_test.test_main()

tensorflow_datasets/core/download/downloader.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@
2424
import io
2525
import os
2626
import re
27-
2827
import concurrent.futures
2928
import promise
3029
import requests
31-
import tensorflow as tf
3230

31+
from six.moves import urllib
32+
33+
import tensorflow as tf
3334
from tensorflow_datasets.core import units
3435
from tensorflow_datasets.core import utils
3536
from tensorflow_datasets.core.download import kaggle
@@ -90,7 +91,9 @@ def tqdm(self):
9091
yield
9192

9293
def download(self, url_info, destination_path):
93-
"""Download url to given path. Returns Promise -> sha256 of downloaded file.
94+
"""Download url to given path.
95+
96+
Returns Promise -> sha256 of downloaded file.
9497
9598
Args:
9699
url_info: `UrlInfo`, resource to download.
@@ -139,11 +142,21 @@ def _sync_file_copy(self, filepath, destination_path):
139142
out_path, checksum_cls=self._checksumer)
140143
return hexdigest, size
141144

145+
def _sync_ftp_download(self, url, destination_path):
146+
out_path = os.path.join(destination_path, download_util.get_file_name(url))
147+
urllib.request.urlretrieve(url, out_path)
148+
hexdigest, size = utils.read_checksum_digest(
149+
out_path, checksum_cls=self._checksumer)
150+
return hexdigest, size
151+
142152
def _sync_download(self, url, destination_path):
143153
"""Synchronous version of `download` method."""
144154
if kaggle.KaggleFile.is_kaggle_url(url):
145155
return self._sync_kaggle_download(url, destination_path)
146156

157+
if url.startswith('ftp'):
158+
return self._sync_ftp_download(url, destination_path)
159+
147160
try:
148161
# If url is on a filesystem that gfile understands, use copy. Otherwise,
149162
# use requests.
@@ -166,8 +179,7 @@ def _sync_download(self, url, destination_path):
166179
size_mb = 0
167180
unit_mb = units.MiB
168181
self._pbar_dl_size.update_total(
169-
int(response.headers.get('Content-length', 0)) // unit_mb
170-
)
182+
int(response.headers.get('Content-length', 0)) // unit_mb)
171183
with tf.io.gfile.GFile(path, 'wb') as file_:
172184
checksum = self._checksumer()
173185
for block in response.iter_content(chunk_size=io.DEFAULT_BUFFER_SIZE):

tensorflow_datasets/core/download/downloader_test.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,19 @@ def setUp(self):
6464
'get',
6565
lambda *a, **kw: _FakeResponse(self.url, self.response, self.cookies),
6666
).start()
67-
absltest.mock.patch.object(
68-
downloader.requests.Session,
69-
'get',
70-
lambda *a, **kw: _FakeResponse(self.url, self.response, self.cookies),
71-
).start()
7267
self.downloader._pbar_url = absltest.mock.MagicMock()
7368
self.downloader._pbar_dl_size = absltest.mock.MagicMock()
7469

70+
def write_fake_ftp_result(_, filename):
71+
with open(filename, 'wb') as result:
72+
result.write(self.response)
73+
74+
absltest.mock.patch.object(
75+
downloader.urllib.request,
76+
'urlretrieve',
77+
write_fake_ftp_result,
78+
).start()
79+
7580
def test_ok(self):
7681
promise = self.downloader.download(self.resource, self.tmp_dir)
7782
checksum, _ = promise.get()
@@ -122,6 +127,28 @@ def test_kaggle_api(self):
122127
with tf.io.gfile.GFile(os.path.join(self.tmp_dir, fname)) as f:
123128
self.assertEqual(fname, f.read())
124129

130+
def test_ftp(self):
131+
resource = resource_lib.Resource(
132+
url='ftp://username:password@example.com/foo.tar.gz')
133+
promise = self.downloader.download(resource, self.tmp_dir)
134+
checksum, _ = promise.get()
135+
self.assertEqual(checksum, self.resp_checksum)
136+
with open(self.path, 'rb') as result:
137+
self.assertEqual(result.read(), self.response)
138+
self.assertFalse(tf.io.gfile.exists(self.incomplete_path))
139+
140+
def test_ftp_error(self):
141+
error = downloader.urllib.error.URLError('Problem serving file.')
142+
absltest.mock.patch.object(
143+
downloader.urllib.request,
144+
'urlretrieve',
145+
side_effect=error,
146+
).start()
147+
resource = resource_lib.Resource(url='ftp://example.com/foo.tar.gz')
148+
promise = self.downloader.download(resource, self.tmp_dir)
149+
with self.assertRaises(downloader.urllib.error.URLError):
150+
promise.get()
151+
125152

126153
class GetFilenameTest(testing.TestCase):
127154

@@ -139,5 +166,6 @@ def test_headers(self):
139166
res = downloader._get_filename(resp)
140167
self.assertEqual(res, 'hello.zip')
141168

169+
142170
if __name__ == '__main__':
143171
testing.test_main()

tensorflow_datasets/core/download/kaggle.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ def competition_files(self):
101101
self._competition_name,
102102
]
103103
output = _run_kaggle_command(command, self._competition_name)
104-
return sorted([line.split(",")[0] for line in output.split("\n")[1:]])
104+
return sorted([
105+
line.split(",")[0] for line in output.split("\n")[1:] if line
106+
])
105107

106108
@utils.memoized_property
107109
def competition_urls(self):

tensorflow_datasets/image/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from tensorflow_datasets.image.image_folder import ImageLabelFolder
3131
from tensorflow_datasets.image.imagenet import Imagenet2012
3232
from tensorflow_datasets.image.lsun import Lsun
33+
from tensorflow_datasets.image.mnist import EMNIST
3334
from tensorflow_datasets.image.mnist import FashionMNIST
3435
from tensorflow_datasets.image.mnist import KMNIST
3536
from tensorflow_datasets.image.mnist import MNIST

0 commit comments

Comments
 (0)