Skip to content

Commit a3a00b9

Browse files
Make device_reference<T>::operator= const (#4740)
This is in line with device_reference being a proxy reference, where assigning to a proxy reference does not change the reference itself, but the referred object (the const is shallow). This also makes device_ptr satisfy std::indirectly_writable Fixes: #4621 Co-authored-by: Michael Schellenberger Costa <miscco@nvidia.com>
1 parent e553155 commit a3a00b9

File tree

4 files changed

+64
-32
lines changed

4 files changed

+64
-32
lines changed

thrust/testing/device_ptr.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
11
#include <thrust/device_ptr.h>
22
#include <thrust/device_vector.h>
33

4+
#include <cuda/std/iterator>
5+
6+
#include <iterator>
7+
48
#include <unittest/unittest.h>
59

10+
#ifdef __cpp_lib_concepts
11+
static_assert(std::indirectly_writable<thrust::device_ptr<uint8_t>, uint8_t>);
12+
#endif // __cpp_lib_concepts
13+
static_assert(cuda::std::indirectly_writable<thrust::device_ptr<uint8_t>, uint8_t>);
14+
615
void TestDevicePointerManipulation()
716
{
817
thrust::device_vector<int> data(5);

thrust/testing/device_reference.cu

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,29 +58,47 @@ void TestDeviceReferenceAssignmentFromDeviceReference()
5858
{
5959
// test same types
6060
using T0 = int;
61-
thrust::device_vector<T0> v0(2, 0);
61+
thrust::device_vector<T0> v0{0, 0};
6262
thrust::device_reference<T0> ref0 = v0[0];
6363
thrust::device_reference<T0> ref1 = v0[1];
6464

6565
ref0 = 13;
66-
6766
ref1 = ref0;
6867

6968
// ref1 equals 13
7069
ASSERT_EQUAL(13, ref1);
7170
ASSERT_EQUAL(ref0, ref1);
7271

72+
// test const references
73+
const thrust::device_reference<T0> cref0 = v0[0];
74+
const thrust::device_reference<T0> cref1 = v0[1];
75+
76+
cref0 = 13;
77+
cref1 = cref0;
78+
79+
// cref1 equals 13
80+
ASSERT_EQUAL(13, cref1);
81+
ASSERT_EQUAL(cref0, cref1);
82+
83+
// mix const and non-const references
84+
ref0 = 12;
85+
cref0 = ref0;
86+
ASSERT_EQUAL(12, cref0);
87+
88+
cref0 = 11;
89+
ref0 = cref0;
90+
ASSERT_EQUAL(11, cref0);
91+
7392
// test different types
7493
using T1 = float;
75-
thrust::device_vector<T1> v1(1, 0.0f);
94+
thrust::device_vector<T1> v1{0.0f};
7695
thrust::device_reference<T1> ref2 = v1[0];
7796

78-
ref2 = ref1;
97+
ref2 = ref0;
7998

80-
// ref2 equals 13.0f
81-
ASSERT_EQUAL(13.0f, ref2);
99+
// ref2 equals 11.0f
100+
ASSERT_EQUAL(11.0f, ref2);
82101
ASSERT_EQUAL(ref0, ref2);
83-
ASSERT_EQUAL(ref1, ref2);
84102
}
85103
DECLARE_UNITTEST(TestDeviceReferenceAssignmentFromDeviceReference);
86104

thrust/thrust/detail/reference.h

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class reference
107107
*
108108
* \return <tt>*this</tt>.
109109
*/
110-
_CCCL_HOST_DEVICE derived_type& operator=(reference const& other)
110+
_CCCL_HOST_DEVICE const derived_type& operator=(reference const& other) const
111111
{
112112
assign_from(&other);
113113
return derived();
@@ -124,21 +124,14 @@ class reference
124124
*
125125
* \return <tt>*this</tt>.
126126
*/
127-
template <typename OtherElement, typename OtherPointer, typename OtherDerived>
128-
_CCCL_HOST_DEVICE
129-
/*! \cond
130-
*/
131-
typename std::enable_if<
132-
std::is_convertible<typename reference<OtherElement, OtherPointer, OtherDerived>::pointer, pointer>::value,
133-
/*! \endcond
134-
*/
135-
derived_type&
136-
/*! \cond
137-
*/
138-
>::type
139-
/*! \endcond
140-
*/
141-
operator=(reference<OtherElement, OtherPointer, OtherDerived> const& other)
127+
template <
128+
typename OtherElement,
129+
typename OtherPointer,
130+
typename OtherDerived,
131+
::cuda::std::enable_if_t<
132+
::cuda::std::is_convertible_v<typename reference<OtherElement, OtherPointer, OtherDerived>::pointer, pointer>,
133+
int> = 0>
134+
_CCCL_HOST_DEVICE const derived_type& operator=(reference<OtherElement, OtherPointer, OtherDerived> const& other) const
142135
{
143136
assign_from(&other);
144137
return derived();
@@ -150,7 +143,7 @@ class reference
150143
*
151144
* \return <tt>*this</tt>.
152145
*/
153-
_CCCL_HOST_DEVICE derived_type& operator=(value_type const& rhs)
146+
_CCCL_HOST_DEVICE const derived_type& operator=(value_type const& rhs) const
154147
{
155148
assign_from(&rhs);
156149
return derived();
@@ -323,6 +316,11 @@ class reference
323316
return static_cast<derived_type&>(*this);
324317
}
325318

319+
_CCCL_HOST_DEVICE const derived_type& derived() const
320+
{
321+
return static_cast<const derived_type&>(*this);
322+
}
323+
326324
template <typename System>
327325
_CCCL_HOST_DEVICE value_type convert_to_value_type(System* system) const
328326
{
@@ -340,14 +338,14 @@ class reference
340338
}
341339

342340
template <typename System0, typename System1, typename OtherPointer>
343-
_CCCL_HOST_DEVICE void assign_from(System0* system0, System1* system1, OtherPointer src)
341+
_CCCL_HOST_DEVICE void assign_from(System0* system0, System1* system1, OtherPointer src) const
344342
{
345343
using thrust::system::detail::generic::select_system;
346344
strip_const_assign_value(select_system(*system0, *system1), src);
347345
}
348346

349347
template <typename OtherPointer>
350-
_CCCL_HOST_DEVICE void assign_from(OtherPointer src)
348+
_CCCL_HOST_DEVICE void assign_from(OtherPointer src) const
351349
{
352350
// Avoid default-constructing systems; instead, just use a null pointer
353351
// for dispatch. This assumes that `get_value` will not access any system
@@ -358,7 +356,7 @@ class reference
358356
}
359357

360358
template <typename System, typename OtherPointer>
361-
_CCCL_HOST_DEVICE void strip_const_assign_value(System const& system, OtherPointer src)
359+
_CCCL_HOST_DEVICE void strip_const_assign_value(System const& system, OtherPointer src) const
362360
{
363361
System& non_const_system = const_cast<System&>(system);
364362

@@ -445,7 +443,7 @@ class tagged_reference
445443
*
446444
* \return <tt>*this</tt>.
447445
*/
448-
_CCCL_HOST_DEVICE tagged_reference& operator=(tagged_reference const& other)
446+
_CCCL_HOST_DEVICE const tagged_reference& operator=(tagged_reference const& other) const
449447
{
450448
return base_type::operator=(other);
451449
}
@@ -461,7 +459,7 @@ class tagged_reference
461459
* \return <tt>*this</tt>.
462460
*/
463461
template <typename OtherElement, typename OtherTag>
464-
_CCCL_HOST_DEVICE tagged_reference& operator=(tagged_reference<OtherElement, OtherTag> const& other)
462+
_CCCL_HOST_DEVICE const tagged_reference& operator=(tagged_reference<OtherElement, OtherTag> const& other) const
465463
{
466464
return base_type::operator=(other);
467465
}
@@ -472,7 +470,7 @@ class tagged_reference
472470
*
473471
* \return <tt>*this</tt>.
474472
*/
475-
_CCCL_HOST_DEVICE tagged_reference& operator=(value_type const& rhs)
473+
_CCCL_HOST_DEVICE const tagged_reference& operator=(value_type const& rhs) const
476474
{
477475
return base_type::operator=(rhs);
478476
}

thrust/thrust/device_reference.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ class device_reference : public thrust::reference<T, thrust::device_ptr<T>, thru
203203
*/
204204
using pointer = typename super_t::pointer;
205205

206+
device_reference(const device_reference& other) = default;
207+
206208
/*! This copy constructor accepts a const reference to another
207209
* \p device_reference. After this \p device_reference is constructed,
208210
* it shall refer to the same object as \p other.
@@ -273,6 +275,11 @@ class device_reference : public thrust::reference<T, thrust::device_ptr<T>, thru
273275
: super_t(ptr)
274276
{}
275277

278+
_CCCL_HOST_DEVICE const device_reference& operator=(const device_reference& other) const
279+
{
280+
return super_t::operator=(other);
281+
}
282+
276283
/*! This assignment operator assigns the value of the object referenced by
277284
* the given \p device_reference to the object referenced by this
278285
* \p device_reference.
@@ -281,7 +288,7 @@ class device_reference : public thrust::reference<T, thrust::device_ptr<T>, thru
281288
* \return <tt>*this</tt>
282289
*/
283290
template <typename OtherT>
284-
_CCCL_HOST_DEVICE device_reference& operator=(const device_reference<OtherT>& other)
291+
_CCCL_HOST_DEVICE const device_reference& operator=(const device_reference<OtherT>& other) const
285292
{
286293
return super_t::operator=(other);
287294
}
@@ -292,7 +299,7 @@ class device_reference : public thrust::reference<T, thrust::device_ptr<T>, thru
292299
* \param x The value to assign from.
293300
* \return <tt>*this</tt>
294301
*/
295-
_CCCL_HOST_DEVICE device_reference& operator=(const value_type& x)
302+
_CCCL_HOST_DEVICE const device_reference& operator=(const value_type& x) const
296303
{
297304
return super_t::operator=(x);
298305
}

0 commit comments

Comments
 (0)