Skip to content

Commit 1d0c9d6

Browse files
authored
[Kernel] some optimizations for dense marlin and moe marlin (vllm-project#16850)
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
1 parent f62cad6 commit 1d0c9d6

File tree

26 files changed

+3501
-3257
lines changed

26 files changed

+3501
-3257
lines changed

CMakeLists.txt

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,52 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
301301
# are not supported by Machete yet.
302302
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
303303
if (MARLIN_ARCHS)
304+
305+
#
306+
# For the Marlin kernels we automatically generate sources for various
307+
# preselected input type pairs and schedules.
308+
# Generate sources:
309+
set(MARLIN_GEN_SCRIPT
310+
${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py)
311+
file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH)
312+
313+
message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH}")
314+
message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH}")
315+
316+
if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH}
317+
OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH})
318+
execute_process(
319+
COMMAND ${CMAKE_COMMAND} -E env
320+
PYTHONPATH=$PYTHONPATH
321+
${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT}
322+
RESULT_VARIABLE marlin_generation_result
323+
OUTPUT_VARIABLE marlin_generation_result
324+
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log
325+
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log
326+
)
327+
328+
if (NOT marlin_generation_result EQUAL 0)
329+
message(FATAL_ERROR "Marlin generation failed."
330+
" Result: \"${marlin_generation_result}\""
331+
"\nCheck the log for details: "
332+
"${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log")
333+
else()
334+
set(MARLIN_GEN_SCRIPT_HASH ${MARLIN_GEN_SCRIPT_HASH}
335+
CACHE STRING "Last run Marlin generate script hash" FORCE)
336+
message(STATUS "Marlin generation completed successfully.")
337+
endif()
338+
else()
339+
message(STATUS "Marlin generation script has not changed, skipping generation.")
340+
endif()
341+
342+
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/kernel_*.cu")
343+
set_gencode_flags_for_srcs(
344+
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
345+
CUDA_ARCHS "${MARLIN_ARCHS}")
346+
347+
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
348+
304349
set(MARLIN_SRCS
305-
"csrc/quantization/fp8/fp8_marlin.cu"
306350
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
307351
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
308352
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
@@ -644,7 +688,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
644688
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH})
645689
execute_process(
646690
COMMAND ${CMAKE_COMMAND} -E env
647-
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
691+
PYTHONPATH=$PYTHONPATH
648692
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT}
649693
RESULT_VARIABLE moe_marlin_generation_result
650694
OUTPUT_VARIABLE moe_marlin_generation_output

csrc/moe/marlin_moe_wna16/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
kernel_*.cu

csrc/moe/marlin_moe_wna16/generate_kernels.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,13 @@
2525
"{{thread_k_blocks}}, "
2626
"{{'true' if m_block_size_8 else 'false'}}, "
2727
"{{stages}}, "
28-
"{{'true' if has_act_order else 'false'}}, "
29-
"{{'true' if has_zp else 'false'}}, "
3028
"{{group_blocks}}, "
3129
"{{'true' if is_zp_float else 'false'}}>"
3230
"( MARLIN_KERNEL_PARAMS );")
3331

3432
# int8 with zero point case (vllm::kU8) is also supported,
3533
# we don't add it to reduce wheel size.
36-
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128"]
34+
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"]
3735
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
3836

3937
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
@@ -52,21 +50,29 @@ def remove_old_kernels():
5250

5351
def generate_new_kernels():
5452
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
55-
has_zp = "B" not in scalar_type
5653
all_template_str_list = []
5754

5855
for group_blocks, m_blocks, thread_configs in itertools.product(
5956
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
6057

61-
has_act_order = group_blocks == 0
62-
if has_zp and has_act_order:
58+
# act order case only support gptq-int4 and gptq-int8
59+
if group_blocks == 0 and scalar_type not in [
60+
"vllm::kU4B8", "vllm::kU8B128"
61+
]:
6362
continue
6463
if thread_configs[2] == 256:
64+
# for small batch (m_blocks == 1), we only need (128, 128, 256)
65+
# for large batch (m_blocks > 1), we only need (64, 256, 256)
6566
if m_blocks <= 1 and thread_configs[0] != 128:
6667
continue
6768
if m_blocks > 1 and thread_configs[0] != 64:
6869
continue
6970

71+
# we only support channelwise quantization and group_size == 128
72+
# for fp8
73+
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
74+
continue
75+
7076
k_blocks = thread_configs[0] // 16
7177
n_blocks = thread_configs[1] // 16
7278
threads = thread_configs[2]
@@ -82,8 +88,6 @@ def generate_new_kernels():
8288
thread_k_blocks=k_blocks,
8389
m_block_size_8=m_blocks == 0.5,
8490
stages="pipe_stages",
85-
has_act_order=has_act_order,
86-
has_zp=has_zp,
8791
group_blocks=group_blocks,
8892
is_zp_float=False,
8993
)

csrc/moe/marlin_moe_wna16/kernel.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
const float *__restrict__ topk_weights_ptr, int top_k, \
1919
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
2020
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
21-
bool use_fp32_reduce
21+
bool use_fp32_reduce, int max_shared_mem
2222

2323
namespace MARLIN_NAMESPACE_NAME {
2424
template <typename scalar_t, // compute dtype, half or nv_float16
@@ -33,11 +33,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
3333
// only works when thread_m_blocks == 1
3434
const int stages, // number of stages for the async global->shared
3535
// fetch pipeline
36-
const bool has_act_order, // whether act_order is enabled
37-
const bool has_zp, // whether zero-points are enabled
38-
const int group_blocks, // number of consecutive 16x16 blocks
39-
// with a separate quantization scale
40-
const bool is_zp_float // is zero point of float16 type?
36+
const int group_blocks, // number of consecutive 16x16 blocks
37+
// with a separate quantization scale
38+
const bool is_zp_float // is zero point of float16 type?
4139
>
4240
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
4341

0 commit comments

Comments
 (0)