@@ -27,6 +27,26 @@ __device__ inline void vectorize_with_alignment(
27
27
constexpr int WIDTH = VEC_SIZE * sizeof (InT); // eg: 64 B
28
28
uintptr_t addr = reinterpret_cast <uintptr_t >(in);
29
29
30
+ // fast path when the whole region is already aligned
31
+ // Note: currently the output is guaranteed to be same as the input, so we
32
+ // don't check it here, comments here just for future reference.
33
+ bool can_vec = ((addr & (WIDTH - 1 )) == 0 ) && ((len & (VEC_SIZE - 1 )) == 0 );
34
+ if (can_vec) {
35
+ int num_vec = len / VEC_SIZE;
36
+
37
+ using vin_t = vec_n_t <InT, VEC_SIZE>;
38
+ using vout_t = vec_n_t <OutT, VEC_SIZE>;
39
+ auto * v_in = reinterpret_cast <const vin_t *>(in);
40
+ auto * v_out = reinterpret_cast <vout_t *>(out);
41
+
42
+ for (int i = tid; i < num_vec; i += stride) {
43
+ vout_t tmp;
44
+ vec_op (tmp, v_in[i]);
45
+ v_out[i] = tmp;
46
+ }
47
+ return ;
48
+ }
49
+
30
50
int misalignment_offset = addr & (WIDTH - 1 ); // addr % 64
31
51
int alignment_bytes = WIDTH - misalignment_offset; // 64 - (addr % 64)
32
52
int prefix_elems = alignment_bytes & (WIDTH - 1 ); // handle 64
@@ -72,4 +92,81 @@ __device__ __forceinline__ void vectorize_with_alignment(const InT* in,
72
92
std::forward<ScaOp>(scalar_op));
73
93
}
74
94
95
+ template <int VEC_SIZE, typename InT, typename ScaOp>
96
+ struct DefaultReadVecOp {
97
+ ScaOp scalar_op;
98
+
99
+ __device__ __forceinline__ void operator ()(
100
+ const vec_n_t <InT, VEC_SIZE>& src) const {
101
+ #pragma unroll
102
+ for (int i = 0 ; i < VEC_SIZE; ++i) {
103
+ scalar_op (src.val [i]);
104
+ }
105
+ }
106
+ };
107
+
108
+ // read-only version: iterate over the input with alignment guarantees
109
+ template <int VEC_SIZE, typename InT, typename VecOp, typename ScaOp>
110
+ __device__ inline void vectorize_read_with_alignment (const InT* in, int len,
111
+ int tid, int stride,
112
+ VecOp&& vec_op,
113
+ ScaOp&& scalar_op) {
114
+ static_assert (VEC_SIZE > 0 && (VEC_SIZE & (VEC_SIZE - 1 )) == 0 ,
115
+ " VEC_SIZE must be a positive power-of-two" );
116
+ constexpr int WIDTH = VEC_SIZE * sizeof (InT);
117
+ uintptr_t addr = reinterpret_cast <uintptr_t >(in);
118
+
119
+ // fast path when the whole region is already aligned
120
+ bool can_vec = ((addr & (WIDTH - 1 )) == 0 ) && ((len & (VEC_SIZE - 1 )) == 0 );
121
+ if (can_vec) {
122
+ int num_vec = len / VEC_SIZE;
123
+
124
+ using vin_t = vec_n_t <InT, VEC_SIZE>;
125
+ auto * v_in = reinterpret_cast <const vin_t *>(in);
126
+
127
+ for (int i = tid; i < num_vec; i += stride) {
128
+ vec_op (v_in[i]);
129
+ }
130
+ return ;
131
+ }
132
+
133
+ int misalignment_offset = addr & (WIDTH - 1 );
134
+ int alignment_bytes = WIDTH - misalignment_offset;
135
+ int prefix_elems = alignment_bytes & (WIDTH - 1 );
136
+ prefix_elems /= sizeof (InT);
137
+ prefix_elems = min (prefix_elems, len);
138
+
139
+ // 1. handle the possibly unaligned prefix with scalar access.
140
+ for (int i = tid; i < prefix_elems; i += stride) {
141
+ scalar_op (in[i]);
142
+ }
143
+
144
+ in += prefix_elems;
145
+ len -= prefix_elems;
146
+
147
+ int num_vec = len / VEC_SIZE;
148
+ using vin_t = vec_n_t <InT, VEC_SIZE>;
149
+ auto * v_in = reinterpret_cast <const vin_t *>(in);
150
+
151
+ // 2. vectorized traversal of the main aligned region.
152
+ for (int i = tid; i < num_vec; i += stride) {
153
+ vec_op (v_in[i]);
154
+ }
155
+
156
+ // 3. handle remaining tail elements.
157
+ int tail_start = num_vec * VEC_SIZE;
158
+ for (int i = tid + tail_start; i < len; i += stride) {
159
+ scalar_op (in[i]);
160
+ }
161
+ }
162
+
163
+ // overload that requires only a scalar_op
164
+ template <int VEC_SIZE, typename InT, typename ScaOp>
165
+ __device__ __forceinline__ void vectorize_read_with_alignment (
166
+ const InT* in, int len, int tid, int stride, ScaOp&& scalar_op) {
167
+ using Vec = DefaultReadVecOp<VEC_SIZE, InT, std::decay_t <ScaOp>>;
168
+ vectorize_read_with_alignment<VEC_SIZE>(in, len, tid, stride, Vec{scalar_op},
169
+ std::forward<ScaOp>(scalar_op));
170
+ }
171
+
75
172
} // namespace vllm
0 commit comments