diff --git a/benchmarks/kernels/weight_shapes.py b/benchmarks/kernels/weight_shapes.py index a27f02394af..99adcfdf5c0 100644 --- a/benchmarks/kernels/weight_shapes.py +++ b/benchmarks/kernels/weight_shapes.py @@ -95,4 +95,10 @@ ([2048, 2816], 1), ([1408, 2048], 0), ], + "CohereLabs/c4ai-command-a-03-2025": [ + ([12288, 73728], 1), # gate_up_proj + ([36864, 12288], 0), # down_proj + ([12288, 14336], 1), # qkv_proj + ([12288, 12288], 0), # o_proj + ], } diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index 9af7833d09f..c0ff78370ce 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -169,8 +169,7 @@ to_cute_constant(sch.tile_shape_mn)|join(', ')}}>; using ClusterShape = Shape<{{ to_cute_constant(sch.cluster_shape_mnk)|join(', ')}}>; - // TODO: Reimplement - // using KernelSchedule = {{KernelScheduleTag[sch.kernel_schedule]}}; + using KernelSchedule = {{KernelScheduleTag[sch.kernel_schedule]}}; using EpilogueSchedule = {{EpilogueScheduleTag[sch.epilogue_schedule]}}; using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}}; using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; @@ -192,7 +191,6 @@ {{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT {{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT {{DataTypeTag[t.a_token_scale]}}, // TokenScaleT - cutlass::gemm::KernelTmaWarpSpecializedCooperative, Sch>; {% for sch in schs %} @@ -504,16 +502,29 @@ def generate(): "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)), "M > 16": ((256, 32), (2, 1, 1)), #### M = 1-16 - "N >= 26624": ((256, 16), (1, 1, 1)), None: ((128, 16), (1, 1, 1)), } + sch_config_overrides = { + #### M = 1-16 + None: + dict( + kernel_schedule=MixedInputKernelScheduleType.TmaWarpSpecialized, + epilogue_schedule=TmaCoop, + tile_scheduler=TileSchedulerType.Default, + ) + } + # For now we use the same heuristic for all types # Heuristic is currently tuned for H100s default_heuristic = [ - (cond, ScheduleConfig(*tile_config, - **sch_common_params)) # type: ignore - for cond, tile_config in default_tile_heuristic_config.items() + ( + cond, + ScheduleConfig( + *tile_config, + **sch_config_overrides.get(cond, + sch_common_params) # type: ignore + )) for cond, tile_config in default_tile_heuristic_config.items() ] def get_unique_schedules(heuristic: dict[str, ScheduleConfig]): diff --git a/csrc/quantization/machete/machete_mm_kernel.cuh b/csrc/quantization/machete/machete_mm_kernel.cuh index cc50e68b058..39710202d46 100644 --- a/csrc/quantization/machete/machete_mm_kernel.cuh +++ b/csrc/quantization/machete/machete_mm_kernel.cuh @@ -40,8 +40,7 @@ using namespace cute; // we compute the transpose to move it to the left-hand side. template + typename ChannelScaleT, typename TokenScaleT, typename ScheduleConfig> struct MacheteKernelTemplate { static constexpr bool with_C = false; // not ever used static constexpr bool with_group_scales = !std::is_same_v; @@ -101,6 +100,7 @@ struct MacheteKernelTemplate { using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; + using KernelSchedule = typename ScheduleConfig::KernelSchedule; using PrepackedLayoutB = PrepackedLayoutBTemplate;