Skip to content

Commit 1434cf3

Browse files
committed
Refactor probing logic into an external iterator
1 parent 0c2cda1 commit 1434cf3

File tree

1 file changed

+81
-15
lines changed

1 file changed

+81
-15
lines changed

src/raw/mod.rs

Lines changed: 81 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ cfg_if! {
3434

3535
mod bitmask;
3636

37-
use self::bitmask::BitMask;
37+
use self::bitmask::{BitMask, BitMaskIter};
3838
use self::imp::Group;
3939

4040
// Branch prediction hint. This is currently only available on nightly but it
@@ -938,23 +938,14 @@ impl<T> RawTable<T> {
938938
#[inline]
939939
pub fn find(&self, hash: u64, mut eq: impl FnMut(&T) -> bool) -> Option<Bucket<T>> {
940940
unsafe {
941-
for pos in self.probe_seq(hash) {
942-
let group = Group::load(self.ctrl(pos));
943-
for bit in group.match_byte(h2(hash)) {
944-
let index = (pos + bit) & self.bucket_mask;
945-
let bucket = self.bucket(index);
946-
if likely(eq(bucket.as_ref())) {
947-
return Some(bucket);
948-
}
949-
}
950-
if likely(group.match_empty().any_bit_set()) {
951-
return None;
941+
for bucket in self.iter_hash(hash) {
942+
let elm = bucket.as_ref();
943+
if likely(eq(elm)) {
944+
return Some(bucket);
952945
}
953946
}
947+
None
954948
}
955-
956-
// probe_seq never returns.
957-
unreachable!();
958949
}
959950

960951
/// Returns the number of elements the map can hold without reallocating.
@@ -1004,6 +995,18 @@ impl<T> RawTable<T> {
1004995
}
1005996
}
1006997

998+
/// Returns an iterator over occupied buckets that could match a given hash.
999+
///
1000+
/// In rare cases, the iterator may return a bucket with a different hash.
1001+
///
1002+
/// It is up to the caller to ensure that the `RawTable` outlives the
1003+
/// `RawIterHash`. Because we cannot make the `next` method unsafe on the
1004+
/// `RawIterHash` struct, we have to make the `iter_hash` method unsafe.
1005+
#[cfg_attr(feature = "inline-more", inline)]
1006+
pub unsafe fn iter_hash(&self, hash: u64) -> RawIterHash<'_, T> {
1007+
RawIterHash::new(self, hash)
1008+
}
1009+
10071010
/// Returns an iterator which removes all elements from the table without
10081011
/// freeing the memory.
10091012
///
@@ -1737,3 +1740,66 @@ impl<T> Iterator for RawDrain<'_, T> {
17371740

17381741
impl<T> ExactSizeIterator for RawDrain<'_, T> {}
17391742
impl<T> FusedIterator for RawDrain<'_, T> {}
1743+
1744+
/// Iterator over occupied buckets that could match a given hash.
1745+
///
1746+
/// In rare cases, the iterator may return a bucket with a different hash.
1747+
pub struct RawIterHash<'a, T> {
1748+
table: &'a RawTable<T>,
1749+
1750+
// The top 7 bits of the hash.
1751+
h2_hash: u8,
1752+
1753+
// The sequence of groups to probe in the search.
1754+
probe_seq: ProbeSeq,
1755+
1756+
// The current group and its position.
1757+
pos: usize,
1758+
group: Group,
1759+
1760+
// The elements within the group with a matching h2-hash.
1761+
bitmask: BitMaskIter,
1762+
}
1763+
1764+
impl<'a, T> RawIterHash<'a, T> {
1765+
fn new(table: &'a RawTable<T>, hash: u64) -> Self {
1766+
unsafe {
1767+
let h2_hash = h2(hash);
1768+
let mut probe_seq = table.probe_seq(hash);
1769+
let pos = probe_seq.next().unwrap();
1770+
let group = Group::load(table.ctrl(pos));
1771+
let bitmask = group.match_byte(h2_hash).into_iter();
1772+
1773+
RawIterHash {
1774+
table,
1775+
h2_hash,
1776+
probe_seq,
1777+
pos,
1778+
group,
1779+
bitmask,
1780+
}
1781+
}
1782+
}
1783+
}
1784+
1785+
impl<'a, T> Iterator for RawIterHash<'a, T> {
1786+
type Item = Bucket<T>;
1787+
1788+
fn next(&mut self) -> Option<Bucket<T>> {
1789+
unsafe {
1790+
loop {
1791+
if let Some(bit) = self.bitmask.next() {
1792+
let index = (self.pos + bit) & self.table.bucket_mask;
1793+
let bucket = self.table.bucket(index);
1794+
return Some(bucket);
1795+
}
1796+
if likely(self.group.match_empty().any_bit_set()) {
1797+
return None;
1798+
}
1799+
self.pos = self.probe_seq.next().unwrap();
1800+
self.group = Group::load(self.table.ctrl(self.pos));
1801+
self.bitmask = self.group.match_byte(self.h2_hash).into_iter();
1802+
}
1803+
}
1804+
}
1805+
}

0 commit comments

Comments
 (0)