Skip to content

Commit b8206d7

Browse files
[scaled grouped mm] integrate triton kernels into differentiable scaled grouped mm (#2077)
1 parent 9af2a45 commit b8206d7

File tree

9 files changed

+337
-153
lines changed

9 files changed

+337
-153
lines changed

test/prototype/scaled_grouped_mm/__init__.py

Whitespace-only changes.

torchao/prototype/scaled_grouped_mm/kernels/test_jagged_float8_scales.py renamed to test/prototype/scaled_grouped_mm/test_kernels.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,27 @@
77
import pytest
88
import torch
99

10+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
11+
12+
# We need to skip before doing any imports which would use triton, since
13+
# triton won't be available on CPU builds and torch < 2.5
14+
if not (
15+
TORCH_VERSION_AT_LEAST_2_5
16+
and torch.cuda.is_available()
17+
and torch.cuda.get_device_capability()[0] >= 9
18+
):
19+
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
20+
21+
1022
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
1123
triton_fp8_col_major_jagged_colwise_scales,
1224
triton_fp8_row_major_jagged_rowwise_scales,
1325
)
14-
from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import (
26+
from torchao.prototype.scaled_grouped_mm.utils import (
27+
_is_column_major,
1528
_to_2d_jagged_float8_tensor_colwise,
1629
_to_2d_jagged_float8_tensor_rowwise,
1730
)
18-
from torchao.prototype.scaled_grouped_mm.utils import _is_column_major
1931

2032

2133
@pytest.mark.parametrize("round_scales_to_power_of_2", [True, False])

torchao/prototype/scaled_grouped_mm/test_scaled_grouped_mm.py renamed to test/prototype/scaled_grouped_mm/test_scaled_grouped_mm.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,18 @@
77
import pytest
88
import torch
99

10+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
11+
12+
# We need to skip before doing any imports which would use triton, since
13+
# triton won't be available on CPU builds and torch < 2.5
14+
if not (
15+
TORCH_VERSION_AT_LEAST_2_5
16+
and torch.cuda.is_available()
17+
and torch.cuda.get_device_capability()[0] >= 9
18+
):
19+
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
20+
21+
1022
from torchao.float8.config import (
1123
Float8LinearConfig,
1224
Float8LinearRecipeName,
@@ -19,7 +31,6 @@
1931
)
2032

2133

22-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
2334
def test_valid_scaled_grouped_mm_2d_3d():
2435
out_dtype = torch.bfloat16
2536
device = "cuda"
@@ -73,7 +84,6 @@ def test_valid_scaled_grouped_mm_2d_3d():
7384
assert torch.equal(b_t.grad, ref_b_t.grad)
7485

7586

76-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
7787
@pytest.mark.parametrize("m", [16, 17])
7888
@pytest.mark.parametrize("k", [16, 18])
7989
@pytest.mark.parametrize("n", [32, 33])

torchao/prototype/scaled_grouped_mm/kernels/benchmark.py renamed to torchao/prototype/scaled_grouped_mm/benchmarks/benchmark_kernels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
triton_fp8_col_major_jagged_colwise_scales,
1919
triton_fp8_row_major_jagged_rowwise_scales,
2020
)
21-
from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import (
21+
from torchao.prototype.scaled_grouped_mm.utils import (
2222
_to_2d_jagged_float8_tensor_colwise,
2323
_to_2d_jagged_float8_tensor_rowwise,
2424
)
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
8+
import itertools
9+
import time
10+
from dataclasses import dataclass
11+
from typing import List
12+
13+
import torch
14+
from tabulate import tabulate
15+
from tqdm import tqdm
16+
17+
from torchao.prototype.scaled_grouped_mm import _scaled_grouped_mm
18+
19+
device = torch.device("cuda")
20+
21+
# Needed since changing args to function causes recompiles
22+
torch._dynamo.config.cache_size_limit = 1000
23+
24+
25+
@dataclass(frozen=True)
26+
class ExperimentConfig:
27+
high_precision_dtype: torch.dtype
28+
A_shape: tuple[int]
29+
B_shape: tuple[int]
30+
31+
32+
@dataclass(frozen=True)
33+
class ExperimentResult:
34+
time_us: float
35+
36+
37+
@dataclass(frozen=True)
38+
class Experiment:
39+
config: ExperimentConfig
40+
result: ExperimentResult
41+
42+
43+
def get_configs() -> List[ExperimentConfig]:
44+
A_shapes = [(2**8, 4096), (2**12, 4096), (2**16, 4096)]
45+
B_shapes = [(4, 4096, 4096), (8, 4096, 4096), (16, 4096, 4096)]
46+
high_precision_dtypes = [torch.bfloat16]
47+
configs = []
48+
for A_shape, B_shape, high_precision_dtype in itertools.product(
49+
A_shapes, B_shapes, high_precision_dtypes
50+
):
51+
configs.append(
52+
ExperimentConfig(
53+
A_shape=A_shape,
54+
B_shape=B_shape,
55+
high_precision_dtype=high_precision_dtype,
56+
)
57+
)
58+
return configs
59+
60+
61+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
62+
# define test inputs
63+
A = torch.randn(
64+
*config.A_shape,
65+
dtype=config.high_precision_dtype,
66+
device=device,
67+
requires_grad=True,
68+
)
69+
B_t = torch.randn(
70+
*config.B_shape,
71+
dtype=config.high_precision_dtype,
72+
device=device,
73+
requires_grad=True,
74+
).transpose(-2, -1)
75+
76+
# - configure input to be row-major with groups divided along the column dimension,
77+
# representing the left operand of grad_weight = grad_output_t @ input
78+
# that occurs in the backward pass of the differentiable scaled grouped mm.
79+
# - the transposed tensor in col-major format with groups along the row dimension,
80+
# which represents the right operand.
81+
n_groups = config.B_shape[0]
82+
group_size = A.shape[0] // n_groups
83+
offs = torch.arange(
84+
group_size,
85+
group_size * n_groups + 1,
86+
group_size,
87+
device=device,
88+
dtype=torch.int32,
89+
)
90+
91+
def warmup(func, *args, **kwargs):
92+
for _ in range(10):
93+
func(*args, **kwargs)
94+
95+
def forward_backward(A, B_t, offs):
96+
out = _scaled_grouped_mm(A, B_t, offs=offs, out_dtype=torch.bfloat16)
97+
out.sum().backward()
98+
99+
# bench triton
100+
warmup(forward_backward, A, B_t, offs)
101+
start_time_ns = time.perf_counter_ns()
102+
forward_backward(A, B_t, offs)
103+
time_ns = time.perf_counter_ns() - start_time_ns
104+
time_us = time_ns / 1e3
105+
106+
return ExperimentResult(time_us=time_us)
107+
108+
109+
def print_results(experiments: List[Experiment]):
110+
headers = [
111+
"A_shape",
112+
"B_shape",
113+
"high_precision_dtype",
114+
"time_us",
115+
]
116+
rows = []
117+
for experiment in experiments:
118+
A_shape = f"({experiment.config.A_shape[0]}, {experiment.config.A_shape[1]})"
119+
B_shape = f"({experiment.config.B_shape[0]}, {experiment.config.B_shape[1]}, {experiment.config.B_shape[2]})"
120+
rows.append(
121+
[
122+
A_shape,
123+
B_shape,
124+
experiment.config.high_precision_dtype,
125+
experiment.result.time_us,
126+
]
127+
)
128+
print(tabulate(rows, headers=headers))
129+
130+
131+
def main():
132+
torch.random.manual_seed(123)
133+
configs = get_configs()
134+
results = []
135+
for config in tqdm(configs):
136+
result = run_experiment(config)
137+
results.append(Experiment(config=config, result=result))
138+
139+
# Use Tabulate to print results
140+
print_results(results)
141+
142+
143+
if __name__ == "__main__":
144+
main()
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
2+
triton_fp8_col_major_jagged_colwise_scales as triton_fp8_col_major_jagged_colwise_scales,
3+
)
4+
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
5+
triton_fp8_row_major_jagged_rowwise_scales as triton_fp8_row_major_jagged_rowwise_scales,
6+
)

torchao/prototype/scaled_grouped_mm/kernels/jagged_float8_scales.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,11 @@ def _triton_fp8_row_major_jagged_rowwise_scales(
160160
data = tl.load(input_ptr + block_offs, mask=block_mask, other=0.0).to(
161161
input_dtype
162162
)
163-
amax_buffer = tl.maximum(amax_buffer, tl.max(tl.abs(data), axis=1))
163+
# we need to cast back to input dtype since triton promotes bf16 to fp32:
164+
# https://github.com/triton-lang/triton/blob/981e987eed9053b952f81153bc0779c99d8c642e/python/triton/language/standard.py#L173
165+
amax_buffer = tl.maximum(amax_buffer, tl.max(tl.abs(data), axis=1)).to(
166+
input_dtype
167+
)
164168

165169
# compute rowwise scales for this group. round scales to nearest power of 2.
166170
amax_buffer = amax_buffer.to(tl.float64)
@@ -317,7 +321,11 @@ def _triton_fp8_col_major_jagged_colwise_scales(
317321
data = tl.load(input_ptr + block_offs, mask=block_mask, other=0.0).to(
318322
input_dtype
319323
)
320-
amax_buffer = tl.maximum(amax_buffer, tl.max(tl.abs(data), axis=0))
324+
# we need to cast back to input dtype since triton promotes bf16 to fp32:
325+
# https://github.com/triton-lang/triton/blob/981e987eed9053b952f81153bc0779c99d8c642e/python/triton/language/standard.py#L173
326+
amax_buffer = tl.maximum(amax_buffer, tl.max(tl.abs(data), axis=0)).to(
327+
input_dtype
328+
)
321329

322330
# compute rowwise scales for this group.
323331
amax_buffer = amax_buffer.to(tl.float64)

0 commit comments

Comments
 (0)