Skip to content

Commit b6e83ae

Browse files
Make device_reference<T>::operator= const
This is in line with device_reference being a proxy reference, where assigning to a proxy reference does not change the reference itself, but the referred object (the const is shallow). This also makes device_ptr satisfy std::indirectly_writable Fixes: NVIDIA#4621
1 parent dde8e11 commit b6e83ae

File tree

4 files changed

+64
-32
lines changed

4 files changed

+64
-32
lines changed

thrust/testing/device_ptr.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
11
#include <thrust/device_ptr.h>
22
#include <thrust/device_vector.h>
33

4+
#include <cuda/std/iterator>
5+
46
#include <unittest/unittest.h>
57

8+
#if _CCCL_STD_VER >= 2020
9+
static_assert(std::indirectly_writable<thrust::device_ptr<uint8_t>, uint8_t>);
10+
static_assert(cuda::std::indirectly_writable<thrust::device_ptr<uint8_t>, uint8_t>);
11+
#endif // _CCCL_STD_VER >= 2020
12+
613
void TestDevicePointerManipulation()
714
{
815
thrust::device_vector<int> data(5);

thrust/testing/device_reference.cu

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,29 +58,47 @@ void TestDeviceReferenceAssignmentFromDeviceReference()
5858
{
5959
// test same types
6060
using T0 = int;
61-
thrust::device_vector<T0> v0(2, 0);
61+
thrust::device_vector<T0> v0{0, 0};
6262
thrust::device_reference<T0> ref0 = v0[0];
6363
thrust::device_reference<T0> ref1 = v0[1];
6464

6565
ref0 = 13;
66-
6766
ref1 = ref0;
6867

6968
// ref1 equals 13
7069
ASSERT_EQUAL(13, ref1);
7170
ASSERT_EQUAL(ref0, ref1);
7271

72+
// test const references
73+
const thrust::device_reference<T0> cref0 = v0[0];
74+
const thrust::device_reference<T0> cref1 = v0[1];
75+
76+
cref0 = 13;
77+
cref1 = cref0;
78+
79+
// cref1 equals 13
80+
ASSERT_EQUAL(13, cref1);
81+
ASSERT_EQUAL(cref0, cref1);
82+
83+
// mix const and non-const references
84+
ref0 = 12;
85+
cref0 = ref0;
86+
ASSERT_EQUAL(12, cref0);
87+
88+
cref0 = 11;
89+
ref0 = cref0;
90+
ASSERT_EQUAL(11, cref0);
91+
7392
// test different types
7493
using T1 = float;
75-
thrust::device_vector<T1> v1(1, 0.0f);
94+
thrust::device_vector<T1> v1{0.0f};
7695
thrust::device_reference<T1> ref2 = v1[0];
7796

78-
ref2 = ref1;
97+
ref2 = ref0;
7998

80-
// ref2 equals 13.0f
81-
ASSERT_EQUAL(13.0f, ref2);
99+
// ref2 equals 11.0f
100+
ASSERT_EQUAL(11.0f, ref2);
82101
ASSERT_EQUAL(ref0, ref2);
83-
ASSERT_EQUAL(ref1, ref2);
84102
}
85103
DECLARE_UNITTEST(TestDeviceReferenceAssignmentFromDeviceReference);
86104

thrust/thrust/detail/reference.h

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class reference
107107
*
108108
* \return <tt>*this</tt>.
109109
*/
110-
_CCCL_HOST_DEVICE derived_type& operator=(reference const& other)
110+
_CCCL_HOST_DEVICE const derived_type& operator=(reference const& other) const
111111
{
112112
assign_from(&other);
113113
return derived();
@@ -124,21 +124,14 @@ class reference
124124
*
125125
* \return <tt>*this</tt>.
126126
*/
127-
template <typename OtherElement, typename OtherPointer, typename OtherDerived>
128-
_CCCL_HOST_DEVICE
129-
/*! \cond
130-
*/
131-
typename std::enable_if<
132-
std::is_convertible<typename reference<OtherElement, OtherPointer, OtherDerived>::pointer, pointer>::value,
133-
/*! \endcond
134-
*/
135-
derived_type&
136-
/*! \cond
137-
*/
138-
>::type
139-
/*! \endcond
140-
*/
141-
operator=(reference<OtherElement, OtherPointer, OtherDerived> const& other)
127+
template <
128+
typename OtherElement,
129+
typename OtherPointer,
130+
typename OtherDerived,
131+
::cuda::std::enable_if_t<
132+
::cuda::std::is_convertible_v<typename reference<OtherElement, OtherPointer, OtherDerived>::pointer, pointer>,
133+
int> = 0>
134+
_CCCL_HOST_DEVICE const derived_type& operator=(reference<OtherElement, OtherPointer, OtherDerived> const& other) const
142135
{
143136
assign_from(&other);
144137
return derived();
@@ -150,7 +143,7 @@ class reference
150143
*
151144
* \return <tt>*this</tt>.
152145
*/
153-
_CCCL_HOST_DEVICE derived_type& operator=(value_type const& rhs)
146+
_CCCL_HOST_DEVICE const derived_type& operator=(value_type const& rhs) const
154147
{
155148
assign_from(&rhs);
156149
return derived();
@@ -323,6 +316,11 @@ class reference
323316
return static_cast<derived_type&>(*this);
324317
}
325318

319+
_CCCL_HOST_DEVICE const derived_type& derived() const
320+
{
321+
return static_cast<const derived_type&>(*this);
322+
}
323+
326324
template <typename System>
327325
_CCCL_HOST_DEVICE value_type convert_to_value_type(System* system) const
328326
{
@@ -340,14 +338,14 @@ class reference
340338
}
341339

342340
template <typename System0, typename System1, typename OtherPointer>
343-
_CCCL_HOST_DEVICE void assign_from(System0* system0, System1* system1, OtherPointer src)
341+
_CCCL_HOST_DEVICE void assign_from(System0* system0, System1* system1, OtherPointer src) const
344342
{
345343
using thrust::system::detail::generic::select_system;
346344
strip_const_assign_value(select_system(*system0, *system1), src);
347345
}
348346

349347
template <typename OtherPointer>
350-
_CCCL_HOST_DEVICE void assign_from(OtherPointer src)
348+
_CCCL_HOST_DEVICE void assign_from(OtherPointer src) const
351349
{
352350
// Avoid default-constructing systems; instead, just use a null pointer
353351
// for dispatch. This assumes that `get_value` will not access any system
@@ -358,7 +356,7 @@ class reference
358356
}
359357

360358
template <typename System, typename OtherPointer>
361-
_CCCL_HOST_DEVICE void strip_const_assign_value(System const& system, OtherPointer src)
359+
_CCCL_HOST_DEVICE void strip_const_assign_value(System const& system, OtherPointer src) const
362360
{
363361
System& non_const_system = const_cast<System&>(system);
364362

@@ -445,7 +443,7 @@ class tagged_reference
445443
*
446444
* \return <tt>*this</tt>.
447445
*/
448-
_CCCL_HOST_DEVICE tagged_reference& operator=(tagged_reference const& other)
446+
_CCCL_HOST_DEVICE const tagged_reference& operator=(tagged_reference const& other) const
449447
{
450448
return base_type::operator=(other);
451449
}
@@ -461,7 +459,7 @@ class tagged_reference
461459
* \return <tt>*this</tt>.
462460
*/
463461
template <typename OtherElement, typename OtherTag>
464-
_CCCL_HOST_DEVICE tagged_reference& operator=(tagged_reference<OtherElement, OtherTag> const& other)
462+
_CCCL_HOST_DEVICE const tagged_reference& operator=(tagged_reference<OtherElement, OtherTag> const& other) const
465463
{
466464
return base_type::operator=(other);
467465
}
@@ -472,7 +470,7 @@ class tagged_reference
472470
*
473471
* \return <tt>*this</tt>.
474472
*/
475-
_CCCL_HOST_DEVICE tagged_reference& operator=(value_type const& rhs)
473+
_CCCL_HOST_DEVICE const tagged_reference& operator=(value_type const& rhs) const
476474
{
477475
return base_type::operator=(rhs);
478476
}

thrust/thrust/device_reference.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,10 @@ class device_reference : public thrust::reference<T, thrust::device_ptr<T>, thru
203203
*/
204204
using pointer = typename super_t::pointer;
205205

206+
_CCCL_HOST_DEVICE device_reference(const device_reference& other)
207+
: super_t(other)
208+
{}
209+
206210
/*! This copy constructor accepts a const reference to another
207211
* \p device_reference. After this \p device_reference is constructed,
208212
* it shall refer to the same object as \p other.
@@ -273,6 +277,11 @@ class device_reference : public thrust::reference<T, thrust::device_ptr<T>, thru
273277
: super_t(ptr)
274278
{}
275279

280+
_CCCL_HOST_DEVICE const device_reference& operator=(const device_reference& other) const
281+
{
282+
return super_t::operator=(other);
283+
}
284+
276285
/*! This assignment operator assigns the value of the object referenced by
277286
* the given \p device_reference to the object referenced by this
278287
* \p device_reference.
@@ -281,7 +290,7 @@ class device_reference : public thrust::reference<T, thrust::device_ptr<T>, thru
281290
* \return <tt>*this</tt>
282291
*/
283292
template <typename OtherT>
284-
_CCCL_HOST_DEVICE device_reference& operator=(const device_reference<OtherT>& other)
293+
_CCCL_HOST_DEVICE const device_reference& operator=(const device_reference<OtherT>& other) const
285294
{
286295
return super_t::operator=(other);
287296
}
@@ -292,7 +301,7 @@ class device_reference : public thrust::reference<T, thrust::device_ptr<T>, thru
292301
* \param x The value to assign from.
293302
* \return <tt>*this</tt>
294303
*/
295-
_CCCL_HOST_DEVICE device_reference& operator=(const value_type& x)
304+
_CCCL_HOST_DEVICE const device_reference& operator=(const value_type& x) const
296305
{
297306
return super_t::operator=(x);
298307
}

0 commit comments

Comments
 (0)