1
- use std:: sync:: atomic:: { AtomicBool , AtomicUsize } ;
1
+ use std:: {
2
+ sync:: atomic:: { AtomicBool , AtomicUsize } ,
3
+ thread:: ThreadId ,
4
+ } ;
2
5
3
6
use bevy:: {
4
7
ecs:: { component:: ComponentId , world:: unsafe_world_cell:: UnsafeWorldCell } ,
5
8
prelude:: Resource ,
6
9
} ;
7
10
use dashmap:: { try_result:: TryResult , DashMap , Entry , Map } ;
11
+ use smallvec:: SmallVec ;
8
12
9
13
use super :: { ReflectAllocationId , ReflectBase } ;
10
14
11
- #[ derive( Debug , Clone , Copy , PartialEq , Eq ) ]
15
+ #[ derive( Debug , Clone , PartialEq , Eq ) ]
16
+ pub struct ClaimOwner {
17
+ id : ThreadId ,
18
+ location : std:: panic:: Location < ' static > ,
19
+ }
20
+
21
+ #[ derive( Debug , Clone , PartialEq , Eq ) ]
12
22
pub struct AccessCount {
13
- count : usize ,
14
- /// set if somebody is writing
15
- written_by : Option < std:: panic:: Location < ' static > > ,
23
+ /// The number of readers including thread information
24
+ read_by : SmallVec < [ ClaimOwner ; 1 ] > ,
25
+ /// If the current read is a write access, this will be set
26
+ written : bool ,
16
27
}
17
28
18
29
impl Default for AccessCount {
@@ -24,25 +35,25 @@ impl Default for AccessCount {
24
35
impl AccessCount {
25
36
fn new ( ) -> Self {
26
37
Self {
27
- count : 0 ,
28
- written_by : None ,
38
+ read_by : Default :: default ( ) ,
39
+ written : false ,
29
40
}
30
41
}
31
42
32
43
fn can_read ( & self ) -> bool {
33
- self . written_by . is_none ( )
44
+ ! self . written
34
45
}
35
46
36
47
fn can_write ( & self ) -> bool {
37
- self . count == 0 && self . written_by . is_none ( )
48
+ self . read_by . is_empty ( ) && ! self . written
38
49
}
39
50
40
51
fn as_location ( & self ) -> Option < std:: panic:: Location < ' static > > {
41
- self . written_by
52
+ self . read_by . first ( ) . map ( |o| o . location . clone ( ) )
42
53
}
43
54
44
55
fn readers ( & self ) -> usize {
45
- self . count
56
+ self . read_by . len ( )
46
57
}
47
58
}
48
59
@@ -174,6 +185,7 @@ pub struct AccessMap {
174
185
175
186
impl AccessMap {
176
187
/// Tries to claim read access, will return false if somebody else is writing to the same key, or holding a global lock
188
+ #[ track_caller]
177
189
pub fn claim_read_access < K : AccessMapKey > ( & self , key : K ) -> bool {
178
190
if self . global_lock . load ( std:: sync:: atomic:: Ordering :: Relaxed ) {
179
191
return false ;
@@ -182,7 +194,10 @@ impl AccessMap {
182
194
let access = self . individual_accesses . try_entry ( key) ;
183
195
match access. map ( Entry :: or_default) {
184
196
Some ( mut entry) if entry. can_read ( ) => {
185
- entry. count += 1 ;
197
+ entry. read_by . push ( ClaimOwner {
198
+ id : std:: thread:: current ( ) . id ( ) ,
199
+ location : * std:: panic:: Location :: caller ( ) ,
200
+ } ) ;
186
201
true
187
202
}
188
203
_ => false ,
@@ -199,8 +214,11 @@ impl AccessMap {
199
214
let access = self . individual_accesses . try_entry ( key) ;
200
215
match access. map ( Entry :: or_default) {
201
216
Some ( mut entry) if entry. can_write ( ) => {
202
- entry. count += 1 ;
203
- entry. written_by = Some ( * std:: panic:: Location :: caller ( ) ) ;
217
+ entry. read_by . push ( ClaimOwner {
218
+ id : std:: thread:: current ( ) . id ( ) ,
219
+ location : * std:: panic:: Location :: caller ( ) ,
220
+ } ) ;
221
+ entry. written = true ;
204
222
true
205
223
}
206
224
_ => false ,
@@ -210,7 +228,7 @@ impl AccessMap {
210
228
/// Tries to claim global access. This type of access prevents any other access from happening simulatenously
211
229
/// Will return false if anybody else is currently accessing any part of the map
212
230
pub fn claim_global_access ( & self ) -> bool {
213
- self . individual_accesses . len ( ) == 0
231
+ self . individual_accesses . is_empty ( )
214
232
&& self
215
233
. global_lock
216
234
. compare_exchange (
@@ -222,17 +240,25 @@ impl AccessMap {
222
240
. is_ok ( )
223
241
}
224
242
243
+ /// Releases an access
244
+ ///
245
+ /// # Panics
246
+ /// if the access is released from a different thread than it was claimed from
225
247
pub fn release_access < K : AccessMapKey > ( & self , key : K ) {
226
248
let key = key. as_usize ( ) ;
227
249
let access = self . individual_accesses . entry ( key) ;
228
250
match access {
229
251
dashmap:: mapref:: entry:: Entry :: Occupied ( mut entry) => {
230
252
let entry_mut = entry. get_mut ( ) ;
231
- if entry_mut. written_by . is_some ( ) {
232
- entry_mut. written_by = None ;
253
+ entry_mut. written = false ;
254
+ if let Some ( p) = entry_mut. read_by . pop ( ) {
255
+ assert ! (
256
+ p. id == std:: thread:: current( ) . id( ) ,
257
+ "Access released from wrong thread, claimed at {}" ,
258
+ p. location. display_location( )
259
+ ) ;
233
260
}
234
- entry_mut. count -= 1 ;
235
- if entry_mut. count == 0 {
261
+ if entry_mut. readers ( ) == 0 {
236
262
entry. remove ( ) ;
237
263
}
238
264
}
@@ -253,15 +279,32 @@ impl AccessMap {
253
279
. collect ( )
254
280
}
255
281
282
+ pub fn count_thread_acceesses ( & self ) -> usize {
283
+ self . individual_accesses
284
+ . iter ( )
285
+ . filter ( |e| {
286
+ e. value ( )
287
+ . read_by
288
+ . iter ( )
289
+ . any ( |o| o. id == std:: thread:: current ( ) . id ( ) )
290
+ } )
291
+ . count ( )
292
+ }
293
+
256
294
pub fn access_location < K : AccessMapKey > (
257
295
& self ,
258
296
key : K ,
259
297
) -> Option < std:: panic:: Location < ' static > > {
260
298
self . individual_accesses
261
299
. try_get ( & key. as_usize ( ) )
262
300
. try_unwrap ( )
263
- . map ( |access| access. as_location ( ) )
264
- . flatten ( )
301
+ . and_then ( |access| access. as_location ( ) )
302
+ }
303
+
304
+ pub fn access_first_location ( & self ) -> Option < std:: panic:: Location < ' static > > {
305
+ self . individual_accesses
306
+ . iter ( )
307
+ . find_map ( |e| e. value ( ) . as_location ( ) )
265
308
}
266
309
}
267
310
@@ -325,8 +368,11 @@ macro_rules! with_global_access {
325
368
( $access_map: expr, $msg: expr, $body: block) => {
326
369
if !$access_map. claim_global_access( ) {
327
370
panic!(
328
- "{}. Another access is held somewhere else preventing locking the world" ,
329
- $msg
371
+ "{}. Another access is held somewhere else preventing locking the world: {}" ,
372
+ $msg,
373
+ $crate:: bindings:: access_map:: DisplayCodeLocation :: display_location(
374
+ $access_map. access_first_location( )
375
+ )
330
376
) ;
331
377
} else {
332
378
let result = ( || $body) ( ) ;
@@ -355,8 +401,8 @@ mod test {
355
401
assert_eq ! ( access_0. 1 . readers( ) , 1 ) ;
356
402
assert_eq ! ( access_1. 1 . readers( ) , 1 ) ;
357
403
358
- assert_eq ! ( access_0. 1 . written_by , None ) ;
359
- assert ! ( access_1. 1 . written_by . is_some ( ) ) ;
404
+ assert ! ( ! access_0. 1 . written ) ;
405
+ assert ! ( access_1. 1 . written ) ;
360
406
}
361
407
362
408
#[ test]
@@ -403,4 +449,30 @@ mod test {
403
449
assert ! ( access_map. claim_write_access( 0 ) ) ;
404
450
assert ! ( !access_map. claim_global_access( ) ) ;
405
451
}
452
+
453
+ #[ test]
454
+ #[ should_panic]
455
+ fn releasing_read_access_from_wrong_thread_panics ( ) {
456
+ let access_map = AccessMap :: default ( ) ;
457
+
458
+ access_map. claim_read_access ( 0 ) ;
459
+ std:: thread:: spawn ( move || {
460
+ access_map. release_access ( 0 ) ;
461
+ } )
462
+ . join ( )
463
+ . unwrap ( ) ;
464
+ }
465
+
466
+ #[ test]
467
+ #[ should_panic]
468
+ fn releasing_write_access_from_wrong_thread_panics ( ) {
469
+ let access_map = AccessMap :: default ( ) ;
470
+
471
+ access_map. claim_write_access ( 0 ) ;
472
+ std:: thread:: spawn ( move || {
473
+ access_map. release_access ( 0 ) ;
474
+ } )
475
+ . join ( )
476
+ . unwrap ( ) ;
477
+ }
406
478
}
0 commit comments