Skip to content

Commit 4af4331

Browse files
committed
refactor accesses slightly
1 parent 811ef75 commit 4af4331

File tree

8 files changed

+149
-51
lines changed

8 files changed

+149
-51
lines changed

assets/scripts/bevy_api.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ function on_event()
1515
print(script)
1616
print(world)
1717

18-
print(world:hello(entity))
18+
print(world:hello(entity, entity))
1919
print(world:test_vec({entity, entity})[1])
2020

2121

crates/bevy_mod_scripting_core/src/bindings/access_map.rs

Lines changed: 97 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,29 @@
1-
use std::sync::atomic::{AtomicBool, AtomicUsize};
1+
use std::{
2+
sync::atomic::{AtomicBool, AtomicUsize},
3+
thread::ThreadId,
4+
};
25

36
use bevy::{
47
ecs::{component::ComponentId, world::unsafe_world_cell::UnsafeWorldCell},
58
prelude::Resource,
69
};
710
use dashmap::{try_result::TryResult, DashMap, Entry, Map};
11+
use smallvec::SmallVec;
812

913
use super::{ReflectAllocationId, ReflectBase};
1014

11-
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15+
#[derive(Debug, Clone, PartialEq, Eq)]
16+
pub struct ClaimOwner {
17+
id: ThreadId,
18+
location: std::panic::Location<'static>,
19+
}
20+
21+
#[derive(Debug, Clone, PartialEq, Eq)]
1222
pub struct AccessCount {
13-
count: usize,
14-
/// set if somebody is writing
15-
written_by: Option<std::panic::Location<'static>>,
23+
/// The number of readers including thread information
24+
read_by: SmallVec<[ClaimOwner; 1]>,
25+
/// If the current read is a write access, this will be set
26+
written: bool,
1627
}
1728

1829
impl Default for AccessCount {
@@ -24,25 +35,25 @@ impl Default for AccessCount {
2435
impl AccessCount {
2536
fn new() -> Self {
2637
Self {
27-
count: 0,
28-
written_by: None,
38+
read_by: Default::default(),
39+
written: false,
2940
}
3041
}
3142

3243
fn can_read(&self) -> bool {
33-
self.written_by.is_none()
44+
!self.written
3445
}
3546

3647
fn can_write(&self) -> bool {
37-
self.count == 0 && self.written_by.is_none()
48+
self.read_by.is_empty() && !self.written
3849
}
3950

4051
fn as_location(&self) -> Option<std::panic::Location<'static>> {
41-
self.written_by
52+
self.read_by.first().map(|o| o.location.clone())
4253
}
4354

4455
fn readers(&self) -> usize {
45-
self.count
56+
self.read_by.len()
4657
}
4758
}
4859

@@ -174,6 +185,7 @@ pub struct AccessMap {
174185

175186
impl AccessMap {
176187
/// Tries to claim read access, will return false if somebody else is writing to the same key, or holding a global lock
188+
#[track_caller]
177189
pub fn claim_read_access<K: AccessMapKey>(&self, key: K) -> bool {
178190
if self.global_lock.load(std::sync::atomic::Ordering::Relaxed) {
179191
return false;
@@ -182,7 +194,10 @@ impl AccessMap {
182194
let access = self.individual_accesses.try_entry(key);
183195
match access.map(Entry::or_default) {
184196
Some(mut entry) if entry.can_read() => {
185-
entry.count += 1;
197+
entry.read_by.push(ClaimOwner {
198+
id: std::thread::current().id(),
199+
location: *std::panic::Location::caller(),
200+
});
186201
true
187202
}
188203
_ => false,
@@ -199,8 +214,11 @@ impl AccessMap {
199214
let access = self.individual_accesses.try_entry(key);
200215
match access.map(Entry::or_default) {
201216
Some(mut entry) if entry.can_write() => {
202-
entry.count += 1;
203-
entry.written_by = Some(*std::panic::Location::caller());
217+
entry.read_by.push(ClaimOwner {
218+
id: std::thread::current().id(),
219+
location: *std::panic::Location::caller(),
220+
});
221+
entry.written = true;
204222
true
205223
}
206224
_ => false,
@@ -210,7 +228,7 @@ impl AccessMap {
210228
/// Tries to claim global access. This type of access prevents any other access from happening simulatenously
211229
/// Will return false if anybody else is currently accessing any part of the map
212230
pub fn claim_global_access(&self) -> bool {
213-
self.individual_accesses.len() == 0
231+
self.individual_accesses.is_empty()
214232
&& self
215233
.global_lock
216234
.compare_exchange(
@@ -222,17 +240,25 @@ impl AccessMap {
222240
.is_ok()
223241
}
224242

243+
/// Releases an access
244+
///
245+
/// # Panics
246+
/// if the access is released from a different thread than it was claimed from
225247
pub fn release_access<K: AccessMapKey>(&self, key: K) {
226248
let key = key.as_usize();
227249
let access = self.individual_accesses.entry(key);
228250
match access {
229251
dashmap::mapref::entry::Entry::Occupied(mut entry) => {
230252
let entry_mut = entry.get_mut();
231-
if entry_mut.written_by.is_some() {
232-
entry_mut.written_by = None;
253+
entry_mut.written = false;
254+
if let Some(p) = entry_mut.read_by.pop() {
255+
assert!(
256+
p.id == std::thread::current().id(),
257+
"Access released from wrong thread, claimed at {}",
258+
p.location.display_location()
259+
);
233260
}
234-
entry_mut.count -= 1;
235-
if entry_mut.count == 0 {
261+
if entry_mut.readers() == 0 {
236262
entry.remove();
237263
}
238264
}
@@ -253,15 +279,32 @@ impl AccessMap {
253279
.collect()
254280
}
255281

282+
pub fn count_thread_acceesses(&self) -> usize {
283+
self.individual_accesses
284+
.iter()
285+
.filter(|e| {
286+
e.value()
287+
.read_by
288+
.iter()
289+
.any(|o| o.id == std::thread::current().id())
290+
})
291+
.count()
292+
}
293+
256294
pub fn access_location<K: AccessMapKey>(
257295
&self,
258296
key: K,
259297
) -> Option<std::panic::Location<'static>> {
260298
self.individual_accesses
261299
.try_get(&key.as_usize())
262300
.try_unwrap()
263-
.map(|access| access.as_location())
264-
.flatten()
301+
.and_then(|access| access.as_location())
302+
}
303+
304+
pub fn access_first_location(&self) -> Option<std::panic::Location<'static>> {
305+
self.individual_accesses
306+
.iter()
307+
.find_map(|e| e.value().as_location())
265308
}
266309
}
267310

@@ -325,8 +368,11 @@ macro_rules! with_global_access {
325368
($access_map:expr, $msg:expr, $body:block) => {
326369
if !$access_map.claim_global_access() {
327370
panic!(
328-
"{}. Another access is held somewhere else preventing locking the world",
329-
$msg
371+
"{}. Another access is held somewhere else preventing locking the world: {}",
372+
$msg,
373+
$crate::bindings::access_map::DisplayCodeLocation::display_location(
374+
$access_map.access_first_location()
375+
)
330376
);
331377
} else {
332378
let result = (|| $body)();
@@ -355,8 +401,8 @@ mod test {
355401
assert_eq!(access_0.1.readers(), 1);
356402
assert_eq!(access_1.1.readers(), 1);
357403

358-
assert_eq!(access_0.1.written_by, None);
359-
assert!(access_1.1.written_by.is_some());
404+
assert!(!access_0.1.written);
405+
assert!(access_1.1.written);
360406
}
361407

362408
#[test]
@@ -403,4 +449,30 @@ mod test {
403449
assert!(access_map.claim_write_access(0));
404450
assert!(!access_map.claim_global_access());
405451
}
452+
453+
#[test]
454+
#[should_panic]
455+
fn releasing_read_access_from_wrong_thread_panics() {
456+
let access_map = AccessMap::default();
457+
458+
access_map.claim_read_access(0);
459+
std::thread::spawn(move || {
460+
access_map.release_access(0);
461+
})
462+
.join()
463+
.unwrap();
464+
}
465+
466+
#[test]
467+
#[should_panic]
468+
fn releasing_write_access_from_wrong_thread_panics() {
469+
let access_map = AccessMap::default();
470+
471+
access_map.claim_write_access(0);
472+
std::thread::spawn(move || {
473+
access_map.release_access(0);
474+
})
475+
.join()
476+
.unwrap();
477+
}
406478
}

crates/bevy_mod_scripting_core/src/bindings/function/from.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,10 @@ impl<T: FromReflect> FromScript for Ref<'_, T> {
170170
})?;
171171
Ok(Ref(cast))
172172
} else {
173-
Err(InteropError::cannot_claim_access(reflect_reference.base))
173+
Err(InteropError::cannot_claim_access(
174+
reflect_reference.base,
175+
world.get_access_location(raid),
176+
))
174177
}
175178
}
176179
_ => Err(InteropError::value_mismatch(
@@ -227,7 +230,10 @@ impl<T: FromReflect> FromScript for Mut<'_, T> {
227230
})?;
228231
Ok(Mut(cast))
229232
} else {
230-
Err(InteropError::cannot_claim_access(reflect_reference.base))
233+
Err(InteropError::cannot_claim_access(
234+
reflect_reference.base,
235+
world.get_access_location(raid),
236+
))
231237
}
232238
}
233239
_ => Err(InteropError::value_mismatch(

crates/bevy_mod_scripting_core/src/bindings/function/mod.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,10 @@ impl<I: Iterator<Item = ScriptValue>> IntoArgsListWithAccess for I {
245245
Err(ref_.expect_err("invariant"))
246246
}
247247
} else {
248-
Err(InteropError::cannot_claim_access(arg_ref.base.clone()))
248+
Err(InteropError::cannot_claim_access(
249+
arg_ref.base.clone(),
250+
world.get_access_location(access_id),
251+
))
249252
}
250253
}
251254
Ownership::Mut => {
@@ -258,7 +261,10 @@ impl<I: Iterator<Item = ScriptValue>> IntoArgsListWithAccess for I {
258261
Err(mut_ref.expect_err("invariant"))
259262
}
260263
} else {
261-
Err(InteropError::cannot_claim_access(arg_ref.base.clone()))
264+
Err(InteropError::cannot_claim_access(
265+
arg_ref.base.clone(),
266+
world.get_access_location(access_id),
267+
))
262268
}
263269
}
264270
_ => unreachable!(),

crates/bevy_mod_scripting_core/src/bindings/function/script_function.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use super::{from::FromScript, into::IntoScript};
1414
message = "Only functions with all arguments impplementing FromScript and return values supporting IntoScript are supported. use assert_impls_into_script!(MyArg) and assert_impls_from_script!(MyReturnType) to verify yours do.",
1515
note = "If you're trying to return a non-primitive type, you might need to use Val<T> Ref<T> or Mut<T> wrappers"
1616
)]
17-
pub trait ScriptFunction<Marker> {
17+
pub trait ScriptFunction<'env, Marker> {
1818
fn into_dynamic_function(self) -> DynamicFunction<'static>;
1919
}
2020

@@ -34,23 +34,24 @@ macro_rules! impl_script_function {
3434
(@ $( $param:ident ),* : $(($callback:ident: $callbackty:ty))? -> O => $res:ty $(where $out:ident)?) => {
3535
#[allow(non_snake_case)]
3636
impl<
37-
'l,
37+
'env,
3838
$( $param: FromScript, )*
3939
O,
4040
F
41-
> ScriptFunction<
41+
> ScriptFunction<'env,
4242
fn( $( $callbackty, )? $($param ),* ) -> $res
4343
> for F
4444
where
4545
O: IntoScript,
4646
F: Fn( $( $callbackty, )? $($param ),* ) -> $res + Send + Sync + 'static,
47-
$( $param::This<'l>: Into<$param>),*
47+
$( $param::This<'env>: Into<$param>),*
4848
{
4949
fn into_dynamic_function(self) -> DynamicFunction<'static> {
5050
(move |world: WorldCallbackAccess, $( $param: ScriptValue ),* | {
5151
let res: Result<ScriptValue, InteropError> = (|| {
5252
$( let $callback = world.clone(); )?
5353
let world = world.read().ok_or_else(|| InteropError::stale_world_access())?;
54+
// TODO: snapshot the accesses and release them after
5455
$( let $param = <$param>::from_script($param, world.clone())?; )*
5556
let out = self( $( $callback, )? $( $param.into(), )* );
5657
$(
@@ -95,7 +96,10 @@ macro_rules! assert_impls_from_script {
9596
trait Check: $crate::bindings::function::from::FromScript {}
9697
impl Check for $ty {}
9798
};
98-
() => {};
99+
($l:lifetime $ty:ty) => {
100+
trait Check: $crate::bindings::function::from::FromScript {}
101+
impl<$l> Check for $ty {}
102+
};
99103
}
100104

101105
/// Utility for quickly checking your function can be used as a script function
@@ -107,7 +111,7 @@ macro_rules! assert_impls_from_script {
107111
#[macro_export]
108112
macro_rules! assert_is_script_function {
109113
($($tt:tt)*) => {
110-
fn _check<M,F: ScriptFunction<M>>(f: F) {
114+
fn _check<'env,M,F: ScriptFunction<'env, M>>(f: F) {
111115

112116
}
113117

crates/bevy_mod_scripting_core/src/bindings/world.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,10 +309,14 @@ impl<'w> WorldAccessGuard<'w> {
309309
self.0.cell.components().get_resource_id(id)
310310
}
311311

312-
pub fn get_access_location(&self, raid: ReflectAccessId) -> Option<std::panic::Location<'w>> {
312+
pub fn get_access_location(
313+
&self,
314+
raid: ReflectAccessId,
315+
) -> Option<std::panic::Location<'static>> {
313316
self.0.accesses.access_location(raid)
314317
}
315318

319+
#[track_caller]
316320
pub fn claim_read_access(&self, raid: ReflectAccessId) -> bool {
317321
self.0.accesses.claim_read_access(raid)
318322
}

0 commit comments

Comments
 (0)