Skip to content

Commit 620356d

Browse files
Initial prototype of differentiable _scaled_grouped_mm function (#1969)
1 parent e52867a commit 620356d

File tree

3 files changed

+560
-0
lines changed

3 files changed

+560
-0
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import (
2+
_scaled_grouped_mm as _scaled_grouped_mm,
3+
)
Lines changed: 361 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,361 @@
1+
from typing import Optional, Tuple
2+
3+
import torch
4+
5+
from torchao.float8.config import ScalingGranularity
6+
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
7+
8+
9+
def _scaled_grouped_mm(
10+
A: torch.Tensor,
11+
B_t: torch.Tensor,
12+
offs: torch.Tensor,
13+
out_dtype: Optional[torch.dtype] = None,
14+
) -> torch.Tensor:
15+
"""
16+
This function performs dynamic float8 quantization with row-wise scaling
17+
on the input tensors A and B, then performs a scaled grouped GEMM and returns the results.
18+
19+
Args:
20+
A (bf16/float32 torch.Tensor): The first high-precision input tensor, which must be a 2D tensor of shape (M * num_groups, K)
21+
and in row-major memory layout.
22+
B_t (bf16/float32 torch.Tensor): The second high-precision input tensor which must be 3D, which must be shape (B, K, N)
23+
and in column-major memory layout.
24+
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
25+
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
26+
"""
27+
return _Float8GroupedMM.apply(
28+
A,
29+
B_t,
30+
offs,
31+
out_dtype,
32+
)
33+
34+
35+
class _Float8GroupedMM(torch.autograd.Function):
36+
"""Differentiable implementation of grouped GEMM with dynamic float8 quantization."""
37+
38+
@staticmethod
39+
def forward(
40+
ctx,
41+
A: torch.Tensor,
42+
B_t: torch.Tensor,
43+
offs: torch.Tensor,
44+
out_dtype: Optional[torch.dtype] = None,
45+
) -> torch.Tensor:
46+
# torchao _scaled_grouped_mm only supports A=2D, B=3D.
47+
assert A.ndim == 2, "A must be 2D"
48+
assert B_t.ndim == 3, "B must be 3D"
49+
50+
assert (
51+
A.size(-1) % 16 == 0
52+
), f"A must have a last dim divisible by 16, but got shape: {A.shape}"
53+
assert (
54+
B_t.size(-2) % 16 == 0 and B_t.size(-1) % 16 == 0
55+
), f"B must have last 2 dims divisible by 16, but got shape: {B_t.shape}"
56+
57+
# Assert input tensors are in high-precision dtypes.
58+
assert (
59+
A.dtype == torch.float32 or A.dtype == torch.bfloat16
60+
), "A must be float32 or bfloat16"
61+
assert (
62+
B_t.dtype == torch.float32 or B_t.dtype == torch.bfloat16
63+
), "B must be float32 or bfloat16"
64+
assert offs.dtype == torch.int32, "offs must be int32"
65+
66+
# Assert A and B dims are compatible for a scaled grouped GEMM.
67+
assert A.size(-1) == B_t.size(
68+
-2
69+
), f"shape {A.shape} and {B_t.shape} are not compatible for _scaled_grouped_mm"
70+
71+
# The left operand in the scaled grouped GEMM must be row-major due to hardware requirements.
72+
assert not _is_column_major(A), "A must be row-major"
73+
74+
# Due to hardware requirements, the right operand in a scaled grouped GEMM must be column-major.
75+
assert _is_column_major(B_t), "B must be column-major"
76+
77+
# Convert high precision input tensor to float8, row-major for left operand of grouped GEMM.
78+
# A shape: (M, K)
79+
# A_scales shape: (M,1)
80+
A_scales = tensor_to_scale(
81+
A,
82+
torch.float8_e4m3fn,
83+
scaling_granularity=ScalingGranularity.AXISWISE,
84+
axiswise_dim=-1,
85+
round_scales_to_power_of_2=True,
86+
)
87+
A_scaled = A.to(torch.float32) * A_scales
88+
A_fp8_row_major = to_fp8_saturated(A_scaled, torch.float8_e4m3fn)
89+
90+
# Convert B to float8, column-major for right operand of grouped GEMM.
91+
# B shape: (B, K, N)
92+
# B scales must be computed rowwise keeping the outer/final dim, so:
93+
# B_scales shape: (B, 1, N)
94+
B_t_scales = tensor_to_scale(
95+
B_t,
96+
torch.float8_e4m3fn,
97+
scaling_granularity=ScalingGranularity.AXISWISE,
98+
axiswise_dim=-2,
99+
round_scales_to_power_of_2=True,
100+
)
101+
B_t_scaled = B_t.to(torch.float32) * B_t_scales
102+
B_t_fp8_col_major = to_fp8_saturated(B_t_scaled, torch.float8_e4m3fn)
103+
104+
# Precompute non-transposed B column-major for backward, to save memory by storing the
105+
# low precision B tensor instead of the high precision B tensor.
106+
# In the backward this is needed for grad_A: grad_output @ B.
107+
B = B_t.contiguous().transpose(-2, -1)
108+
109+
# - B shape: (B, K, N)
110+
# - B scales must be computed rowwise keeping the outer/final dim, so:
111+
# - B_scale shape: (B, 1, N)
112+
B_scales = tensor_to_scale(
113+
B,
114+
torch.float8_e4m3fn,
115+
scaling_granularity=ScalingGranularity.AXISWISE,
116+
axiswise_dim=-2,
117+
round_scales_to_power_of_2=True,
118+
)
119+
B_scaled = B.to(torch.float32) * B_scales
120+
B_fp8_col_major = to_fp8_saturated(B_scaled, torch.float8_e4m3fn)
121+
122+
# Store what we need for backward.
123+
ctx.save_for_backward(A, B_fp8_col_major, B_scales, offs)
124+
ctx.out_dtype = out_dtype
125+
126+
# Perform scaled grouped GEMM and return result.
127+
# output shape: scaled grouped mm of (M,K) @ (B,K,N) = (M,N)
128+
return torch._scaled_grouped_mm(
129+
A_fp8_row_major,
130+
B_t_fp8_col_major,
131+
A_scales.squeeze().reciprocal(),
132+
B_t_scales.squeeze().reciprocal(),
133+
offs,
134+
out_dtype=out_dtype,
135+
use_fast_accum=True,
136+
)
137+
138+
@staticmethod
139+
def backward(ctx, grad_output: torch.Tensor):
140+
A, B_fp8_col_major, B_scales, offs = ctx.saved_tensors
141+
out_dtype = ctx.out_dtype
142+
143+
# Convert grad_output to float8, row-major for left operand of grouped GEMM
144+
# needed for grad_A: grad_output @ B
145+
#
146+
# grad_output shape: (M, N)
147+
# grad_output_scale shape: (M, 1)
148+
grad_output_scales = tensor_to_scale(
149+
grad_output,
150+
torch.float8_e4m3fn,
151+
scaling_granularity=ScalingGranularity.AXISWISE,
152+
axiswise_dim=-1,
153+
round_scales_to_power_of_2=True,
154+
)
155+
grad_output_scaled = grad_output.to(torch.float32) * grad_output_scales
156+
grad_output_fp8_row_major = to_fp8_saturated(
157+
grad_output_scaled, torch.float8_e4m3fn
158+
)
159+
160+
# Compute grad_A.
161+
#
162+
# grad_A = grad_output @ B
163+
# grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
164+
grad_A = torch._scaled_grouped_mm(
165+
grad_output_fp8_row_major,
166+
B_fp8_col_major,
167+
grad_output_scales.squeeze().reciprocal(),
168+
B_scales.squeeze().reciprocal(),
169+
offs,
170+
out_dtype=out_dtype,
171+
use_fast_accum=True,
172+
)
173+
174+
# Convert tranpose of grad_output to float8, row-major for left operand of grouped GEMM
175+
# needed for grad_B: grad_output_t @ A
176+
grad_output_t_row_major = grad_output.transpose(-2, -1).contiguous()
177+
178+
# Convert A to float8, column-major for right operand of grouped GEMM:
179+
# needed for grad_B: grad_output @ A
180+
A_col_major = A.transpose(-2, -1).contiguous().transpose(-2, -1)
181+
182+
# grad_B is a special case. both operands of the grouped gemm will be 2D with offsets determing the "groups."
183+
# Compute scales for grad_output_t and A, which are both 2D tensors with offsets which define the "jagged" groups.
184+
grad_output_t_fp8_row_major, grad_output_t_scales = (
185+
_to_2d_jagged_float8_tensor_rowwise(
186+
grad_output_t_row_major,
187+
offs,
188+
target_dtype=torch.float8_e4m3fn,
189+
round_scales_to_power_of_2=True,
190+
)
191+
)
192+
A_fp8_col_major, A_scales = _to_2d_jagged_float8_tensor_colwise(
193+
A_col_major,
194+
offs,
195+
target_dtype=torch.float8_e4m3fn,
196+
round_scales_to_power_of_2=True,
197+
)
198+
199+
# Compute grad_B = grad_output_t @ A.
200+
# grad_B = grad_output_t @ A
201+
# grad_B = (N,M) @ (M,K) = (N,K)
202+
grad_B = torch._scaled_grouped_mm(
203+
grad_output_t_fp8_row_major,
204+
A_fp8_col_major,
205+
grad_output_t_scales.reciprocal(),
206+
A_scales.reciprocal(),
207+
offs,
208+
out_dtype=out_dtype,
209+
use_fast_accum=True,
210+
)
211+
return grad_A, grad_B.transpose(-2, -1), None, None, None, None
212+
213+
214+
def _to_2d_jagged_float8_tensor_colwise(
215+
A_col_major: torch.Tensor,
216+
offs: torch.Tensor,
217+
target_dtype: torch.dtype = torch.float8_e4m3fn,
218+
round_scales_to_power_of_2: bool = False,
219+
) -> Tuple[torch.Tensor, torch.Tensor]:
220+
"""
221+
This function converts the 2D input tensor A to a jagged float8 tensor,
222+
with scales computed along *logical columns* for each group individually,
223+
where groups are determined based on the offsets.
224+
225+
For the right operand of a normal scaled GEMM, the rowwise scales are computed over logical columns.
226+
(i.e., a tensor of (K,N) will have scales of shape (1,N).
227+
228+
However, for a 2D right operand of a grouped GEMM, these logical columns go through multiple distinct
229+
groups/subtensors, for which we want to compute scales individually. So we cannot take one set of scales
230+
along the logical columns and apply it to the entire tensor.
231+
232+
Instead, we need to compute scales for each subtensor individually. For a tensor of shape (K,N) this results
233+
in scales of shape (1,N * num_groups).
234+
235+
Args:
236+
A (torch.Tensor): The input tensor to be converted to a jagged float8 tensor.
237+
238+
Returns:
239+
A tuple containing the jagged float8 tensor and the scales used for the conversion.
240+
"""
241+
assert A_col_major.ndim == 2, "A must be 2D"
242+
243+
num_groups = offs.numel()
244+
A_fp8_col_major = torch.empty_like(A_col_major, dtype=target_dtype)
245+
A_scales = torch.empty(
246+
A_fp8_col_major.size(1) * num_groups,
247+
dtype=torch.float32,
248+
device=A_fp8_col_major.device,
249+
)
250+
251+
start_idx = 0
252+
next_scale_idx = 0
253+
for end_idx in offs.tolist():
254+
# Get the subtensor of A for this group, fetching the next group of rows, with all columns for each.
255+
subtensor = A_col_major[start_idx:end_idx, :] # (local_group_size, K)
256+
257+
# Compute local rowwise scales for this subtensor, which are along logical columns for the right operand.
258+
subtensor_scales = tensor_to_scale(
259+
subtensor,
260+
target_dtype,
261+
scaling_granularity=ScalingGranularity.AXISWISE,
262+
axiswise_dim=0,
263+
round_scales_to_power_of_2=round_scales_to_power_of_2,
264+
)
265+
266+
# Apply scales to subtensor and convert to float8.
267+
tensor_scaled = subtensor.to(torch.float32) * subtensor_scales
268+
float8_subtensor = to_fp8_saturated(tensor_scaled, target_dtype)
269+
270+
# Store this portion of the resulting float8 tensor and scales.
271+
A_fp8_col_major[start_idx:end_idx, :] = float8_subtensor
272+
A_scales[next_scale_idx : next_scale_idx + subtensor_scales.numel()] = (
273+
subtensor_scales.squeeze()
274+
)
275+
276+
# Update start index for next group.
277+
start_idx = end_idx
278+
next_scale_idx += subtensor_scales.numel()
279+
280+
return A_fp8_col_major, A_scales
281+
282+
283+
def _to_2d_jagged_float8_tensor_rowwise(
284+
x: torch.Tensor,
285+
offs: torch.Tensor,
286+
target_dtype: torch.dtype,
287+
round_scales_to_power_of_2: bool = False,
288+
) -> Tuple[torch.Tensor, torch.Tensor]:
289+
"""
290+
This function converts the 2D input tensor to a jagged float8 tensor,
291+
with scales computed along *logical rows* for each group individually,
292+
where groups are determined based on the offsets.
293+
294+
For a 2D *left* operand of a normal scaled GEMM, the rowwise scales are computed over logical rows.
295+
(i.e., a tensor of (M,K) will have scales of shape (M,1).
296+
297+
However, for a 2D left operand of a grouped GEMM, these logical rows go through multiple distinct
298+
groups/subtensors, for which we want to compute scales individually. So we cannot take one set of scales
299+
along the logical rows and apply it to the entire tensor.
300+
301+
Instead, we need to compute scales for each subtensor individually. For a tensor of shape (M,K) this results
302+
in scales of shape (M * num_groups, 1).
303+
304+
Args:
305+
A (torch.Tensor): The input tensor to be converted to a jagged float8 tensor.
306+
307+
Returns:
308+
A tuple containing the jagged float8 tensor and the scales used for the conversion.
309+
"""
310+
assert x.ndim == 2, "input tensor must be 2D"
311+
312+
num_groups = offs.numel()
313+
x_fp8 = torch.empty_like(x, dtype=target_dtype)
314+
x_scales = torch.empty(
315+
x_fp8.size(0) * num_groups, dtype=torch.float32, device=x_fp8.device
316+
)
317+
318+
start_idx = 0
319+
next_scale_idx = 0
320+
for end_idx in offs.tolist():
321+
# Get the subtensor of A for this group, fetching all rows with the next group of rows.
322+
subtensor = x[:, start_idx:end_idx] # (M, local_group_size)
323+
324+
# Compute local rowwise scales for this subtensor, which are along logical rows for the left operand.
325+
subtensor_scales = tensor_to_scale(
326+
subtensor,
327+
target_dtype,
328+
scaling_granularity=ScalingGranularity.AXISWISE,
329+
axiswise_dim=-1,
330+
round_scales_to_power_of_2=round_scales_to_power_of_2,
331+
)
332+
333+
# Apply scales to subtensor and convert to float8.
334+
tensor_scaled = subtensor.to(torch.float32) * subtensor_scales
335+
float8_subtensor = to_fp8_saturated(tensor_scaled, target_dtype)
336+
337+
# Store this portion of the resulting float8 tensor and scales.
338+
x_fp8[:, start_idx:end_idx] = float8_subtensor
339+
x_scales[next_scale_idx : next_scale_idx + subtensor_scales.numel()] = (
340+
subtensor_scales.squeeze()
341+
)
342+
343+
# Update start index for next group.
344+
start_idx = end_idx
345+
next_scale_idx += subtensor_scales.numel()
346+
347+
return x_fp8, x_scales
348+
349+
350+
def _is_column_major(x: torch.Tensor) -> bool:
351+
"""
352+
This function checks if the input tensor is column-major.
353+
354+
Args:
355+
x (torch.Tensor): The input tensor to be checked.
356+
357+
Returns:
358+
A boolean indicating whether the input tensor is column-major.
359+
"""
360+
assert x.ndim == 2 or x.ndim == 3, "input tensor must be 2D or 3D"
361+
return x.stride(-2) == 1 and x.stride(-1) > 1

0 commit comments

Comments
 (0)