Skip to content

Commit 689b837

Browse files
committed
Make usage of mutable functors a bit more sane
1 parent 6dfdea7 commit 689b837

File tree

4 files changed

+62
-5
lines changed

4 files changed

+62
-5
lines changed

libcudacxx/include/cuda/__iterator/transform_iterator.h

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -215,10 +215,19 @@ class transform_iterator : public __transform_iterator_category_base<_Iter, _Fn>
215215
}
216216

217217
//! @brief Invokes the stored functor with the value pointed to by the stored iterator
218+
_CCCL_TEMPLATE(class _Iter2 = _Iter)
219+
_CCCL_REQUIRES(_CUDA_VSTD::regular_invocable<const _Fn&, _CUDA_VSTD::iter_reference_t<const _Iter2>>)
218220
[[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI constexpr decltype(auto) operator*() const
219-
noexcept(noexcept(_CUDA_VSTD::invoke(const_cast<_Fn&>(*__func_), *__current_)))
221+
noexcept(noexcept(_CUDA_VSTD::invoke(*__func_, *__current_)))
220222
{
221-
return _CUDA_VSTD::invoke(const_cast<_Fn&>(*__func_), *__current_);
223+
return _CUDA_VSTD::invoke(*__func_, *__current_);
224+
}
225+
226+
//! @brief Invokes the stored functor with the value pointed to by the stored iterator
227+
[[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI constexpr decltype(auto)
228+
operator*() noexcept(noexcept(_CUDA_VSTD::invoke(*__func_, *__current_)))
229+
{
230+
return _CUDA_VSTD::invoke(*__func_, *__current_);
222231
}
223232

224233
//! @brief Increments the stored iterator
@@ -289,11 +298,23 @@ class transform_iterator : public __transform_iterator_category_base<_Iter, _Fn>
289298
//! @param __n The additional offset
290299
//! @returns _CUDA_VSTD::invoke(__func_, __current_[__n])
291300
_CCCL_TEMPLATE(class _Iter2 = _Iter)
292-
_CCCL_REQUIRES(_CUDA_VSTD::random_access_iterator<_Iter2>)
301+
_CCCL_REQUIRES(_CUDA_VSTD::random_access_iterator<_Iter2> _CCCL_AND
302+
_CUDA_VSTD::regular_invocable<const _Fn&, _CUDA_VSTD::iter_reference_t<const _Iter2>>)
293303
[[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI constexpr decltype(auto) operator[](difference_type __n) const
294-
noexcept(__transform_iterator_nothrow_subscript<_Fn, _Iter2>)
304+
noexcept(__transform_iterator_nothrow_subscript<const _Fn, _Iter2>)
305+
{
306+
return _CUDA_VSTD::invoke(*__func_, __current_[__n]);
307+
}
308+
309+
//! @brief Subscripts the stored iterator by \p __n and applies the stored functor to the result
310+
//! @param __n The additional offset
311+
//! @returns _CUDA_VSTD::invoke(__func_, __current_[__n])
312+
_CCCL_TEMPLATE(class _Iter2 = _Iter)
313+
_CCCL_REQUIRES(_CUDA_VSTD::random_access_iterator<_Iter2>)
314+
[[nodiscard]] _LIBCUDACXX_HIDE_FROM_ABI constexpr decltype(auto)
315+
operator[](difference_type __n) noexcept(__transform_iterator_nothrow_subscript<_Fn, _Iter2>)
295316
{
296-
return _CUDA_VSTD::invoke(const_cast<_Fn&>(*__func_), __current_[__n]);
317+
return _CUDA_VSTD::invoke(*__func_, __current_[__n]);
297318
}
298319

299320
//! @brief Compares two \c transform_iterator for equality, directly comparing the stored iterators

libcudacxx/test/libcudacxx/cuda/iterators/transform_iterator/deref.pass.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ __host__ __device__ constexpr void test()
2727
{
2828
cuda::transform_iterator iter{Iter{buffer}, PlusOne{}};
2929
assert(*iter == 1);
30+
assert(*cuda::std::as_const(iter) == 1);
3031
assert(*buffer == 0);
3132

3233
static_assert(!noexcept(*iter));
@@ -68,6 +69,17 @@ __host__ __device__ constexpr void test()
6869
static_assert(!noexcept(*iter));
6970
static_assert(cuda::std::is_same_v<int&&, decltype(*iter)>);
7071
}
72+
73+
{
74+
cuda::transform_iterator iter{Iter{buffer}, PlusWithMutableMember{3}};
75+
assert(*iter == 5);
76+
assert(*iter == 6);
77+
assert(*iter == 7);
78+
assert(*buffer == 2);
79+
80+
static_assert(noexcept(*iter) == noexcept(*cuda::std::declval<Iter>()));
81+
static_assert(cuda::std::is_same_v<int, decltype(*iter)>);
82+
}
7183
}
7284

7385
__host__ __device__ constexpr bool test()

libcudacxx/test/libcudacxx/cuda/iterators/transform_iterator/subscript.pass.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ __host__ __device__ constexpr void test()
3131
{
3232
cuda::transform_iterator iter{Iter{buffer}, PlusOne{}};
3333
assert(iter[4] == 5);
34+
assert(cuda::std::as_const(iter)[4] == 5);
3435
assert(buffer[4] == 4);
3536

3637
static_assert(!noexcept(iter[4]));
@@ -72,6 +73,17 @@ __host__ __device__ constexpr void test()
7273
static_assert(!noexcept(iter[4]));
7374
static_assert(cuda::std::is_same_v<int&&, decltype(iter[4])>);
7475
}
76+
77+
{
78+
cuda::transform_iterator iter{Iter{buffer}, PlusWithMutableMember{3}};
79+
assert(iter[4] == 9);
80+
assert(iter[4] == 10);
81+
assert(iter[4] == 11);
82+
assert(buffer[4] == 6);
83+
84+
static_assert(noexcept(iter[4]) == noexcept(*cuda::std::declval<Iter>()));
85+
static_assert(cuda::std::is_same_v<int, decltype(iter[4])>);
86+
}
7587
}
7688
else
7789
{

libcudacxx/test/libcudacxx/cuda/iterators/transform_iterator/types.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,16 @@ struct PlusOneNoexcept
6363
}
6464
};
6565

66+
struct PlusWithMutableMember
67+
{
68+
int val_ = 0;
69+
__host__ __device__ constexpr PlusWithMutableMember(const int val) noexcept
70+
: val_(val)
71+
{}
72+
__host__ __device__ constexpr int operator()(int x) noexcept
73+
{
74+
return x + val_++;
75+
}
76+
};
77+
6678
#endif // TEST_CUDA_TRANSFORM_ITERATOR_TYPES_H

0 commit comments

Comments
 (0)