@@ -130,7 +130,7 @@ def __new__(
130
130
cls ,
131
131
int_data : torch .Tensor ,
132
132
scale : torch .Tensor ,
133
- zero_point : torch .Tensor ,
133
+ zero : torch .Tensor ,
134
134
meta : torch .Tensor ,
135
135
_layout : Layout ,
136
136
original_shape : torch .Size ,
@@ -151,16 +151,17 @@ def __init__(
151
151
self ,
152
152
int_data : torch .Tensor ,
153
153
scale : torch .Tensor ,
154
- zero_point : torch .Tensor ,
154
+ zero : torch .Tensor ,
155
155
meta : torch .Tensor ,
156
156
_layout : Layout ,
157
157
original_shape : torch .Size ,
158
158
group_size : int ,
159
159
num_bits : int ,
160
160
):
161
161
self .int_data = int_data
162
+ self .scale_and_zero = None
162
163
self .scale = scale
163
- self .zero_point = zero_point
164
+ self .zero = zero
164
165
self .meta = meta
165
166
self ._layout = _layout
166
167
self .original_shape = original_shape
@@ -181,7 +182,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
181
182
)
182
183
183
184
def __tensor_flatten__ (self ):
184
- return ["int_data" , "scale" , "zero_point " , "meta" ], [
185
+ return ["int_data" , "scale" , "zero " , "meta" ], [
185
186
self ._layout ,
186
187
self .original_shape ,
187
188
self .group_size ,
@@ -194,13 +195,13 @@ def __tensor_unflatten__(
194
195
):
195
196
int_data = tensor_data_dict ["int_data" ]
196
197
scale = tensor_data_dict ["scale" ]
197
- zero_point = tensor_data_dict ["zero_point " ]
198
+ zero = tensor_data_dict ["zero " ]
198
199
meta = tensor_data_dict ["meta" ]
199
200
_layout , original_shape , group_size , num_bits = tensor_attributes
200
201
return cls (
201
202
int_data ,
202
203
scale ,
203
- zero_point ,
204
+ zero ,
204
205
meta ,
205
206
_layout ,
206
207
original_shape ,
@@ -223,14 +224,14 @@ def get_plain(self):
223
224
)
224
225
int_data_expanded_t = int_data_expanded .t ()
225
226
scales_expanded_t = scales_expanded .t ()
226
- return int_data_expanded_t , scales_expanded_t , self .zero_point
227
+ return int_data_expanded_t , scales_expanded_t , self .zero
227
228
228
229
@classmethod
229
230
def from_plain (
230
231
cls ,
231
232
int_data : torch .Tensor ,
232
233
scale : torch .Tensor ,
233
- zero_point : torch .Tensor ,
234
+ zero : torch .Tensor ,
234
235
_layout : Layout ,
235
236
):
236
237
from torchao .sparsity .marlin import (
@@ -291,7 +292,7 @@ def from_plain(
291
292
return cls (
292
293
marlin_24_q_w_comp ,
293
294
marlin_24_s ,
294
- zero_point ,
295
+ zero ,
295
296
meta ,
296
297
_layout ,
297
298
q_w_24 .shape ,
@@ -305,6 +306,6 @@ def get_layout(self) -> Layout:
305
306
def _apply_fn_to_data (self , fn ):
306
307
self .int_data = fn (self .int_data )
307
308
self .scale = fn (self .scale )
308
- self .zero_point = fn (self .zero_point )
309
+ self .zero = fn (self .zero )
309
310
self .meta = fn (self .meta )
310
311
return self
0 commit comments