Skip to content

Commit 6b6d930

Browse files
Allow mutation through a transform_iterator
But only if the transform iterator's base iterator returns a true l-value reference (and not a proxy reference).
1 parent 56d99db commit 6b6d930

File tree

4 files changed

+56
-5
lines changed

4 files changed

+56
-5
lines changed

thrust/testing/transform_iterator.cu

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,33 @@ void TestTransformIteratorNonCopyable()
108108
}
109109

110110
DECLARE_UNITTEST(TestTransformIteratorNonCopyable);
111+
112+
struct foo
113+
{
114+
int x, y;
115+
};
116+
117+
struct access_x
118+
{
119+
_CCCL_HOST_DEVICE int& operator()(foo& f) const noexcept
120+
{
121+
return f.x;
122+
}
123+
};
124+
125+
void TestTransformIteratorAsDestination()
126+
{
127+
constexpr auto n = 10;
128+
thrust::host_vector<int> src(n, 1234);
129+
thrust::host_vector<foo> dst(n, foo{1, 2});
130+
131+
thrust::copy(src.begin(), src.end(), thrust::make_transform_iterator(dst.begin(), access_x{}));
132+
133+
for (const auto& f : dst)
134+
{
135+
ASSERT_EQUAL(f.x, 1234);
136+
ASSERT_EQUAL(f.y, 2);
137+
}
138+
}
139+
140+
DECLARE_UNITTEST(TestTransformIteratorAsDestination);

thrust/thrust/detail/reference.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ class reference
308308
pointer const ptr;
309309

310310
// `thrust::detail::is_wrapped_reference` is a trait that indicates whether
311-
// a type is a fancy reference. It detects such types by loooking for a
311+
// a type is a fancy reference. It detects such types by looking for a
312312
// nested `wrapped_reference_hint` type.
313313
struct wrapped_reference_hint
314314
{};

thrust/thrust/iterator/detail/transform_iterator.inl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,14 @@ template <class UnaryFunc, class Iterator, class Reference, class Value>
4545
struct transform_iterator_base
4646
{
4747
private:
48+
using unary_func_input_t =
49+
::cuda::std::_If<::cuda::std::is_lvalue_reference<iterator_reference_t<Iterator>>::value,
50+
iterator_reference_t<Iterator>,
51+
const iterator_value_t<Iterator>&>;
52+
4853
// By default, dereferencing the iterator yields the same as the function.
49-
using reference = typename thrust::detail::ia_dflt_help<
50-
Reference,
51-
thrust::detail::result_of_adaptable_function<UnaryFunc(typename thrust::iterator_value<Iterator>::type)>>::type;
54+
using reference = typename thrust::detail::
55+
ia_dflt_help<Reference, thrust::detail::result_of_adaptable_function<UnaryFunc(unary_func_input_t)>>::type;
5256

5357
// To get the default for Value: remove cvref on the result type.
5458
using value_type = typename thrust::detail::ia_dflt_help<Value, thrust::remove_cvref<reference>>::type;

thrust/thrust/iterator/transform_iterator.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,25 @@ class transform_iterator
298298
// See goo.gl/LELTNp
299299
THRUST_DISABLE_MSVC_WARNING_BEGIN(4172)
300300

301-
_CCCL_EXEC_CHECK_DISABLE
302301
_CCCL_HOST_DEVICE typename super_t::reference dereference() const
302+
{
303+
// TODO(bgruber): I am not sure this is the correct check here. There is also the trait
304+
// thrust::detail::is_wrapped_reference that sounds fitting. Only allowing to pass through l-value references
305+
// strikes me as more conservative though.
306+
// TODO(bgruber): use an if constexpr in C++17
307+
return dereference_impl(::cuda::std::is_lvalue_reference<iterator_reference_t<Iterator>>{});
308+
}
309+
310+
_CCCL_EXEC_CHECK_DISABLE
311+
_CCCL_HOST_DEVICE
312+
typename super_t::reference dereference_impl(::cuda::std::true_type /* iterator returns a T& */) const
313+
{
314+
return m_f(*this->base());
315+
}
316+
317+
_CCCL_EXEC_CHECK_DISABLE
318+
_CCCL_HOST_DEVICE
319+
typename super_t::reference dereference_impl(::cuda::std::false_type /* iterator returns a proxy ref */) const
303320
{
304321
// Create a temporary to allow iterators with wrapped references to
305322
// convert to their value type before calling m_f. Note that this

0 commit comments

Comments
 (0)