Skip to content

Commit 65ba0dc

Browse files
committed
Update on "Add GPTQQuantizer"
Summary: Implement GPTQQuantizer with the unified quantizer API Test Plan: python test/quantization/test_quant_api.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent 7bc824a commit 65ba0dc

File tree

2 files changed

+254
-2
lines changed

2 files changed

+254
-2
lines changed

test/quantization/model.py

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from dataclasses import dataclass
7+
from typing import Optional
8+
9+
import torch
10+
import torch.nn as nn
11+
from torch import Tensor
12+
from torch.nn import functional as F
13+
14+
15+
def find_multiple(n: int, k: int) -> int:
16+
if n % k == 0:
17+
return n
18+
return n + k - (n % k)
19+
20+
@dataclass
21+
class ModelArgs:
22+
block_size: int = 2048
23+
vocab_size: int = 32000
24+
n_layer: int = 32
25+
n_head: int = 32
26+
dim: int = 4096
27+
intermediate_size: int = None
28+
n_local_heads: int = -1
29+
head_dim: int = 64
30+
rope_base: float = 10000
31+
norm_eps: float = 1e-5
32+
33+
def __post_init__(self):
34+
if self.n_local_heads == -1:
35+
self.n_local_heads = self.n_head
36+
if self.intermediate_size is None:
37+
hidden_dim = 4 * self.dim
38+
n_hidden = int(2 * hidden_dim / 3)
39+
self.intermediate_size = find_multiple(n_hidden, 256)
40+
self.head_dim = self.dim // self.n_head
41+
42+
@classmethod
43+
def from_name(cls, name: str):
44+
if name in transformer_configs:
45+
return cls(**transformer_configs[name])
46+
# fuzzy search
47+
config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)]
48+
49+
# We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match,
50+
# take longer name (as it have more symbols matched)
51+
if len(config) > 1:
52+
config.sort(key=len, reverse=True)
53+
assert len(config[0]) != len(config[1]), name # make sure only one 'best' match
54+
55+
return cls(**transformer_configs[config[0]])
56+
57+
58+
transformer_configs = {
59+
"CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim = 4096, rope_base=1000000),
60+
"7B": dict(n_layer=32, n_head=32, dim=4096),
61+
"13B": dict(n_layer=40, n_head=40, dim=5120),
62+
"30B": dict(n_layer=60, n_head=52, dim=6656),
63+
"34B": dict(n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016, rope_base=1000000), # CodeLlama-34B-Python-hf
64+
"70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672),
65+
"Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000),
66+
}
67+
68+
class KVCache(nn.Module):
69+
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
70+
super().__init__()
71+
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
72+
self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
73+
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
74+
75+
def update(self, input_pos, k_val, v_val):
76+
# input_pos: [S], k_val: [B, H, S, D]
77+
assert input_pos.shape[0] == k_val.shape[2]
78+
79+
k_out = self.k_cache
80+
v_out = self.v_cache
81+
k_out[:, :, input_pos] = k_val
82+
v_out[:, :, input_pos] = v_val
83+
84+
return k_out, v_out
85+
86+
class Transformer(nn.Module):
87+
def __init__(self, config: ModelArgs) -> None:
88+
super().__init__()
89+
self.config = config
90+
self.vocab_size = self.config.vocab_size
91+
92+
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
93+
self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
94+
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
95+
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
96+
97+
self.freqs_cis: Optional[Tensor] = None
98+
self.mask_cache: Optional[Tensor] = None
99+
self.max_batch_size = -1
100+
self.max_seq_length = -1
101+
102+
def setup_caches(self, max_batch_size, max_seq_length):
103+
if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
104+
return
105+
head_dim = self.config.dim // self.config.n_head
106+
max_seq_length = find_multiple(max_seq_length, 8)
107+
self.max_seq_length = max_seq_length
108+
self.max_batch_size = max_batch_size
109+
for b in self.layers:
110+
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim)
111+
112+
self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base)
113+
self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
114+
115+
def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
116+
assert self.freqs_cis is not None, "Caches must be initialized first"
117+
mask = self.causal_mask[None, None, input_pos]
118+
freqs_cis = self.freqs_cis[input_pos]
119+
x = self.tok_embeddings(idx)
120+
121+
for i, layer in enumerate(self.layers):
122+
x = layer(x, input_pos, freqs_cis, mask)
123+
x = self.norm(x)
124+
logits = self.output(x)
125+
return logits
126+
127+
@classmethod
128+
def from_name(cls, name: str):
129+
return cls(ModelArgs.from_name(name))
130+
131+
132+
class TransformerBlock(nn.Module):
133+
def __init__(self, config: ModelArgs) -> None:
134+
super().__init__()
135+
self.attention = Attention(config)
136+
self.feed_forward = FeedForward(config)
137+
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
138+
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
139+
140+
def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
141+
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
142+
out = h + self.feed_forward(self.ffn_norm(h))
143+
return out
144+
145+
146+
class Attention(nn.Module):
147+
def __init__(self, config: ModelArgs):
148+
super().__init__()
149+
assert config.dim % config.n_head == 0
150+
151+
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
152+
# key, query, value projections for all heads, but in a batch
153+
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
154+
self.wo = nn.Linear(config.dim, config.dim, bias=False)
155+
self.kv_cache = None
156+
157+
self.n_head = config.n_head
158+
self.head_dim = config.head_dim
159+
self.n_local_heads = config.n_local_heads
160+
self.dim = config.dim
161+
self._register_load_state_dict_pre_hook(self.load_hook)
162+
163+
def load_hook(self, state_dict, prefix, *args):
164+
if prefix + "wq.weight" in state_dict:
165+
wq = state_dict.pop(prefix + "wq.weight")
166+
wk = state_dict.pop(prefix + "wk.weight")
167+
wv = state_dict.pop(prefix + "wv.weight")
168+
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
169+
170+
def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
171+
bsz, seqlen, _ = x.shape
172+
173+
kv_size = self.n_local_heads * self.head_dim
174+
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
175+
176+
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
177+
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
178+
v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
179+
180+
q = apply_rotary_emb(q, freqs_cis)
181+
k = apply_rotary_emb(k, freqs_cis)
182+
183+
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
184+
185+
if self.kv_cache is not None:
186+
k, v = self.kv_cache.update(input_pos, k, v)
187+
188+
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
189+
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
190+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
191+
192+
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
193+
194+
y = self.wo(y)
195+
return y
196+
197+
198+
class FeedForward(nn.Module):
199+
def __init__(self, config: ModelArgs) -> None:
200+
super().__init__()
201+
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
202+
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
203+
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
204+
205+
def forward(self, x: Tensor) -> Tensor:
206+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
207+
208+
209+
class RMSNorm(nn.Module):
210+
def __init__(self, dim: int, eps: float = 1e-5):
211+
super().__init__()
212+
self.eps = eps
213+
self.weight = nn.Parameter(torch.ones(dim))
214+
215+
def _norm(self, x):
216+
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
217+
218+
def forward(self, x: Tensor) -> Tensor:
219+
output = self._norm(x.float()).type_as(x)
220+
return output * self.weight
221+
222+
223+
def precompute_freqs_cis(
224+
seq_len: int, n_elem: int, base: int = 10000
225+
) -> Tensor:
226+
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
227+
t = torch.arange(seq_len, device=freqs.device)
228+
freqs = torch.outer(t, freqs)
229+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
230+
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
231+
return cache.to(dtype=torch.bfloat16)
232+
233+
234+
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
235+
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
236+
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
237+
x_out2 = torch.stack(
238+
[
239+
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
240+
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
241+
],
242+
-1,
243+
)
244+
245+
x_out2 = x_out2.flatten(3)
246+
return x_out2.type_as(x)

test/quantization/test_quant_api.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727
from pathlib import Path
2828
from sentencepiece import SentencePieceProcessor
29+
from model import Transformer
2930

3031

3132
def dynamic_quant(model, example_inputs):
@@ -131,8 +132,13 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
131132

132133
def test_gptq(self):
133134
# should be similar to TorchCompileDynamicQuantizer
134-
m = M().eval()
135+
precision = torch.bfloat16
136+
device = "cuda"
135137
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
138+
model = Transformer.from_name(checkpoint_path.parent.name)
139+
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
140+
model.load_state_dict(checkpoint, assign=True)
141+
model = model.to(dtype=precision, device=device)
136142
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
137143
assert tokenizer_path.is_file(), tokenizer_path
138144
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
@@ -155,7 +161,7 @@ def test_gptq(self):
155161
calibration_seq_length,
156162
pad_calibration_inputs,
157163
)
158-
m = quantizer.quantize(m)
164+
model = quantizer.quantize(model)
159165

160166
if __name__ == "__main__":
161167
unittest.main()

0 commit comments

Comments
 (0)