Skip to content

Commit 7b67e27

Browse files
committed
lint
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
1 parent 7f0bfa7 commit 7b67e27

File tree

3 files changed

+11
-10
lines changed

3 files changed

+11
-10
lines changed

benchmarks/kernels/weight_shapes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@
9696
([1408, 2048], 0),
9797
],
9898
"CohereLabs/c4ai-command-a-03-2025": [
99-
([12288, 36864*2], 1), # gate_up_proj
99+
([12288, 73728], 1), # gate_up_proj
100100
([36864, 12288], 0), # down_proj
101101
([12288, 14336], 1), # qkv_proj
102102
([12288, 12288], 0), # o_proj
103-
]
103+
],
104104
}

csrc/quantization/machete/generate.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,8 @@ def generate():
508508

509509
sch_config_overrides = {
510510
#### M = 1-16
511-
None: dict(
511+
None:
512+
dict(
512513
kernel_schedule=MixedInputKernelScheduleType.TmaWarpSpecialized,
513514
epilogue_schedule=TmaCoop,
514515
tile_scheduler=TileSchedulerType.Default,
@@ -518,11 +519,12 @@ def generate():
518519
# For now we use the same heuristic for all types
519520
# Heuristic is currently tuned for H100s
520521
default_heuristic = [
521-
(cond, ScheduleConfig(
522-
*tile_config,
523-
**sch_config_overrides.get(cond, sch_common_params) # type: ignore
524-
))
525-
for cond, tile_config in default_tile_heuristic_config.items()
522+
(
523+
cond, ScheduleConfig(
524+
*tile_config,
525+
**sch_config_overrides.get(cond,
526+
sch_common_params) # type: ignore
527+
)) for cond, tile_config in default_tile_heuristic_config.items()
526528
]
527529

528530
def get_unique_schedules(heuristic: dict[str, ScheduleConfig]):

csrc/quantization/machete/machete_mm_kernel.cuh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ using namespace cute;
4040
// we compute the transpose to move it to the left-hand side.
4141
template <typename ElementA_, typename ElementB_, typename ElementD_,
4242
typename AccumulatorT, typename GroupScaleT, typename GroupZeroT,
43-
typename ChannelScaleT, typename TokenScaleT,
44-
typename ScheduleConfig>
43+
typename ChannelScaleT, typename TokenScaleT, typename ScheduleConfig>
4544
struct MacheteKernelTemplate {
4645
static constexpr bool with_C = false; // not ever used
4746
static constexpr bool with_group_scales = !std::is_same_v<GroupScaleT, void>;

0 commit comments

Comments
 (0)