File tree Expand file tree Collapse file tree 4 files changed +6
-24
lines changed Expand file tree Collapse file tree 4 files changed +6
-24
lines changed Original file line number Diff line number Diff line change 24
24
25
25
#include " attention_dtypes.h"
26
26
#include " attention_utils.cuh"
27
+ #include " cuda_compat.h"
27
28
28
29
#ifdef USE_ROCM
29
30
#include < hip/hip_bf16.h>
@@ -33,12 +34,6 @@ typedef __hip_bfloat16 __nv_bfloat16;
33
34
#include " ../quantization/fp8/nvidia/quant_utils.cuh"
34
35
#endif
35
36
36
- #ifndef USE_ROCM
37
- #define WARP_SIZE 32
38
- #else
39
- #define WARP_SIZE warpSize
40
- #endif
41
-
42
37
#define MAX (a, b ) ((a) > (b) ? (a) : (b))
43
38
#define MIN (a, b ) ((a) < (b) ? (a) : (b))
44
39
#define DIVIDE_ROUND_UP (a, b ) (((a) + (b) - 1 ) / (b))
@@ -670,7 +665,6 @@ __global__ void paged_attention_v2_reduce_kernel(
670
665
671
666
} // namespace vllm
672
667
673
- #undef WARP_SIZE
674
668
#undef MAX
675
669
#undef MIN
676
670
#undef DIVIDE_ROUND_UP
Original file line number Diff line number Diff line change 18
18
*/
19
19
20
20
#include " attention_kernels.cuh"
21
-
22
- #ifndef USE_ROCM
23
- #define WARP_SIZE 32
24
- #else
25
- #define WARP_SIZE warpSize
26
- #endif
21
+ #include " cuda_compat.h"
27
22
28
23
#define MAX (a, b ) ((a) > (b) ? (a) : (b))
29
24
#define MIN (a, b ) ((a) < (b) ? (a) : (b))
@@ -187,7 +182,6 @@ void paged_attention_v1(
187
182
CALL_V1_LAUNCHER_BLOCK_SIZE)
188
183
}
189
184
190
- #undef WARP_SIZE
191
185
#undef MAX
192
186
#undef MIN
193
187
#undef DIVIDE_ROUND_UP
Original file line number Diff line number Diff line change 18
18
*/
19
19
20
20
#include " attention_kernels.cuh"
21
-
22
- #ifndef USE_ROCM
23
- #define WARP_SIZE 32
24
- #else
25
- #define WARP_SIZE warpSize
26
- #endif
21
+ #include " cuda_compat.h"
27
22
28
23
#define MAX (a, b ) ((a) > (b) ? (a) : (b))
29
24
#define MIN (a, b ) ((a) < (b) ? (a) : (b))
@@ -197,7 +192,6 @@ void paged_attention_v2(
197
192
CALL_V2_LAUNCHER_BLOCK_SIZE)
198
193
}
199
194
200
- #undef WARP_SIZE
201
195
#undef MAX
202
196
#undef MIN
203
197
#undef DIVIDE_ROUND_UP
Original file line number Diff line number Diff line change 4
4
#include <hip/hip_runtime.h>
5
5
#endif
6
6
7
- #ifndef USE_ROCM
8
- #define WARP_SIZE 32
7
+ #if defined( USE_ROCM ) && defined( __GFX9__ )
8
+ #define WARP_SIZE 64
9
9
#else
10
- #define WARP_SIZE warpSize
10
+ #define WARP_SIZE 32
11
11
#endif
12
12
13
13
#ifndef USE_ROCM
You can’t perform that action at this time.
0 commit comments