Skip to content

Commit b2afef0

Browse files
committed
Expand mutable capture check for is_iter_with_side_effects()
1 parent c6d76bb commit b2afef0

File tree

3 files changed

+98
-27
lines changed

3 files changed

+98
-27
lines changed

clippy_utils/src/ty/mod.rs

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,33 +1377,40 @@ pub fn option_arg_ty<'tcx>(cx: &LateContext<'tcx>, ty: Ty<'tcx>) -> Option<Ty<'t
13771377
}
13781378
}
13791379

1380-
/// Check if `ty` is an `Iterator` and has side effects when iterated over. Currently, this only
1381-
/// checks if the `ty` contains mutable captures, and thus may be imcomplete.
1380+
/// Check if a Ty<'_> of `Iterator` has side effects when iterated over by checking if it
1381+
/// captures any mutable references or equivalents.
13821382
pub fn is_iter_with_side_effects<'tcx>(cx: &LateContext<'tcx>, iter_ty: Ty<'tcx>) -> bool {
1383-
let Some(iter_trait) = cx.tcx.lang_items().iterator_trait() else {
1384-
return false;
1385-
};
1386-
1387-
is_iter_with_side_effects_impl(cx, iter_ty, iter_trait)
1383+
cx.tcx
1384+
.lang_items()
1385+
.iterator_trait()
1386+
.is_some_and(|iter_trait| has_non_owning_mutable_access(cx, iter_ty, iter_trait))
13881387
}
13891388

1390-
fn is_iter_with_side_effects_impl<'tcx>(cx: &LateContext<'tcx>, iter_ty: Ty<'tcx>, iter_trait: DefId) -> bool {
1391-
if implements_trait(cx, iter_ty, iter_trait, &[])
1392-
&& let ty::Adt(_, args) = iter_ty.kind()
1393-
{
1394-
return args.types().any(|arg_ty| {
1395-
if let ty::Closure(_, closure_args) = arg_ty.kind()
1396-
&& let Some(captures) = closure_args.types().next_back()
1397-
{
1398-
captures
1399-
.tuple_fields()
1400-
.iter()
1401-
.any(|capture_ty| matches!(capture_ty.ref_mutability(), Some(Mutability::Mut)))
1402-
} else {
1403-
is_iter_with_side_effects_impl(cx, arg_ty, iter_trait)
1404-
}
1405-
});
1389+
/// Check if `ty` contains mutable references or equivalent, which includes:
1390+
/// - A mutable reference/pointer.
1391+
/// - A reference/pointer to a non-`Freeze` type.
1392+
/// - A `PhantomData` type containing any of the previous.
1393+
fn has_non_owning_mutable_access<'tcx>(cx: &LateContext<'tcx>, ty: Ty<'tcx>, iter_trait: DefId) -> bool {
1394+
match ty.kind() {
1395+
ty::Adt(adt_def, args) if adt_def.is_phantom_data() => args
1396+
.types()
1397+
.any(|arg_ty| has_non_owning_mutable_access(cx, arg_ty, iter_trait)),
1398+
ty::Adt(adt_def, args) => adt_def
1399+
.all_fields()
1400+
.any(|field| has_non_owning_mutable_access(cx, field.ty(cx.tcx, args), iter_trait)),
1401+
ty::Array(elem_ty, _) | ty::Slice(elem_ty) => has_non_owning_mutable_access(cx, *elem_ty, iter_trait),
1402+
ty::Ref(_, pointee_ty, Mutability::Mut) if implements_trait(cx, *pointee_ty, iter_trait, &[]) => {
1403+
has_non_owning_mutable_access(cx, *pointee_ty, iter_trait)
1404+
},
1405+
ty::RawPtr(pointee_ty, mutability) | ty::Ref(_, pointee_ty, mutability) => {
1406+
mutability.is_mut() || !pointee_ty.is_freeze(cx.tcx, cx.typing_env())
1407+
},
1408+
ty::Closure(_, closure_args) => {
1409+
matches!(closure_args.types().next_back(), Some(captures) if has_non_owning_mutable_access(cx, captures, iter_trait))
1410+
},
1411+
ty::Tuple(tuple_args) => tuple_args
1412+
.iter()
1413+
.any(|arg_ty| has_non_owning_mutable_access(cx, arg_ty, iter_trait)),
1414+
_ => false,
14061415
}
1407-
1408-
false
14091416
}

tests/ui/needless_collect.fixed

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,45 @@ fn bar<I: IntoIterator<Item = usize>>(_: Vec<usize>, _: I) {}
128128
fn baz<I: IntoIterator<Item = usize>>(_: I, _: (), _: impl IntoIterator<Item = char>) {}
129129

130130
mod issue9191 {
131+
use std::cell::Cell;
131132
use std::collections::HashSet;
133+
use std::hash::Hash;
134+
use std::marker::PhantomData;
135+
use std::ops::Deref;
132136

133-
fn foo(xs: Vec<i32>, mut ys: HashSet<i32>) {
137+
fn captures_ref_mut(xs: Vec<i32>, mut ys: HashSet<i32>) {
134138
if xs.iter().map(|x| ys.remove(x)).collect::<Vec<_>>().contains(&true) {
135139
todo!()
136140
}
137141
}
142+
143+
#[derive(Debug, Clone)]
144+
struct MyRef<'a>(PhantomData<&'a mut Cell<HashSet<i32>>>, *mut Cell<HashSet<i32>>);
145+
146+
impl MyRef<'_> {
147+
fn new(target: &mut Cell<HashSet<i32>>) -> Self {
148+
MyRef(PhantomData, target)
149+
}
150+
151+
fn get(&mut self) -> &mut Cell<HashSet<i32>> {
152+
unsafe { &mut *self.1 }
153+
}
154+
}
155+
156+
fn captures_phantom(xs: Vec<i32>, mut ys: Cell<HashSet<i32>>) {
157+
let mut ys_ref = MyRef::new(&mut ys);
158+
if xs
159+
.iter()
160+
.map({
161+
let mut ys_ref = ys_ref.clone();
162+
move |x| ys_ref.get().get_mut().remove(x)
163+
})
164+
.collect::<Vec<_>>()
165+
.contains(&true)
166+
{
167+
todo!()
168+
}
169+
}
138170
}
139171

140172
pub fn issue8055(v: impl IntoIterator<Item = i32>) -> Result<impl Iterator<Item = i32>, usize> {

tests/ui/needless_collect.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,45 @@ fn bar<I: IntoIterator<Item = usize>>(_: Vec<usize>, _: I) {}
128128
fn baz<I: IntoIterator<Item = usize>>(_: I, _: (), _: impl IntoIterator<Item = char>) {}
129129

130130
mod issue9191 {
131+
use std::cell::Cell;
131132
use std::collections::HashSet;
133+
use std::hash::Hash;
134+
use std::marker::PhantomData;
135+
use std::ops::Deref;
132136

133-
fn foo(xs: Vec<i32>, mut ys: HashSet<i32>) {
137+
fn captures_ref_mut(xs: Vec<i32>, mut ys: HashSet<i32>) {
134138
if xs.iter().map(|x| ys.remove(x)).collect::<Vec<_>>().contains(&true) {
135139
todo!()
136140
}
137141
}
142+
143+
#[derive(Debug, Clone)]
144+
struct MyRef<'a>(PhantomData<&'a mut Cell<HashSet<i32>>>, *mut Cell<HashSet<i32>>);
145+
146+
impl MyRef<'_> {
147+
fn new(target: &mut Cell<HashSet<i32>>) -> Self {
148+
MyRef(PhantomData, target)
149+
}
150+
151+
fn get(&mut self) -> &mut Cell<HashSet<i32>> {
152+
unsafe { &mut *self.1 }
153+
}
154+
}
155+
156+
fn captures_phantom(xs: Vec<i32>, mut ys: Cell<HashSet<i32>>) {
157+
let mut ys_ref = MyRef::new(&mut ys);
158+
if xs
159+
.iter()
160+
.map({
161+
let mut ys_ref = ys_ref.clone();
162+
move |x| ys_ref.get().get_mut().remove(x)
163+
})
164+
.collect::<Vec<_>>()
165+
.contains(&true)
166+
{
167+
todo!()
168+
}
169+
}
138170
}
139171

140172
pub fn issue8055(v: impl IntoIterator<Item = i32>) -> Result<impl Iterator<Item = i32>, usize> {

0 commit comments

Comments
 (0)