Skip to content

Commit 801af03

Browse files
authored
[sparse] marlin fixes (#2305)
* [sparse] marlin fixes Summary: This PR updates sparse-marlin to not use CPU tensors and updates it to be compatible with Int4WeightOnl. Test Plan: ``` pytest test/sparsity/test_marlin.py ``` Reviewers: Subscribers: Tasks: Tags: * ruff check
1 parent 152a8e3 commit 801af03

File tree

2 files changed

+15
-16
lines changed

2 files changed

+15
-16
lines changed

torchao/dtypes/uintx/marlin_sparse_layout.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def __new__(
130130
cls,
131131
int_data: torch.Tensor,
132132
scale: torch.Tensor,
133-
zero_point: torch.Tensor,
133+
zero: torch.Tensor,
134134
meta: torch.Tensor,
135135
_layout: Layout,
136136
original_shape: torch.Size,
@@ -151,16 +151,17 @@ def __init__(
151151
self,
152152
int_data: torch.Tensor,
153153
scale: torch.Tensor,
154-
zero_point: torch.Tensor,
154+
zero: torch.Tensor,
155155
meta: torch.Tensor,
156156
_layout: Layout,
157157
original_shape: torch.Size,
158158
group_size: int,
159159
num_bits: int,
160160
):
161161
self.int_data = int_data
162+
self.scale_and_zero = None
162163
self.scale = scale
163-
self.zero_point = zero_point
164+
self.zero = zero
164165
self.meta = meta
165166
self._layout = _layout
166167
self.original_shape = original_shape
@@ -181,7 +182,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
181182
)
182183

183184
def __tensor_flatten__(self):
184-
return ["int_data", "scale", "zero_point", "meta"], [
185+
return ["int_data", "scale", "zero", "meta"], [
185186
self._layout,
186187
self.original_shape,
187188
self.group_size,
@@ -194,13 +195,13 @@ def __tensor_unflatten__(
194195
):
195196
int_data = tensor_data_dict["int_data"]
196197
scale = tensor_data_dict["scale"]
197-
zero_point = tensor_data_dict["zero_point"]
198+
zero = tensor_data_dict["zero"]
198199
meta = tensor_data_dict["meta"]
199200
_layout, original_shape, group_size, num_bits = tensor_attributes
200201
return cls(
201202
int_data,
202203
scale,
203-
zero_point,
204+
zero,
204205
meta,
205206
_layout,
206207
original_shape,
@@ -223,14 +224,14 @@ def get_plain(self):
223224
)
224225
int_data_expanded_t = int_data_expanded.t()
225226
scales_expanded_t = scales_expanded.t()
226-
return int_data_expanded_t, scales_expanded_t, self.zero_point
227+
return int_data_expanded_t, scales_expanded_t, self.zero
227228

228229
@classmethod
229230
def from_plain(
230231
cls,
231232
int_data: torch.Tensor,
232233
scale: torch.Tensor,
233-
zero_point: torch.Tensor,
234+
zero: torch.Tensor,
234235
_layout: Layout,
235236
):
236237
from torchao.sparsity.marlin import (
@@ -291,7 +292,7 @@ def from_plain(
291292
return cls(
292293
marlin_24_q_w_comp,
293294
marlin_24_s,
294-
zero_point,
295+
zero,
295296
meta,
296297
_layout,
297298
q_w_24.shape,
@@ -305,6 +306,6 @@ def get_layout(self) -> Layout:
305306
def _apply_fn_to_data(self, fn):
306307
self.int_data = fn(self.int_data)
307308
self.scale = fn(self.scale)
308-
self.zero_point = fn(self.zero_point)
309+
self.zero = fn(self.zero)
309310
self.meta = fn(self.meta)
310311
return self

torchao/sparsity/marlin/__init__.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,10 @@ def _to_marlin_weights(
226226

227227
# Pack
228228
pack_factor = utils.get_pack_factor(num_bits)
229-
orig_device = q_w.device
230229

231230
# Original implementation uses numpy + uint32 but we need to use int64 because torch.uint32
232231
# does not support rshift_cpu.
233-
q_w = q_w.cpu().to(torch.int64)
232+
q_w = q_w.to(torch.int64)
234233
q_packed = torch.zeros(
235234
(q_w.shape[0], q_w.shape[1] // pack_factor),
236235
dtype=torch.int64,
@@ -239,7 +238,7 @@ def _to_marlin_weights(
239238
for i in range(pack_factor):
240239
q_packed |= q_w[:, i::pack_factor] << (num_bits * i)
241240

242-
q_packed = q_packed.to(orig_device, dtype=torch.int32)
241+
q_packed = q_packed.to(dtype=torch.int32)
243242
return q_packed
244243

245244

@@ -259,12 +258,11 @@ def _from_marlin_weights(
259258
perm_24, _, _ = utils.get_reverse_perms_24(num_bits)
260259

261260
pack_factor = utils.get_pack_factor(num_bits)
262-
orig_device = q_packed.device
263261

264262
# Unpack from marlin format.
265263
# Original implementation uses numpy + uint32 but we need to use int64 because torch.uint32
266264
# does not support rshift_cpu.
267-
q_packed = q_packed.cpu().to(torch.int64)
265+
q_packed = q_packed.to(torch.int64)
268266
q_w_unpacked = torch.zeros(
269267
(q_packed.shape[0], q_packed.shape[1] * pack_factor),
270268
dtype=torch.int64,
@@ -275,7 +273,7 @@ def _from_marlin_weights(
275273
(1 << num_bits) - 1
276274
)
277275

278-
q_w_unpacked = q_w_unpacked.to(orig_device, dtype=torch.int32)
276+
q_w_unpacked = q_w_unpacked.to(dtype=torch.int32)
279277

280278
q_w_comp = utils.reverse_marlin_permute_weights(
281279
q_w_unpacked, size_k, size_n, perm_24

0 commit comments

Comments
 (0)