11
11
import torch
12
12
13
13
from torchao .float8 .config import Float8LinearConfig
14
- from torchao .float8 .float8_linear import manual_float8_matmul_with_args_in_float8
15
14
from torchao .float8 .float8_tensor import GemmInputRole , LinearMMConfig , ScaledMMConfig
16
15
from torchao .prototype .float8nocompile .float8nocompile_scaling_utils import (
17
- Float8NoCompileConversionFunc ,
18
- NoopFwToFloat8NoCompileBwDynamic ,
16
+ ToFP8ColumnMajor ,
17
+ ToFP8ColumnMajorT ,
18
+ ToFP8RowAndColumnMajor ,
19
+ ToFP8RowMajor ,
20
+ ToFP8RowMajorT ,
19
21
)
20
22
from torchao .prototype .float8nocompile .kernels .fp8_dynamic_tensorwise import (
21
23
KernelAlgorithm ,
@@ -69,53 +71,14 @@ def __init__(self, *args, **kwargs):
69
71
70
72
def forward (self , input : torch .Tensor ) -> torch .Tensor :
71
73
# TODO(danielvegamyhre): support for FSDP once dependencies are implemented
72
- input_fp8 = self .cast_input_to_float8 (input )
73
- weight_fp8_t = self .cast_weight_to_float8_t (self .weight )
74
-
75
- # compute fp8 matmul
76
- output = manual_float8_matmul_with_args_in_float8 .apply (input_fp8 , weight_fp8_t )
77
-
78
- # cast grad_output to float8_e5m2 during backward
79
- return self .cast_output_to_float8_in_bw (output )
80
-
81
- def cast_input_to_float8 (self , input : torch .Tensor ) -> torch .Tensor :
82
- # Duplicate the autocast logic for F.linear, so that the output
83
- # of our module has the right original precision
84
- if torch .is_autocast_enabled ():
85
- # For now, hardcode to GPU's autocast dtype
86
- # if we need CPU support in the future, we can add it
87
- autocast_dtype = torch .get_autocast_gpu_dtype ()
88
- input = input .to (autocast_dtype )
89
-
90
- return Float8NoCompileConversionFunc .apply (
74
+ output = matmul_with_args_in_hp .apply (
91
75
input ,
92
- self .config .cast_config_input .target_dtype ,
93
- self .linear_mm_config ,
94
- GemmInputRole .INPUT ,
95
- self .kernel_algo ,
96
- )
97
-
98
- def cast_weight_to_float8_t (
99
- self ,
100
- weight : torch .Tensor ,
101
- ) -> torch .Tensor :
102
- weight_fp8 = Float8NoCompileConversionFunc .apply (
103
- weight ,
104
- self .config .cast_config_weight .target_dtype ,
105
- self .linear_mm_config ,
106
- GemmInputRole .WEIGHT ,
107
- self .kernel_algo ,
108
- )
109
- return weight_fp8 .t ()
110
-
111
- def cast_output_to_float8_in_bw (self , output : torch .Tensor ) -> torch .Tensor :
112
- # casts grad_output to float8_e5m2 for backward
113
- return NoopFwToFloat8NoCompileBwDynamic .apply (
114
- output ,
115
- self .config .cast_config_grad_output .target_dtype ,
76
+ self .weight ,
77
+ self .config ,
116
78
self .linear_mm_config ,
117
79
self .kernel_algo ,
118
80
)
81
+ return output
119
82
120
83
@classmethod
121
84
def from_float (cls , mod , kernel_algo : KernelAlgorithm = KernelAlgorithm .ATOMIC_MAX ):
@@ -140,3 +103,68 @@ def from_float(cls, mod, kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_M
140
103
141
104
# TODO(danielvegamyhre): support for FSDP once dependencies are implemented
142
105
return new_mod
106
+
107
+
108
+ class matmul_with_args_in_hp (torch .autograd .Function ):
109
+ @staticmethod
110
+ def forward (ctx , input_hp , weight_hp , config , linear_mm_config , kernel_algo ):
111
+ # output = input @ weight_t
112
+ input_fp8_row_major , input_fp8_col_major = ToFP8RowAndColumnMajor .apply (
113
+ input_hp ,
114
+ config .cast_config_input .target_dtype ,
115
+ linear_mm_config ,
116
+ GemmInputRole .INPUT ,
117
+ kernel_algo ,
118
+ )
119
+ weight_t_fp8_col_major = ToFP8ColumnMajorT .apply (
120
+ weight_hp ,
121
+ config .cast_config_weight .target_dtype ,
122
+ linear_mm_config ,
123
+ GemmInputRole .WEIGHT ,
124
+ kernel_algo ,
125
+ )
126
+ output = torch .mm (input_fp8_row_major , weight_t_fp8_col_major )
127
+
128
+ # save data for backward before returning
129
+ ctx .save_for_backward (input_fp8_col_major , weight_hp )
130
+ ctx .config = config
131
+ ctx .linear_mm_config = linear_mm_config
132
+ ctx .kernel_algo = kernel_algo
133
+
134
+ return output
135
+
136
+ @staticmethod
137
+ def backward (ctx , grad_output ):
138
+ input_fp8_col_major , weight_hp = ctx .saved_tensors
139
+
140
+ # cast grad output to float8_e5m2 for backward
141
+ grad_output_fp8_row_major = ToFP8RowMajor .apply (
142
+ grad_output ,
143
+ ctx .config .cast_config_grad_output .target_dtype ,
144
+ ctx .linear_mm_config ,
145
+ GemmInputRole .GRAD_OUTPUT ,
146
+ ctx .kernel_algo ,
147
+ )
148
+
149
+ # grad_input = grad_output @ weight
150
+ weight_fp8_col_major = ToFP8ColumnMajor .apply (
151
+ weight_hp ,
152
+ ctx .config .cast_config_weight .target_dtype ,
153
+ ctx .linear_mm_config ,
154
+ GemmInputRole .WEIGHT ,
155
+ ctx .kernel_algo ,
156
+ )
157
+ grad_input = torch .mm (grad_output_fp8_row_major , weight_fp8_col_major )
158
+
159
+ # grad_weight = grad_output_t @ input
160
+ # apparently this variant is slightly faster than `grad_weight_t = input_t @ grad_output`
161
+ # source: https://github.com/pytorch/ao/blob/fe5f11b2c58b452e01ba9ec7359629928b143619/torchao/float8/float8_linear.py#L84-L85
162
+ grad_output_t_row_major = ToFP8RowMajorT .apply (
163
+ grad_output ,
164
+ ctx .config .cast_config_grad_output .target_dtype ,
165
+ ctx .linear_mm_config ,
166
+ GemmInputRole .GRAD_OUTPUT ,
167
+ ctx .kernel_algo ,
168
+ )
169
+ grad_weight = torch .mm (grad_output_t_row_major , input_fp8_col_major )
170
+ return grad_input , grad_weight , None , None , None
0 commit comments