Skip to content

Commit 9b6212f

Browse files
y-sqfacebook-github-bot
authored andcommitted
Also pad the N dimention if inner-padding enabled (#858)
Summary: Pull Request resolved: #858 The diff modifies the `padding` option and added tests with `compile`: * For the scaled_mm of shape MxKxN, the current `inner_padding` option only pads the `K` dimension. However, if `N` is not divisible by 16, we also got the error ``` E RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling `cublasLtMatmulAlgoGetHeuristic( ltHandle, computeDesc.descriptor(), Adesc.descriptor(), Bdesc.descriptor(), Cdesc.descriptor(), Ddesc.descriptor(), preference.descriptor(), 1, &heuristicResult, &returnedResult)` ``` So, modified the pad_inner option to also pad the K dimensions. ----- * The compile of inner-padding only works with the triton PR triton-lang/triton#4222. Before the triton PR, the inductor code-gen kernel fails at ``` tmp10 = tl.where(tmp6, tmp8, tmp9) TypeError: unexpected type fp8e5 and fp8e5 ``` Reviewed By: irobert0126 Differential Revision: D62003827
1 parent 8aa6533 commit 9b6212f

File tree

8 files changed

+60
-23
lines changed

8 files changed

+60
-23
lines changed

benchmarks/float8/bench_padding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
5050
scale_b = torch.tensor([1], device="cuda", dtype=torch.float32)
5151

5252
a_config = ScaledMMConfig(
53-
emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True
53+
emulate=False, use_fast_accum=True, fp8_output=True, pad_dimensions=True
5454
)
5555
b_config = ScaledMMConfig(
56-
emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True
56+
emulate=False, use_fast_accum=True, fp8_output=True, pad_dimensions=True
5757
)
5858
a_config = LinearMMConfig(a_config, a_config, a_config)
5959
b_config = LinearMMConfig(b_config, b_config, b_config)

test/float8/test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ def test_different_configs_error(self):
494494
"base_dtype", [torch.float16, torch.bfloat16, torch.float32]
495495
)
496496
@pytest.mark.parametrize("use_fast_accum", [True, False])
497-
def test_pad_inner_dim(self, base_dtype, use_fast_accum):
497+
def test_pad_dimensions(self, base_dtype, use_fast_accum):
498498
torch.manual_seed(42)
499499
input_dtype = torch.float8_e4m3fn
500500
compare_type = torch.float32

test/float8/test_compile.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,25 @@ def _test_compile_base(
4040
fullgraph: bool,
4141
config: Float8LinearConfig,
4242
dtype: torch.dtype,
43+
pad_dimensions: bool,
4344
):
4445
random.seed(0)
4546
torch.manual_seed(0)
46-
x_shape = (16, 16)
47+
48+
if pad_dimensions:
49+
x_shape = (17, 17)
50+
else:
51+
x_shape = (16, 16)
52+
4753
linear_dtype = torch.bfloat16
4854

4955
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
50-
m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype)
56+
57+
if pad_dimensions:
58+
m_ref = nn.Linear(17, 35, bias=True, device="cuda", dtype=linear_dtype)
59+
else:
60+
m_ref = nn.Linear(16, 16, bias=True, device="cuda", dtype=linear_dtype)
61+
5162

5263
m_fp8 = Float8Linear.from_float(
5364
copy.deepcopy(m_ref),
@@ -71,6 +82,7 @@ def _get_config(
7182
scaling_type_weight,
7283
scaling_type_grad_output,
7384
emulate,
85+
pad_dimensions,
7486
):
7587
if scaling_type_input is ScalingType.STATIC:
7688
cast_config_input = CastConfig(
@@ -99,11 +111,13 @@ def _get_config(
99111
cast_config_weight=cast_config_weight,
100112
cast_config_grad_output=cast_config_grad_output,
101113
emulate=emulate,
114+
pad_dimensions=pad_dimensions,
102115
)
103116
return config
104117

105118

106119
@pytest.mark.parametrize("fullgraph", [True])
120+
@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True])
107121
@pytest.mark.parametrize(
108122
"scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
109123
)
@@ -113,7 +127,9 @@ def _get_config(
113127
@pytest.mark.parametrize(
114128
"scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
115129
)
116-
@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True])
130+
@pytest.mark.parametrize(
131+
"pad_dimensions", [False, True]
132+
)
117133
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
118134
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
119135
def test_eager_only(
@@ -122,17 +138,19 @@ def test_eager_only(
122138
scaling_type_input: ScalingType,
123139
scaling_type_weight: ScalingType,
124140
scaling_type_grad_output: ScalingType,
141+
pad_dimensions: bool,
125142
dtype: torch.dtype,
126143
):
127144
torch._dynamo.reset()
128145
config = _get_config(
129-
scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate,
146+
scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, pad_dimensions,
130147
)
131148
_test_compile_base(
132149
"eager",
133150
fullgraph,
134151
config,
135152
dtype,
153+
pad_dimensions,
136154
)
137155

138156

@@ -147,6 +165,9 @@ def test_eager_only(
147165
@pytest.mark.parametrize(
148166
"scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
149167
)
168+
@pytest.mark.parametrize(
169+
"pad_dimensions", [False, True]
170+
)
150171
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
151172
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
152173
def test_aot_eager(
@@ -155,17 +176,19 @@ def test_aot_eager(
155176
scaling_type_input: ScalingType,
156177
scaling_type_weight: ScalingType,
157178
scaling_type_grad_output: ScalingType,
179+
pad_dimensions: bool,
158180
dtype: torch.dtype,
159181
):
160182
torch._dynamo.reset()
161183
config = _get_config(
162-
scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate,
184+
scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, pad_dimensions,
163185
)
164186
_test_compile_base(
165187
"aot_eager",
166188
fullgraph,
167189
config,
168190
dtype,
191+
pad_dimensions,
169192
)
170193

171194

@@ -180,6 +203,9 @@ def test_aot_eager(
180203
@pytest.mark.parametrize(
181204
"scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
182205
)
206+
@pytest.mark.parametrize(
207+
"pad_dimensions", [True, False]
208+
)
183209
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
184210
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
185211
def test_inductor(
@@ -188,17 +214,19 @@ def test_inductor(
188214
scaling_type_input: ScalingType,
189215
scaling_type_weight: ScalingType,
190216
scaling_type_grad_output: ScalingType,
217+
pad_dimensions: bool,
191218
dtype: torch.dtype,
192219
):
193220
torch._dynamo.reset()
194221
config = _get_config(
195-
scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate,
222+
scaling_type_input, scaling_type_weight, scaling_type_grad_output, emulate, pad_dimensions,
196223
)
197224
_test_compile_base(
198225
"inductor",
199226
fullgraph,
200227
config,
201228
dtype,
229+
pad_dimensions,
202230
)
203231

204232

torchao/float8/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ class Float8LinearConfig:
120120
# inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls
121121
# _scaled_mm since it has the strong constraint that for M,N,K N, K must be a multiple of 16.
122122
# This can cause a memory spike however so we keep this off by default.
123-
pad_inner_dim: bool = False
123+
pad_dimensions: bool = False
124124

125125
# If True, emulation is used instead of hardware accelerated gemm
126126
emulate: bool = False

torchao/float8/float8_linear.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,21 +141,21 @@ def __init__(self, *args, **kwargs):
141141
emulate,
142142
self.config.gemm_config_output.use_fast_accum,
143143
False,
144-
self.config.pad_inner_dim,
144+
self.config.pad_dimensions,
145145
),
146146
# grad_input
147147
ScaledMMConfig(
148148
emulate,
149149
self.config.gemm_config_grad_input.use_fast_accum,
150150
False,
151-
self.config.pad_inner_dim,
151+
self.config.pad_dimensions,
152152
),
153153
# grad_weight
154154
ScaledMMConfig(
155155
emulate,
156156
self.config.gemm_config_grad_weight.use_fast_accum,
157157
False,
158-
self.config.pad_inner_dim,
158+
self.config.pad_dimensions,
159159
),
160160
)
161161

torchao/float8/float8_ops.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,27 +134,34 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):
134134
a_scale = a._scale
135135
b_data = b._data
136136

137+
out_shape = (a._data.size(0), b._data.size(1))
138+
137139
scaled_mm_config = choose_scaled_mm_config(
138140
a._gemm_input_role,
139141
a._linear_mm_config,
140142
b._gemm_input_role,
141143
b._linear_mm_config,
142144
)
143145

144-
if scaled_mm_config.pad_inner_dim:
146+
if scaled_mm_config.pad_dimensions:
145147
assert a._data.size(1) == b._data.size(
146148
0
147149
), f"Inner dims must match for mm, got {a._data.size(1)} and {b._data.size(0)}"
148150
a_data = pad_tensor_for_matmul(a_data, dims=1)
149-
b_data = pad_tensor_for_matmul(b_data, dims=0)
151+
b_data = pad_tensor_for_matmul(b_data, dims=[0,1])
150152

151153
if not is_row_major(a_data.stride()):
152154
a_data = a_data.contiguous()
153155
if is_row_major(b_data.stride()):
154156
b_data = b_data.t().contiguous().t()
155157
b_scale = b._scale
156-
return a_data, a_scale, b_data, b_scale
157158

159+
return a_data, a_scale, b_data, b_scale, out_shape
160+
161+
def postprocess_addmm(out: torch.Tensor, scaled_mm_config, out_shape):
162+
if scaled_mm_config.pad_dimensions:
163+
out = out[:, :out_shape[1]]
164+
return out
158165

159166
@implements([aten.mm.default, aten.matmul.default])
160167
def float8_mm(aten_op, args, kwargs=None):
@@ -166,7 +173,7 @@ def float8_mm(aten_op, args, kwargs=None):
166173
), "Expecting both Float8Tensor for mm inputs but found {} and {}".format(
167174
type(a), type(b)
168175
)
169-
a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b)
176+
a_data, a_scale, b_data, b_scale, out_shape = preprocess_addmm(a, b)
170177
output_dtype = a._orig_dtype
171178
scaled_mm_config = choose_scaled_mm_config(
172179
a._gemm_input_role,
@@ -188,6 +195,7 @@ def float8_mm(aten_op, args, kwargs=None):
188195
bias=None,
189196
use_fast_accum=scaled_mm_config.use_fast_accum,
190197
)
198+
tensor_out = postprocess_addmm(out=tensor_out, scaled_mm_config=scaled_mm_config, out_shape=out_shape)
191199
return tensor_out
192200

193201

@@ -201,7 +209,7 @@ def float8_addmm(aten_op, args, kwargs=None):
201209
bias = args[0]
202210
a = args[1]
203211
b = args[2]
204-
a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b)
212+
a_data, a_scale, b_data, b_scale, out_shape = preprocess_addmm(a, b)
205213
output_dtype = a._orig_dtype
206214
assert bias.dtype == output_dtype, "bias dtype must match output dtype"
207215
scaled_mm_config = choose_scaled_mm_config(
@@ -225,6 +233,7 @@ def float8_addmm(aten_op, args, kwargs=None):
225233
bias=bias,
226234
use_fast_accum=scaled_mm_config.use_fast_accum,
227235
)
236+
tensor_out = postprocess_addmm(out=tensor_out, scaled_mm_config=scaled_mm_config, out_shape=out_shape)
228237
return tensor_out
229238

230239

torchao/float8/float8_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,14 @@ class ScaledMMConfig(NamedTuple):
5353
emulate (bool): Whether to emulate the matmuls in fp32.
5454
use_fast_accum (bool): Whether to use the fast-accumulation option for scaled_mm.
5555
fp8_output (bool): Whether to output the result of the scaled_mm in fp8.
56-
pad_inner_dim (bool): Whether to pad the inner dimension of a and b with 0s.
56+
pad_dimensions (bool): Whether to pad the inner dimension of a and b with 0s.
5757
This is needed for matmuls not aligned to 16.
5858
"""
5959

6060
emulate: bool = False
6161
use_fast_accum: bool = False
6262
fp8_output: bool = False
63-
pad_inner_dim: bool = False
63+
pad_dimensions: bool = False
6464

6565

6666
class LinearMMConfig(NamedTuple):

torchao/float8/inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ class Float8MMConfig(NamedTuple):
2222
Attributes:
2323
emulate (bool): Whether to emulate the matmuls in fp32.
2424
use_fast_accum (bool): Whether to use the fast-accumulation option for scaled_mm.
25-
pad_inner_dim (bool): Whether to pad the inner dimension of a and b with 0s.
25+
pad_dimensions (bool): Whether to pad the inner dimension of a and b with 0s.
2626
This is needed for matmuls not aligned to 16.
2727
"""
2828

2929
emulate: bool = False
3030
use_fast_accum: bool = False
31-
pad_inner_dim: bool = False
31+
pad_dimensions: bool = False
3232

3333

3434
def preprocess_data(
@@ -44,7 +44,7 @@ def preprocess_data(
4444
Returns:
4545
Preprocessed tensors A and B in the format for _scaled_mm.
4646
"""
47-
if scaled_mm_config.pad_inner_dim:
47+
if scaled_mm_config.pad_dimensions:
4848
assert a_data.size(1) == b_data.size(
4949
0
5050
), f"Inner dims must match for mm, got {a_data.size(1)} and {b_data.size(0)}"

0 commit comments

Comments
 (0)