Skip to content

Commit eb49333

Browse files
integrate new differentiable fp8 conversion funcs into Float8NoCompileLinear (#1496)
1 parent e4827f2 commit eb49333

File tree

2 files changed

+74
-46
lines changed

2 files changed

+74
-46
lines changed

torchao/prototype/float8nocompile/float8nocompile_linear.py

Lines changed: 74 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
import torch
1212

1313
from torchao.float8.config import Float8LinearConfig
14-
from torchao.float8.float8_linear import manual_float8_matmul_with_args_in_float8
1514
from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig
1615
from torchao.prototype.float8nocompile.float8nocompile_scaling_utils import (
17-
Float8NoCompileConversionFunc,
18-
NoopFwToFloat8NoCompileBwDynamic,
16+
ToFP8ColumnMajor,
17+
ToFP8ColumnMajorT,
18+
ToFP8RowAndColumnMajor,
19+
ToFP8RowMajor,
20+
ToFP8RowMajorT,
1921
)
2022
from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import (
2123
KernelAlgorithm,
@@ -69,53 +71,14 @@ def __init__(self, *args, **kwargs):
6971

7072
def forward(self, input: torch.Tensor) -> torch.Tensor:
7173
# TODO(danielvegamyhre): support for FSDP once dependencies are implemented
72-
input_fp8 = self.cast_input_to_float8(input)
73-
weight_fp8_t = self.cast_weight_to_float8_t(self.weight)
74-
75-
# compute fp8 matmul
76-
output = manual_float8_matmul_with_args_in_float8.apply(input_fp8, weight_fp8_t)
77-
78-
# cast grad_output to float8_e5m2 during backward
79-
return self.cast_output_to_float8_in_bw(output)
80-
81-
def cast_input_to_float8(self, input: torch.Tensor) -> torch.Tensor:
82-
# Duplicate the autocast logic for F.linear, so that the output
83-
# of our module has the right original precision
84-
if torch.is_autocast_enabled():
85-
# For now, hardcode to GPU's autocast dtype
86-
# if we need CPU support in the future, we can add it
87-
autocast_dtype = torch.get_autocast_gpu_dtype()
88-
input = input.to(autocast_dtype)
89-
90-
return Float8NoCompileConversionFunc.apply(
74+
output = matmul_with_args_in_hp.apply(
9175
input,
92-
self.config.cast_config_input.target_dtype,
93-
self.linear_mm_config,
94-
GemmInputRole.INPUT,
95-
self.kernel_algo,
96-
)
97-
98-
def cast_weight_to_float8_t(
99-
self,
100-
weight: torch.Tensor,
101-
) -> torch.Tensor:
102-
weight_fp8 = Float8NoCompileConversionFunc.apply(
103-
weight,
104-
self.config.cast_config_weight.target_dtype,
105-
self.linear_mm_config,
106-
GemmInputRole.WEIGHT,
107-
self.kernel_algo,
108-
)
109-
return weight_fp8.t()
110-
111-
def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
112-
# casts grad_output to float8_e5m2 for backward
113-
return NoopFwToFloat8NoCompileBwDynamic.apply(
114-
output,
115-
self.config.cast_config_grad_output.target_dtype,
76+
self.weight,
77+
self.config,
11678
self.linear_mm_config,
11779
self.kernel_algo,
11880
)
81+
return output
11982

12083
@classmethod
12184
def from_float(cls, mod, kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX):
@@ -140,3 +103,68 @@ def from_float(cls, mod, kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_M
140103

141104
# TODO(danielvegamyhre): support for FSDP once dependencies are implemented
142105
return new_mod
106+
107+
108+
class matmul_with_args_in_hp(torch.autograd.Function):
109+
@staticmethod
110+
def forward(ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo):
111+
# output = input @ weight_t
112+
input_fp8_row_major, input_fp8_col_major = ToFP8RowAndColumnMajor.apply(
113+
input_hp,
114+
config.cast_config_input.target_dtype,
115+
linear_mm_config,
116+
GemmInputRole.INPUT,
117+
kernel_algo,
118+
)
119+
weight_t_fp8_col_major = ToFP8ColumnMajorT.apply(
120+
weight_hp,
121+
config.cast_config_weight.target_dtype,
122+
linear_mm_config,
123+
GemmInputRole.WEIGHT,
124+
kernel_algo,
125+
)
126+
output = torch.mm(input_fp8_row_major, weight_t_fp8_col_major)
127+
128+
# save data for backward before returning
129+
ctx.save_for_backward(input_fp8_col_major, weight_hp)
130+
ctx.config = config
131+
ctx.linear_mm_config = linear_mm_config
132+
ctx.kernel_algo = kernel_algo
133+
134+
return output
135+
136+
@staticmethod
137+
def backward(ctx, grad_output):
138+
input_fp8_col_major, weight_hp = ctx.saved_tensors
139+
140+
# cast grad output to float8_e5m2 for backward
141+
grad_output_fp8_row_major = ToFP8RowMajor.apply(
142+
grad_output,
143+
ctx.config.cast_config_grad_output.target_dtype,
144+
ctx.linear_mm_config,
145+
GemmInputRole.GRAD_OUTPUT,
146+
ctx.kernel_algo,
147+
)
148+
149+
# grad_input = grad_output @ weight
150+
weight_fp8_col_major = ToFP8ColumnMajor.apply(
151+
weight_hp,
152+
ctx.config.cast_config_weight.target_dtype,
153+
ctx.linear_mm_config,
154+
GemmInputRole.WEIGHT,
155+
ctx.kernel_algo,
156+
)
157+
grad_input = torch.mm(grad_output_fp8_row_major, weight_fp8_col_major)
158+
159+
# grad_weight = grad_output_t @ input
160+
# apparently this variant is slightly faster than `grad_weight_t = input_t @ grad_output`
161+
# source: https://github.com/pytorch/ao/blob/fe5f11b2c58b452e01ba9ec7359629928b143619/torchao/float8/float8_linear.py#L84-L85
162+
grad_output_t_row_major = ToFP8RowMajorT.apply(
163+
grad_output,
164+
ctx.config.cast_config_grad_output.target_dtype,
165+
ctx.linear_mm_config,
166+
GemmInputRole.GRAD_OUTPUT,
167+
ctx.kernel_algo,
168+
)
169+
grad_weight = torch.mm(grad_output_t_row_major, input_fp8_col_major)
170+
return grad_input, grad_weight, None, None, None

0 commit comments

Comments
 (0)