From 1c1184f5f5a2882986710970c6a3054e2cbe3aa1 Mon Sep 17 00:00:00 2001 From: czhu-cohere Date: Tue, 8 Jul 2025 19:08:44 +0000 Subject: [PATCH 1/3] small bs optimization Signed-off-by: czhu-cohere --- benchmarks/kernels/weight_shapes.py | 6 ++++++ csrc/quantization/machete/generate.py | 21 +++++++++++++------ .../machete/machete_mm_kernel.cuh | 3 ++- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/benchmarks/kernels/weight_shapes.py b/benchmarks/kernels/weight_shapes.py index a27f02394af..f937be2ba04 100644 --- a/benchmarks/kernels/weight_shapes.py +++ b/benchmarks/kernels/weight_shapes.py @@ -95,4 +95,10 @@ ([2048, 2816], 1), ([1408, 2048], 0), ], + "CohereLabs/c4ai-command-a-03-2025": [ + ([12288, 36864*2], 1), # gate_up_proj + ([36864, 12288], 0), # down_proj + ([12288, 14336], 1), # qkv_proj + ([12288, 12288], 0), # o_proj + ] } diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index 9af7833d09f..795f47fe955 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -169,8 +169,7 @@ to_cute_constant(sch.tile_shape_mn)|join(', ')}}>; using ClusterShape = Shape<{{ to_cute_constant(sch.cluster_shape_mnk)|join(', ')}}>; - // TODO: Reimplement - // using KernelSchedule = {{KernelScheduleTag[sch.kernel_schedule]}}; + using KernelSchedule = {{KernelScheduleTag[sch.kernel_schedule]}}; using EpilogueSchedule = {{EpilogueScheduleTag[sch.epilogue_schedule]}}; using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}}; using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; @@ -192,7 +191,7 @@ {{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT {{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT {{DataTypeTag[t.a_token_scale]}}, // TokenScaleT - cutlass::gemm::KernelTmaWarpSpecializedCooperative, + // cutlass::gemm::KernelTmaWarpSpecializedCooperative, // moved to schedule config Sch>; {% for sch in schs %} @@ -504,15 +503,25 @@ def generate(): "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)), "M > 16": ((256, 32), (2, 1, 1)), #### M = 1-16 - "N >= 26624": ((256, 16), (1, 1, 1)), None: ((128, 16), (1, 1, 1)), } + sch_config_overrides = { + #### M = 1-16 + None: dict( + kernel_schedule=MixedInputKernelScheduleType.TmaWarpSpecialized, + epilogue_schedule=TmaCoop, + tile_scheduler=TileSchedulerType.Default, + ) + } + # For now we use the same heuristic for all types # Heuristic is currently tuned for H100s default_heuristic = [ - (cond, ScheduleConfig(*tile_config, - **sch_common_params)) # type: ignore + (cond, ScheduleConfig( + *tile_config, + **sch_config_overrides.get(cond, sch_common_params) # type: ignore + )) for cond, tile_config in default_tile_heuristic_config.items() ] diff --git a/csrc/quantization/machete/machete_mm_kernel.cuh b/csrc/quantization/machete/machete_mm_kernel.cuh index cc50e68b058..d23c92567b0 100644 --- a/csrc/quantization/machete/machete_mm_kernel.cuh +++ b/csrc/quantization/machete/machete_mm_kernel.cuh @@ -40,7 +40,7 @@ using namespace cute; // we compute the transpose to move it to the left-hand side. template struct MacheteKernelTemplate { static constexpr bool with_C = false; // not ever used @@ -101,6 +101,7 @@ struct MacheteKernelTemplate { using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; + using KernelSchedule = typename ScheduleConfig::KernelSchedule; using PrepackedLayoutB = PrepackedLayoutBTemplate; From afbb816233c465bb507ff0d7073fdb388903890a Mon Sep 17 00:00:00 2001 From: czhu-cohere Date: Tue, 8 Jul 2025 19:44:12 +0000 Subject: [PATCH 2/3] lint Signed-off-by: czhu-cohere --- benchmarks/kernels/weight_shapes.py | 4 ++-- csrc/quantization/machete/generate.py | 14 ++++++++------ csrc/quantization/machete/machete_mm_kernel.cuh | 3 +-- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/benchmarks/kernels/weight_shapes.py b/benchmarks/kernels/weight_shapes.py index f937be2ba04..e9848b2e6ac 100644 --- a/benchmarks/kernels/weight_shapes.py +++ b/benchmarks/kernels/weight_shapes.py @@ -96,9 +96,9 @@ ([1408, 2048], 0), ], "CohereLabs/c4ai-command-a-03-2025": [ - ([12288, 36864*2], 1), # gate_up_proj + ([12288, 73728], 1), # gate_up_proj ([36864, 12288], 0), # down_proj ([12288, 14336], 1), # qkv_proj ([12288, 12288], 0), # o_proj - ] + ], } diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index 795f47fe955..9fab65ec41c 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -508,7 +508,8 @@ def generate(): sch_config_overrides = { #### M = 1-16 - None: dict( + None: + dict( kernel_schedule=MixedInputKernelScheduleType.TmaWarpSpecialized, epilogue_schedule=TmaCoop, tile_scheduler=TileSchedulerType.Default, @@ -518,11 +519,12 @@ def generate(): # For now we use the same heuristic for all types # Heuristic is currently tuned for H100s default_heuristic = [ - (cond, ScheduleConfig( - *tile_config, - **sch_config_overrides.get(cond, sch_common_params) # type: ignore - )) - for cond, tile_config in default_tile_heuristic_config.items() + ( + cond, ScheduleConfig( + *tile_config, + **sch_config_overrides.get(cond, + sch_common_params) # type: ignore + )) for cond, tile_config in default_tile_heuristic_config.items() ] def get_unique_schedules(heuristic: dict[str, ScheduleConfig]): diff --git a/csrc/quantization/machete/machete_mm_kernel.cuh b/csrc/quantization/machete/machete_mm_kernel.cuh index d23c92567b0..39710202d46 100644 --- a/csrc/quantization/machete/machete_mm_kernel.cuh +++ b/csrc/quantization/machete/machete_mm_kernel.cuh @@ -40,8 +40,7 @@ using namespace cute; // we compute the transpose to move it to the left-hand side. template + typename ChannelScaleT, typename TokenScaleT, typename ScheduleConfig> struct MacheteKernelTemplate { static constexpr bool with_C = false; // not ever used static constexpr bool with_group_scales = !std::is_same_v; From 7343cc62fb905c0cea44dc1abe142f0219c61d15 Mon Sep 17 00:00:00 2001 From: czhu-cohere Date: Tue, 8 Jul 2025 20:06:30 +0000 Subject: [PATCH 3/3] precommit Signed-off-by: czhu-cohere --- benchmarks/kernels/weight_shapes.py | 8 ++++---- csrc/quantization/machete/generate.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/benchmarks/kernels/weight_shapes.py b/benchmarks/kernels/weight_shapes.py index e9848b2e6ac..99adcfdf5c0 100644 --- a/benchmarks/kernels/weight_shapes.py +++ b/benchmarks/kernels/weight_shapes.py @@ -96,9 +96,9 @@ ([1408, 2048], 0), ], "CohereLabs/c4ai-command-a-03-2025": [ - ([12288, 73728], 1), # gate_up_proj - ([36864, 12288], 0), # down_proj - ([12288, 14336], 1), # qkv_proj - ([12288, 12288], 0), # o_proj + ([12288, 73728], 1), # gate_up_proj + ([36864, 12288], 0), # down_proj + ([12288, 14336], 1), # qkv_proj + ([12288, 12288], 0), # o_proj ], } diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index 9fab65ec41c..c0ff78370ce 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -191,7 +191,6 @@ {{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT {{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT {{DataTypeTag[t.a_token_scale]}}, // TokenScaleT - // cutlass::gemm::KernelTmaWarpSpecializedCooperative, // moved to schedule config Sch>; {% for sch in schs %} @@ -520,11 +519,12 @@ def generate(): # Heuristic is currently tuned for H100s default_heuristic = [ ( - cond, ScheduleConfig( + cond, + ScheduleConfig( *tile_config, **sch_config_overrides.get(cond, sch_common_params) # type: ignore - )) for cond, tile_config in default_tile_heuristic_config.items() + )) for cond, tile_config in default_tile_heuristic_config.items() ] def get_unique_schedules(heuristic: dict[str, ScheduleConfig]):