@@ -51,14 +51,50 @@ def noop_detach(func, *args, **kwargs):
51
51
# pyre-fixme[3]: Return type must be annotated.
52
52
# pyre-fixme[2]: Parameter must be annotated.
53
53
def _to_copy (func , * args , ** kwargs ):
54
+ if not args [0 ][0 ].is_contiguous ():
55
+ assert args [0 ][0 ].t ().is_contiguous ()
56
+ return func (args [0 ][0 ].t ()).t ()
54
57
return args [0 ][0 ].get_original_weight ().to (args [1 ]['dtype' ])
55
58
56
59
@implements ([torch .ops .aten .to .dtype ])
57
60
# pyre-fixme[3]: Return type must be annotated.
58
61
# pyre-fixme[2]: Parameter must be annotated.
59
62
def to_dtype (func , * args , ** kwargs ):
63
+ if not args [0 ][0 ].is_contiguous ():
64
+ assert args [0 ][0 ].t ().is_contiguous ()
65
+ return torch .ops .aten .to .dtype (args [0 ][0 ].t (), args [0 ][1 ]).t ()
60
66
return args [0 ][0 ].get_original_weight ().to (args [0 ][1 ])
61
67
68
+ @implements ([torch .ops .aten .t .default ])
69
+ # pyre-fixme[3]: Return type must be annotated.
70
+ # pyre-fixme[2]: Parameter must be annotated.
71
+ def t_default (func , * args , ** kwargs ):
72
+ a = args [0 ][0 ]
73
+ tensor_meta = SubclassTensorArgs (
74
+ a .size (),
75
+ (a .stride (1 ), a .stride (0 )),
76
+ a .storage_offset (),
77
+ torch .bits2x4 ,
78
+ a .device ,
79
+ a .requires_grad )
80
+ b = NF4Tensor (
81
+ tensor_meta ,
82
+ a .block_size ,
83
+ a .n_blocks ,
84
+ a .scaler_block_size ,
85
+ a .quantized_scalers ,
86
+ a .quantization_factor ,
87
+ a .scaler_mean ,
88
+ a .quantized_data ,
89
+ a .nf4 )
90
+ return b
91
+
92
+ @implements ([torch .ops .aten .mm .default ])
93
+ # pyre-fixme[3]: Return type must be annotated.
94
+ # pyre-fixme[2]: Parameter must be annotated.
95
+ def mm_default (func , * args , ** kwargs ):
96
+ return linear_nf4 (args [0 ][0 ], args [0 ][1 ])
97
+
62
98
63
99
@implements (
64
100
[
@@ -160,7 +196,8 @@ def __new__(
160
196
tensor_meta .original_shape ,
161
197
tensor_meta .original_strides ,
162
198
tensor_meta .storage_offset ,
163
- dtype = tensor_meta .dtype ,
199
+ # Picked some floating dtype, but we need dtype extensibility
200
+ dtype = torch .float8_e5m2fnuz ,
164
201
device = tensor_meta .device ,
165
202
requires_grad = tensor_meta .requires_grad ,
166
203
)
@@ -198,6 +235,7 @@ def from_tensor(
198
235
block_size : int ,
199
236
scaler_block_size : int ,
200
237
):
238
+ assert inpt_tensor .dim () <= 2
201
239
assert inpt_tensor .dtype == torch .bfloat16
202
240
assert (
203
241
inpt_tensor .numel () % block_size == 0
@@ -428,7 +466,7 @@ def quantize_tensor_nearest(
428
466
# pyre-fixme[40]: Static method `dequantize` cannot override a non-static method
429
467
# defined in `torch._C.TensorBase`.
430
468
def dequantize (value : torch .Tensor , nf4 : torch .Tensor ) -> torch .Tensor :
431
- """Dequantize a nf4 value to float16 format"""
469
+ """Dequantize a nf4 value to bfloat16 format"""
432
470
# return nf4.index_select(0, value)
433
471
return nf4 [value ]
434
472
@@ -546,7 +584,7 @@ class LinearNF4(torch.autograd.Function):
546
584
def forward (ctx , input : torch .Tensor , weight : NF4Tensor ):
547
585
"""Save the quantized nf4 weight for backward pass"""
548
586
ctx .nf4_weight = weight
549
- return F .linear (input , weight .get_original_weight ( ))
587
+ return F .linear (input , weight .to ( input . dtype ))
550
588
551
589
@staticmethod
552
590
# pyre-fixme[14]: `backward` overrides method defined in `_SingleLevelFunction`
0 commit comments