Skip to content

Commit 796ac5b

Browse files
authored
Merge pull request #30 from ecstatic-morse/binary-search-micro
Micro-optimize `binary_search`
2 parents 883f028 + a329127 commit 796ac5b

File tree

1 file changed

+24
-7
lines changed

1 file changed

+24
-7
lines changed

src/treefrog.rs

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,6 @@ pub(crate) mod filters {
233233
values.retain(|val| (self.predicate)(prefix, val));
234234
}
235235
}
236-
237236
}
238237

239238
/// Extension method for relations.
@@ -363,7 +362,7 @@ pub(crate) mod extend_with {
363362
{
364363
fn count(&mut self, prefix: &Tuple) -> usize {
365364
let key = (self.key_func)(prefix);
366-
self.start = binary_search(&self.relation[..], |x| &x.0 < &key);
365+
self.start = binary_search(&self.relation.elements, |x| &x.0 < &key);
367366
let slice1 = &self.relation[self.start..];
368367
let slice2 = gallop(slice1, |x| &x.0 <= &key);
369368
self.end = self.relation.len() - slice2.len();
@@ -455,7 +454,7 @@ pub(crate) mod extend_anti {
455454
}
456455
fn intersect(&mut self, prefix: &Tuple, values: &mut Vec<&'leap Val>) {
457456
let key = (self.key_func)(prefix);
458-
let start = binary_search(&self.relation[..], |x| &x.0 < &key);
457+
let start = binary_search(&self.relation.elements, |x| &x.0 < &key);
459458
let slice1 = &self.relation[start..];
460459
let slice2 = gallop(slice1, |x| &x.0 <= &key);
461460
let mut slice = &slice1[..(slice1.len() - slice2.len())];
@@ -643,15 +642,33 @@ pub(crate) mod filter_anti {
643642
}
644643
}
645644

646-
fn binary_search<T>(slice: &[T], mut cmp: impl FnMut(&T) -> bool) -> usize {
645+
/// Returns the lowest index for which `cmp(&vec[i])` returns `true`, assuming `vec` is in sorted
646+
/// order.
647+
///
648+
/// By accepting a vector instead of a slice, we can do a small optimization when computing the
649+
/// midpoint.
650+
fn binary_search<T>(vec: &Vec<T>, mut cmp: impl FnMut(&T) -> bool) -> usize {
651+
// The midpoint calculation we use below is only correct for vectors with less than `isize::MAX`
652+
// elements. This is always true for vectors of sized types but maybe not for ZSTs? Sorting
653+
// ZSTs doesn't make much sense, so just forbid it here.
654+
assert!(std::mem::size_of::<T>() > 0);
655+
647656
// we maintain the invariant that `lo` many elements of `slice` satisfy `cmp`.
648657
// `hi` is maintained at the first element we know does not satisfy `cmp`.
649658

650-
let mut hi = slice.len();
659+
let mut hi = vec.len();
651660
let mut lo = 0;
652661
while lo < hi {
653-
let mid = lo + (hi - lo) / 2;
654-
if cmp(&slice[mid]) {
662+
// Unlike in the general case, this expression cannot overflow because `Vec` is limited to
663+
// `isize::MAX` capacity and we disallow ZSTs above. If we needed to support slices or
664+
// vectors of ZSTs, which don't have an upper bound on their size AFAIK, we would need to
665+
// use a slightly less efficient version that cannot overflow: `lo + (hi - lo) / 2`.
666+
let mid = (hi + lo) / 2;
667+
668+
// LLVM seems to be unable to prove that `mid` is always less than `vec.len()`, so use
669+
// `get_unchecked` to avoid a bounds check since this code is hot.
670+
let el: &T = unsafe { vec.get_unchecked(mid) };
671+
if cmp(el) {
655672
lo = mid + 1;
656673
} else {
657674
hi = mid;

0 commit comments

Comments
 (0)