Skip to content

Commit 101c039

Browse files
[float8 moe training] make using triton kernels for per-group scaling configurable (#2405)
* improve moe training benchmarking * lint * readability improvements * grab use_triton for args instead of class attribute * add comment
1 parent c561d26 commit 101c039

File tree

4 files changed

+122
-31
lines changed

4 files changed

+122
-31
lines changed

torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7-
7+
import argparse
88
import itertools
99
import time
1010
from dataclasses import dataclass
@@ -31,7 +31,9 @@ class ExperimentConfig:
3131

3232
@dataclass(frozen=True)
3333
class ExperimentResult:
34-
time_us: float
34+
torch_time_us: float
35+
triton_time_us: bool
36+
triton_speedup: float
3537

3638

3739
@dataclass(frozen=True)
@@ -41,12 +43,14 @@ class Experiment:
4143

4244

4345
def get_configs() -> List[ExperimentConfig]:
44-
A_shapes = [(2**8, 4096), (2**12, 4096), (2**16, 4096)]
45-
B_shapes = [(4, 4096, 4096), (8, 4096, 4096), (16, 4096, 4096)]
46+
A_shapes = [(2**8, 8192), (2**12, 8192), (2**16, 8192)]
47+
B_shapes = [(4, 8192, 8192), (8, 8192, 8192), (16, 8192, 8192)]
4648
high_precision_dtypes = [torch.bfloat16]
4749
configs = []
4850
for A_shape, B_shape, high_precision_dtype in itertools.product(
49-
A_shapes, B_shapes, high_precision_dtypes
51+
A_shapes,
52+
B_shapes,
53+
high_precision_dtypes,
5054
):
5155
configs.append(
5256
ExperimentConfig(
@@ -58,7 +62,9 @@ def get_configs() -> List[ExperimentConfig]:
5862
return configs
5963

6064

61-
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
65+
def run_experiment(
66+
config: ExperimentConfig, args: argparse.Namespace
67+
) -> ExperimentResult:
6268
# define test inputs
6369
A = torch.randn(
6470
*config.A_shape,
@@ -92,26 +98,46 @@ def warmup(func, *args, **kwargs):
9298
for _ in range(10):
9399
func(*args, **kwargs)
94100

95-
def forward_backward(A, B_t, offs):
96-
out = _scaled_grouped_mm(A, B_t, offs=offs, out_dtype=torch.bfloat16)
101+
def forward_backward(A, B_t, offs, use_triton=True):
102+
out = _scaled_grouped_mm(
103+
A,
104+
B_t,
105+
offs=offs,
106+
out_dtype=torch.bfloat16,
107+
use_triton_for_per_group_scales=use_triton,
108+
)
97109
out.sum().backward()
110+
torch.cuda.synchronize()
98111

99-
# bench triton
100-
warmup(forward_backward, A, B_t, offs)
112+
# benchmark torch
113+
torch_func = torch.compile(forward_backward) if args.compile else forward_backward
114+
warmup(torch_func, A, B_t, offs, use_triton=False)
101115
start_time_ns = time.perf_counter_ns()
102-
forward_backward(A, B_t, offs)
103-
time_ns = time.perf_counter_ns() - start_time_ns
104-
time_us = time_ns / 1e3
116+
torch_func(A, B_t, offs, use_triton=False)
117+
torch_time_ns = time.perf_counter_ns() - start_time_ns
118+
torch_time_us = torch_time_ns / 1e3
105119

106-
return ExperimentResult(time_us=time_us)
120+
# benchmark triton
121+
warmup(forward_backward, A, B_t, offs, use_triton=True)
122+
start_time_ns = time.perf_counter_ns()
123+
forward_backward(A, B_t, offs, use_triton=True)
124+
triton_time_ns = time.perf_counter_ns() - start_time_ns
125+
triton_time_us = triton_time_ns / 1e3
126+
127+
return ExperimentResult(
128+
torch_time_us=round(torch_time_us, 3),
129+
triton_time_us=round(triton_time_us, 3),
130+
triton_speedup=round(torch_time_us / triton_time_us, 3),
131+
)
107132

108133

109134
def print_results(experiments: List[Experiment]):
110135
headers = [
111136
"A_shape",
112137
"B_shape",
113-
"high_precision_dtype",
114-
"time_us",
138+
"torch_time_us",
139+
"triton_time_us",
140+
"triton_speedup",
115141
]
116142
rows = []
117143
for experiment in experiments:
@@ -121,24 +147,28 @@ def print_results(experiments: List[Experiment]):
121147
[
122148
A_shape,
123149
B_shape,
124-
experiment.config.high_precision_dtype,
125-
experiment.result.time_us,
150+
experiment.result.torch_time_us,
151+
experiment.result.triton_time_us,
152+
experiment.result.triton_speedup,
126153
]
127154
)
128155
print(tabulate(rows, headers=headers))
129156

130157

131-
def main():
158+
def main(args: argparse.Namespace):
132159
torch.random.manual_seed(123)
133160
configs = get_configs()
134161
results = []
135162
for config in tqdm(configs):
136-
result = run_experiment(config)
163+
result = run_experiment(config, args)
137164
results.append(Experiment(config=config, result=result))
138165

139166
# Use Tabulate to print results
140167
print_results(results)
141168

142169

143170
if __name__ == "__main__":
144-
main()
171+
arg_parser = argparse.ArgumentParser()
172+
arg_parser.add_argument("--compile", action="store_true")
173+
args = arg_parser.parse_args()
174+
main(args)

torchao/prototype/moe_training/conversion_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ class MoETrainingConfig(AOBaseConfig):
2828
For all other ops, ScaledGroupedMMTensor behaves like a regular torch.Tensor.
2929
"""
3030

31-
pass
31+
# temporary config flag for testing/benchmarking, will remove before graduating out of prototype
32+
use_triton_for_per_group_scales: bool = True
3233

3334

3435
@register_quantize_module_handler(MoETrainingConfig)
@@ -46,14 +47,15 @@ def _moe_training_transform(
4647
Returns:
4748
nn.Module: The modified module with swapped parameters.
4849
"""
49-
out = _swap_params(module)
50+
out = _swap_params(module, config=config)
5051
return out
5152

5253

5354
def _swap_params(
5455
module: nn.Module,
5556
*,
5657
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
58+
config: Optional[MoETrainingConfig] = None,
5759
) -> nn.Module:
5860
"""
5961
Recurses through the nn.Module, recursively swapping the data tensor of
@@ -69,6 +71,7 @@ def _swap_params(
6971
Returns:
7072
nn.Module: The modified module with swapped linear layers.
7173
"""
74+
use_triton = config.use_triton_for_per_group_scales if config is not None else False
7275
if isinstance(module, nn.Parameter) and (
7376
module_filter_fn is None or module_filter_fn(module, "")
7477
):
@@ -77,7 +80,9 @@ def _swap_params(
7780
f"Does not support a root nn.Parameter with children: {module}"
7881
)
7982
if not isinstance(module.data, ScaledGroupedMMTensor):
80-
new_data = ScaledGroupedMMTensor(module.data)
83+
new_data = ScaledGroupedMMTensor(
84+
module.data, use_triton_for_per_group_scales=use_triton
85+
)
8186
return nn.Parameter(new_data, requires_grad=module.requires_grad)
8287
return module
8388

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,19 @@
1414
triton_fp8_col_major_jagged_colwise_scales,
1515
triton_fp8_row_major_jagged_rowwise_scales,
1616
)
17-
from torchao.prototype.moe_training.utils import _is_column_major
17+
from torchao.prototype.moe_training.utils import (
18+
_is_column_major,
19+
_to_2d_jagged_float8_tensor_colwise,
20+
_to_2d_jagged_float8_tensor_rowwise,
21+
)
1822

1923

2024
def _scaled_grouped_mm(
2125
A: torch.Tensor,
2226
B_t: torch.Tensor,
2327
offs: torch.Tensor,
2428
out_dtype: Optional[torch.dtype] = torch.bfloat16,
29+
use_triton_for_per_group_scales: bool = True,
2530
) -> torch.Tensor:
2631
"""
2732
This function performs dynamic float8 quantization with row-wise scaling
@@ -34,6 +39,7 @@ def _scaled_grouped_mm(
3439
and in column-major memory layout.
3540
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
3641
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
42+
use_triton_for_per_group_scales (bool): Whether to use custom triton kernels to compute per-group scales. Default is True.
3743
"""
3844
return _Float8GroupedMM.apply(
3945
A,
@@ -53,6 +59,7 @@ def forward(
5359
B_t: torch.Tensor,
5460
offs: torch.Tensor,
5561
out_dtype: Optional[torch.dtype] = torch.bfloat16,
62+
use_triton_for_per_group_scales: bool = True,
5663
) -> torch.Tensor:
5764
# torchao _scaled_grouped_mm only supports A=2D, B=3D.
5865
assert A.ndim == 2, "A must be 2D"
@@ -136,9 +143,16 @@ def forward(
136143
# Store what we need for backward.
137144
ctx.save_for_backward(A, B_fp8_col_major, B_scales, offs)
138145
ctx.out_dtype = out_dtype
146+
ctx.use_triton_for_per_group_scales = use_triton_for_per_group_scales
139147

140148
# Perform scaled grouped GEMM and return result.
141149
# output shape: scaled grouped mm of (M,K) @ (B,K,N) = (M,N)
150+
assert not _is_column_major(A_fp8_row_major), (
151+
"A must be row-major for output = A @ B"
152+
)
153+
assert _is_column_major(B_t_fp8_col_major), (
154+
"B must be column-major for output = A @ B"
155+
)
142156
return torch._scaled_grouped_mm(
143157
A_fp8_row_major,
144158
B_t_fp8_col_major,
@@ -153,6 +167,7 @@ def forward(
153167
def backward(ctx, grad_output: torch.Tensor):
154168
A, B_fp8_col_major, B_scales, offs = ctx.saved_tensors
155169
out_dtype = ctx.out_dtype
170+
use_triton_for_per_group_scales = ctx.use_triton_for_per_group_scales
156171

157172
# Convert grad_output to float8, row-major for left operand of grouped GEMM
158173
# needed for grad_A: grad_output @ B
@@ -175,6 +190,12 @@ def backward(ctx, grad_output: torch.Tensor):
175190
#
176191
# grad_A = grad_output @ B
177192
# grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
193+
assert not _is_column_major(grad_output_fp8_row_major), (
194+
"grad_output must be row-major for grad_A = grad_output @ B"
195+
)
196+
assert _is_column_major(B_fp8_col_major), (
197+
"B must be column-major for grad_A = grad_output @ B"
198+
)
178199
grad_A = torch._scaled_grouped_mm(
179200
grad_output_fp8_row_major,
180201
B_fp8_col_major,
@@ -195,25 +216,42 @@ def backward(ctx, grad_output: torch.Tensor):
195216

196217
# grad_B is a special case. both operands of the grouped gemm will be 2D with offsets determing the "groups."
197218
# Compute scales for grad_output_t and A, which are both 2D tensors with offsets which define the "jagged" groups.
219+
per_group_rowwise_scale_func = (
220+
triton_fp8_row_major_jagged_rowwise_scales
221+
if use_triton_for_per_group_scales
222+
else _to_2d_jagged_float8_tensor_rowwise
223+
)
224+
per_group_colwise_scale_func = (
225+
triton_fp8_col_major_jagged_colwise_scales
226+
if use_triton_for_per_group_scales
227+
else _to_2d_jagged_float8_tensor_colwise
228+
)
229+
198230
grad_output_t_fp8_row_major, grad_output_t_scales = (
199-
triton_fp8_row_major_jagged_rowwise_scales(
231+
per_group_rowwise_scale_func(
200232
grad_output_t_row_major,
201233
offs,
202-
output_dtype=torch.float8_e4m3fn,
234+
torch.float8_e4m3fn,
203235
round_scales_to_power_of_2=True,
204236
)
205237
)
206238

207-
A_fp8_col_major, A_scales = triton_fp8_col_major_jagged_colwise_scales(
239+
A_fp8_col_major, A_scales = per_group_colwise_scale_func(
208240
A_col_major,
209241
offs,
210-
output_dtype=torch.float8_e4m3fn,
242+
torch.float8_e4m3fn,
211243
round_scales_to_power_of_2=True,
212244
)
213245

214246
# Compute grad_B = grad_output_t @ A.
215247
# grad_B = grad_output_t @ A
216248
# grad_B = (N,M) @ (M,K) = (N,K)
249+
assert not _is_column_major(grad_output_t_fp8_row_major), (
250+
"grad_output_t must be row-major for grad_B = grad_output_t @ A"
251+
)
252+
assert _is_column_major(A_fp8_col_major), (
253+
"A must be column-major for grad_B = grad_output_t @ A"
254+
)
217255
grad_B = torch._scaled_grouped_mm(
218256
grad_output_t_fp8_row_major,
219257
A_fp8_col_major,

torchao/prototype/moe_training/tensor.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,16 @@ class ScaledGroupedMMTensor(torch.Tensor):
1212

1313
grouped_mm_func_name = "_grouped_mm"
1414
offs_arg_name = "offs"
15+
use_triton_for_per_group_scales = True
1516

16-
def __init__(self, data: torch.Tensor):
17+
def __init__(
18+
self, data: torch.Tensor, use_triton_for_per_group_scales: bool = True
19+
):
1720
self._data = data
21+
self._use_triton_for_per_group_scales = use_triton_for_per_group_scales
22+
23+
def __repr__(self):
24+
return f"ScaledGroupedMMTensor(use_triton_for_per_group_scales={self._use_triton_for_per_group_scales}, {self._data})"
1825

1926
@classmethod
2027
def __torch_function__(cls, func, types, args, kwargs={}):
@@ -31,5 +38,16 @@ def __torch_function__(cls, func, types, args, kwargs={}):
3138
B_is_3d = B.dim() == 3
3239
has_offs = kwargs.get(cls.offs_arg_name) is not None
3340
if A_is_2d and B_is_3d and has_offs:
34-
return _scaled_grouped_mm(*args, **kwargs)
41+
# prefer to use B to check use_triton, as that will be the weight/nn.Parameter
42+
# that is converted to ScaledGroupedMMTensor
43+
use_triton = (
44+
B._use_triton_for_per_group_scales
45+
if isinstance(B, cls)
46+
else A._use_triton_for_per_group_scales
47+
)
48+
return _scaled_grouped_mm(
49+
*args,
50+
use_triton_for_per_group_scales=use_triton,
51+
**kwargs,
52+
)
3553
return super().__torch_function__(func, types, args, kwargs)

0 commit comments

Comments
 (0)