Skip to content

Commit 7c23cc9

Browse files
authored
Merge pull request #37 from raphlinus/master
Use minimal perfect hashing for lookups
2 parents f24cb8a + 40f9ba6 commit 7c23cc9

File tree

8 files changed

+21617
-10783
lines changed

8 files changed

+21617
-10783
lines changed

scripts/unicode.py

Lines changed: 100 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
# Since this should not require frequent updates, we just store this
1919
# out-of-line and check the unicode.rs file into git.
2020
import collections
21-
import requests
21+
import urllib.request
2222

2323
UNICODE_VERSION = "9.0.0"
2424
UCD_URL = "https://www.unicode.org/Public/%s/ucd/" % UNICODE_VERSION
@@ -68,9 +68,9 @@ def __init__(self):
6868

6969
def stats(name, table):
7070
count = sum(len(v) for v in table.values())
71-
print "%s: %d chars => %d decomposed chars" % (name, len(table), count)
71+
print("%s: %d chars => %d decomposed chars" % (name, len(table), count))
7272

73-
print "Decomposition table stats:"
73+
print("Decomposition table stats:")
7474
stats("Canonical decomp", self.canon_decomp)
7575
stats("Compatible decomp", self.compat_decomp)
7676
stats("Canonical fully decomp", self.canon_fully_decomp)
@@ -79,8 +79,8 @@ def stats(name, table):
7979
self.ss_leading, self.ss_trailing = self._compute_stream_safe_tables()
8080

8181
def _fetch(self, filename):
82-
resp = requests.get(UCD_URL + filename)
83-
return resp.text
82+
resp = urllib.request.urlopen(UCD_URL + filename)
83+
return resp.read().decode('utf-8')
8484

8585
def _load_unicode_data(self):
8686
self.combining_classes = {}
@@ -234,7 +234,7 @@ def _decompose(char_int, compatible):
234234
# need to store their overlap when they agree. When they don't agree,
235235
# store the decomposition in the compatibility table since we'll check
236236
# that first when normalizing to NFKD.
237-
assert canon_fully_decomp <= compat_fully_decomp
237+
assert set(canon_fully_decomp) <= set(compat_fully_decomp)
238238

239239
for ch in set(canon_fully_decomp) & set(compat_fully_decomp):
240240
if canon_fully_decomp[ch] == compat_fully_decomp[ch]:
@@ -284,27 +284,37 @@ def _compute_stream_safe_tables(self):
284284

285285
return leading_nonstarters, trailing_nonstarters
286286

287-
hexify = lambda c: hex(c)[2:].upper().rjust(4, '0')
287+
hexify = lambda c: '{:04X}'.format(c)
288288

289-
def gen_combining_class(combining_classes, out):
290-
out.write("#[inline]\n")
291-
out.write("pub fn canonical_combining_class(c: char) -> u8 {\n")
292-
out.write(" match c {\n")
293-
294-
for char, combining_class in sorted(combining_classes.items()):
295-
out.write(" '\u{%s}' => %s,\n" % (hexify(char), combining_class))
289+
def gen_mph_data(name, d, kv_type, kv_callback):
290+
(salt, keys) = minimal_perfect_hash(d)
291+
out.write("pub(crate) const %s_SALT: &[u16] = &[\n" % name.upper())
292+
for s in salt:
293+
out.write(" 0x{:x},\n".format(s))
294+
out.write("];\n")
295+
out.write("pub(crate) const {}_KV: &[{}] = &[\n".format(name.upper(), kv_type))
296+
for k in keys:
297+
out.write(" {},\n".format(kv_callback(k)))
298+
out.write("];\n\n")
296299

297-
out.write(" _ => 0,\n")
298-
out.write(" }\n")
299-
out.write("}\n")
300+
def gen_combining_class(combining_classes, out):
301+
gen_mph_data('canonical_combining_class', combining_classes, 'u32',
302+
lambda k: "0x{:X}".format(int(combining_classes[k]) | (k << 8)))
300303

301304
def gen_composition_table(canon_comp, out):
302-
out.write("#[inline]\n")
303-
out.write("pub fn composition_table(c1: char, c2: char) -> Option<char> {\n")
305+
table = {}
306+
for (c1, c2), c3 in canon_comp.items():
307+
if c1 < 0x10000 and c2 < 0x10000:
308+
table[(c1 << 16) | c2] = c3
309+
(salt, keys) = minimal_perfect_hash(table)
310+
gen_mph_data('COMPOSITION_TABLE', table, '(u32, char)',
311+
lambda k: "(0x%s, '\\u{%s}')" % (hexify(k), hexify(table[k])))
312+
313+
out.write("pub(crate) fn composition_table_astral(c1: char, c2: char) -> Option<char> {\n")
304314
out.write(" match (c1, c2) {\n")
305-
306315
for (c1, c2), c3 in sorted(canon_comp.items()):
307-
out.write(" ('\u{%s}', '\u{%s}') => Some('\u{%s}'),\n" % (hexify(c1), hexify(c2), hexify(c3)))
316+
if c1 >= 0x10000 and c2 >= 0x10000:
317+
out.write(" ('\\u{%s}', '\\u{%s}') => Some('\\u{%s}'),\n" % (hexify(c1), hexify(c2), hexify(c3)))
308318

309319
out.write(" _ => None,\n")
310320
out.write(" }\n")
@@ -313,23 +323,9 @@ def gen_composition_table(canon_comp, out):
313323
def gen_decomposition_tables(canon_decomp, compat_decomp, out):
314324
tables = [(canon_decomp, 'canonical'), (compat_decomp, 'compatibility')]
315325
for table, name in tables:
316-
out.write("#[inline]\n")
317-
out.write("pub fn %s_fully_decomposed(c: char) -> Option<&'static [char]> {\n" % name)
318-
# The "Some" constructor is around the match statement here, because
319-
# putting it into the individual arms would make the item_bodies
320-
# checking of rustc takes almost twice as long, and it's already pretty
321-
# slow because of the huge number of match arms and the fact that there
322-
# is a borrow inside each arm
323-
out.write(" Some(match c {\n")
324-
325-
for char, chars in sorted(table.items()):
326-
d = ", ".join("'\u{%s}'" % hexify(c) for c in chars)
327-
out.write(" '\u{%s}' => &[%s],\n" % (hexify(char), d))
328-
329-
out.write(" _ => return None,\n")
330-
out.write(" })\n")
331-
out.write("}\n")
332-
out.write("\n")
326+
gen_mph_data(name + '_decomposed', table, "(u32, &'static [char])",
327+
lambda k: "(0x{:x}, &[{}])".format(k,
328+
", ".join("'\\u{%s}'" % hexify(c) for c in table[k])))
333329

334330
def gen_qc_match(prop_table, out):
335331
out.write(" match c {\n")
@@ -371,40 +367,25 @@ def gen_nfkd_qc(prop_tables, out):
371367
out.write("}\n")
372368

373369
def gen_combining_mark(general_category_mark, out):
374-
out.write("#[inline]\n")
375-
out.write("pub fn is_combining_mark(c: char) -> bool {\n")
376-
out.write(" match c {\n")
377-
378-
for char in general_category_mark:
379-
out.write(" '\u{%s}' => true,\n" % hexify(char))
380-
381-
out.write(" _ => false,\n")
382-
out.write(" }\n")
383-
out.write("}\n")
370+
gen_mph_data('combining_mark', general_category_mark, 'u32',
371+
lambda k: '0x{:04x}'.format(k))
384372

385373
def gen_stream_safe(leading, trailing, out):
374+
# This could be done as a hash but the table is very small.
386375
out.write("#[inline]\n")
387376
out.write("pub fn stream_safe_leading_nonstarters(c: char) -> usize {\n")
388377
out.write(" match c {\n")
389378

390-
for char, num_leading in leading.items():
391-
out.write(" '\u{%s}' => %d,\n" % (hexify(char), num_leading))
379+
for char, num_leading in sorted(leading.items()):
380+
out.write(" '\\u{%s}' => %d,\n" % (hexify(char), num_leading))
392381

393382
out.write(" _ => 0,\n")
394383
out.write(" }\n")
395384
out.write("}\n")
396385
out.write("\n")
397386

398-
out.write("#[inline]\n")
399-
out.write("pub fn stream_safe_trailing_nonstarters(c: char) -> usize {\n")
400-
out.write(" match c {\n")
401-
402-
for char, num_trailing in trailing.items():
403-
out.write(" '\u{%s}' => %d,\n" % (hexify(char), num_trailing))
404-
405-
out.write(" _ => 0,\n")
406-
out.write(" }\n")
407-
out.write("}\n")
387+
gen_mph_data('trailing_nonstarters', trailing, 'u32',
388+
lambda k: "0x{:X}".format(int(trailing[k]) | (k << 8)))
408389

409390
def gen_tests(tests, out):
410391
out.write("""#[derive(Debug)]
@@ -419,7 +400,7 @@ def gen_tests(tests, out):
419400
""")
420401

421402
out.write("pub const NORMALIZATION_TESTS: &[NormalizationTest] = &[\n")
422-
str_literal = lambda s: '"%s"' % "".join("\u{%s}" % c for c in s)
403+
str_literal = lambda s: '"%s"' % "".join("\\u{%s}" % c for c in s)
423404

424405
for test in tests:
425406
out.write(" NormalizationTest {\n")
@@ -432,9 +413,65 @@ def gen_tests(tests, out):
432413

433414
out.write("];\n")
434415

416+
# Guaranteed to be less than n.
417+
def my_hash(x, salt, n):
418+
# This is hash based on the theory that multiplication is efficient
419+
mask_32 = 0xffffffff
420+
y = ((x + salt) * 2654435769) & mask_32
421+
y ^= (x * 0x31415926) & mask_32
422+
return (y * n) >> 32
423+
424+
# Compute minimal perfect hash function, d can be either a dict or list of keys.
425+
def minimal_perfect_hash(d):
426+
n = len(d)
427+
buckets = dict((h, []) for h in range(n))
428+
for key in d:
429+
h = my_hash(key, 0, n)
430+
buckets[h].append(key)
431+
bsorted = [(len(buckets[h]), h) for h in range(n)]
432+
bsorted.sort(reverse = True)
433+
claimed = [False] * n
434+
salts = [0] * n
435+
keys = [0] * n
436+
for (bucket_size, h) in bsorted:
437+
# Note: the traditional perfect hashing approach would also special-case
438+
# bucket_size == 1 here and assign any empty slot, rather than iterating
439+
# until rehash finds an empty slot. But we're not doing that so we can
440+
# avoid the branch.
441+
if bucket_size == 0:
442+
break
443+
else:
444+
for salt in range(1, 32768):
445+
rehashes = [my_hash(key, salt, n) for key in buckets[h]]
446+
# Make sure there are no rehash collisions within this bucket.
447+
if all(not claimed[hash] for hash in rehashes):
448+
if len(set(rehashes)) < bucket_size:
449+
continue
450+
salts[h] = salt
451+
for key in buckets[h]:
452+
rehash = my_hash(key, salt, n)
453+
claimed[rehash] = True
454+
keys[rehash] = key
455+
break
456+
if salts[h] == 0:
457+
print("minimal perfect hashing failed")
458+
# Note: if this happens (because of unfortunate data), then there are
459+
# a few things that could be done. First, the hash function could be
460+
# tweaked. Second, the bucket order could be scrambled (especially the
461+
# singletons). Right now, the buckets are sorted, which has the advantage
462+
# of being deterministic.
463+
#
464+
# As a more extreme approach, the singleton bucket optimization could be
465+
# applied (give the direct address for singleton buckets, rather than
466+
# relying on a rehash). That is definitely the more standard approach in
467+
# the minimal perfect hashing literature, but in testing the branch was a
468+
# significant slowdown.
469+
exit(1)
470+
return (salts, keys)
471+
435472
if __name__ == '__main__':
436473
data = UnicodeData()
437-
with open("tables.rs", "w") as out:
474+
with open("tables.rs", "w", newline = "\n") as out:
438475
out.write(PREAMBLE)
439476
out.write("use quick_check::IsNormalized;\n")
440477
out.write("use quick_check::IsNormalized::*;\n")
@@ -470,6 +507,6 @@ def gen_tests(tests, out):
470507
gen_stream_safe(data.ss_leading, data.ss_trailing, out)
471508
out.write("\n")
472509

473-
with open("normalization_tests.rs", "w") as out:
510+
with open("normalization_tests.rs", "w", newline = "\n") as out:
474511
out.write(PREAMBLE)
475512
gen_tests(data.norm_tests, out)

src/lib.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ pub use stream_safe::StreamSafe;
6565
use std::str::Chars;
6666

6767
mod decompose;
68+
mod lookups;
6869
mod normalize;
70+
mod perfect_hash;
6971
mod recompose;
7072
mod quick_check;
7173
mod stream_safe;
@@ -80,11 +82,7 @@ mod normalization_tests;
8082
pub mod char {
8183
pub use normalize::{decompose_canonical, decompose_compatible, compose};
8284

83-
/// Look up the canonical combining class of a character.
84-
pub use tables::canonical_combining_class;
85-
86-
/// Return whether the given character is a combining mark (`General_Category=Mark`)
87-
pub use tables::is_combining_mark;
85+
pub use lookups::{canonical_combining_class, is_combining_mark};
8886
}
8987

9088

src/lookups.rs

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// Copyright 2019 The Rust Project Developers. See the COPYRIGHT
2+
// file at the top-level directory of this distribution and at
3+
// http://rust-lang.org/COPYRIGHT.
4+
//
5+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6+
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7+
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8+
// option. This file may not be copied, modified, or distributed
9+
// except according to those terms.
10+
11+
//! Lookups of unicode properties using minimal perfect hashing.
12+
13+
use perfect_hash::mph_lookup;
14+
use tables::*;
15+
16+
/// Look up the canonical combining class for a codepoint.
17+
///
18+
/// The value returned is as defined in the Unicode Character Database.
19+
pub fn canonical_combining_class(c: char) -> u8 {
20+
mph_lookup(c.into(), CANONICAL_COMBINING_CLASS_SALT, CANONICAL_COMBINING_CLASS_KV,
21+
u8_lookup_fk, u8_lookup_fv, 0)
22+
}
23+
24+
pub(crate) fn composition_table(c1: char, c2: char) -> Option<char> {
25+
if c1 < '\u{10000}' && c2 < '\u{10000}' {
26+
mph_lookup((c1 as u32) << 16 | (c2 as u32),
27+
COMPOSITION_TABLE_SALT, COMPOSITION_TABLE_KV,
28+
pair_lookup_fk, pair_lookup_fv_opt, None)
29+
} else {
30+
composition_table_astral(c1, c2)
31+
}
32+
}
33+
34+
pub(crate) fn canonical_fully_decomposed(c: char) -> Option<&'static [char]> {
35+
mph_lookup(c.into(), CANONICAL_DECOMPOSED_SALT, CANONICAL_DECOMPOSED_KV,
36+
pair_lookup_fk, pair_lookup_fv_opt, None)
37+
}
38+
39+
pub(crate) fn compatibility_fully_decomposed(c: char) -> Option<&'static [char]> {
40+
mph_lookup(c.into(), COMPATIBILITY_DECOMPOSED_SALT, COMPATIBILITY_DECOMPOSED_KV,
41+
pair_lookup_fk, pair_lookup_fv_opt, None)
42+
}
43+
44+
/// Return whether the given character is a combining mark (`General_Category=Mark`)
45+
pub fn is_combining_mark(c: char) -> bool {
46+
mph_lookup(c.into(), COMBINING_MARK_SALT, COMBINING_MARK_KV,
47+
bool_lookup_fk, bool_lookup_fv, false)
48+
}
49+
50+
pub fn stream_safe_trailing_nonstarters(c: char) -> usize {
51+
mph_lookup(c.into(), TRAILING_NONSTARTERS_SALT, TRAILING_NONSTARTERS_KV,
52+
u8_lookup_fk, u8_lookup_fv, 0) as usize
53+
}
54+
55+
/// Extract the key in a 24 bit key and 8 bit value packed in a u32.
56+
#[inline]
57+
fn u8_lookup_fk(kv: u32) -> u32 {
58+
kv >> 8
59+
}
60+
61+
/// Extract the value in a 24 bit key and 8 bit value packed in a u32.
62+
#[inline]
63+
fn u8_lookup_fv(kv: u32) -> u8 {
64+
(kv & 0xff) as u8
65+
}
66+
67+
/// Extract the key for a boolean lookup.
68+
#[inline]
69+
fn bool_lookup_fk(kv: u32) -> u32 {
70+
kv
71+
}
72+
73+
/// Extract the value for a boolean lookup.
74+
#[inline]
75+
fn bool_lookup_fv(_kv: u32) -> bool {
76+
true
77+
}
78+
79+
/// Extract the key in a pair.
80+
#[inline]
81+
fn pair_lookup_fk<T>(kv: (u32, T)) -> u32 {
82+
kv.0
83+
}
84+
85+
/// Extract the value in a pair, returning an option.
86+
#[inline]
87+
fn pair_lookup_fv_opt<T>(kv: (u32, T)) -> Option<T> {
88+
Some(kv.1)
89+
}

0 commit comments

Comments
 (0)