27
27
class FbgemmFp8Tensor (TorchAOBaseTensor ):
28
28
"""
29
29
TODO: needs padding for cutlass kernels
30
+ Args:
31
+ data_to_scale_dim: the dim mapping from float8_data to scale, e.g.
32
+ float8_data: (batch_size, output_channel, input_channel)
33
+ scale: (batch_size, output_channel) (since it's per row quantization)
34
+ data_to_scale_dim: {0: 0, 1: 1}
30
35
"""
31
36
32
37
tensor_data_attrs = ["float8_data" , "scale" , "activation_scale_ub" ]
33
- tensor_attributes = ["dtype" ]
38
+ tensor_attributes = ["data_to_scale_dim" , " dtype" ]
34
39
35
- def __new__ (cls , float8_data , scale , activation_scale_ub , dtype ):
40
+ def __new__ (cls , float8_data , scale , activation_scale_ub , data_to_scale_dim , dtype ):
36
41
shape = float8_data .shape
37
42
kwargs = {}
38
43
kwargs ["device" ] = float8_data .device
39
44
kwargs ["dtype" ] = dtype
40
45
kwargs ["requires_grad" ] = False
41
46
return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
42
47
43
- def __init__ (self , float8_data , scale , activation_scale_ub , dtype ):
48
+ def __init__ (self , float8_data , scale , activation_scale_ub , data_to_scale_dim , dtype ):
44
49
self .float8_data = float8_data
45
50
self .scale = scale
51
+ self .data_to_scale_dim = data_to_scale_dim
46
52
self .activation_scale_ub = activation_scale_ub
47
53
48
54
def __tensor_flatten__ (self ):
@@ -68,12 +74,12 @@ def _apply_fn_to_data(self, fn):
68
74
def __repr__ (self ):
69
75
return (
70
76
f"{ self .__class__ .__name__ } (weight={ self .float8_data } , scale={ self .scale } , "
71
- f"activation_scale_ub={ self .activation_scale_ub } , "
77
+ f"activation_scale_ub={ self .activation_scale_ub } , data_to_scale_dim= { self . data_to_scale_dim } , "
72
78
f"shape={ self .shape } , device={ self .device } , dtype={ self .dtype } , requires_grad={ self .requires_grad } )"
73
79
)
74
80
75
81
def _quantization_type (self ):
76
- return f"shape={ self .shape } , activation_scale_ub={ self .activation_scale_ub } , device={ self .device } "
82
+ return f"shape={ self .shape } , data_to_scale_dim= { self . data_to_scale_dim } , activation_scale_ub={ self .activation_scale_ub } , device={ self .device } "
77
83
78
84
def to (self , * args , ** kwargs ):
79
85
kwargs = self ._get_to_kwargs (* args , ** kwargs )
@@ -82,9 +88,57 @@ def to(self, *args, **kwargs):
82
88
self .float8_data .to (device ),
83
89
self .scale .to (device ),
84
90
self .activation_scale_ub .to (device ),
91
+ self .data_to_scale_dim ,
85
92
self .dtype ,
86
93
)
87
94
95
+ def _transpose_and_reshape (self ):
96
+ """This is added for resharding support, since the resharding logic for the model we are
97
+ working with only support 2D
98
+ """
99
+ assert len (self .shape ) == 3 , f"Only expected to be used when the Tensor is 3D, got { len (self .shape )} "
100
+ dim0 , dim1 , dim2 = self .shape
101
+ # because we first transpose the weight before quantization, we'll recover the original shape
102
+ # by swapping dim1 and dim2
103
+ original_shape = (dim0 , dim2 , dim1 )
104
+ # we must save this as 2D in the state dict, since loading code expects 2D weights
105
+ new_shape = (- 1 , original_shape [- 1 ])
106
+ float8_data = self .float8_data
107
+ float8_data = float8_data .transpose (1 , 2 ).reshape (* new_shape ).contiguous ()
108
+ data_to_scale_dim = {
109
+ 0 : 0 ,
110
+ 1 : 1
111
+ }
112
+ return self .__class__ (
113
+ float8_data ,
114
+ self .scale ,
115
+ self .activation_scale_ub ,
116
+ data_to_scale_dim ,
117
+ self .dtype
118
+ )
119
+
120
+ def _unflatten (self , num_experts ):
121
+ """This is added for resharding support, since the resharding logic for the model we are
122
+ working with only support 2D
123
+ """
124
+ float8_data = self .float8_data
125
+ dim0 , dim1 = self .shape
126
+ float8_data = float8_data .unflatten (0 , (num_experts , - 1 )).squeeze (dim = 0 )
127
+ data_to_scale_dim = {0 : 0 }
128
+ dim0 , dim1 , dim2 = float8_data .shape
129
+ if dim1 == self .scale .shape [1 ]:
130
+ data_to_scale_dim [1 ] = 1
131
+ else :
132
+ data_to_scale_dim [2 ] = 1
133
+
134
+ return self .__class__ (
135
+ float8_data ,
136
+ self .scale ,
137
+ self .activation_scale_ub ,
138
+ data_to_scale_dim ,
139
+ self .dtype
140
+ )
141
+
88
142
@classmethod
89
143
def from_float (
90
144
cls ,
@@ -106,14 +160,18 @@ def from_float(
106
160
else :
107
161
w = w .t ()
108
162
109
- wq , w_scale = torch .ops .triton .quantize_fp8_row (w )
110
- # wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
163
+ data_to_scale_dim = {0 : 0 }
164
+ if w .ndim == 3 :
165
+ data_to_scale_dim [1 ] = 1
166
+
167
+ wq , w_scale = torch .ops .fbgemm .quantize_fp8_per_row (w )
111
168
dtype = w .dtype
112
169
del w
113
170
return FbgemmFp8Tensor (
114
171
wq ,
115
172
w_scale ,
116
173
activation_scale_ub = activation_scale_ub ,
174
+ data_to_scale_dim = data_to_scale_dim ,
117
175
dtype = dtype ,
118
176
)
119
177
@@ -169,6 +227,8 @@ def _(func, types, args, kwargs):
169
227
170
228
a_data = xq
171
229
b_data = weight_tensor .float8_data
230
+ assert b_data .is_contiguous (), "weight for bmm must be contiguous"
231
+
172
232
orig_out_features = b_data .shape [- 2 ]
173
233
174
234
res = torch .ops .fbgemm .f8f8bf16_rowwise_batched (
@@ -269,6 +329,65 @@ def _(func, types, args, kwargs):
269
329
)
270
330
271
331
332
+ @implements (aten .cat .default )
333
+ def _ (func , types , args , kwargs ):
334
+ tensors , dim = fill_defaults (args , 2 , [[], 0 ])
335
+ tensor_0 = tensors [0 ]
336
+ if dim < 0 :
337
+ dim = tensor_0 .ndim + dim
338
+
339
+ for i in range (1 , len (tensors )):
340
+ assert tensor_0 .float8_data .ndim == tensors [i ].float8_data .ndim
341
+ assert tensor_0 .scale .ndim == tensors [i ].scale .ndim
342
+ assert tensor_0 .activation_scale_ub == tensors [i ].activation_scale_ub
343
+ assert tensor_0 .data_to_scale_dim == tensors [i ].data_to_scale_dim
344
+
345
+
346
+ float8_data = [t .float8_data for t in tensors ]
347
+ scale = [t .scale for t in tensors ]
348
+
349
+ # with rowwise quantization, dimension of float8_data and
350
+ # origianl shape will be the same, so original dim argument applies
351
+ # to float8_data
352
+ cat_float8_data = aten .cat .default (float8_data , dim )
353
+
354
+ # if cat dimension has a corresponding scale dimension, then we'll concat the corresponding
355
+ # scale dimension, otherwise, we'll just use the existing scale
356
+ if dim in tensor_0 .data_to_scale_dim :
357
+ cat_scale = aten .cat .default (scale , dim = tensor_0 .data_to_scale_dim [dim ])
358
+ else :
359
+ cat_scale = scale [0 ]
360
+
361
+ new = tensor_0 .__class__ (
362
+ cat_float8_data , cat_scale , tensor_0 .activation_scale_ub , tensor_0 .data_to_scale_dim , tensor_0 .dtype
363
+ )
364
+ return return_and_correct_aliasing (func , args , kwargs , new )
365
+
366
+
367
+ @implements (aten .transpose .int )
368
+ def _ (func , types , args , kwargs ):
369
+ self , dim0 , dim1 = args
370
+ float8_data = self .float8_data .transpose (dim0 , dim1 ).contiguous ()
371
+ data_to_scale_dim = self .data_to_scale_dim .copy ()
372
+
373
+ if dim0 in data_to_scale_dim :
374
+ data_to_scale_dim [dim1 ] = data_to_scale_dim [dim0 ]
375
+ del data_to_scale_dim [dim0 ]
376
+ elif dim1 in data_to_scale_dim :
377
+ data_to_scale_dim [dim0 ] = data_to_scale_dim [dim1 ]
378
+ del data_to_scale_dim [dim1 ]
379
+
380
+ new = self .__class__ (
381
+ float8_data ,
382
+ self .scale ,
383
+ self .activation_scale_ub ,
384
+ data_to_scale_dim ,
385
+ self .dtype
386
+ )
387
+ return return_and_correct_aliasing (
388
+ func , args , kwargs , new
389
+ )
390
+
272
391
to_fbgemm_fp8 = FbgemmFp8Tensor .from_float
273
392
274
393
0 commit comments