Skip to content

Commit 7debcd9

Browse files
committed
Update
[ghstack-poisoned]
1 parent c9e26bd commit 7debcd9

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

torchao/float8/float8_python_api.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@ def addmm_float8_unwrapped(
3838
b_inverse_scale = b_scale.reciprocal()
3939

4040
post_inverse_scale = None
41-
if a_scale.shape == (a_data.shape[0], 1) and b_scale.shape == (1, b_data.shape[1]) and not use_fast_accum:
41+
if (
42+
a_scale.shape == (a_data.shape[0], 1)
43+
and b_scale.shape == (1, b_data.shape[1])
44+
and not use_fast_accum
45+
):
4246
# The rowwise CUTLASS-based kernel is so slow without fast-accum that
4347
# we'd rather use the tensorwise cuBLAS-based kernel and do the scaling
4448
# manually afterwards (hoping Inductor will be able to fuse it).

0 commit comments

Comments
 (0)