Skip to content

Commit 7f0bfa7

Browse files
committed
small bs optimization
1 parent 906e05d commit 7f0bfa7

File tree

3 files changed

+23
-7
lines changed

3 files changed

+23
-7
lines changed

benchmarks/kernels/weight_shapes.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,4 +95,10 @@
9595
([2048, 2816], 1),
9696
([1408, 2048], 0),
9797
],
98+
"CohereLabs/c4ai-command-a-03-2025": [
99+
([12288, 36864*2], 1), # gate_up_proj
100+
([36864, 12288], 0), # down_proj
101+
([12288, 14336], 1), # qkv_proj
102+
([12288, 12288], 0), # o_proj
103+
]
98104
}

csrc/quantization/machete/generate.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,7 @@
169169
to_cute_constant(sch.tile_shape_mn)|join(', ')}}>;
170170
using ClusterShape = Shape<{{
171171
to_cute_constant(sch.cluster_shape_mnk)|join(', ')}}>;
172-
// TODO: Reimplement
173-
// using KernelSchedule = {{KernelScheduleTag[sch.kernel_schedule]}};
172+
using KernelSchedule = {{KernelScheduleTag[sch.kernel_schedule]}};
174173
using EpilogueSchedule = {{EpilogueScheduleTag[sch.epilogue_schedule]}};
175174
using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}};
176175
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
@@ -192,7 +191,7 @@
192191
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
193192
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
194193
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
195-
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
194+
// cutlass::gemm::KernelTmaWarpSpecializedCooperative, // moved to schedule config
196195
Sch>;
197196
198197
{% for sch in schs %}
@@ -504,15 +503,25 @@ def generate():
504503
"M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)),
505504
"M > 16": ((256, 32), (2, 1, 1)),
506505
#### M = 1-16
507-
"N >= 26624": ((256, 16), (1, 1, 1)),
508506
None: ((128, 16), (1, 1, 1)),
509507
}
510508

509+
sch_config_overrides = {
510+
#### M = 1-16
511+
None: dict(
512+
kernel_schedule=MixedInputKernelScheduleType.TmaWarpSpecialized,
513+
epilogue_schedule=TmaCoop,
514+
tile_scheduler=TileSchedulerType.Default,
515+
)
516+
}
517+
511518
# For now we use the same heuristic for all types
512519
# Heuristic is currently tuned for H100s
513520
default_heuristic = [
514-
(cond, ScheduleConfig(*tile_config,
515-
**sch_common_params)) # type: ignore
521+
(cond, ScheduleConfig(
522+
*tile_config,
523+
**sch_config_overrides.get(cond, sch_common_params) # type: ignore
524+
))
516525
for cond, tile_config in default_tile_heuristic_config.items()
517526
]
518527

csrc/quantization/machete/machete_mm_kernel.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ using namespace cute;
4040
// we compute the transpose to move it to the left-hand side.
4141
template <typename ElementA_, typename ElementB_, typename ElementD_,
4242
typename AccumulatorT, typename GroupScaleT, typename GroupZeroT,
43-
typename ChannelScaleT, typename TokenScaleT, class KernelSchedule,
43+
typename ChannelScaleT, typename TokenScaleT,
4444
typename ScheduleConfig>
4545
struct MacheteKernelTemplate {
4646
static constexpr bool with_C = false; // not ever used
@@ -101,6 +101,7 @@ struct MacheteKernelTemplate {
101101
using ArchTag = cutlass::arch::Sm90;
102102
using OperatorClass = cutlass::arch::OpClassTensorOp;
103103

104+
using KernelSchedule = typename ScheduleConfig::KernelSchedule;
104105
using PrepackedLayoutB =
105106
PrepackedLayoutBTemplate<ElementA_, ElementB_, ElementConvertGroup,
106107
AccumulatorT, LayoutA_Transpose, KernelSchedule>;

0 commit comments

Comments
 (0)