26
26
27
27
class FbgemmFp8Tensor (TorchAOBaseTensor ):
28
28
"""
29
+ Float8 Rowwise Quantized (weight) Tensor, with float8 rowwise dynamic quantization for activation.
29
30
TODO: needs padding for cutlass kernels
31
+
32
+ Tensor Attributes:
33
+ float8_data: float8 raw data, dtype torchao.float8.config.e4m3_dtype
34
+ scale: the rowwise scale for float8 Tensor
35
+ activation_scale_ub: upper bound for activation scale, used during dynamic quantization for activation
36
+
37
+ Non-Tensor Attributes:
38
+ dtype: Original Tensor dtype
30
39
"""
31
40
32
41
tensor_data_attrs = ["float8_data" , "scale" , "activation_scale_ub" ]
@@ -40,7 +49,9 @@ def __new__(cls, float8_data, scale, activation_scale_ub, dtype):
40
49
kwargs ["requires_grad" ] = False
41
50
return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
42
51
43
- def __init__ (self , float8_data , scale , activation_scale_ub , dtype ):
52
+ def __init__ (
53
+ self , float8_data , scale , activation_scale_ub , dtype
54
+ ):
44
55
self .float8_data = float8_data
45
56
self .scale = scale
46
57
self .activation_scale_ub = activation_scale_ub
@@ -85,6 +96,47 @@ def to(self, *args, **kwargs):
85
96
self .dtype ,
86
97
)
87
98
99
+ def _transpose_and_reshape (self ):
100
+ """This is added for resharding support, since the resharding logic for the model we are
101
+ working with only support 2D
102
+ """
103
+ assert len (self .shape ) == 3 , (
104
+ f"Only expected to be used when the Tensor is 3D, got { len (self .shape )} "
105
+ )
106
+ dim0 , dim1 , dim2 = self .shape
107
+ # because we first transpose the weight before quantization, we'll recover the original shape
108
+ # by swapping dim1 and dim2
109
+ original_shape = (dim0 , dim2 , dim1 )
110
+ # we must save this as 2D in the state dict, since loading code expects 2D weights
111
+ new_shape = (- 1 , original_shape [- 1 ])
112
+ float8_data = self .float8_data
113
+ float8_data = float8_data .transpose (1 , 2 ).reshape (* new_shape ).contiguous ()
114
+ scale = self .scale .transpose (1 , 2 ).reshape (* new_shape ).contiguous ()
115
+ return self .__class__ (
116
+ float8_data ,
117
+ scale ,
118
+ self .activation_scale_ub ,
119
+ self .dtype ,
120
+ )
121
+
122
+ def _unflatten (self , num_experts ):
123
+ """This is added for resharding support, since the resharding logic for the model we are
124
+ working with only support 2D
125
+ """
126
+ float8_data = self .float8_data
127
+ scale = self .scale
128
+ dim0 , dim1 = self .shape
129
+ float8_data = float8_data .unflatten (0 , (num_experts , - 1 )).squeeze (dim = 0 )
130
+ scale = scale .unflatten (0 , (num_experts , - 1 )).squeeze (dim = 0 )
131
+ dim0 , dim1 , dim2 = float8_data .shape
132
+
133
+ return self .__class__ (
134
+ float8_data ,
135
+ scale ,
136
+ self .activation_scale_ub ,
137
+ self .dtype ,
138
+ )
139
+
88
140
@classmethod
89
141
def from_float (
90
142
cls ,
@@ -106,8 +158,10 @@ def from_float(
106
158
else :
107
159
w = w .t ()
108
160
109
- wq , w_scale = torch .ops .triton .quantize_fp8_row (w )
110
- # wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
161
+ wq , w_scale = torch .ops .fbgemm .quantize_fp8_per_row (w )
162
+ # add a last dimension for per row quantization to align the rank of
163
+ # w_scale and wq
164
+ w_scale = w_scale .unsqueeze (- 1 ).contiguous ()
111
165
dtype = w .dtype
112
166
del w
113
167
return FbgemmFp8Tensor (
@@ -133,18 +187,18 @@ def _(func, types, args, kwargs):
133
187
134
188
# not used
135
189
num_tokens = torch .empty ([input_tensor .size (0 )], device = input_tensor .device )
136
- xq , x_scale = torch .ops .fbgemm .quantize_fp8_per_row (
190
+ a_data , a_scale = torch .ops .fbgemm .quantize_fp8_per_row (
137
191
input_tensor , num_tokens , weight_tensor .activation_scale_ub
138
192
)
139
193
140
- a_data = xq
141
194
b_data = weight_tensor .float8_data
195
+ b_scale = weight_tensor .scale .squeeze (- 1 )
142
196
143
197
res = torch .ops .fbgemm .f8f8bf16_rowwise (
144
198
a_data ,
145
199
b_data ,
146
- x_scale ,
147
- weight_tensor . scale ,
200
+ a_scale ,
201
+ b_scale ,
148
202
use_fast_accum = True ,
149
203
)
150
204
res = res .reshape (* orig_act_size [:- 1 ], orig_out_features )
@@ -163,19 +217,21 @@ def _(func, types, args, kwargs):
163
217
orig_act_size = input_tensor .size ()
164
218
# not used
165
219
num_tokens = torch .empty ([input_tensor .size (0 )], device = input_tensor .device )
166
- xq , x_scale = torch .ops .fbgemm .quantize_fp8_per_row (
220
+ a_data , a_scale = torch .ops .fbgemm .quantize_fp8_per_row (
167
221
input_tensor , num_tokens , weight_tensor .activation_scale_ub
168
222
)
169
223
170
- a_data = xq
171
224
b_data = weight_tensor .float8_data
225
+ b_scale = weight_tensor .scale .squeeze (- 1 )
226
+ assert b_data .is_contiguous (), "weight for bmm must be contiguous"
227
+
172
228
orig_out_features = b_data .shape [- 2 ]
173
229
174
230
res = torch .ops .fbgemm .f8f8bf16_rowwise_batched (
175
231
a_data ,
176
232
b_data ,
177
- x_scale ,
178
- weight_tensor . scale ,
233
+ a_scale ,
234
+ b_scale ,
179
235
)
180
236
res = res .reshape (* orig_act_size [:- 1 ], orig_out_features )
181
237
return res
@@ -269,6 +325,52 @@ def _(func, types, args, kwargs):
269
325
)
270
326
271
327
328
+ @implements (aten .cat .default )
329
+ def _ (func , types , args , kwargs ):
330
+ tensors , dim = fill_defaults (args , 2 , [[], 0 ])
331
+ tensor_0 = tensors [0 ]
332
+ if dim < 0 :
333
+ dim = tensor_0 .ndim + dim
334
+
335
+ for i in range (1 , len (tensors )):
336
+ assert tensor_0 .float8_data .ndim == tensors [i ].float8_data .ndim
337
+ assert tensor_0 .scale .ndim == tensors [i ].scale .ndim
338
+ assert tensor_0 .activation_scale_ub == tensors [i ].activation_scale_ub
339
+
340
+ float8_datas = [t .float8_data for t in tensors ]
341
+ scales = [t .scale for t in tensors ]
342
+
343
+ # with rowwise quantization, dimension of float8_data and
344
+ # origianl shape will be the same, so original dim argument applies
345
+ # to float8_data
346
+ cat_float8_data = aten .cat .default (float8_datas , dim )
347
+
348
+ if dim != 2 :
349
+ cat_scale = aten .cat .default (scales , dim = dim )
350
+ else :
351
+ cat_scale = scales [0 ]
352
+
353
+ new = tensor_0 .__class__ (
354
+ cat_float8_data ,
355
+ cat_scale ,
356
+ tensor_0 .activation_scale_ub ,
357
+ tensor_0 .dtype ,
358
+ )
359
+ return return_and_correct_aliasing (func , args , kwargs , new )
360
+
361
+
362
+ @implements (aten .transpose .int )
363
+ def _ (func , types , args , kwargs ):
364
+ self , dim0 , dim1 = args
365
+ float8_data = self .float8_data .transpose (dim0 , dim1 ).contiguous ()
366
+ scale = self .scale .transpose (dim0 , dim1 ).contiguous ()
367
+
368
+ new = self .__class__ (
369
+ float8_data , scale , self .activation_scale_ub , self .dtype
370
+ )
371
+ return return_and_correct_aliasing (func , args , kwargs , new )
372
+
373
+
272
374
to_fbgemm_fp8 = FbgemmFp8Tensor .from_float
273
375
274
376
0 commit comments