Skip to content

Commit 841294e

Browse files
authored
[quantization] llm quant (#222)
* [quantization] init commit for llm quant * [tests] delete dequantize * [misc] refactor * [llm_quant] add readme and fix typo * [llm_quant] add readme and fix typo
1 parent edf2852 commit 841294e

File tree

8 files changed

+652
-0
lines changed

8 files changed

+652
-0
lines changed

tinynn/llm_quant/README.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# LLM QUANT
2+
3+
## 安装依赖
4+
5+
- PyTorch: tested on PyTorch 1.13 & CUDA 11.6
6+
- transformers: tested on v4.28.1
7+
- easyquant: 需要到[Releases](https://github.com/alibaba/TinyNeuralNetwork/releases)手动下载安装包进行安装, 提供权重动态解压和动态量化的cuda加速kernel
8+
9+
## 量化模式
10+
11+
- 8bit仅权重量化: 权重压缩为8-bit,显存需求降低,计算时还原为FP16进行计算,相比于FP16的模型推理存在额外开销。模型精度几乎没有影响。
12+
- 4bit仅权重量化: 权重压缩为4-bit,显存需求大幅度降低, 计算时还原为FP16进行计算,相比于FP16的模型推理存在额外开销。模型精度下降较严重。
13+
- token-wise动态量化: 权重压缩为8-bit, 激活值运行时动态量化为8-bit, 结合easyquant库的int8 GEMM可以有效提升推理性能。在Llama-family模型中精度小幅度下降,基本没有影响。
14+
15+
## Llama 量化
16+
我们对llama模型进行了详细的量化分析和测试,推荐使用8-bit的动态量化,其可以有效提升推理速度并降低显存需求,同时精度几乎不受影响。
17+
18+
| 量化模式 | wikitext2(ppl⬇️) | 推理性能(ms/token) <br/>GPU:2080Ti | 推理性能(ms/token)<br/> GPU:T4 | 模型占用显存(GB) |
19+
|-------------------------|------------------|--------------------------------|----------------------------|------------|
20+
| llama-7b fp16 | 5.68 | - | 61.5882 | 12.90 |
21+
| llama-7b weight8 | 5.68 | 68.6845 | 151.1209 | 7.10 |
22+
| llama-7b token-wise动态量化 | 5.82(+0.14) | 43.0228 | 47.1649 | 7.09 |
23+
| llama-7b weight4 | 6.5657(+0.89) | 63.7035 | 141.1330 | 3.99 |
24+
25+
> 除了模型占用显存外,在模型推理过程中还存在激活值和上下文的显存占用,需要预留1~2GB的额外显存。
26+
27+
## 未来工作
28+
29+
- 4-bit量化精度恢复及加速推理
30+
- 8-bit静态量化

tinynn/llm_quant/__init__.py

Whitespace-only changes.

tinynn/llm_quant/examples/chatglm.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# This script is based on https://github.com/THUDM/ChatGLM-6B
2+
import signal
3+
import os
4+
import torch
5+
from transformers import AutoModel, AutoTokenizer
6+
7+
from tinynn.llm_quant.modules import quant_fc
8+
9+
10+
def basic_usage(model_path='THUDM/chatglm-6b', quant_mod='dynamic'):
11+
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half()
12+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
13+
device = torch.device('cuda')
14+
15+
# Do quantization.
16+
if quant_mod != 'fp16':
17+
quant_fc(model, quant_mod=quant_mod)
18+
model.to(device)
19+
20+
clear_command = 'clear'
21+
stop_stream = False
22+
23+
def build_prompt(history):
24+
prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
25+
for query, response in history:
26+
prompt += f"\n\n用户:{query}"
27+
prompt += f"\n\nChatGLM-6B:{response}"
28+
return prompt
29+
30+
def signal_handler(signal, frame):
31+
global stop_stream
32+
stop_stream = True
33+
34+
history = []
35+
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
36+
while True:
37+
query = input("\n用户:")
38+
if query.strip() == "stop":
39+
break
40+
if query.strip() == "clear":
41+
history = []
42+
os.system(clear_command)
43+
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
44+
continue
45+
count = 0
46+
for response, history in model.stream_chat(tokenizer, query, history=history):
47+
if stop_stream:
48+
stop_stream = False
49+
break
50+
else:
51+
count += 1
52+
if count % 8 == 0:
53+
os.system(clear_command)
54+
print(build_prompt(history), flush=True)
55+
signal.signal(signal.SIGINT, signal_handler)
56+
os.system(clear_command)
57+
print(build_prompt(history), flush=True)
58+
59+
60+
if __name__ == '__main__':
61+
basic_usage()

tinynn/llm_quant/examples/llama.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import torch
2+
3+
from transformers import AutoModelForCausalLM, AutoTokenizer
4+
5+
from tinynn.llm_quant.modules import quant_fc
6+
7+
8+
def basic_usage(model_path='huggyllama/llama-7b', quant_mod='dynamic'):
9+
device = torch.device('cuda')
10+
11+
# load LLM model from huggingface or local path
12+
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
13+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
14+
15+
# Do quantization.
16+
if quant_mod != 'fp16':
17+
# If your LLM model is Llama-family, you can set fuse_qkv to fuse qkv linear and scaled-dot-product-attention.
18+
quant_fc(model, quant_mod=quant_mod, fuse_qkv=True)
19+
model.to(device)
20+
21+
prompt = "Building a website can be done in 10 simple steps:\n"
22+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
23+
input_ids = input_ids.to(device)
24+
25+
generated_ids = model.generate(
26+
input_ids,
27+
max_new_tokens=1024,
28+
do_sample=True,
29+
top_k=1,
30+
top_p=0.95,
31+
temperature=0.8,
32+
repetition_penalty=1.2,
33+
use_cache=True,
34+
)
35+
36+
outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
37+
for output in outputs:
38+
print(output)
39+
40+
41+
if __name__ == '__main__':
42+
basic_usage()

tinynn/llm_quant/llama.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import math
2+
from typing import Optional, Tuple
3+
from distutils.version import LooseVersion
4+
5+
import torch
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
9+
from transformers.modeling_utils import set_module_tensor_to_device
10+
11+
12+
class LlamaAttentionFused(nn.Module):
13+
def __init__(self, origin_attention):
14+
super().__init__()
15+
self.config = origin_attention.config
16+
self.hidden_size = origin_attention.hidden_size
17+
self.num_heads = origin_attention.num_heads
18+
self.head_dim = origin_attention.head_dim
19+
self.max_position_embeddings = origin_attention.max_position_embeddings
20+
21+
self.qkv_proj = nn.Linear(
22+
origin_attention.hidden_size, origin_attention.num_heads * origin_attention.head_dim * 3, bias=False
23+
)
24+
fused_weight = torch.cat(
25+
[
26+
fc_node.weight.data
27+
for fc_node in [origin_attention.q_proj, origin_attention.k_proj, origin_attention.v_proj]
28+
],
29+
dim=0,
30+
)
31+
set_module_tensor_to_device(
32+
self.qkv_proj, 'weight', fused_weight.device, value=fused_weight, dtype=fused_weight.dtype
33+
)
34+
self.o_proj = origin_attention.o_proj
35+
self.rotary_emb = origin_attention.rotary_emb
36+
37+
origin_attention.q_proj = None
38+
origin_attention.k_proj = None
39+
origin_attention.v_proj = None
40+
41+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
42+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
43+
44+
def forward(
45+
self,
46+
hidden_states: torch.Tensor,
47+
attention_mask: Optional[torch.Tensor] = None,
48+
position_ids: Optional[torch.LongTensor] = None,
49+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
50+
output_attentions: bool = False,
51+
use_cache: bool = False,
52+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
53+
bsz, q_len, _ = hidden_states.size()
54+
# use fused fc output to get qkv states
55+
qkv_states = self.qkv_proj(hidden_states).view(bsz, q_len, self.num_heads * 3, self.head_dim).transpose(1, 2)
56+
(query_states, key_states, value_states) = torch.chunk(qkv_states, 3, 1)
57+
58+
is_causal = past_key_value is None
59+
60+
kv_seq_len = key_states.shape[-2]
61+
if past_key_value is not None:
62+
kv_seq_len += past_key_value[0].shape[-2]
63+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
64+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
65+
# [bsz, nh, t, hd]
66+
67+
if past_key_value is not None:
68+
# reuse k, v, self_attention
69+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
70+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
71+
72+
past_key_value = (key_states, value_states) if use_cache else None
73+
if LooseVersion(torch.__version__) == LooseVersion('1.13.0'):
74+
with torch.backends.cuda.sdp_kernel(enable_math=False):
75+
attn_output, attn_weights = F._scaled_dot_product_attention(
76+
query_states, key_states, value_states, is_causal=is_causal
77+
)
78+
elif LooseVersion(torch.__version__) >= LooseVersion('2.0.0'):
79+
with torch.backends.cuda.sdp_kernel(enable_math=False):
80+
attn_output, attn_weights = F.scaled_dot_product_attention(
81+
query_states, key_states, value_states, is_causal=is_causal
82+
)
83+
else:
84+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
85+
86+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
87+
raise ValueError(
88+
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
89+
f" {attn_weights.size()}"
90+
)
91+
92+
if attention_mask is not None:
93+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
94+
raise ValueError(
95+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is"
96+
f" {attention_mask.size()}"
97+
)
98+
attn_weights = attn_weights + attention_mask
99+
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
100+
101+
# upcast attention to fp32
102+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
103+
attn_output = torch.matmul(attn_weights, value_states)
104+
105+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
106+
raise ValueError(
107+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
108+
f" {attn_output.size()}"
109+
)
110+
del query_states, key_states, value_states
111+
112+
attn_output = attn_output.transpose(1, 2)
113+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
114+
115+
attn_output = self.o_proj(attn_output)
116+
117+
if not output_attentions:
118+
attn_weights = None
119+
120+
return attn_output, attn_weights, past_key_value

0 commit comments

Comments
 (0)