Skip to content

Commit 718bbfd

Browse files
committed
Port thrust::transform_iterator to cuda
1 parent 6fa7512 commit 718bbfd

13 files changed

+1329
-0
lines changed

libcudacxx/include/cuda/__iterator/transform_iterator.h

Lines changed: 434 additions & 0 deletions
Large diffs are not rendered by default.

libcudacxx/include/cuda/iterator

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of libcu++, the C++ Standard Library for your entire system,
4+
// under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
#ifndef _CUDA_ITERATOR
12+
#define _CUDA_ITERATOR
13+
14+
#include <cuda/std/detail/__config>
15+
16+
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
17+
# pragma GCC system_header
18+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
19+
# pragma clang system_header
20+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
21+
# pragma system_header
22+
#endif // no system header
23+
24+
#include <cuda/__iterator/transform_iterator.h>
25+
#include <cuda/std/iterator>
26+
27+
#endif // _CUDA_ITERATOR
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of libcu++, the C++ Standard Library for your entire system,
4+
// under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
// transform_iterator::operator{++,--,+=,-=}
12+
13+
#include <cuda/iterator>
14+
#include <cuda/std/cassert>
15+
#include <cuda/std/utility>
16+
17+
#include "test_iterators.h"
18+
#include "test_macros.h"
19+
#include "types.h"
20+
21+
template <class Iter>
22+
_CCCL_CONCEPT can_decrement = _CCCL_REQUIRES_EXPR((Iter), Iter i)((--i));
23+
template <class Iter>
24+
_CCCL_CONCEPT can_post_decrement = _CCCL_REQUIRES_EXPR((Iter), Iter i)((i--));
25+
26+
template <class Iter>
27+
_CCCL_CONCEPT can_plus_equal = _CCCL_REQUIRES_EXPR((Iter), Iter i)((i += 1));
28+
template <class Iter>
29+
_CCCL_CONCEPT can_minus_equal = _CCCL_REQUIRES_EXPR((Iter), Iter i)((i -= 1));
30+
31+
template <class Iter>
32+
__host__ __device__ constexpr void test()
33+
{
34+
int buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7};
35+
36+
cuda::transform_iterator iter{Iter{buffer}, PlusOne{}};
37+
assert((++iter).base() == Iter{buffer + 1});
38+
39+
if constexpr (cuda::std::forward_iterator<Iter>)
40+
{
41+
assert((iter++).base() == Iter{buffer + 1});
42+
}
43+
else
44+
{
45+
iter++;
46+
static_assert(cuda::std::is_same_v<decltype(iter++), void>);
47+
}
48+
assert(iter.base() == Iter{buffer + 2});
49+
50+
if constexpr (cuda::std::bidirectional_iterator<Iter>)
51+
{
52+
assert((--iter).base() == Iter{buffer + 1});
53+
assert((iter--).base() == Iter{buffer + 1});
54+
assert(iter.base() == Iter{buffer});
55+
}
56+
else
57+
{
58+
static_assert(!can_decrement<Iter>);
59+
static_assert(!can_post_decrement<Iter>);
60+
}
61+
62+
if constexpr (cuda::std::random_access_iterator<Iter>)
63+
{
64+
assert((iter += 4).base() == Iter{buffer + 4});
65+
assert((iter -= 3).base() == Iter{buffer + 1});
66+
}
67+
else
68+
{
69+
static_assert(!can_plus_equal<Iter>);
70+
static_assert(!can_minus_equal<Iter>);
71+
}
72+
}
73+
74+
__host__ __device__ constexpr bool test()
75+
{
76+
test<cpp17_input_iterator<int*>>();
77+
test<forward_iterator<int*>>();
78+
test<bidirectional_iterator<int*>>();
79+
test<random_access_iterator<int*>>();
80+
test<int*>();
81+
82+
return true;
83+
}
84+
85+
int main(int, char**)
86+
{
87+
test();
88+
static_assert(test(), "");
89+
90+
return 0;
91+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of libcu++, the C++ Standard Library for your entire system,
4+
// under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
// transform_iterator::base
12+
13+
#include <cuda/iterator>
14+
#include <cuda/std/cassert>
15+
#include <cuda/std/type_traits>
16+
#include <cuda/std/utility>
17+
18+
#include "test_iterators.h"
19+
#include "test_macros.h"
20+
#include "types.h"
21+
22+
template <class Iter>
23+
__host__ __device__ constexpr void test()
24+
{
25+
int buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7};
26+
27+
cuda::transform_iterator iter{Iter{buffer}, PlusOne{}};
28+
using transform_iter = decltype(iter);
29+
static_assert(cuda::std::is_same_v<decltype(static_cast<transform_iter&>(iter).base()), Iter const&>);
30+
static_assert(cuda::std::is_same_v<decltype(static_cast<transform_iter&&>(iter).base()), Iter>);
31+
static_assert(cuda::std::is_same_v<decltype(static_cast<const transform_iter&>(iter).base()), Iter const&>);
32+
static_assert(cuda::std::is_same_v<decltype(static_cast<const transform_iter&&>(iter).base()), Iter const&>);
33+
static_assert(noexcept(iter.base()));
34+
static_assert(
35+
noexcept(static_cast<transform_iter&&>(iter).base()) == cuda::std::is_nothrow_move_constructible_v<Iter>);
36+
assert(base(iter.base()) == buffer);
37+
assert(base(cuda::std::move(iter).base()) == buffer);
38+
}
39+
40+
__host__ __device__ constexpr bool test()
41+
{
42+
test<cpp17_input_iterator<int*>>();
43+
test<random_access_iterator<int*>>();
44+
test<int*>();
45+
46+
return true;
47+
}
48+
49+
int main(int, char**)
50+
{
51+
test();
52+
static_assert(test(), "");
53+
54+
return 0;
55+
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of libcu++, the C++ Standard Library for your entire system,
4+
// under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
// transform_iterator::operator{<,>,<=,>=,==,!=,<=>}
12+
13+
#include <cuda/iterator>
14+
#if _LIBCUDACXX_HAS_SPACESHIP_OPERATOR()
15+
# include <cuda/std/compare>
16+
#endif // _LIBCUDACXX_HAS_SPACESHIP_OPERATOR()
17+
18+
#include "test_iterators.h"
19+
#include "test_macros.h"
20+
#include "types.h"
21+
22+
template <class Iter>
23+
__host__ __device__ constexpr void test()
24+
{
25+
int buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7};
26+
27+
cuda::transform_iterator iter1{Iter{buffer}, PlusOne{}};
28+
cuda::transform_iterator iter2{Iter{buffer + 4}, PlusOne{}};
29+
30+
assert(!(iter1 < iter1));
31+
assert(iter1 < iter2);
32+
assert(!(iter2 < iter1));
33+
assert(iter1 <= iter1);
34+
assert(iter1 <= iter2);
35+
assert(!(iter2 <= iter1));
36+
assert(!(iter1 > iter1));
37+
assert(!(iter1 > iter2));
38+
assert(iter2 > iter1);
39+
assert(iter1 >= iter1);
40+
assert(!(iter1 >= iter2));
41+
assert(iter2 >= iter1);
42+
assert(iter1 == iter1);
43+
assert(!(iter1 == iter2));
44+
assert(iter2 == iter2);
45+
assert(!(iter1 != iter1));
46+
assert(iter1 != iter2);
47+
assert(!(iter2 != iter2));
48+
49+
#if TEST_HAS_SPACESHIP()
50+
// Test a new-school iterator with operator<=>; the transform iterator should also have operator<=>.
51+
if constexpr (cuda::std::is_same_v<Iter, three_way_contiguous_iterator<int*>>)
52+
{
53+
static_assert(cuda::std::three_way_comparable<Iter>);
54+
static_assert(cuda::std::three_way_comparable<decltype(iter1)>);
55+
56+
assert((iter1 <=> iter2) == cuda::std::strong_ordering::less);
57+
assert((iter1 <=> iter1) == cuda::std::strong_ordering::equal);
58+
assert((iter2 <=> iter1) == cuda::std::strong_ordering::greater);
59+
}
60+
#endif // TEST_HAS_SPACESHIP()
61+
}
62+
63+
__host__ __device__ constexpr bool test()
64+
{
65+
test<random_access_iterator<int*>>();
66+
test<contiguous_iterator<int*>>();
67+
test<int*>();
68+
69+
#if TEST_HAS_SPACESHIP()
70+
test<three_way_contiguous_iterator<int*>>();
71+
#endif // TEST_HAS_SPACESHIP()
72+
73+
return true;
74+
}
75+
76+
int main(int, char**)
77+
{
78+
test();
79+
static_assert(test(), "");
80+
81+
return 0;
82+
}
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of libcu++, the C++ Standard Library for your entire system,
4+
// under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
// transform_iterator::transform_iterator();
12+
13+
#include <cuda/iterator>
14+
#include <cuda/std/cassert>
15+
#include <cuda/std/concepts>
16+
17+
#include "test_iterators.h"
18+
#include "test_macros.h"
19+
#include "types.h"
20+
21+
struct NoDefaultInitIter
22+
{
23+
int* ptr_;
24+
typedef cuda::std::random_access_iterator_tag iterator_category;
25+
typedef int value_type;
26+
typedef cuda::std::ptrdiff_t difference_type;
27+
typedef int* pointer;
28+
typedef int& reference;
29+
typedef NoDefaultInitIter self;
30+
31+
__host__ __device__ constexpr NoDefaultInitIter(int* ptr)
32+
: ptr_(ptr)
33+
{}
34+
35+
__host__ __device__ constexpr reference operator*() const;
36+
__host__ __device__ constexpr pointer operator->() const;
37+
#if TEST_HAS_SPACESHIP()
38+
__host__ __device__ constexpr auto operator<=>(const self&) const = default;
39+
#else // ^^^ TEST_HAS_SPACESHIP() ^^^ / vvv !TEST_HAS_SPACESHIP() vvv
40+
__host__ __device__ constexpr bool operator<(const self&) const;
41+
__host__ __device__ constexpr bool operator<=(const self&) const;
42+
__host__ __device__ constexpr bool operator>(const self&) const;
43+
__host__ __device__ constexpr bool operator>=(const self&) const;
44+
#endif // !TEST_HAS_SPACESHIP()
45+
46+
__host__ __device__ constexpr friend bool operator==(const self& lhs, const self& rhs)
47+
{
48+
return lhs.ptr_ == rhs.ptr_;
49+
}
50+
#if TEST_STD_VER <= 2017
51+
__host__ __device__ constexpr friend bool operator!=(const self& lhs, const self& rhs)
52+
{
53+
return lhs.ptr_ != rhs.ptr_;
54+
}
55+
#endif // TEST_STD_VER <= 2017
56+
57+
__host__ __device__ constexpr self& operator++();
58+
__host__ __device__ constexpr self operator++(int);
59+
60+
__host__ __device__ constexpr self& operator--();
61+
__host__ __device__ constexpr self operator--(int);
62+
63+
__host__ __device__ constexpr self& operator+=(difference_type n);
64+
__host__ __device__ constexpr self operator+(difference_type n) const;
65+
__host__ __device__ constexpr friend self operator+(difference_type n, self x);
66+
67+
__host__ __device__ constexpr self& operator-=(difference_type n);
68+
__host__ __device__ constexpr self operator-(difference_type n) const;
69+
__host__ __device__ constexpr difference_type operator-(const self&) const;
70+
71+
__host__ __device__ constexpr reference operator[](difference_type n) const;
72+
};
73+
74+
struct NoDefaultInitFunc
75+
{
76+
int val_;
77+
78+
__host__ __device__ constexpr NoDefaultInitFunc(int val)
79+
: val_(val)
80+
{}
81+
82+
__host__ __device__ constexpr int operator()(int x) const
83+
{
84+
return x * val_;
85+
}
86+
};
87+
88+
template <class Iter, class Fn>
89+
__host__ __device__ constexpr void test(Fn fun)
90+
{
91+
int buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7};
92+
93+
{ // default initialization
94+
constexpr bool can_default_init = cuda::std::default_initializable<Iter> && cuda::std::default_initializable<Fn>;
95+
static_assert(cuda::std::default_initializable<cuda::transform_iterator<Iter, Fn>> == can_default_init);
96+
if constexpr (can_default_init)
97+
{
98+
[[maybe_unused]] cuda::transform_iterator<Iter, Fn> iter{};
99+
}
100+
}
101+
102+
{ // construction from iter and functor
103+
cuda::transform_iterator iter{Iter{buffer}, fun};
104+
assert(iter.base() == Iter{buffer});
105+
}
106+
}
107+
108+
__host__ __device__ constexpr bool test()
109+
{
110+
test<NoDefaultInitIter>(PlusOne{});
111+
test<random_access_iterator<int*>>(PlusOne{});
112+
113+
NoDefaultInitFunc func{42};
114+
test<NoDefaultInitIter>(func);
115+
test<random_access_iterator<int*>>(func);
116+
117+
return true;
118+
}
119+
120+
int main(int, char**)
121+
{
122+
test();
123+
static_assert(test(), "");
124+
125+
return 0;
126+
}

0 commit comments

Comments
 (0)