Skip to content

Commit fd4a4f7

Browse files
pierrot0copybara-github
authored andcommitted
TFDS Issue #737: use md5 instead of siphash.
PiperOrigin-RevId: 257178144
1 parent fecbb15 commit fd4a4f7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

81 files changed

+271
-202
lines changed

setup.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
'absl-py',
4343
'attrs',
4444
'dill', # TODO(tfds): move to TESTS_REQUIRE.
45-
'siphash',
4645
'future',
4746
'numpy',
4847
'promise',
@@ -58,7 +57,6 @@
5857

5958
TESTS_REQUIRE = [
6059
'apache-beam',
61-
# 'csiphash', # https://github.com/tensorflow/datasets/issues/737
6260
'jupyter',
6361
'pytest',
6462
'pytest-xdist',
@@ -124,8 +122,6 @@
124122

125123
EXTRAS_REQUIRE = {
126124
'apache-beam': ['apache-beam'],
127-
# https://github.com/tensorflow/datasets/issues/737
128-
# 'siphash': ['csiphash'],
129125
'tensorflow': ['tensorflow>=1.13.0'],
130126
'tensorflow_gpu': ['tensorflow-gpu>=1.13.0'],
131127
'tests': TESTS_REQUIRE + all_dataset_extras,

tensorflow_datasets/core/dataset_builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -898,7 +898,8 @@ def _prepare_split(self, split_generator, max_examples_per_split):
898898
return self._prepare_split_legacy(generator, split_info)
899899
fname = "{}-{}.tfrecord".format(self.name, split_generator.name)
900900
fpath = os.path.join(self._data_dir, fname)
901-
writer = tfrecords_writer.Writer(self._example_specs, fpath)
901+
writer = tfrecords_writer.Writer(self._example_specs, fpath,
902+
hash_salt=split_generator.name)
902903
for key, record in utils.tqdm(generator, unit=" examples",
903904
total=split_info.num_examples, leave=False):
904905
example = self.info.features.encode_example(record)

tensorflow_datasets/core/hashing.py

Lines changed: 47 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -13,80 +13,75 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""Stable hashing function using SipHash (https://131002.net/siphash/).
16+
"""Stable hashing function using md5.
1717
1818
Note that the properties we are looking at here are:
19-
1- Good distribution of hashes;
19+
1- Good distribution of hashes (random uniformity);
2020
2- Speed;
2121
3- Availability as portable library, giving the same hash independently of
2222
platform.
2323
2424
Crypto level hashing is not a requirement.
2525
26-
Although SipHash being slower than alternatives (eg: CityHash), its
27-
implementation is simpler and it has a pure Python implementation, making it
28-
easier to distribute TFDS on Windows or MAC. The speed cost being only paid at
29-
data preparation time, portability wins.
30-
31-
Python3 uses SipHash internally, but since it is not exposed and `hash` function
32-
is not guaranteed to use SipHash in the future, we have to import libraries.
26+
A bit of history:
27+
- CityHash was first used. However the C implementation used complex
28+
instructions and was hard to compile on some platforms.
29+
- Siphash was chosen as a replacement, because although being slower,
30+
it has a simpler implementation and it has a pure Python implementation, making
31+
it easier to distribute TFDS on Windows or MAC. However, the used library
32+
(reference C implementation wrapped using cffi) crashed the python interpreter
33+
on py3 with tf1.13.
34+
- So md5, although being slower that the two above works everywhere and is
35+
still faster than a pure python implementation of siphash.
3336
3437
Changing the hash function should be done thoughfully, as it would change the
3538
order of datasets (and thus sets of records when using slicing API). If done,
3639
all datasets would need to have their major version bumped.
3740
3841
Note that if we were to find a dataset for which two different keys give the
3942
same hash (collision), a solution could be to append the key to its hash.
43+
44+
The split name is being used as salt to avoid having the same keys in two splits
45+
result in same order.
4046
"""
4147

4248
from __future__ import absolute_import
4349
from __future__ import division
4450
from __future__ import print_function
4551

46-
import struct
52+
import hashlib
4753

48-
import siphash
4954
import six
55+
import tensorflow as tf
56+
57+
58+
def _to_bytes(data):
59+
if not isinstance(data, (six.string_types, bytes)):
60+
data = str(data)
61+
return tf.compat.as_bytes(data)
62+
63+
64+
class Hasher(object):
65+
"""Hasher: to initialize a md5 with salt."""
66+
67+
def __init__(self, salt):
68+
self._md5 = hashlib.md5(_to_bytes(salt))
69+
70+
def hash_key(self, key):
71+
"""Returns 128 bits hash of given key.
5072
51-
_CSIPHASH_AVAILABLE = False
52-
try:
53-
import csiphash # pylint: disable=g-import-not-at-top
54-
_CSIPHASH_AVAILABLE = True
55-
except ImportError:
56-
pass
57-
58-
# SipHash needs a 16 bits key.
59-
_SECRET = b'\0' * 16
60-
61-
62-
def _siphash(data):
63-
if _CSIPHASH_AVAILABLE:
64-
hash_bytes = csiphash.siphash24(_SECRET, data)
65-
# Equivalent to `int.from_bytes(hash_bytes, sys.byteorder)` in py3,
66-
# but py2 compatible:
67-
return struct.unpack('<Q', hash_bytes)[0]
68-
else:
69-
return siphash.SipHash24(_SECRET, data).hash()
70-
71-
72-
def hash_key(key):
73-
"""Returns 64 bits hash of given key.
74-
75-
Args:
76-
key (bytes, string or anything convertible to a string): key to be hashed.
77-
If the key is a string, it will be encoded to bytes using utf-8.
78-
If the key is neither a string nor bytes, it will be converted to a str,
79-
then to bytes.
80-
This means that `"1"` (str) and `1` (int) will have the same hash. The
81-
intent of the hash being to shuffle keys, it is recommended that all keys
82-
of a given set to shuffle use a single type.
83-
84-
Returns:
85-
64 bits integer, hash of key.
86-
"""
87-
if not isinstance(key, (six.string_types, bytes)):
88-
key = str(key)
89-
if not isinstance(key, bytes):
90-
key = key.encode('utf-8')
91-
return _siphash(key)
73+
Args:
74+
key (bytes, string or anything convertible to a string): key to be hashed.
75+
If the key is a string, it will be encoded to bytes using utf-8.
76+
If the key is neither a string nor bytes, it will be converted to a str,
77+
then to bytes.
78+
This means that `"1"` (str) and `1` (int) will have the same hash. The
79+
intent of the hash being to shuffle keys, it is recommended that all
80+
keys of a given set to shuffle use a single type.
9281
82+
Returns:
83+
128 bits integer, hash of key.
84+
"""
85+
md5 = self._md5.copy()
86+
md5.update(_to_bytes(key))
87+
return int(md5.hexdigest(), 16)

tensorflow_datasets/core/hashing_test.py

Lines changed: 8 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -19,89 +19,23 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22-
import string
23-
2422
from tensorflow_datasets import testing
2523
from tensorflow_datasets.core import hashing
2624

27-
EXPECTED_ASCII_HASHES = [
28-
10863254463029944905, 17270894748891556580, 1421958803217889556,
29-
2755936516345535118, 10743397483099227609, 14847904966537228119,
30-
8418086691718335568, 11915783202300667430, 9716548853924804360,
31-
1754241953508342506, 3694815635762763843, 3596675173038026136,
32-
8070480893871644201, 10828895317094483299, 6744895540894632903,
33-
14841125270424448140, 17688157251436176611, 233329666533924104,
34-
12405195962649360880, 6834955004197680363, 13969673605019971927,
35-
7535964442537143799, 4233129741512941149, 6982149393400456965,
36-
6390938501586896152, 10691964138414231865, 2507792285634992701,
37-
8595505165718044977, 12668683514780208634, 4572719121647279127,
38-
3793178347273963799, 12176803637319179781, 8203705581228860887,
39-
6621546828295374121, 10489668860986262979, 2912571526136585784,
40-
7938998531316856084, 9873236742883981124, 11452724586647173125,
41-
17125311730203373552, 14019180516172084896, 4703666581826301487,
42-
855367633508639125, 14196452709219207778, 12230774050365420318,
43-
2537184170178946005, 7687283993009238581, 8820003036235826491,
44-
17283458755135480410, 17416811303727010451, 5922709015603821319,
45-
14428075881971433291]
46-
47-
EXPECTED_INT_HASHES = [
48-
13163753087252323914, 5003827105613308882, 11449545338359147399,
49-
3672830208859661989, 5406800756778728304, 14481025309921804611,
50-
1609946449970207933, 8255655750251093705, 8491335656787965458,
51-
614618433269745394, 13284406229155908416, 15288099744915406683,
52-
15843726386018726000, 5119555216682145351, 18142174963634300418,
53-
9628501536262262909, 6944376202371927192, 14434159142587098114,
54-
9198994799148619777, 10976198626365556062, 12201794547984161272,
55-
3793153257673850654, 17671034425519684114, 1052793885524652273,
56-
6624479012289984384, 8863054234635171593, 4965745346604460244,
57-
9391177234155550519, 4717670777148145883, 17524121804784260174,
58-
11037926627716248709, 6960985957824457329, 12195906204051370437,
59-
10975328135781691390, 5730073803446725122, 13712792850809427923,
60-
4455483863044151629, 3518672581294300691, 15747605586304671771,
61-
13668305533495453291, 4654232860820002596, 10574044313476412487,
62-
11212237458977771261, 15365614270461858889, 13872585532843456912,
63-
17241372482471269826, 3462651383276316179, 12647419365920702661,
64-
1995464140987078702, 1972561720831881829, 3955328643960597520,
65-
1027369990565220197, 11322815499836299743, 5956980248780520012,
66-
18278096046037631584, 5241067853637136955, 7630275157338448032,
67-
1913046367367276708, 7440126551217879907, 7220748048216444121,
68-
5064805925883037540, 1318705738537093225, 2730963110225791297,
69-
6920195161195209846, 4682001368639242156, 9166607120404080113,
70-
11268721706256216334, 5379201623047735445, 15999685243572303930,
71-
13046608731927560566, 16276928149450612660, 16298571539550440629,
72-
17035045450101282343, 14240119263724925078, 9965315260615748500,
73-
14921974451741066715, 3620887669157180415, 14499246414755411500,
74-
188410546870139183, 14101909720529780551, 1623775225152586541,
75-
1826999275929156985, 5289921295512723016, 151075781360207052,
76-
17598366794955569210, 1171265316432012145, 104641658363814304,
77-
9264688353594391671, 458105873437653640, 1830791798008018143,
78-
15150529348822956655, 7610023982731034430, 12031109555968877759,
79-
2814999125315447998, 14537302745610214253, 14150033554738292901,
80-
1316273336836242886, 4973610113424219627, 11435740729903845598,
81-
4598673630928554393]
82-
8325

8426
class HashingTest(testing.TestCase):
8527

86-
def _assert_hashes(self, keys, expected_hashes):
87-
hashes = [hashing.hash_key(k) for k in keys]
88-
self.assertEqual(hashes, expected_hashes)
89-
9028
def test_ints(self):
91-
ints = list(range(100))
92-
hashing._CSIPHASH_AVAILABLE = False
93-
self._assert_hashes(ints, EXPECTED_INT_HASHES)
94-
# https://github.com/tensorflow/datasets/issues/737
95-
# hashing._CSIPHASH_AVAILABLE = True
96-
# self._assert_hashes(ints, EXPECTED_INT_HASHES)
29+
hasher = hashing.Hasher(salt='')
30+
res = hasher.hash_key(0)
31+
self.assertEqual(res, 276215275525073243129443018166533317850)
32+
res = hasher.hash_key(123455678901234567890)
33+
self.assertEqual(res, 6876359009333865997613257802033240610)
9734

9835
def test_ascii(self):
99-
letters = string.ascii_lowercase + string.ascii_uppercase
100-
hashing._CSIPHASH_AVAILABLE = False
101-
self._assert_hashes(letters, EXPECTED_ASCII_HASHES)
102-
# https://github.com/tensorflow/datasets/issues/737
103-
# hashing._CSIPHASH_AVAILABLE = True
104-
# self._assert_hashes(letters, EXPECTED_ASCII_HASHES)
36+
hasher = hashing.Hasher(salt='')
37+
res = hasher.hash_key('foo')
38+
self.assertEqual(res, 229609063533823256041787889330700985560)
10539

10640

10741
if __name__ == '__main__':

tensorflow_datasets/core/shuffle.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929

3030
from tensorflow_datasets.core import hashing
3131

32-
HKEY_SIZE = 64 # Hash of keys is 64 bits.
33-
3432
# Approximately how much data to store in memory before writing to disk.
3533
# If the amount of data to shuffle is < MAX_MEM_BUFFER_SIZE, no intermediary
3634
# data is written to disk.
@@ -47,6 +45,21 @@
4745
# the number of buckets might warrant some changes in implementation.
4846
BUCKETS_NUMBER = 1000 # Number of buckets to pre-sort and hold generated data.
4947

48+
HKEY_SIZE = 128 # Hash of keys is 128 bits (md5).
49+
HKEY_SIZE_BYTES = HKEY_SIZE // 8
50+
51+
52+
def _hkey_to_bytes(hkey):
53+
"""Converts 128 bits integer hkey to binary representation."""
54+
max_int64 = 0xFFFFFFFFFFFFFFFF
55+
return struct.pack('QQ', (hkey >> 64) & max_int64, hkey & max_int64)
56+
57+
58+
def _read_hkey(buff):
59+
"""Reads from fobj and returns hkey (128 bites integer)."""
60+
a, b = struct.unpack('QQ', buff)
61+
return (a << 64) | b
62+
5063

5164
def _get_shard(hkey, shards_number):
5265
"""Returns shard number (int) for given hashed key (int)."""
@@ -93,7 +106,8 @@ def add(self, key, data):
93106
if not self._fobj:
94107
self._fobj = tf.io.gfile.GFile(self._path, mode='wb')
95108
data_size = len(data)
96-
self._fobj.write(struct.pack('LL', key, data_size))
109+
self._fobj.write(_hkey_to_bytes(key))
110+
self._fobj.write(struct.pack('L', data_size))
97111
self._fobj.write(data)
98112
self._length += 1
99113
self._size += data_size
@@ -107,10 +121,12 @@ def read_values(self):
107121
res = []
108122
with tf.io.gfile.GFile(path, 'rb') as fobj:
109123
while True:
110-
buff = fobj.read(16)
124+
buff = fobj.read(HKEY_SIZE_BYTES)
111125
if not buff:
112126
break
113-
hkey, size = struct.unpack('LL', buff)
127+
hkey = _read_hkey(buff)
128+
size_bytes = fobj.read(8)
129+
size = struct.unpack('L', size_bytes)[0]
114130
data = fobj.read(size)
115131
res.append((hkey, data))
116132
return res
@@ -121,14 +137,17 @@ def del_file(self):
121137

122138

123139
class Shuffler(object):
124-
"""Stores data in temp buckets, restitute it shuffled.
140+
"""Stores data in temp buckets, restitute it shuffled."""
125141

126-
Args:
127-
dirpath: directory in which to store temporary files.
128-
"""
142+
def __init__(self, dirpath, hash_salt):
143+
"""Initialize Shuffler.
129144
130-
def __init__(self, dirpath):
145+
Args:
146+
dirpath (string): directory in which to store temporary files.
147+
hash_salt (string or bytes): salt to hash keys.
148+
"""
131149
grp_name = uuid.uuid4()
150+
self._hasher = hashing.Hasher(hash_salt)
132151
self._buckets = [
133152
_Bucket(os.path.join(dirpath, 'bucket_%s_%03d.tmp' % (grp_name, i)))
134153
for i in range(BUCKETS_NUMBER)]
@@ -162,7 +181,7 @@ def add(self, key, data):
162181
if not isinstance(data, six.binary_type):
163182
raise AssertionError('Only bytes (not %s) can be stored in Shuffler!' % (
164183
type(data)))
165-
hkey = hashing.hash_key(key)
184+
hkey = self._hasher.hash_key(key)
166185
self._total_bytes += len(data)
167186
if self._in_memory:
168187
self._add_to_mem_buffer(hkey, data)

0 commit comments

Comments
 (0)