We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 576cf6b commit f38c272Copy full SHA for f38c272
torchao/float8/float8_ops.py
@@ -41,11 +41,12 @@ def addmm_float8_unwrapped(
41
b_inverse_scale = b_scale.reciprocal()
42
43
post_inverse_scale = None
44
- if (
45
- a_scale.shape == (a_data.shape[0], 1)
46
- and b_scale.shape == (1, b_data.shape[1])
47
- and not use_fast_accum
48
- ):
+ is_rowwise_scaling = a_scale.shape == (a_data.shape[0], 1) and b_scale.shape == (
+ 1,
+ b_data.shape[1],
+ )
+
49
+ if is_rowwise_scaling and not use_fast_accum:
50
# The rowwise CUTLASS-based kernel is so slow without fast-accum that
51
# we'd rather use the tensorwise cuBLAS-based kernel and do the scaling
52
# manually afterwards (hoping Inductor will be able to fuse it).
0 commit comments