Skip to content

Commit f86fda9

Browse files
add handling for batch dim in float8nocompile (#1512)
1 parent 457c5b1 commit f86fda9

File tree

4 files changed

+138
-6
lines changed

4 files changed

+138
-6
lines changed

torchao/prototype/float8nocompile/float8nocompile_linear.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,20 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
8080
return output
8181

8282
@classmethod
83-
def from_float(cls, mod, kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX):
83+
def from_float(
84+
cls,
85+
mod,
86+
config: Float8LinearConfig, # only default config is supported, non-defaults silently ignored
87+
kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
88+
):
8489
"""
8590
Create an nn.Linear with fp8 compute from a regular nn.Linear
8691
8792
Args:
8893
mod (torch.nn.Linear): nn.Linear to convert
89-
config (Optional[Float8LinearConfig]): configuration for conversion to float8
94+
config (Optional[Float8LinearConfig]): configuration for conversion to float8 (note: only
95+
default config is supported, non-defaults silently ignored)
9096
"""
91-
config = Float8LinearConfig()
9297
with torch.device("meta"):
9398
new_mod = cls(
9499
mod.in_features,
@@ -107,6 +112,10 @@ def from_float(cls, mod, kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_M
107112
class matmul_with_args_in_hp(torch.autograd.Function):
108113
@staticmethod
109114
def forward(ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo):
115+
# reshape to be 2D for triton kernels
116+
orig_input_shape = input_hp.shape
117+
input_hp = input_hp.reshape(-1, input_hp.shape[-1])
118+
110119
# output = input @ weight_t
111120
input_fp8_row_major, input_fp8_col_major = ToFP8RowAndColumnMajor.apply(
112121
input_hp,
@@ -130,12 +139,24 @@ def forward(ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo):
130139
ctx.linear_mm_config = linear_mm_config
131140
ctx.kernel_algo = kernel_algo
132141

142+
# reshape back to expected dims
143+
output = output.reshape(*orig_input_shape[:-1], output.shape[-1])
133144
return output
134145

135146
@staticmethod
136147
def backward(ctx, grad_output):
148+
# grad_output may not be contiguous in cases like:
149+
# output.sum().backward() where grad is all 1s, so the (M,N) view of the scalar "1"
150+
# results in a non-contiguous tensor with stride (0,0).
151+
if not grad_output.is_contiguous():
152+
grad_output = grad_output.contiguous()
153+
137154
input_fp8_col_major, weight_hp = ctx.saved_tensors
138155

156+
# reshsape to be 2D for triton kernels
157+
orig_grad_output_shape = grad_output.shape
158+
grad_output = grad_output.reshape(-1, grad_output.shape[-1])
159+
139160
# cast grad output to float8_e5m2 for backward
140161
grad_output_fp8_row_major, grad_output_t_row_major = (
141162
ToFP8RowMajorTAndNonT.apply(
@@ -162,4 +183,10 @@ def backward(ctx, grad_output):
162183
# source: https://github.com/pytorch/ao/blob/fe5f11b2c58b452e01ba9ec7359629928b143619/torchao/float8/float8_linear.py#L84-L85
163184
grad_weight = torch.mm(grad_output_t_row_major, input_fp8_col_major)
164185

186+
# reshape grad input to match original shape
187+
grad_input = grad_input.reshape(
188+
*orig_grad_output_shape[:-1], grad_input.shape[-1]
189+
)
190+
191+
# grad input shape
165192
return grad_input, grad_weight, None, None, None
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import pytest
2+
import torch
3+
4+
from torchao.float8.config import Float8LinearConfig
5+
from torchao.float8.float8_linear import manual_float8_matmul_with_args_in_hp
6+
from torchao.float8.float8_tensor import LinearMMConfig, ScaledMMConfig
7+
from torchao.prototype.float8nocompile.float8nocompile_linear import (
8+
matmul_with_args_in_hp,
9+
)
10+
from torchao.prototype.float8nocompile.float8nocompile_scaling_utils import (
11+
KernelAlgorithm,
12+
)
13+
14+
15+
# unit test comparing the two implementations
16+
@pytest.mark.parametrize(
17+
"input_shape",
18+
[(32, 16), (1, 32, 16), (2, 32, 16)],
19+
)
20+
def test_matmul_with_args_in_hp(input_shape: tuple[int, int]):
21+
assert torch.cuda.is_available()
22+
device = "cuda"
23+
24+
# high precision inputs
25+
input_bf16 = torch.randn(
26+
input_shape, dtype=torch.bfloat16, device=device, requires_grad=True
27+
)
28+
prod_input_bf16 = input_bf16.clone().detach().to(device).requires_grad_(True)
29+
prototype_input_bf16 = input_bf16.clone().detach().to(device).requires_grad_(True)
30+
31+
# high precision weights
32+
# nn.Linear stores weights in transposed form
33+
weight_bf16 = torch.randn(
34+
(32, input_bf16.shape[-1]),
35+
dtype=torch.bfloat16,
36+
device=device,
37+
requires_grad=True,
38+
)
39+
prod_weight_bf16 = weight_bf16.clone().detach().to(device).requires_grad_(True)
40+
prototype_weight_bf16 = weight_bf16.clone().detach().to(device).requires_grad_(True)
41+
42+
# default configs
43+
config = Float8LinearConfig()
44+
emulate = False
45+
linear_mm_config = linear_mm_config = LinearMMConfig(
46+
# output
47+
ScaledMMConfig(
48+
emulate,
49+
config.gemm_config_output.use_fast_accum,
50+
False,
51+
config.pad_inner_dim,
52+
),
53+
# grad_input
54+
ScaledMMConfig(
55+
emulate,
56+
config.gemm_config_grad_input.use_fast_accum,
57+
False,
58+
config.pad_inner_dim,
59+
),
60+
# grad_weight
61+
ScaledMMConfig(
62+
emulate,
63+
config.gemm_config_grad_weight.use_fast_accum,
64+
False,
65+
config.pad_inner_dim,
66+
),
67+
)
68+
69+
# prod forward. expects transposed weight.
70+
out_prod = manual_float8_matmul_with_args_in_hp.apply(
71+
prod_input_bf16, prod_weight_bf16.t(), linear_mm_config, config
72+
)
73+
74+
# prototype forward. expects non-transposed weight
75+
out_prototype = matmul_with_args_in_hp.apply(
76+
prototype_input_bf16,
77+
prototype_weight_bf16,
78+
config,
79+
linear_mm_config,
80+
KernelAlgorithm.ATOMIC_MAX,
81+
)
82+
83+
# compare model outputs
84+
assert torch.allclose(out_prod, out_prototype, atol=0, rtol=0)
85+
86+
out_prod.sum().backward()
87+
out_prototype.sum().backward()
88+
89+
# compare input gradients
90+
assert torch.allclose(
91+
prod_input_bf16.grad, prototype_input_bf16.grad, atol=0, rtol=0
92+
)
93+
94+
# compare weight gradients
95+
assert torch.allclose(
96+
prod_weight_bf16.grad, prototype_weight_bf16.grad, atol=0, rtol=0
97+
)

torchao/prototype/float8nocompile/float8nocompile_linear_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import torch.nn as nn
1010

11+
from torchao.float8.config import Float8LinearConfig
1112
from torchao.float8.float8_linear_utils import swap_linear_layers
1213
from torchao.prototype.float8nocompile.float8nocompile_linear import (
1314
Float8LinearNoCompile,
@@ -23,6 +24,7 @@
2324
def convert_to_float8_nocompile_training(
2425
module: nn.Module,
2526
*,
27+
config: Float8LinearConfig = None,
2628
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
2729
kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
2830
) -> nn.Module:
@@ -39,7 +41,12 @@ def convert_to_float8_nocompile_training(
3941
Returns:
4042
nn.Module: The modified module with swapped linear layers.
4143
"""
42-
from_float = lambda m: Float8LinearNoCompile.from_float(m, kernel_algo=kernel_algo)
44+
if config is None:
45+
config = Float8LinearConfig()
46+
47+
from_float = lambda m: Float8LinearNoCompile.from_float(
48+
m, config=config, kernel_algo=kernel_algo
49+
)
4350
return swap_linear_layers(
4451
module,
4552
from_float,

torchao/prototype/float8nocompile/test/train_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def model2():
3636
return TestModel()
3737

3838

39-
def test_model_weights_and_gradients(model1, model2):
39+
@pytest.mark.parametrize("input_shape", [(16, 32), (1, 16, 32), (2, 16, 32)])
40+
def test_model_weights_and_gradients(model1, model2, input_shape: tuple[int, int]):
4041
assert torch.cuda.is_available()
4142
device = torch.device("cuda")
4243

@@ -48,7 +49,7 @@ def test_model_weights_and_gradients(model1, model2):
4849
convert_to_float8_nocompile_training(model1)
4950

5051
input_tensor = torch.randn(
51-
16, 32, requires_grad=True, dtype=torch.bfloat16, device=device
52+
*input_shape, requires_grad=True, dtype=torch.bfloat16, device=device
5253
)
5354
input_copy1 = input_tensor.clone().detach().requires_grad_(True)
5455
input_copy2 = input_tensor.clone().detach().requires_grad_(True)

0 commit comments

Comments
 (0)