From bf58e85bc2083f7f82c09ec12f17a04f4c36cc6d Mon Sep 17 00:00:00 2001 From: Bernhard Manfred Gruber Date: Thu, 18 Jul 2024 22:43:32 +0200 Subject: [PATCH] Allow mutation through a transform_iterator But only if the transform iterator's base iterator does not return a wrapped reference and is not a device_vector --- thrust/testing/transform_iterator.cu | 92 ++++++++++++++++++++- thrust/thrust/iterator/transform_iterator.h | 64 +++++++++----- 2 files changed, 133 insertions(+), 23 deletions(-) diff --git a/thrust/testing/transform_iterator.cu b/thrust/testing/transform_iterator.cu index 7a013b99395..6f254e46426 100644 --- a/thrust/testing/transform_iterator.cu +++ b/thrust/testing/transform_iterator.cu @@ -183,7 +183,7 @@ void TestTransformIteratorReferenceAndValueType() (void) it_tr_ref; auto it_tr_fwd = thrust::make_transform_iterator(it, forward{}); - static_assert(is_same::value, ""); + static_assert(is_same::value, ""); static_assert(is_same::value, ""); (void) it_tr_fwd; @@ -264,3 +264,93 @@ void TestTransformIteratorIdentity() } DECLARE_UNITTEST(TestTransformIteratorIdentity); + +struct foo +{ + int x, y; +}; + +struct access_x +{ + _CCCL_HOST_DEVICE int& operator()(foo& f) const noexcept + { + return f.x; + } +}; + +void TestTransformIteratorAsDestination() +{ + constexpr auto n = 10; + + auto check = [](auto& dst) { + const thrust::host_vector& dst_h = dst; // no copy when Vec is a host vector + for (const auto& f : dst_h) + { + ASSERT_EQUAL(f.x, 1234); + ASSERT_EQUAL(f.y, 2); + } + }; + + // host -> host + { + thrust::host_vector src(n, 1234); + thrust::host_vector dst(n, foo{1, 2}); + + // can use iterator and raw pointer as base iterator + thrust::copy(src.begin(), src.end(), thrust::make_transform_iterator(dst.begin(), access_x{})); + check(dst); + thrust::copy( + src.begin(), src.end(), thrust::make_transform_iterator(thrust::raw_pointer_cast(dst.data()), access_x{})); + check(dst); + } + + // device -> device + { + thrust::device_vector src(n, 1234); + thrust::device_vector dst(n, foo{1, 2}); + + // either unwrap base iterator and specify device execution (thrust would wrongly determine a cross_system) + thrust::copy(thrust::device, + src.begin(), + src.end(), + thrust::make_transform_iterator(thrust::raw_pointer_cast(dst.data()), access_x{})); + check(dst); + + // or unwrap base iterator and specify device system (thrust would wrongly determine a cross_system) + using It = thrust::transform_iterator; + thrust::copy(src.begin(), src.end(), It(thrust::raw_pointer_cast(dst.data()), access_x{})); + check(dst); + } + + // host -> device + { + thrust::host_vector src(n, 1234); + thrust::device_vector dst(n, foo{1, 2}); + + // either unwrap base iterator and specify device execution (thrust would wrongly determine a cross_system) + thrust::copy(thrust::device, + src.begin(), + src.end(), + thrust::make_transform_iterator(thrust::raw_pointer_cast(dst.data()), access_x{})); + check(dst); + + // or unwrap base iterator and specify device system (thrust would wrongly determine a cross_system) + using It = thrust::transform_iterator; + thrust::copy(src.begin(), src.end(), It(thrust::raw_pointer_cast(dst.data()), access_x{})); + check(dst); + } + + // device -> host + { + thrust::device_vector src(n, 1234); + thrust::host_vector dst(n, foo{1, 2}); + + // can use iterator and raw pointer as base iterator + thrust::copy(src.begin(), src.end(), thrust::make_transform_iterator(dst.begin(), access_x{})); + check(dst); + thrust::copy( + src.begin(), src.end(), thrust::make_transform_iterator(thrust::raw_pointer_cast(dst.data()), access_x{})); + check(dst); + } +} +DECLARE_UNITTEST(TestTransformIteratorAsDestination); diff --git a/thrust/thrust/iterator/transform_iterator.h b/thrust/thrust/iterator/transform_iterator.h index 7783d39bfa6..89481e3969b 100644 --- a/thrust/thrust/iterator/transform_iterator.h +++ b/thrust/thrust/iterator/transform_iterator.h @@ -53,17 +53,21 @@ THRUST_NAMESPACE_BEGIN -template +template class transform_iterator; namespace detail { - template struct transform_iterator_reference { + static constexpr bool base_iter_ref_needs_decay = + !::cuda::std::is_same_v, it_value_t&>; + + using func_input_t = ::cuda::std::_If, it_reference_t>; + // by default, dereferencing the iterator yields the same as the function. - using type = decltype(::cuda::std::declval()(::cuda::std::declval>())); + using type = decltype(::cuda::std::declval()(::cuda::std::declval())); }; // for certain function objects, we need to tweak the reference type. Notably, identity functions must decay to values. @@ -81,7 +85,7 @@ struct transform_iterator_reference, Iterator> }; // Type function to compute the iterator_adaptor instantiation to be used for transform_iterator -template +template struct make_transform_iterator_base { private: @@ -90,11 +94,17 @@ struct make_transform_iterator_base public: using type = - iterator_adaptor, + iterator_adaptor, Iterator, value_type, - use_default, - typename ::cuda::std::iterator_traits::iterator_category, + System, + use_default, // FIXME(bgruber): this should probably be + // `typename ::cuda::std::iterator_traits::iterator_category` but with the system + // replaced by System. Something like (doesn't work): + // iterator_category_with_system_and_traversal::iterator_category, + // System, + // iterator_traversal_t>, reference>; }; @@ -228,15 +238,19 @@ struct make_transform_iterator_base //! \endcode //! //! \see make_transform_iterator -template +template class transform_iterator - : public detail::make_transform_iterator_base::type + : public detail::make_transform_iterator_base::type { //! \cond public: using super_t = - typename detail::make_transform_iterator_base::type; + typename detail::make_transform_iterator_base::type; friend class iterator_core_access; //! \endcond @@ -306,10 +320,6 @@ class transform_iterator //! \cond private: - // MSVC 2013 and 2015 incorrectly warning about returning a reference to - // a local/temporary here. - // See goo.gl/LELTNp - _CCCL_EXEC_CHECK_DISABLE _CCCL_HOST_DEVICE typename super_t::reference dereference() const { @@ -326,14 +336,24 @@ class transform_iterator // std::cout << e << '\n'; // See: https://godbolt.org/z/jrKcnMqhK - // The workaround is to create a temporary to allow iterators with wrapped/proxy references to convert to their - // value type before calling m_f. This also loads values from a different memory space (cf. `device_reference`). - // Note that this disallows mutable operations through m_f. - detail::it_value_t const& x = *this->base(); - // FIXME(bgruber): x may be a reference to a temporary (e.g. if the base iterator is a counting_iterator). If `m_f` - // does not produce an independent copy and super_t::reference is a reference, we return a dangling reference (e.g. - // for any `[thrust|::cuda::std]::identity` functor). - return m_f(x); + static constexpr bool base_iter_ref_needs_decay = + !::cuda::std::is_same_v, detail::it_value_t&>; + + if constexpr (base_iter_ref_needs_decay) + { + // The workaround is to create a temporary to allow iterators with wrapped/proxy references to convert to their + // value type before calling m_f. This also loads values from a different memory space (cf. `device_reference`). + // Note that this disallows mutable operations through m_f. + detail::it_value_t const& x = *this->base(); + // FIXME(bgruber): x may be a reference to a temporary (e.g. if the base iterator is a counting_iterator). If + // `m_f` does not produce an independent copy and super_t::reference is a reference, we return a dangling + // reference (e.g. for any `[thrust|::cuda::std]::identity` functor). + return m_f(x); + } + else + { + return ::cuda::std::invoke(m_f, *this->base()); + } } // tag this as mutable per Dave Abrahams in this thread: