Skip to content

Commit e152522

Browse files
committed
Fix Dictionary::get unsoundness and add a few convenience functions
1 parent 0413391 commit e152522

File tree

5 files changed

+82
-23
lines changed

5 files changed

+82
-23
lines changed

gdnative-core/src/core_types/dictionary.rs

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,21 +69,42 @@ impl<Access: ThreadAccess> Dictionary<Access> {
6969
unsafe { (get_api().godot_dictionary_has_all)(self.sys(), keys.sys()) }
7070
}
7171

72-
/// Returns a copy of the value corresponding to the key.
72+
/// Returns a copy of the value corresponding to the key if it exists.
7373
#[inline]
74-
pub fn get<K>(&self, key: K) -> Variant
74+
pub fn get<K>(&self, key: K) -> Option<Variant>
7575
where
7676
K: ToVariant + ToVariantEq,
77+
{
78+
let key = key.to_variant();
79+
self.contains(&key).then(|| self.get_or_nil(key))
80+
}
81+
82+
/// Returns a copy of the value corresponding to the key, or `default` if it doesn't exist
83+
#[inline]
84+
pub fn get_or<K, D>(&self, default: D, key: K) -> Variant
85+
where
86+
K: ToVariant + ToVariantEq,
87+
D: ToVariant + ToVariantEq,
7788
{
7889
unsafe {
79-
Variant((get_api().godot_dictionary_get)(
90+
Variant((get_api().godot_dictionary_get_with_default)(
8091
self.sys(),
8192
key.to_variant().sys(),
93+
default.to_variant().sys(),
8294
))
8395
}
8496
}
8597

86-
/// Update an existing element corresponding ot the key.
98+
/// Returns a copy of the element corresponding to the key, or `Nil` if it doesn't exist.
99+
#[inline]
100+
pub fn get_or_nil<K>(&self, key: K) -> Variant
101+
where
102+
K: ToVariant + ToVariantEq,
103+
{
104+
self.get_or(Variant::new(), key)
105+
}
106+
107+
/// Update an existing element corresponding to the key.
87108
///
88109
/// # Panics
89110
///
@@ -106,12 +127,14 @@ impl<Access: ThreadAccess> Dictionary<Access> {
106127
}
107128
}
108129

109-
/// Returns a reference to the value corresponding to the key.
130+
/// Returns a reference to the value corresponding to the key, inserting `Nil` first if
131+
/// it does not exist.
110132
///
111133
/// # Safety
112134
///
113135
/// The returned reference is invalidated if the same container is mutated through another
114-
/// reference.
136+
/// reference, and other references may be invalidated if the entry does not already exist
137+
/// (which causes this function to insert `Nil` and thus possibly re-allocate).
115138
///
116139
/// `Variant` is reference-counted and thus cheaply cloned. Consider using `get` instead.
117140
#[inline]
@@ -125,13 +148,16 @@ impl<Access: ThreadAccess> Dictionary<Access> {
125148
))
126149
}
127150

128-
/// Returns a mutable reference to the value corresponding to the key.
151+
/// Returns a mutable reference to the value corresponding to the key, inserting `Nil`
152+
/// first if it does not exist.
129153
///
130154
/// # Safety
131155
///
132156
/// The returned reference is invalidated if the same container is mutated through another
133-
/// reference. It is possible to create two mutable references to the same memory location
134-
/// if the same `key` is provided, causing undefined behavior.
157+
/// reference, and other references may be invalidated if the `key` does not already exist
158+
/// (which causes this function to insert `Nil` and thus possibly re-allocate). It is also
159+
/// possible to create two mutable references to the same memory location if the same `key`
160+
/// is provided, causing undefined behavior.
135161
#[inline]
136162
#[allow(clippy::mut_from_ref)]
137163
pub unsafe fn get_mut_ref<K>(&self, key: K) -> &mut Variant
@@ -266,6 +292,19 @@ impl Dictionary<Shared> {
266292
pub unsafe fn clear(&self) {
267293
(get_api().godot_dictionary_clear)(self.sys_mut())
268294
}
295+
296+
/// Returns a copy of the value corresponding to the key, inserting `Nil` first if it does not exist.
297+
#[doc_variant_collection_safety]
298+
#[inline]
299+
pub unsafe fn get_or_insert_nil<K>(&self, key: K) -> Variant
300+
where
301+
K: ToVariant + ToVariantEq,
302+
{
303+
Variant((get_api().godot_dictionary_get)(
304+
self.sys(),
305+
key.to_variant().sys(),
306+
))
307+
}
269308
}
270309

271310
/// Operations allowed on Dictionaries that may only be shared on the current thread.
@@ -327,6 +366,20 @@ impl<Access: LocalThreadAccess> Dictionary<Access> {
327366
pub fn clear(&self) {
328367
unsafe { (get_api().godot_dictionary_clear)(self.sys_mut()) }
329368
}
369+
370+
/// Returns a copy of the value corresponding to the key, inserting `Nil` first if it does not exist.
371+
#[inline]
372+
pub fn get_or_insert_nil<K>(&self, key: K) -> Variant
373+
where
374+
K: ToVariant + ToVariantEq,
375+
{
376+
unsafe {
377+
Variant((get_api().godot_dictionary_get)(
378+
self.sys(),
379+
key.to_variant().sys(),
380+
))
381+
}
382+
}
330383
}
331384

332385
/// Operations allowed on unique Dictionaries.
@@ -425,7 +478,7 @@ unsafe fn iter_next<Access: ThreadAccess>(
425478
None
426479
} else {
427480
let key = Variant::cast_ref(next_ptr).clone();
428-
let value = dic.get(&key);
481+
let value = dic.get_or_nil(&key);
429482
*last_key = Some(key.clone());
430483
Some((key, value))
431484
}
@@ -591,7 +644,7 @@ godot_test!(test_dictionary {
591644
let mut iter_keys = HashSet::new();
592645
let expected_keys = ["foo", "bar"].iter().map(|&s| s.to_string()).collect::<HashSet<_>>();
593646
for (key, value) in &dict {
594-
assert_eq!(value, dict.get(&key));
647+
assert_eq!(Some(value), dict.get(&key));
595648
if !iter_keys.insert(key.to_string()) {
596649
panic!("key is already contained in set: {:?}", key);
597650
}

gdnative-core/src/core_types/variant.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1753,7 +1753,7 @@ impl<T: FromVariant, E: FromVariant> FromVariant for Result<T, E> {
17531753

17541754
match key.as_str() {
17551755
"Ok" => {
1756-
let val = T::from_variant(&dict.get(key_variant)).map_err(|err| {
1756+
let val = T::from_variant(&dict.get_or_nil(key_variant)).map_err(|err| {
17571757
FVE::InvalidEnumVariant {
17581758
variant: "Ok",
17591759
error: Box::new(err),
@@ -1762,7 +1762,7 @@ impl<T: FromVariant, E: FromVariant> FromVariant for Result<T, E> {
17621762
Ok(Ok(val))
17631763
}
17641764
"Err" => {
1765-
let err = E::from_variant(&dict.get(key_variant)).map_err(|err| {
1765+
let err = E::from_variant(&dict.get_or_nil(key_variant)).map_err(|err| {
17661766
FVE::InvalidEnumVariant {
17671767
variant: "Err",
17681768
error: Box::new(err),
@@ -1912,11 +1912,11 @@ godot_test!(
19121912
test_variant_result {
19131913
let variant = Result::<i64, ()>::Ok(42_i64).to_variant();
19141914
let dict = variant.try_to_dictionary().expect("should be dic");
1915-
assert_eq!(Some(42), dict.get("Ok").try_to_i64());
1915+
assert_eq!(Some(42), dict.get("Ok").and_then(|v| v.try_to_i64()));
19161916

19171917
let variant = Result::<(), i64>::Err(54_i64).to_variant();
19181918
let dict = variant.try_to_dictionary().expect("should be dic");
1919-
assert_eq!(Some(54), dict.get("Err").try_to_i64());
1919+
assert_eq!(Some(54), dict.get("Err").and_then(|v| v.try_to_i64()));
19201920

19211921
let variant = Variant::from_bool(true);
19221922
assert_eq!(

gdnative-derive/src/variant/from.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ pub(crate) fn expand_from_variant(derive_data: DeriveData) -> Result<TokenStream
8787
match __key.as_str() {
8888
#(
8989
#ref_var_ident_string_literals => {
90-
let #var_input_ident_iter = &__dict.get(&__keys.get(0));
90+
let #var_input_ident_iter = &__dict.get_or_nil(&__keys.get(0));
9191
(#var_from_variants).map_err(|err| FVE::InvalidEnumVariant {
9292
variant: #ref_var_ident_string_literals,
9393
error: Box::new(err),

gdnative-derive/src/variant/repr.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ impl VariantRepr {
280280
let name_string_literals =
281281
name_strings.iter().map(|string| Literal::string(&string));
282282

283-
let expr_variant = &quote!(&__dict.get(&__key));
283+
let expr_variant = &quote!(&__dict.get_or_nil(&__key));
284284
let exprs = non_skipped_fields
285285
.iter()
286286
.map(|f| f.from_variant(expr_variant));

test/src/test_derive.rs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,19 +83,25 @@ fn test_derive_to_variant() -> bool {
8383

8484
let variant = data.to_variant();
8585
let dictionary = variant.try_to_dictionary().expect("should be dictionary");
86-
assert_eq!(Some(42), dictionary.get("foo").try_to_i64());
87-
assert_eq!(Some(54.0), dictionary.get("bar").try_to_f64());
86+
assert_eq!(Some(42), dictionary.get("foo").and_then(|v| v.try_to_i64()));
87+
assert_eq!(
88+
Some(54.0),
89+
dictionary.get("bar").and_then(|v| v.try_to_f64())
90+
);
8891
assert_eq!(
8992
Some("*mut ()".into()),
90-
dictionary.get("ptr").try_to_string()
93+
dictionary.get("ptr").and_then(|v| v.try_to_string())
9194
);
9295
assert!(!dictionary.contains("skipped"));
9396

9497
let enum_dict = dictionary
9598
.get("baz")
96-
.try_to_dictionary()
99+
.and_then(|v| v.try_to_dictionary())
97100
.expect("should be dictionary");
98-
assert_eq!(Some(true), enum_dict.get("Foo").try_to_bool());
101+
assert_eq!(
102+
Some(true),
103+
enum_dict.get("Foo").and_then(|v| v.try_to_bool())
104+
);
99105

100106
assert_eq!(
101107
Ok(ToVar::<f64, i128> {
@@ -146,7 +152,7 @@ fn test_derive_owned_to_variant() -> bool {
146152
let dictionary = variant.try_to_dictionary().expect("should be dictionary");
147153
let array = dictionary
148154
.get("arr")
149-
.try_to_array()
155+
.and_then(|v| v.try_to_array())
150156
.expect("should be array");
151157
assert_eq!(3, array.len());
152158
assert_eq!(

0 commit comments

Comments
 (0)