Skip to content

Commit 87cc01d

Browse files
pierrot0copybara-github
authored andcommitted
TFDS shuffling: swap cityhash for siphash (Issue #653, Issue #690).
PiperOrigin-RevId: 255375385
1 parent 762cc55 commit 87cc01d

File tree

6 files changed

+214
-20
lines changed

6 files changed

+214
-20
lines changed

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
'absl-py',
4343
'attrs',
4444
'dill', # TODO(tfds): move to TESTS_REQUIRE.
45-
'cityhash',
45+
'siphash',
4646
'future',
4747
'numpy',
4848
'promise',
@@ -58,6 +58,7 @@
5858

5959
TESTS_REQUIRE = [
6060
'apache-beam',
61+
'csiphash',
6162
'jupyter',
6263
'pytest',
6364
]
@@ -122,6 +123,7 @@
122123

123124
EXTRAS_REQUIRE = {
124125
'apache-beam': ['apache-beam'],
126+
'siphash': ['csiphash'],
125127
'tensorflow': ['tensorflow>=1.13.0'],
126128
'tensorflow_gpu': ['tensorflow-gpu>=1.13.0'],
127129
'tests': TESTS_REQUIRE + all_dataset_extras,

tensorflow_datasets/core/hashing.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# coding=utf-8
2+
# Copyright 2019 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Stable hashing function using SipHash (https://131002.net/siphash/).
17+
18+
Note that the properties we are looking at here are:
19+
1- Good distribution of hashes;
20+
2- Speed;
21+
3- Availability as portable library, giving the same hash independently of
22+
platform.
23+
24+
Crypto level hashing is not a requirement.
25+
26+
Although SipHash being slower than alternatives (eg: CityHash), its
27+
implementation is simpler and it has a pure Python implementation, making it
28+
easier to distribute TFDS on Windows or MAC. The speed cost being only paid at
29+
data preparation time, portability wins.
30+
31+
Python3 uses SipHash internally, but since it is not exposed and `hash` function
32+
is not guaranteed to use SipHash in the future, we have to import libraries.
33+
34+
Changing the hash function should be done thoughfully, as it would change the
35+
order of datasets (and thus sets of records when using slicing API). If done,
36+
all datasets would need to have their major version bumped.
37+
38+
Note that if we were to find a dataset for which two different keys give the
39+
same hash (collision), a solution could be to append the key to its hash.
40+
"""
41+
42+
from __future__ import absolute_import
43+
from __future__ import division
44+
from __future__ import print_function
45+
46+
import struct
47+
48+
import siphash
49+
import six
50+
51+
_CSIPHASH_AVAILABLE = False
52+
try:
53+
import csiphash # pylint: disable=g-import-not-at-top
54+
_CSIPHASH_AVAILABLE = True
55+
except ImportError:
56+
pass
57+
58+
# SipHash needs a 16 bits key.
59+
_SECRET = b'\0' * 16
60+
61+
62+
def _siphash(data):
63+
if _CSIPHASH_AVAILABLE:
64+
hash_bytes = csiphash.siphash24(_SECRET, data)
65+
# Equivalent to `int.from_bytes(hash_bytes, sys.byteorder)` in py3,
66+
# but py2 compatible:
67+
return struct.unpack('<Q', hash_bytes)[0]
68+
else:
69+
return siphash.SipHash24(_SECRET, data).hash()
70+
71+
72+
def hash_key(key):
73+
"""Returns 64 bits hash of given key.
74+
75+
Args:
76+
key (bytes, string or anything convertible to a string): key to be hashed.
77+
If the key is a string, it will be encoded to bytes using utf-8.
78+
If the key is neither a string nor bytes, it will be converted to a str,
79+
then to bytes.
80+
This means that `"1"` (str) and `1` (int) will have the same hash. The
81+
intent of the hash being to shuffle keys, it is recommended that all keys
82+
of a given set to shuffle use a single type.
83+
84+
Returns:
85+
64 bits integer, hash of key.
86+
"""
87+
if not isinstance(key, (six.string_types, bytes)):
88+
key = str(key)
89+
if not isinstance(key, bytes):
90+
key = key.encode('utf-8')
91+
return _siphash(key)
92+
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# coding=utf-8
2+
# Copyright 2019 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for tensorflow_datasets.core.hashing."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import string
23+
24+
from tensorflow_datasets import testing
25+
from tensorflow_datasets.core import hashing
26+
27+
EXPECTED_ASCII_HASHES = [
28+
10863254463029944905, 17270894748891556580, 1421958803217889556,
29+
2755936516345535118, 10743397483099227609, 14847904966537228119,
30+
8418086691718335568, 11915783202300667430, 9716548853924804360,
31+
1754241953508342506, 3694815635762763843, 3596675173038026136,
32+
8070480893871644201, 10828895317094483299, 6744895540894632903,
33+
14841125270424448140, 17688157251436176611, 233329666533924104,
34+
12405195962649360880, 6834955004197680363, 13969673605019971927,
35+
7535964442537143799, 4233129741512941149, 6982149393400456965,
36+
6390938501586896152, 10691964138414231865, 2507792285634992701,
37+
8595505165718044977, 12668683514780208634, 4572719121647279127,
38+
3793178347273963799, 12176803637319179781, 8203705581228860887,
39+
6621546828295374121, 10489668860986262979, 2912571526136585784,
40+
7938998531316856084, 9873236742883981124, 11452724586647173125,
41+
17125311730203373552, 14019180516172084896, 4703666581826301487,
42+
855367633508639125, 14196452709219207778, 12230774050365420318,
43+
2537184170178946005, 7687283993009238581, 8820003036235826491,
44+
17283458755135480410, 17416811303727010451, 5922709015603821319,
45+
14428075881971433291]
46+
47+
EXPECTED_INT_HASHES = [
48+
13163753087252323914, 5003827105613308882, 11449545338359147399,
49+
3672830208859661989, 5406800756778728304, 14481025309921804611,
50+
1609946449970207933, 8255655750251093705, 8491335656787965458,
51+
614618433269745394, 13284406229155908416, 15288099744915406683,
52+
15843726386018726000, 5119555216682145351, 18142174963634300418,
53+
9628501536262262909, 6944376202371927192, 14434159142587098114,
54+
9198994799148619777, 10976198626365556062, 12201794547984161272,
55+
3793153257673850654, 17671034425519684114, 1052793885524652273,
56+
6624479012289984384, 8863054234635171593, 4965745346604460244,
57+
9391177234155550519, 4717670777148145883, 17524121804784260174,
58+
11037926627716248709, 6960985957824457329, 12195906204051370437,
59+
10975328135781691390, 5730073803446725122, 13712792850809427923,
60+
4455483863044151629, 3518672581294300691, 15747605586304671771,
61+
13668305533495453291, 4654232860820002596, 10574044313476412487,
62+
11212237458977771261, 15365614270461858889, 13872585532843456912,
63+
17241372482471269826, 3462651383276316179, 12647419365920702661,
64+
1995464140987078702, 1972561720831881829, 3955328643960597520,
65+
1027369990565220197, 11322815499836299743, 5956980248780520012,
66+
18278096046037631584, 5241067853637136955, 7630275157338448032,
67+
1913046367367276708, 7440126551217879907, 7220748048216444121,
68+
5064805925883037540, 1318705738537093225, 2730963110225791297,
69+
6920195161195209846, 4682001368639242156, 9166607120404080113,
70+
11268721706256216334, 5379201623047735445, 15999685243572303930,
71+
13046608731927560566, 16276928149450612660, 16298571539550440629,
72+
17035045450101282343, 14240119263724925078, 9965315260615748500,
73+
14921974451741066715, 3620887669157180415, 14499246414755411500,
74+
188410546870139183, 14101909720529780551, 1623775225152586541,
75+
1826999275929156985, 5289921295512723016, 151075781360207052,
76+
17598366794955569210, 1171265316432012145, 104641658363814304,
77+
9264688353594391671, 458105873437653640, 1830791798008018143,
78+
15150529348822956655, 7610023982731034430, 12031109555968877759,
79+
2814999125315447998, 14537302745610214253, 14150033554738292901,
80+
1316273336836242886, 4973610113424219627, 11435740729903845598,
81+
4598673630928554393]
82+
83+
84+
class HashingTest(testing.TestCase):
85+
86+
def _assert_hashes(self, keys, expected_hashes):
87+
hashes = [hashing.hash_key(k) for k in keys]
88+
self.assertEqual(hashes, expected_hashes)
89+
90+
def test_ints(self):
91+
ints = list(range(100))
92+
hashing._CSIPHASH_AVAILABLE = False
93+
self._assert_hashes(ints, EXPECTED_INT_HASHES)
94+
hashing._CSIPHASH_AVAILABLE = True
95+
self._assert_hashes(ints, EXPECTED_INT_HASHES)
96+
97+
def test_ascii(self):
98+
letters = string.ascii_lowercase + string.ascii_uppercase
99+
hashing._CSIPHASH_AVAILABLE = False
100+
self._assert_hashes(letters, EXPECTED_ASCII_HASHES)
101+
hashing._CSIPHASH_AVAILABLE = True
102+
self._assert_hashes(letters, EXPECTED_ASCII_HASHES)
103+
104+
105+
if __name__ == '__main__':
106+
testing.test_main()

tensorflow_datasets/core/shuffle.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
import struct
2525
import uuid
2626

27-
import cityhash
2827
import six
2928
import tensorflow as tf
3029

30+
from tensorflow_datasets.core import hashing
3131

3232
HKEY_SIZE = 64 # Hash of keys is 64 bits.
3333

@@ -48,13 +48,6 @@
4848
BUCKETS_NUMBER = 1000 # Number of buckets to pre-sort and hold generated data.
4949

5050

51-
def _get_hashed_key(key):
52-
"""Returns hash (int) for given key."""
53-
if not isinstance(key, (six.string_types, bytes)):
54-
key = str(key)
55-
return cityhash.CityHash64(key)
56-
57-
5851
def _get_shard(hkey, shards_number):
5952
"""Returns shard number (int) for given hashed key (int)."""
6053
# We purposely do not use modulo (%) to keep global order across shards.
@@ -73,12 +66,14 @@ class _Bucket(object):
7366
key1 (8 bytes) | size1 (8 bytes) | data1 (size1 bytes) |
7467
key2 (8 bytes) | size2 (8 bytes) | data2 (size2 bytes) |
7568
...
76-
77-
Args:
78-
path (str): where to write the bucket file.
7969
"""
8070

8171
def __init__(self, path):
72+
"""Initialize a _Bucket instance.
73+
74+
Args:
75+
path (str): where to write the bucket file.
76+
"""
8277
self._path = path
8378
self._fobj = None
8479
self._length = 0
@@ -167,7 +162,7 @@ def add(self, key, data):
167162
if not isinstance(data, six.binary_type):
168163
raise AssertionError('Only bytes (not %s) can be stored in Shuffler!' % (
169164
type(data)))
170-
hkey = _get_hashed_key(key)
165+
hkey = hashing.hash_key(key)
171166
self._total_bytes += len(data)
172167
if self._in_memory:
173168
self._add_to_mem_buffer(hkey, data)

tensorflow_datasets/core/shuffle_test.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,15 @@
3838
]
3939

4040
_ORDERED_ITEMS = [
41-
b'over',
42-
b'jumps',
43-
b'quick ',
4441
b' dog.',
42+
b'over',
43+
b'brown',
4544
b'The',
45+
b' fox ',
4646
b' the ',
4747
b'lazy',
48-
b'brown',
49-
b' fox ',
48+
b'quick ',
49+
b'jumps',
5050
]
5151

5252
_TOTAL_SIZE = sum(len(rec) for rec in _ORDERED_ITEMS)
@@ -97,7 +97,6 @@ def test_duplicate_key(self):
9797
shuffler.add(2, b'b')
9898
shuffler.add(1, b'c')
9999
iterator = iter(shuffler)
100-
self.assertEqual(next(iterator), b'b')
101100
self.assertEqual(next(iterator), b'a')
102101
with self.assertRaisesWithPredicateMatch(
103102
AssertionError, 'Two records share the same hashed key!'):

tensorflow_datasets/core/tfrecords_writer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def test_write(self):
107107
self.assertEqual(written_files,
108108
['foo.tfrecord-0000%s-of-00005' % i for i in range(5)])
109109
self.assertEqual(all_recs, [
110-
[b'f', b'e'], [b'b'], [b'a', b'g'], [b'h'], [b'c', b'd'],
110+
[b'f', b'c'], [b'a'], [b'd', b'g'], [b'h'], [b'b', b'e'],
111111
])
112112

113113
if __name__ == '__main__':

0 commit comments

Comments
 (0)