Skip to content

Commit 331d939

Browse files
jerryzh168facebook-github-bot
authored andcommitted
Fixes for running GPTQ in executorch (#58)
Summary: Pull Request resolved: #58 att Reviewed By: cpuhrsch Differential Revision: D54885767 fbshipit-source-id: 331af7c1e6fdb2fc8202f1dc8a34e0a42b1d6314
1 parent c1b564a commit 331d939

File tree

2 files changed

+40
-5
lines changed

2 files changed

+40
-5
lines changed

torchao/quantization/GPTQ.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,8 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(
9393
input_pos = torch.arange(0, T, device=device)
9494

9595
# no caches in executorch llama2 7b model?
96-
print("setting up cache")
97-
with torch.device(device):
98-
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
96+
# with torch.device(device):
97+
# model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
9998

10099
return seq, input_pos, max_seq_length
101100

torchao/quantization/quant_api.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,40 @@ def quantize(
394394
return model
395395

396396

397+
def linear_forward_8da4w(
398+
x, weight_int8, scales, zeros, out_features, group_size, precision
399+
):
400+
x = per_token_dynamic_quant(x)
401+
# TODO: verify and remove following reshape code
402+
# origin_x_size = x.size()
403+
# x = x.reshape(-1, origin_x_size[-1])
404+
405+
# TODO: better API
406+
# weight_int8 = torch.ops.quantized_decomposed.unpack_int4_to_int8(weight_int4packed)
407+
n_bit = 4
408+
quant_min = -(2 ** (n_bit - 1))
409+
quant_max = 2 ** (n_bit - 1) - 1
410+
w_dq = torch.ops.quantized_decomposed.dequantize_per_channel_group(
411+
weight_int8,
412+
scales,
413+
zeros,
414+
quant_min,
415+
quant_max,
416+
torch.int8,
417+
group_size,
418+
precision,
419+
)
420+
421+
# x = x.to(torch.float16)
422+
# w_dq = w_dq.to(torch.float16)
423+
c = torch.nn.functional.linear(x, w_dq)
424+
425+
# new_shape = origin_x_size[:-1] + (out_features,)
426+
# c = c.reshape(new_shape)
427+
428+
return c
429+
430+
397431
class Int8DynActInt4WeightLinear(torch.nn.Module):
398432
__constants__ = ["in_features", "out_features"]
399433

@@ -433,6 +467,7 @@ def __init__(
433467
self.in_features = in_features
434468
self.out_features = out_features
435469
assert not bias, "require bias=False"
470+
# TODO: align groupsize naming
436471
self.group_size = group_size
437472
# Precision of the activation which also indicates
438473
# output precision of the dynamically quantized linear layer
@@ -469,10 +504,11 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
469504
self.scales,
470505
self.zeros,
471506
self.out_features,
472-
self.groupsize,
507+
self.group_size,
473508
self.precision,
474509
)
475510

511+
476512
from math import gcd
477513
from functools import reduce
478514

@@ -630,7 +666,7 @@ def _convert_for_runtime(self, model):
630666
model,
631667
self.groupsize,
632668
self.padding_allowed,
633-
torch.int8,
669+
self.precision,
634670
self.precision,
635671
)
636672
return model

0 commit comments

Comments
 (0)