Skip to content

Commit 98c4e2e

Browse files
authored
Fix potential out-of-bound access in int8_mm.py (#1751)
* fix potential out-of-bound access * remove unused EVEN_K * refactor fix with triton.heuristics * restore EVEN_K as an input * fix typo * fix another typo * ruff reformatted
1 parent 38e36de commit 98c4e2e

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torchao/prototype/quantized_training/int8_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454

5555

5656
@triton.autotune(configs=configs, key=["M", "N", "K", "stride_ak", "stride_bk"])
57+
@triton.heuristics({"EVEN_K": lambda args: args["K"] % args["BLOCK_K"] == 0})
5758
@triton.jit
5859
def _scaled_int8_mm_kernel(
5960
A_ptr,
@@ -176,7 +177,6 @@ def scaled_int8_mm_cuda(A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tens
176177
*A.stride(),
177178
*B.stride(),
178179
*C.stride(),
179-
EVEN_K=K % 2 == 0,
180180
COL_SCALE_SCALAR=col_scale.numel() == 1,
181181
)
182182
return C

0 commit comments

Comments
 (0)