Skip to content

Commit f38c272

Browse files
[float8] Make check for rowwise scales more explicit/readable (#1890)
1 parent 576cf6b commit f38c272

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

torchao/float8/float8_ops.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,12 @@ def addmm_float8_unwrapped(
4141
b_inverse_scale = b_scale.reciprocal()
4242

4343
post_inverse_scale = None
44-
if (
45-
a_scale.shape == (a_data.shape[0], 1)
46-
and b_scale.shape == (1, b_data.shape[1])
47-
and not use_fast_accum
48-
):
44+
is_rowwise_scaling = a_scale.shape == (a_data.shape[0], 1) and b_scale.shape == (
45+
1,
46+
b_data.shape[1],
47+
)
48+
49+
if is_rowwise_scaling and not use_fast_accum:
4950
# The rowwise CUTLASS-based kernel is so slow without fast-accum that
5051
# we'd rather use the tensorwise cuBLAS-based kernel and do the scaling
5152
# manually afterwards (hoping Inductor will be able to fuse it).

0 commit comments

Comments
 (0)