Skip to content

Make device_reference<T>::operator= const #4740

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions thrust/testing/device_ptr.cu
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>

#include <cuda/std/iterator>

#include <iterator>

#include <unittest/unittest.h>

#ifdef __cpp_lib_concepts
static_assert(std::indirectly_writable<thrust::device_ptr<uint8_t>, uint8_t>);
#endif // __cpp_lib_concepts
static_assert(cuda::std::indirectly_writable<thrust::device_ptr<uint8_t>, uint8_t>);

void TestDevicePointerManipulation()
{
thrust::device_vector<int> data(5);
Expand Down
32 changes: 25 additions & 7 deletions thrust/testing/device_reference.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,29 +58,47 @@ void TestDeviceReferenceAssignmentFromDeviceReference()
{
// test same types
using T0 = int;
thrust::device_vector<T0> v0(2, 0);
thrust::device_vector<T0> v0{0, 0};
thrust::device_reference<T0> ref0 = v0[0];
thrust::device_reference<T0> ref1 = v0[1];

ref0 = 13;

ref1 = ref0;

// ref1 equals 13
ASSERT_EQUAL(13, ref1);
ASSERT_EQUAL(ref0, ref1);

// test const references
const thrust::device_reference<T0> cref0 = v0[0];
const thrust::device_reference<T0> cref1 = v0[1];

cref0 = 13;
cref1 = cref0;

// cref1 equals 13
ASSERT_EQUAL(13, cref1);
ASSERT_EQUAL(cref0, cref1);

// mix const and non-const references
ref0 = 12;
cref0 = ref0;
ASSERT_EQUAL(12, cref0);

cref0 = 11;
ref0 = cref0;
ASSERT_EQUAL(11, cref0);

// test different types
using T1 = float;
thrust::device_vector<T1> v1(1, 0.0f);
thrust::device_vector<T1> v1{0.0f};
thrust::device_reference<T1> ref2 = v1[0];

ref2 = ref1;
ref2 = ref0;

// ref2 equals 13.0f
ASSERT_EQUAL(13.0f, ref2);
// ref2 equals 11.0f
ASSERT_EQUAL(11.0f, ref2);
ASSERT_EQUAL(ref0, ref2);
ASSERT_EQUAL(ref1, ref2);
}
DECLARE_UNITTEST(TestDeviceReferenceAssignmentFromDeviceReference);

Expand Down
44 changes: 21 additions & 23 deletions thrust/thrust/detail/reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class reference
*
* \return <tt>*this</tt>.
*/
_CCCL_HOST_DEVICE derived_type& operator=(reference const& other)
_CCCL_HOST_DEVICE const derived_type& operator=(reference const& other) const
{
assign_from(&other);
return derived();
Expand All @@ -124,21 +124,14 @@ class reference
*
* \return <tt>*this</tt>.
*/
template <typename OtherElement, typename OtherPointer, typename OtherDerived>
_CCCL_HOST_DEVICE
/*! \cond
*/
typename std::enable_if<
std::is_convertible<typename reference<OtherElement, OtherPointer, OtherDerived>::pointer, pointer>::value,
/*! \endcond
*/
derived_type&
/*! \cond
*/
>::type
/*! \endcond
*/
operator=(reference<OtherElement, OtherPointer, OtherDerived> const& other)
template <
typename OtherElement,
typename OtherPointer,
typename OtherDerived,
::cuda::std::enable_if_t<
::cuda::std::is_convertible_v<typename reference<OtherElement, OtherPointer, OtherDerived>::pointer, pointer>,
int> = 0>
_CCCL_HOST_DEVICE const derived_type& operator=(reference<OtherElement, OtherPointer, OtherDerived> const& other) const
{
assign_from(&other);
return derived();
Expand All @@ -150,7 +143,7 @@ class reference
*
* \return <tt>*this</tt>.
*/
_CCCL_HOST_DEVICE derived_type& operator=(value_type const& rhs)
_CCCL_HOST_DEVICE const derived_type& operator=(value_type const& rhs) const
{
assign_from(&rhs);
return derived();
Expand Down Expand Up @@ -323,6 +316,11 @@ class reference
return static_cast<derived_type&>(*this);
}

_CCCL_HOST_DEVICE const derived_type& derived() const
{
return static_cast<const derived_type&>(*this);
}

template <typename System>
_CCCL_HOST_DEVICE value_type convert_to_value_type(System* system) const
{
Expand All @@ -340,14 +338,14 @@ class reference
}

template <typename System0, typename System1, typename OtherPointer>
_CCCL_HOST_DEVICE void assign_from(System0* system0, System1* system1, OtherPointer src)
_CCCL_HOST_DEVICE void assign_from(System0* system0, System1* system1, OtherPointer src) const
{
using thrust::system::detail::generic::select_system;
strip_const_assign_value(select_system(*system0, *system1), src);
}

template <typename OtherPointer>
_CCCL_HOST_DEVICE void assign_from(OtherPointer src)
_CCCL_HOST_DEVICE void assign_from(OtherPointer src) const
{
// Avoid default-constructing systems; instead, just use a null pointer
// for dispatch. This assumes that `get_value` will not access any system
Expand All @@ -358,7 +356,7 @@ class reference
}

template <typename System, typename OtherPointer>
_CCCL_HOST_DEVICE void strip_const_assign_value(System const& system, OtherPointer src)
_CCCL_HOST_DEVICE void strip_const_assign_value(System const& system, OtherPointer src) const
{
System& non_const_system = const_cast<System&>(system);

Expand Down Expand Up @@ -445,7 +443,7 @@ class tagged_reference
*
* \return <tt>*this</tt>.
*/
_CCCL_HOST_DEVICE tagged_reference& operator=(tagged_reference const& other)
_CCCL_HOST_DEVICE const tagged_reference& operator=(tagged_reference const& other) const
{
return base_type::operator=(other);
}
Expand All @@ -461,7 +459,7 @@ class tagged_reference
* \return <tt>*this</tt>.
*/
template <typename OtherElement, typename OtherTag>
_CCCL_HOST_DEVICE tagged_reference& operator=(tagged_reference<OtherElement, OtherTag> const& other)
_CCCL_HOST_DEVICE const tagged_reference& operator=(tagged_reference<OtherElement, OtherTag> const& other) const
{
return base_type::operator=(other);
}
Expand All @@ -472,7 +470,7 @@ class tagged_reference
*
* \return <tt>*this</tt>.
*/
_CCCL_HOST_DEVICE tagged_reference& operator=(value_type const& rhs)
_CCCL_HOST_DEVICE const tagged_reference& operator=(value_type const& rhs) const
{
return base_type::operator=(rhs);
}
Expand Down
11 changes: 9 additions & 2 deletions thrust/thrust/device_reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ class device_reference : public thrust::reference<T, thrust::device_ptr<T>, thru
*/
using pointer = typename super_t::pointer;

device_reference(const device_reference& other) = default;

/*! This copy constructor accepts a const reference to another
* \p device_reference. After this \p device_reference is constructed,
* it shall refer to the same object as \p other.
Expand Down Expand Up @@ -273,6 +275,11 @@ class device_reference : public thrust::reference<T, thrust::device_ptr<T>, thru
: super_t(ptr)
{}

_CCCL_HOST_DEVICE const device_reference& operator=(const device_reference& other) const
{
return super_t::operator=(other);
}

/*! This assignment operator assigns the value of the object referenced by
* the given \p device_reference to the object referenced by this
* \p device_reference.
Expand All @@ -281,7 +288,7 @@ class device_reference : public thrust::reference<T, thrust::device_ptr<T>, thru
* \return <tt>*this</tt>
*/
template <typename OtherT>
_CCCL_HOST_DEVICE device_reference& operator=(const device_reference<OtherT>& other)
_CCCL_HOST_DEVICE const device_reference& operator=(const device_reference<OtherT>& other) const
{
return super_t::operator=(other);
}
Expand All @@ -292,7 +299,7 @@ class device_reference : public thrust::reference<T, thrust::device_ptr<T>, thru
* \param x The value to assign from.
* \return <tt>*this</tt>
*/
_CCCL_HOST_DEVICE device_reference& operator=(const value_type& x)
_CCCL_HOST_DEVICE const device_reference& operator=(const value_type& x) const
{
return super_t::operator=(x);
}
Expand Down
Loading