Skip to content

Commit a329127

Browse files
Restrict binary_search to Vecs and take advantage of size limit
1 parent e7477cd commit a329127

File tree

1 file changed

+23
-6
lines changed

1 file changed

+23
-6
lines changed

src/treefrog.rs

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ pub(crate) mod extend_with {
362362
{
363363
fn count(&mut self, prefix: &Tuple) -> usize {
364364
let key = (self.key_func)(prefix);
365-
self.start = binary_search(&self.relation[..], |x| &x.0 < &key);
365+
self.start = binary_search(&self.relation.elements, |x| &x.0 < &key);
366366
let slice1 = &self.relation[self.start..];
367367
let slice2 = gallop(slice1, |x| &x.0 <= &key);
368368
self.end = self.relation.len() - slice2.len();
@@ -454,7 +454,7 @@ pub(crate) mod extend_anti {
454454
}
455455
fn intersect(&mut self, prefix: &Tuple, values: &mut Vec<&'leap Val>) {
456456
let key = (self.key_func)(prefix);
457-
let start = binary_search(&self.relation[..], |x| &x.0 < &key);
457+
let start = binary_search(&self.relation.elements, |x| &x.0 < &key);
458458
let slice1 = &self.relation[start..];
459459
let slice2 = gallop(slice1, |x| &x.0 <= &key);
460460
let mut slice = &slice1[..(slice1.len() - slice2.len())];
@@ -642,15 +642,32 @@ pub(crate) mod filter_anti {
642642
}
643643
}
644644

645-
fn binary_search<T>(slice: &[T], mut cmp: impl FnMut(&T) -> bool) -> usize {
645+
/// Returns the lowest index for which `cmp(&vec[i])` returns `true`, assuming `vec` is in sorted
646+
/// order.
647+
///
648+
/// By accepting a vector instead of a slice, we can do a small optimization when computing the
649+
/// midpoint.
650+
fn binary_search<T>(vec: &Vec<T>, mut cmp: impl FnMut(&T) -> bool) -> usize {
651+
// The midpoint calculation we use below is only correct for vectors with less than `isize::MAX`
652+
// elements. This is always true for vectors of sized types but maybe not for ZSTs? Sorting
653+
// ZSTs doesn't make much sense, so just forbid it here.
654+
assert!(std::mem::size_of::<T>() > 0);
655+
646656
// we maintain the invariant that `lo` many elements of `slice` satisfy `cmp`.
647657
// `hi` is maintained at the first element we know does not satisfy `cmp`.
648658

649-
let mut hi = slice.len();
659+
let mut hi = vec.len();
650660
let mut lo = 0;
651661
while lo < hi {
652-
let mid = lo + (hi - lo) / 2;
653-
let el: &T = unsafe { slice.get_unchecked(mid) };
662+
// Unlike in the general case, this expression cannot overflow because `Vec` is limited to
663+
// `isize::MAX` capacity and we disallow ZSTs above. If we needed to support slices or
664+
// vectors of ZSTs, which don't have an upper bound on their size AFAIK, we would need to
665+
// use a slightly less efficient version that cannot overflow: `lo + (hi - lo) / 2`.
666+
let mid = (hi + lo) / 2;
667+
668+
// LLVM seems to be unable to prove that `mid` is always less than `vec.len()`, so use
669+
// `get_unchecked` to avoid a bounds check since this code is hot.
670+
let el: &T = unsafe { vec.get_unchecked(mid) };
654671
if cmp(el) {
655672
lo = mid + 1;
656673
} else {

0 commit comments

Comments
 (0)