Skip to content

Allow mutation through a transform_iterator #2006

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 91 additions & 1 deletion thrust/testing/transform_iterator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ void TestTransformIteratorReferenceAndValueType()
(void) it_tr_ref;

auto it_tr_fwd = thrust::make_transform_iterator(it, forward{});
static_assert(is_same<decltype(it_tr_fwd)::reference, bool&&>::value, "");
static_assert(is_same<decltype(it_tr_fwd)::reference, bool&>::value, "");
static_assert(is_same<decltype(it_tr_fwd)::value_type, bool>::value, "");
(void) it_tr_fwd;

Expand Down Expand Up @@ -264,3 +264,93 @@ void TestTransformIteratorIdentity()
}

DECLARE_UNITTEST(TestTransformIteratorIdentity);

struct foo
{
int x, y;
};

struct access_x
{
_CCCL_HOST_DEVICE int& operator()(foo& f) const noexcept
{
return f.x;
}
};

void TestTransformIteratorAsDestination()
{
constexpr auto n = 10;

auto check = [](auto& dst) {
const thrust::host_vector<foo>& dst_h = dst; // no copy when Vec is a host vector
for (const auto& f : dst_h)
{
ASSERT_EQUAL(f.x, 1234);
ASSERT_EQUAL(f.y, 2);
}
};

// host -> host
{
thrust::host_vector<int> src(n, 1234);
thrust::host_vector<foo> dst(n, foo{1, 2});

// can use iterator and raw pointer as base iterator
thrust::copy(src.begin(), src.end(), thrust::make_transform_iterator(dst.begin(), access_x{}));
check(dst);
thrust::copy(
src.begin(), src.end(), thrust::make_transform_iterator(thrust::raw_pointer_cast(dst.data()), access_x{}));
check(dst);
}

// device -> device
{
thrust::device_vector<int> src(n, 1234);
thrust::device_vector<foo> dst(n, foo{1, 2});

// either unwrap base iterator and specify device execution (thrust would wrongly determine a cross_system)
thrust::copy(thrust::device,
src.begin(),
src.end(),
thrust::make_transform_iterator(thrust::raw_pointer_cast(dst.data()), access_x{}));
check(dst);

// or unwrap base iterator and specify device system (thrust would wrongly determine a cross_system)
using It = thrust::transform_iterator<access_x, foo*, int&, int, thrust::device_system_tag>;
thrust::copy(src.begin(), src.end(), It(thrust::raw_pointer_cast(dst.data()), access_x{}));
check(dst);
}

// host -> device
{
thrust::host_vector<int> src(n, 1234);
thrust::device_vector<foo> dst(n, foo{1, 2});

// either unwrap base iterator and specify device execution (thrust would wrongly determine a cross_system)
thrust::copy(thrust::device,
src.begin(),
src.end(),
thrust::make_transform_iterator(thrust::raw_pointer_cast(dst.data()), access_x{}));
check(dst);

// or unwrap base iterator and specify device system (thrust would wrongly determine a cross_system)
using It = thrust::transform_iterator<access_x, foo*, int&, int, thrust::device_system_tag>;
thrust::copy(src.begin(), src.end(), It(thrust::raw_pointer_cast(dst.data()), access_x{}));
check(dst);
}

// device -> host
{
thrust::device_vector<int> src(n, 1234);
thrust::host_vector<foo> dst(n, foo{1, 2});

// can use iterator and raw pointer as base iterator
thrust::copy(src.begin(), src.end(), thrust::make_transform_iterator(dst.begin(), access_x{}));
check(dst);
thrust::copy(
src.begin(), src.end(), thrust::make_transform_iterator(thrust::raw_pointer_cast(dst.data()), access_x{}));
check(dst);
}
}
DECLARE_UNITTEST(TestTransformIteratorAsDestination);
64 changes: 42 additions & 22 deletions thrust/thrust/iterator/transform_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,21 @@

THRUST_NAMESPACE_BEGIN

template <class UnaryFunction, class Iterator, class Reference, class Value>
template <class UnaryFunction, class Iterator, class Reference, class Value, class System>
class transform_iterator;

namespace detail
{

template <class UnaryFunc, class Iterator>
struct transform_iterator_reference
{
static constexpr bool base_iter_ref_needs_decay =
!::cuda::std::is_same_v<it_reference_t<Iterator>, it_value_t<Iterator>&>;

using func_input_t = ::cuda::std::_If<base_iter_ref_needs_decay, it_value_t<Iterator>, it_reference_t<Iterator>>;

// by default, dereferencing the iterator yields the same as the function.
using type = decltype(::cuda::std::declval<UnaryFunc>()(::cuda::std::declval<it_value_t<Iterator>>()));
using type = decltype(::cuda::std::declval<UnaryFunc>()(::cuda::std::declval<func_input_t>()));
};

// for certain function objects, we need to tweak the reference type. Notably, identity functions must decay to values.
Expand All @@ -81,7 +85,7 @@ struct transform_iterator_reference<functional::actor<Eval>, Iterator>
};

// Type function to compute the iterator_adaptor instantiation to be used for transform_iterator
template <class UnaryFunc, class Iterator, class Reference, class Value>
template <class UnaryFunc, class Iterator, class Reference, class Value, class System>
struct make_transform_iterator_base
{
private:
Expand All @@ -90,11 +94,17 @@ struct make_transform_iterator_base

public:
using type =
iterator_adaptor<transform_iterator<UnaryFunc, Iterator, Reference, Value>,
iterator_adaptor<transform_iterator<UnaryFunc, Iterator, Reference, Value, System>,
Iterator,
value_type,
use_default,
typename ::cuda::std::iterator_traits<Iterator>::iterator_category,
System,
use_default, // FIXME(bgruber): this should probably be
// `typename ::cuda::std::iterator_traits<Iterator>::iterator_category` but with the system
// replaced by System. Something like (doesn't work):
// iterator_category_with_system_and_traversal<typename
// ::cuda::std::iterator_traits<Iterator>::iterator_category,
// System,
// iterator_traversal_t<Iterator>>,
reference>;
};

Expand Down Expand Up @@ -228,15 +238,19 @@ struct make_transform_iterator_base
//! \endcode
//!
//! \see make_transform_iterator
template <class AdaptableUnaryFunction, class Iterator, class Reference = use_default, class Value = use_default>
template <class AdaptableUnaryFunction,
class Iterator,
class Reference = use_default,
class Value = use_default,
class System = use_default>
class transform_iterator
: public detail::make_transform_iterator_base<AdaptableUnaryFunction, Iterator, Reference, Value>::type
: public detail::make_transform_iterator_base<AdaptableUnaryFunction, Iterator, Reference, Value, System>::type
{
//! \cond

public:
using super_t =
typename detail::make_transform_iterator_base<AdaptableUnaryFunction, Iterator, Reference, Value>::type;
typename detail::make_transform_iterator_base<AdaptableUnaryFunction, Iterator, Reference, Value, System>::type;

friend class iterator_core_access;
//! \endcond
Expand Down Expand Up @@ -306,10 +320,6 @@ class transform_iterator
//! \cond

private:
// MSVC 2013 and 2015 incorrectly warning about returning a reference to
// a local/temporary here.
// See goo.gl/LELTNp

_CCCL_EXEC_CHECK_DISABLE
_CCCL_HOST_DEVICE typename super_t::reference dereference() const
{
Expand All @@ -326,14 +336,24 @@ class transform_iterator
// std::cout << e << '\n';
// See: https://godbolt.org/z/jrKcnMqhK

// The workaround is to create a temporary to allow iterators with wrapped/proxy references to convert to their
// value type before calling m_f. This also loads values from a different memory space (cf. `device_reference`).
// Note that this disallows mutable operations through m_f.
detail::it_value_t<Iterator> const& x = *this->base();
// FIXME(bgruber): x may be a reference to a temporary (e.g. if the base iterator is a counting_iterator). If `m_f`
// does not produce an independent copy and super_t::reference is a reference, we return a dangling reference (e.g.
// for any `[thrust|::cuda::std]::identity` functor).
return m_f(x);
static constexpr bool base_iter_ref_needs_decay =
!::cuda::std::is_same_v<detail::it_reference_t<Iterator>, detail::it_value_t<Iterator>&>;

if constexpr (base_iter_ref_needs_decay)
{
// The workaround is to create a temporary to allow iterators with wrapped/proxy references to convert to their
// value type before calling m_f. This also loads values from a different memory space (cf. `device_reference`).
// Note that this disallows mutable operations through m_f.
detail::it_value_t<Iterator> const& x = *this->base();
// FIXME(bgruber): x may be a reference to a temporary (e.g. if the base iterator is a counting_iterator). If
// `m_f` does not produce an independent copy and super_t::reference is a reference, we return a dangling
// reference (e.g. for any `[thrust|::cuda::std]::identity` functor).
return m_f(x);
}
else
{
return ::cuda::std::invoke(m_f, *this->base());
}
}

// tag this as mutable per Dave Abrahams in this thread:
Expand Down