Skip to content

[wip] optimize memory-bound perf for machete #20641

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions benchmarks/kernels/weight_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,10 @@
([2048, 2816], 1),
([1408, 2048], 0),
],
"CohereLabs/c4ai-command-a-03-2025": [
([12288, 73728], 1), # gate_up_proj
([36864, 12288], 0), # down_proj
([12288, 14336], 1), # qkv_proj
([12288, 12288], 0), # o_proj
],
}
25 changes: 18 additions & 7 deletions csrc/quantization/machete/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,7 @@
to_cute_constant(sch.tile_shape_mn)|join(', ')}}>;
using ClusterShape = Shape<{{
to_cute_constant(sch.cluster_shape_mnk)|join(', ')}}>;
// TODO: Reimplement
// using KernelSchedule = {{KernelScheduleTag[sch.kernel_schedule]}};
using KernelSchedule = {{KernelScheduleTag[sch.kernel_schedule]}};
using EpilogueSchedule = {{EpilogueScheduleTag[sch.epilogue_schedule]}};
using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}};
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
Expand All @@ -192,7 +191,6 @@
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
Sch>;

{% for sch in schs %}
Expand Down Expand Up @@ -504,16 +502,29 @@ def generate():
"M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)),
"M > 16": ((256, 32), (2, 1, 1)),
#### M = 1-16
"N >= 26624": ((256, 16), (1, 1, 1)),
None: ((128, 16), (1, 1, 1)),
}

sch_config_overrides = {
#### M = 1-16
None:
dict(
kernel_schedule=MixedInputKernelScheduleType.TmaWarpSpecialized,
epilogue_schedule=TmaCoop,
tile_scheduler=TileSchedulerType.Default,
)
}

# For now we use the same heuristic for all types
# Heuristic is currently tuned for H100s
default_heuristic = [
(cond, ScheduleConfig(*tile_config,
**sch_common_params)) # type: ignore
for cond, tile_config in default_tile_heuristic_config.items()
(
cond,
ScheduleConfig(
*tile_config,
**sch_config_overrides.get(cond,
sch_common_params) # type: ignore
)) for cond, tile_config in default_tile_heuristic_config.items()
]

def get_unique_schedules(heuristic: dict[str, ScheduleConfig]):
Expand Down
4 changes: 2 additions & 2 deletions csrc/quantization/machete/machete_mm_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ using namespace cute;
// we compute the transpose to move it to the left-hand side.
template <typename ElementA_, typename ElementB_, typename ElementD_,
typename AccumulatorT, typename GroupScaleT, typename GroupZeroT,
typename ChannelScaleT, typename TokenScaleT, class KernelSchedule,
typename ScheduleConfig>
typename ChannelScaleT, typename TokenScaleT, typename ScheduleConfig>
struct MacheteKernelTemplate {
static constexpr bool with_C = false; // not ever used
static constexpr bool with_group_scales = !std::is_same_v<GroupScaleT, void>;
Expand Down Expand Up @@ -101,6 +100,7 @@ struct MacheteKernelTemplate {
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;

using KernelSchedule = typename ScheduleConfig::KernelSchedule;
using PrepackedLayoutB =
PrepackedLayoutBTemplate<ElementA_, ElementB_, ElementConvertGroup,
AccumulatorT, LayoutA_Transpose, KernelSchedule>;
Expand Down