@@ -73,6 +73,53 @@ struct __align__(OP_ALIGNMENT) {0} {{
73
73
return std::format (format_template, diff_t , alignment, size, value_t , deref, advance, iter_def);
74
74
};
75
75
76
+ std::string make_kernel_input_iterator_with_host_increment (
77
+ std::string_view diff_t ,
78
+ size_t alignment,
79
+ size_t size,
80
+ std::string_view iterator_name,
81
+ std::string_view value_t ,
82
+ std::string_view deref,
83
+ std::string_view advance)
84
+ {
85
+ const std::string iter_def = std::format (R"XXX(
86
+ extern "C" __device__ VALUE_T DEREF(const void *self_ptr);
87
+ extern "C" __device__ void ADVANCE(void *self_ptr, DIFF_T offset);
88
+ struct __align__(OP_ALIGNMENT) {0} {{
89
+ using iterator_category = cuda::std::random_access_iterator_tag;
90
+ using value_type = VALUE_T;
91
+ using difference_type = DIFF_T;
92
+ using pointer = VALUE_T*;
93
+ using reference = VALUE_T&;
94
+
95
+ static_assert(sizeof(difference_type) == sizeof(long long int));
96
+
97
+ __device__ inline value_type operator*() const {{
98
+ const {0} &it = (*this + host_offset);
99
+ return DEREF(it.data);
100
+ }}
101
+ __device__ inline {0}& operator+=(difference_type diff) {{
102
+ ADVANCE(data, diff);
103
+ return *this;
104
+ }}
105
+ __device__ inline value_type operator[](difference_type diff) const {{
106
+ const {0} &it = (*this + (diff + host_offset));
107
+ return DEREF(it.data);
108
+ }}
109
+ __device__ inline {0} operator+(difference_type diff) const {{
110
+ {0} result = *this;
111
+ result += diff;
112
+ return result;
113
+ }}
114
+ char data[OP_SIZE];
115
+ difference_type host_offset;
116
+ }};
117
+ )XXX" ,
118
+ iterator_name);
119
+
120
+ return std::format (format_template, diff_t , alignment, size, value_t , deref, advance, iter_def);
121
+ };
122
+
76
123
std::string make_kernel_input_iterator (
77
124
std::string_view offset_t , std::string_view iterator_name, std::string_view input_value_t , cccl_iterator_t iter)
78
125
{
@@ -85,6 +132,18 @@ std::string make_kernel_input_iterator(
85
132
offset_t , iter.alignment , iter.size , iterator_name, input_value_t , iter.dereference .name , iter.advance .name );
86
133
}
87
134
135
+ std::string make_kernel_input_iterator_with_host_increment (
136
+ std::string_view offset_t , std::string_view iterator_name, std::string_view input_value_t , cccl_iterator_t iter)
137
+ {
138
+ if (iter.type == cccl_iterator_kind_t ::CCCL_POINTER)
139
+ {
140
+ return {};
141
+ }
142
+
143
+ return make_kernel_input_iterator_with_host_increment (
144
+ offset_t , iter.alignment , iter.size , iterator_name, input_value_t , iter.dereference .name , iter.advance .name );
145
+ }
146
+
88
147
std::string make_kernel_output_iterator (
89
148
std::string_view diff_t ,
90
149
size_t alignment,
@@ -148,6 +207,73 @@ std::string make_kernel_output_iterator(
148
207
offset_t , iter.alignment , iter.size , iterator_name, input_value_t , iter.dereference .name , iter.advance .name );
149
208
}
150
209
210
+ std::string make_kernel_output_iterator_with_host_increment (
211
+ std::string_view diff_t ,
212
+ size_t alignment,
213
+ size_t size,
214
+ std::string_view iterator_name,
215
+ std::string_view value_t ,
216
+ std::string_view deref,
217
+ std::string_view advance)
218
+ {
219
+ const std::string iter_def = std::format (R"XXX(
220
+ extern "C" __device__ void DEREF(const void *self_ptr, VALUE_T x);
221
+ extern "C" __device__ void ADVANCE(void *self_ptr, DIFF_T offset);
222
+ struct __align__(OP_ALIGNMENT) {0}_state_t {{
223
+ char data[OP_SIZE];
224
+ }};
225
+ struct {0}_proxy_t {{
226
+ __device__ {0}_proxy_t operator=(VALUE_T x) {{
227
+ DEREF(&state, x);
228
+ return *this;
229
+ }}
230
+ {0}_state_t state;
231
+ }};
232
+ struct {0} {{
233
+ using iterator_category = cuda::std::random_access_iterator_tag;
234
+ using difference_type = DIFF_T;
235
+ using value_type = void;
236
+ using pointer = {0}_proxy_t*;
237
+ using reference = {0}_proxy_t;
238
+ __device__ {0}_proxy_t operator*() const {{
239
+ const {0} &it = (*this + host_offset);
240
+ return {{it.state}};
241
+ }}
242
+ __device__ {0}& operator+=(difference_type diff) {{
243
+ ADVANCE(&state, diff);
244
+ return *this;
245
+ }}
246
+ __device__ {0}_proxy_t operator[](difference_type diff) const {{
247
+ {0} result = *this;
248
+ result += (diff + host_offset);
249
+ return {{ result.state }};
250
+ }}
251
+ __device__ {0} operator+(difference_type diff) const {{
252
+ {0} result = *this;
253
+ result += diff;
254
+ return result;
255
+ }}
256
+ {0}_state_t state;
257
+ difference_type host_offset;
258
+ }};
259
+ )XXX" ,
260
+ iterator_name);
261
+
262
+ return std::format (format_template, diff_t , alignment, size, value_t , deref, advance, iter_def);
263
+ };
264
+
265
+ std::string make_kernel_output_iterator_with_host_increment (
266
+ std::string_view offset_t , std::string_view iterator_name, std::string_view input_value_t , cccl_iterator_t iter)
267
+ {
268
+ if (iter.type == cccl_iterator_kind_t ::CCCL_POINTER)
269
+ {
270
+ return {};
271
+ }
272
+
273
+ return make_kernel_output_iterator_with_host_increment (
274
+ offset_t , iter.alignment , iter.size , iterator_name, input_value_t , iter.dereference .name , iter.advance .name );
275
+ }
276
+
151
277
std::string make_kernel_inout_iterator (
152
278
std::string_view diff_t ,
153
279
size_t alignment,
0 commit comments