Skip to content

Commit e4e97ae

Browse files
committed
with_shard
1 parent 085c75d commit e4e97ae

File tree

5 files changed

+177
-117
lines changed

5 files changed

+177
-117
lines changed

compiler/rustc_data_structures/src/sharded.rs

Lines changed: 62 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,24 @@ impl<T> Sharded<T> {
4141
}
4242
}
4343

44+
/// The shard is selected by hashing `val` with `FxHasher`.
45+
#[inline]
46+
pub fn with_get_shard_by_value<K: Hash + ?Sized, F: FnOnce(&mut T) -> R, R>(
47+
&self,
48+
val: &K,
49+
f: F,
50+
) -> R {
51+
if likely(self.single_thread) {
52+
let shard = &self.shard;
53+
assert!(!shard.borrow.replace(true));
54+
let r = unsafe { f(&mut *shard.data.get()) };
55+
shard.borrow.set(false);
56+
r
57+
} else {
58+
self.shards[get_shard_index_by_hash(make_hash(val))].0.with_mt_lock(f)
59+
}
60+
}
61+
4462
/// The shard is selected by hashing `val` with `FxHasher`.
4563
#[inline]
4664
pub fn get_shard_by_value<K: Hash + ?Sized>(&self, val: &K) -> &Lock<T> {
@@ -51,6 +69,23 @@ impl<T> Sharded<T> {
5169
}
5270
}
5371

72+
#[inline]
73+
pub fn with_get_shard_by_hash<F: FnOnce(&mut T) -> R, R>(
74+
&self,
75+
hash: u64,
76+
f: F,
77+
) -> R {
78+
if likely(self.single_thread) {
79+
let shard = &self.shard;
80+
assert!(!shard.borrow.replace(true));
81+
let r = unsafe { f(&mut *shard.data.get()) };
82+
shard.borrow.set(false);
83+
r
84+
} else {
85+
self.shards[get_shard_index_by_hash(hash)].0.with_mt_lock(f)
86+
}
87+
}
88+
5489
#[inline]
5590
pub fn get_shard_by_hash(&self, hash: u64) -> &Lock<T> {
5691
if likely(self.single_thread) {
@@ -93,17 +128,18 @@ impl<K: Eq + Hash + Copy> ShardedHashMap<K, ()> {
93128
Q: Hash + Eq,
94129
{
95130
let hash = make_hash(value);
96-
let mut shard = self.get_shard_by_hash(hash).lock();
97-
let entry = shard.raw_entry_mut().from_key_hashed_nocheck(hash, value);
98-
99-
match entry {
100-
RawEntryMut::Occupied(e) => *e.key(),
101-
RawEntryMut::Vacant(e) => {
102-
let v = make();
103-
e.insert_hashed_nocheck(hash, v, ());
104-
v
131+
self.with_get_shard_by_hash(hash, |shard| {
132+
let entry = shard.raw_entry_mut().from_key_hashed_nocheck(hash, value);
133+
134+
match entry {
135+
RawEntryMut::Occupied(e) => *e.key(),
136+
RawEntryMut::Vacant(e) => {
137+
let v = make();
138+
e.insert_hashed_nocheck(hash, v, ());
139+
v
140+
}
105141
}
106-
}
142+
})
107143
}
108144

109145
#[inline]
@@ -113,17 +149,18 @@ impl<K: Eq + Hash + Copy> ShardedHashMap<K, ()> {
113149
Q: Hash + Eq,
114150
{
115151
let hash = make_hash(&value);
116-
let mut shard = self.get_shard_by_hash(hash).lock();
117-
let entry = shard.raw_entry_mut().from_key_hashed_nocheck(hash, &value);
118-
119-
match entry {
120-
RawEntryMut::Occupied(e) => *e.key(),
121-
RawEntryMut::Vacant(e) => {
122-
let v = make(value);
123-
e.insert_hashed_nocheck(hash, v, ());
124-
v
152+
self.with_get_shard_by_hash(hash, |shard| {
153+
let entry = shard.raw_entry_mut().from_key_hashed_nocheck(hash, &value);
154+
155+
match entry {
156+
RawEntryMut::Occupied(e) => *e.key(),
157+
RawEntryMut::Vacant(e) => {
158+
let v = make(value);
159+
e.insert_hashed_nocheck(hash, v, ());
160+
v
161+
}
125162
}
126-
}
163+
})
127164
}
128165
}
129166

@@ -135,9 +172,11 @@ pub trait IntoPointer {
135172
impl<K: Eq + Hash + Copy + IntoPointer> ShardedHashMap<K, ()> {
136173
pub fn contains_pointer_to<T: Hash + IntoPointer>(&self, value: &T) -> bool {
137174
let hash = make_hash(&value);
138-
let shard = self.get_shard_by_hash(hash).lock();
139-
let value = value.into_pointer();
140-
shard.raw_entry().from_hash(hash, |entry| entry.into_pointer() == value).is_some()
175+
176+
self.with_get_shard_by_hash(hash, |shard| {
177+
let value = value.into_pointer();
178+
shard.raw_entry().from_hash(hash, |entry| entry.into_pointer() == value).is_some()
179+
})
141180
}
142181
}
143182

compiler/rustc_data_structures/src/sync.rs

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -591,8 +591,8 @@ impl<K: Eq + Hash, V: Eq, S: BuildHasher> HashMapExt<K, V> for HashMap<K, V, S>
591591

592592
pub struct Lock<T> {
593593
single_thread: bool,
594-
data: UnsafeCell<T>,
595-
borrow: Cell<bool>,
594+
pub(crate) data: UnsafeCell<T>,
595+
pub(crate) borrow: Cell<bool>,
596596
mutex: RawMutex,
597597
}
598598

@@ -657,8 +657,7 @@ impl<T> Lock<T> {
657657
#[inline(never)]
658658
fn lock_raw(&self) {
659659
if likely(self.single_thread) {
660-
assert!(!self.borrow.get());
661-
self.borrow.set(true);
660+
assert!(!self.borrow.replace(true));
662661
} else {
663662
self.mutex.lock();
664663
}
@@ -671,10 +670,50 @@ impl<T> Lock<T> {
671670
LockGuard { lock: &self, marker: PhantomData }
672671
}
673672

673+
#[inline(never)]
674+
pub(crate) fn with_mt_lock<F: FnOnce(&mut T) -> R, R>(&self, f: F) -> R {
675+
unsafe {
676+
self.mutex.lock();
677+
let r = f(&mut *self.data.get());
678+
self.mutex.unlock();
679+
r
680+
}
681+
}
682+
674683
#[inline(always)]
675684
#[track_caller]
676685
pub fn with_lock<F: FnOnce(&mut T) -> R, R>(&self, f: F) -> R {
677-
f(&mut *self.lock())
686+
if likely(self.single_thread) {
687+
assert!(!self.borrow.replace(true));
688+
let r = unsafe { f(&mut *self.data.get()) };
689+
self.borrow.set(false);
690+
r
691+
} else {
692+
self.with_mt_lock(f)
693+
}
694+
}
695+
696+
#[inline(never)]
697+
fn with_mt_borrow<F: FnOnce(&T) -> R, R>(&self, f: F) -> R {
698+
unsafe {
699+
self.mutex.lock();
700+
let r = f(&*self.data.get());
701+
self.mutex.unlock();
702+
r
703+
}
704+
}
705+
706+
#[inline(always)]
707+
#[track_caller]
708+
pub fn with_borrow<F: FnOnce(&T) -> R, R>(&self, f: F) -> R {
709+
if likely(self.single_thread) {
710+
assert!(!self.borrow.replace(true));
711+
let r = unsafe { f(&*self.data.get()) };
712+
self.borrow.set(false);
713+
r
714+
} else {
715+
self.with_mt_borrow(f)
716+
}
678717
}
679718

680719
#[inline(always)]

compiler/rustc_query_system/src/dep_graph/graph.rs

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -609,10 +609,7 @@ impl<K: DepKind> DepGraphData<K> {
609609
} else {
610610
self.current
611611
.new_node_to_index
612-
.get_shard_by_value(dep_node)
613-
.lock()
614-
.get(dep_node)
615-
.copied()
612+
.with_get_shard_by_value(dep_node, |node| node.get(dep_node).copied())
616613
}
617614
}
618615

@@ -1180,16 +1177,16 @@ impl<K: DepKind> CurrentDepGraph<K> {
11801177
edges: EdgesVec,
11811178
current_fingerprint: Fingerprint,
11821179
) -> DepNodeIndex {
1183-
let dep_node_index = match self.new_node_to_index.get_shard_by_value(&key).lock().entry(key)
1184-
{
1185-
Entry::Occupied(entry) => *entry.get(),
1186-
Entry::Vacant(entry) => {
1187-
let dep_node_index =
1188-
self.encoder.borrow().send(profiler, key, current_fingerprint, edges);
1189-
entry.insert(dep_node_index);
1190-
dep_node_index
1191-
}
1192-
};
1180+
let dep_node_index =
1181+
self.new_node_to_index.with_get_shard_by_value(&key, |node| match node.entry(key) {
1182+
Entry::Occupied(entry) => *entry.get(),
1183+
Entry::Vacant(entry) => {
1184+
let dep_node_index =
1185+
self.encoder.borrow().send(profiler, key, current_fingerprint, edges);
1186+
entry.insert(dep_node_index);
1187+
dep_node_index
1188+
}
1189+
});
11931190

11941191
#[cfg(debug_assertions)]
11951192
self.record_edge(dep_node_index, key, current_fingerprint);

compiler/rustc_query_system/src/query/caches.rs

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -56,20 +56,17 @@ where
5656
#[inline(always)]
5757
fn lookup(&self, key: &K) -> Option<(V, DepNodeIndex)> {
5858
let key_hash = sharded::make_hash(key);
59-
let lock = self.cache.get_shard_by_hash(key_hash).lock();
60-
61-
let result = lock.raw_entry().from_key_hashed_nocheck(key_hash, key);
62-
63-
if let Some((_, value)) = result { Some(*value) } else { None }
59+
self.cache.with_get_shard_by_hash(key_hash, |lock| {
60+
let result = lock.raw_entry().from_key_hashed_nocheck(key_hash, key);
61+
if let Some((_, value)) = result { Some(*value) } else { None }
62+
})
6463
}
6564

6665
#[inline]
6766
fn complete(&self, key: K, value: V, index: DepNodeIndex) {
68-
let mut lock = self.cache.get_shard_by_value(&key).lock();
69-
7067
// We may be overwriting another value. This is all right, since the dep-graph
7168
// will check that the fingerprint matches.
72-
lock.insert(key, (value, index));
69+
self.cache.with_get_shard_by_value(&key, |cache| cache.insert(key, (value, index)));
7370
}
7471

7572
fn iter(&self, f: &mut dyn FnMut(&Self::Key, &Self::Value, DepNodeIndex)) {
@@ -150,16 +147,16 @@ where
150147

151148
#[inline(always)]
152149
fn lookup(&self, key: &K) -> Option<(V, DepNodeIndex)> {
153-
let lock = self.cache.get_shard_by_hash(key.index() as u64).lock();
154-
155-
if let Some(Some(value)) = lock.get(*key) { Some(*value) } else { None }
150+
self.cache.with_get_shard_by_hash(key.index() as u64, |lock| {
151+
if let Some(Some(value)) = lock.get(*key) { Some(*value) } else { None }
152+
})
156153
}
157154

158155
#[inline]
159156
fn complete(&self, key: K, value: V, index: DepNodeIndex) {
160-
let mut lock = self.cache.get_shard_by_hash(key.index() as u64).lock();
161-
162-
lock.insert(key, (value, index));
157+
self.cache.with_get_shard_by_hash(key.index() as u64, |lock| {
158+
lock.insert(key, (value, index));
159+
})
163160
}
164161

165162
fn iter(&self, f: &mut dyn FnMut(&Self::Key, &Self::Value, DepNodeIndex)) {

0 commit comments

Comments
 (0)