27
27
class FbgemmFp8Tensor (TorchAOBaseTensor ):
28
28
"""
29
29
TODO: needs padding for cutlass kernels
30
+ Args:
31
+ data_to_scale_dim: the dim mapping from float8_data to scale, e.g.
32
+ float8_data: (batch_size, output_channel, input_channel)
33
+ scale: (batch_size, output_channel) (since it's per row quantization)
34
+ data_to_scale_dim: {0: 0, 1: 1}
30
35
"""
31
36
32
37
tensor_data_attrs = ["float8_data" , "scale" , "activation_scale_ub" ]
33
- tensor_attributes = ["dtype" ]
38
+ tensor_attributes = ["data_to_scale_dim" , " dtype" ]
34
39
35
- def __new__ (cls , float8_data , scale , activation_scale_ub , dtype ):
40
+ def __new__ (cls , float8_data , scale , activation_scale_ub , data_to_scale_dim , dtype ):
36
41
shape = float8_data .shape
37
42
kwargs = {}
38
43
kwargs ["device" ] = float8_data .device
39
44
kwargs ["dtype" ] = dtype
40
45
kwargs ["requires_grad" ] = False
41
46
return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
42
47
43
- def __init__ (self , float8_data , scale , activation_scale_ub , dtype ):
48
+ def __init__ (
49
+ self , float8_data , scale , activation_scale_ub , data_to_scale_dim , dtype
50
+ ):
44
51
self .float8_data = float8_data
45
52
self .scale = scale
53
+ self .data_to_scale_dim = data_to_scale_dim
46
54
self .activation_scale_ub = activation_scale_ub
47
55
48
56
def __tensor_flatten__ (self ):
@@ -68,12 +76,12 @@ def _apply_fn_to_data(self, fn):
68
76
def __repr__ (self ):
69
77
return (
70
78
f"{ self .__class__ .__name__ } (weight={ self .float8_data } , scale={ self .scale } , "
71
- f"activation_scale_ub={ self .activation_scale_ub } , "
79
+ f"activation_scale_ub={ self .activation_scale_ub } , data_to_scale_dim= { self . data_to_scale_dim } , "
72
80
f"shape={ self .shape } , device={ self .device } , dtype={ self .dtype } , requires_grad={ self .requires_grad } )"
73
81
)
74
82
75
83
def _quantization_type (self ):
76
- return f"shape={ self .shape } , activation_scale_ub={ self .activation_scale_ub } , device={ self .device } "
84
+ return f"shape={ self .shape } , data_to_scale_dim= { self . data_to_scale_dim } , activation_scale_ub={ self .activation_scale_ub } , device={ self .device } "
77
85
78
86
def to (self , * args , ** kwargs ):
79
87
kwargs = self ._get_to_kwargs (* args , ** kwargs )
@@ -82,6 +90,53 @@ def to(self, *args, **kwargs):
82
90
self .float8_data .to (device ),
83
91
self .scale .to (device ),
84
92
self .activation_scale_ub .to (device ),
93
+ self .data_to_scale_dim ,
94
+ self .dtype ,
95
+ )
96
+
97
+ def _transpose_and_reshape (self ):
98
+ """This is added for resharding support, since the resharding logic for the model we are
99
+ working with only support 2D
100
+ """
101
+ assert len (self .shape ) == 3 , (
102
+ f"Only expected to be used when the Tensor is 3D, got { len (self .shape )} "
103
+ )
104
+ dim0 , dim1 , dim2 = self .shape
105
+ # because we first transpose the weight before quantization, we'll recover the original shape
106
+ # by swapping dim1 and dim2
107
+ original_shape = (dim0 , dim2 , dim1 )
108
+ # we must save this as 2D in the state dict, since loading code expects 2D weights
109
+ new_shape = (- 1 , original_shape [- 1 ])
110
+ float8_data = self .float8_data
111
+ float8_data = float8_data .transpose (1 , 2 ).reshape (* new_shape ).contiguous ()
112
+ data_to_scale_dim = {0 : 0 , 1 : 1 }
113
+ return self .__class__ (
114
+ float8_data ,
115
+ self .scale ,
116
+ self .activation_scale_ub ,
117
+ data_to_scale_dim ,
118
+ self .dtype ,
119
+ )
120
+
121
+ def _unflatten (self , num_experts ):
122
+ """This is added for resharding support, since the resharding logic for the model we are
123
+ working with only support 2D
124
+ """
125
+ float8_data = self .float8_data
126
+ dim0 , dim1 = self .shape
127
+ float8_data = float8_data .unflatten (0 , (num_experts , - 1 )).squeeze (dim = 0 )
128
+ data_to_scale_dim = {0 : 0 }
129
+ dim0 , dim1 , dim2 = float8_data .shape
130
+ if dim1 == self .scale .shape [1 ]:
131
+ data_to_scale_dim [1 ] = 1
132
+ else :
133
+ data_to_scale_dim [2 ] = 1
134
+
135
+ return self .__class__ (
136
+ float8_data ,
137
+ self .scale ,
138
+ self .activation_scale_ub ,
139
+ data_to_scale_dim ,
85
140
self .dtype ,
86
141
)
87
142
@@ -106,14 +161,18 @@ def from_float(
106
161
else :
107
162
w = w .t ()
108
163
109
- wq , w_scale = torch .ops .triton .quantize_fp8_row (w )
110
- # wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
164
+ data_to_scale_dim = {0 : 0 }
165
+ if w .ndim == 3 :
166
+ data_to_scale_dim [1 ] = 1
167
+
168
+ wq , w_scale = torch .ops .fbgemm .quantize_fp8_per_row (w )
111
169
dtype = w .dtype
112
170
del w
113
171
return FbgemmFp8Tensor (
114
172
wq ,
115
173
w_scale ,
116
174
activation_scale_ub = activation_scale_ub ,
175
+ data_to_scale_dim = data_to_scale_dim ,
117
176
dtype = dtype ,
118
177
)
119
178
@@ -169,6 +228,8 @@ def _(func, types, args, kwargs):
169
228
170
229
a_data = xq
171
230
b_data = weight_tensor .float8_data
231
+ assert b_data .is_contiguous (), "weight for bmm must be contiguous"
232
+
172
233
orig_out_features = b_data .shape [- 2 ]
173
234
174
235
res = torch .ops .fbgemm .f8f8bf16_rowwise_batched (
@@ -269,6 +330,63 @@ def _(func, types, args, kwargs):
269
330
)
270
331
271
332
333
+ @implements (aten .cat .default )
334
+ def _ (func , types , args , kwargs ):
335
+ tensors , dim = fill_defaults (args , 2 , [[], 0 ])
336
+ tensor_0 = tensors [0 ]
337
+ if dim < 0 :
338
+ dim = tensor_0 .ndim + dim
339
+
340
+ for i in range (1 , len (tensors )):
341
+ assert tensor_0 .float8_data .ndim == tensors [i ].float8_data .ndim
342
+ assert tensor_0 .scale .ndim == tensors [i ].scale .ndim
343
+ assert tensor_0 .activation_scale_ub == tensors [i ].activation_scale_ub
344
+ assert tensor_0 .data_to_scale_dim == tensors [i ].data_to_scale_dim
345
+
346
+ float8_data = [t .float8_data for t in tensors ]
347
+ scale = [t .scale for t in tensors ]
348
+
349
+ # with rowwise quantization, dimension of float8_data and
350
+ # origianl shape will be the same, so original dim argument applies
351
+ # to float8_data
352
+ cat_float8_data = aten .cat .default (float8_data , dim )
353
+
354
+ # if cat dimension has a corresponding scale dimension, then we'll concat the corresponding
355
+ # scale dimension, otherwise, we'll just use the existing scale
356
+ if dim in tensor_0 .data_to_scale_dim :
357
+ cat_scale = aten .cat .default (scale , dim = tensor_0 .data_to_scale_dim [dim ])
358
+ else :
359
+ cat_scale = scale [0 ]
360
+
361
+ new = tensor_0 .__class__ (
362
+ cat_float8_data ,
363
+ cat_scale ,
364
+ tensor_0 .activation_scale_ub ,
365
+ tensor_0 .data_to_scale_dim ,
366
+ tensor_0 .dtype ,
367
+ )
368
+ return return_and_correct_aliasing (func , args , kwargs , new )
369
+
370
+
371
+ @implements (aten .transpose .int )
372
+ def _ (func , types , args , kwargs ):
373
+ self , dim0 , dim1 = args
374
+ float8_data = self .float8_data .transpose (dim0 , dim1 ).contiguous ()
375
+ data_to_scale_dim = self .data_to_scale_dim .copy ()
376
+
377
+ if dim0 in data_to_scale_dim :
378
+ data_to_scale_dim [dim1 ] = data_to_scale_dim [dim0 ]
379
+ del data_to_scale_dim [dim0 ]
380
+ elif dim1 in data_to_scale_dim :
381
+ data_to_scale_dim [dim0 ] = data_to_scale_dim [dim1 ]
382
+ del data_to_scale_dim [dim1 ]
383
+
384
+ new = self .__class__ (
385
+ float8_data , self .scale , self .activation_scale_ub , data_to_scale_dim , self .dtype
386
+ )
387
+ return return_and_correct_aliasing (func , args , kwargs , new )
388
+
389
+
272
390
to_fbgemm_fp8 = FbgemmFp8Tensor .from_float
273
391
274
392
0 commit comments