Skip to content

Zhwang/llama #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 324 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
324 commits
Select commit Hold shift + click to select a range
b5e84a2
commit
sfc-gh-zhwang Aug 22, 2023
c934345
commit
sfc-gh-zhwang Aug 22, 2023
47b1abb
commit
sfc-gh-zhwang Aug 22, 2023
f8c78b3
commit
sfc-gh-zhwang Aug 22, 2023
ad3956e
commit
sfc-gh-zhwang Aug 22, 2023
2b0c601
commit
sfc-gh-zhwang Aug 22, 2023
25aa5a2
commit
sfc-gh-zhwang Aug 22, 2023
c0d4362
commit
sfc-gh-zhwang Aug 22, 2023
5292a4d
commit
sfc-gh-zhwang Aug 22, 2023
73d69da
commit
sfc-gh-zhwang Aug 22, 2023
cd532be
commit
sfc-gh-zhwang Aug 22, 2023
319a34a
commit
sfc-gh-zhwang Aug 22, 2023
2c8727b
commit
sfc-gh-zhwang Aug 22, 2023
b749248
commit
sfc-gh-zhwang Aug 22, 2023
0829bea
commit
sfc-gh-zhwang Aug 22, 2023
7592b58
commit
sfc-gh-zhwang Aug 22, 2023
35d48ab
commit
sfc-gh-zhwang Aug 22, 2023
3acbe5b
commit
sfc-gh-zhwang Aug 22, 2023
333a040
commit
sfc-gh-zhwang Aug 22, 2023
08cb870
commit
sfc-gh-zhwang Aug 22, 2023
f347915
commit
sfc-gh-zhwang Aug 22, 2023
d02ed5e
commit
sfc-gh-zhwang Aug 22, 2023
1d32ea3
commit
sfc-gh-zhwang Aug 22, 2023
041bab8
commit
sfc-gh-zhwang Aug 22, 2023
d617cd0
commit
sfc-gh-zhwang Aug 22, 2023
eb83205
commit
sfc-gh-zhwang Aug 22, 2023
18ec23e
commit
sfc-gh-zhwang Aug 22, 2023
ec7247f
commit
sfc-gh-zhwang Aug 22, 2023
d8db97d
commit
sfc-gh-zhwang Aug 22, 2023
a276e3e
commit
sfc-gh-zhwang Aug 22, 2023
627162a
commit
sfc-gh-zhwang Aug 22, 2023
136f8be
commit
sfc-gh-zhwang Aug 22, 2023
fe7436d
commit
sfc-gh-zhwang Aug 22, 2023
7aacd4e
commit
sfc-gh-zhwang Aug 22, 2023
b23a8ad
commit
sfc-gh-zhwang Aug 22, 2023
6a7926a
commit
sfc-gh-zhwang Aug 22, 2023
35fa364
commit
sfc-gh-zhwang Aug 22, 2023
de4b804
commit
sfc-gh-zhwang Aug 22, 2023
58ee364
commit
sfc-gh-zhwang Aug 22, 2023
053d4dd
commit
sfc-gh-zhwang Aug 22, 2023
15635be
commit
sfc-gh-zhwang Aug 22, 2023
6629d42
commit
sfc-gh-zhwang Aug 22, 2023
26b636e
commit
sfc-gh-zhwang Aug 22, 2023
eabf3b1
commit
sfc-gh-zhwang Aug 22, 2023
82bd6fe
commit
sfc-gh-zhwang Aug 22, 2023
aa70dc7
commit
sfc-gh-zhwang Aug 22, 2023
b496c72
commit
sfc-gh-zhwang Aug 22, 2023
c6c6593
commit
sfc-gh-zhwang Aug 22, 2023
d8006a3
commit
sfc-gh-zhwang Aug 22, 2023
d061a90
commit
sfc-gh-zhwang Aug 22, 2023
a4da43e
commit
sfc-gh-zhwang Aug 22, 2023
3b9b391
commit
sfc-gh-zhwang Aug 22, 2023
b867796
commit
sfc-gh-zhwang Aug 22, 2023
3e0fcc0
commit
sfc-gh-zhwang Aug 22, 2023
36600a0
commit
sfc-gh-zhwang Aug 22, 2023
207ebac
commit
sfc-gh-zhwang Aug 22, 2023
853d33f
commit
sfc-gh-zhwang Aug 22, 2023
fb3f7d7
commit
sfc-gh-zhwang Aug 22, 2023
5509fab
commit
sfc-gh-zhwang Aug 22, 2023
7093bd0
commit
sfc-gh-zhwang Aug 22, 2023
60f7b51
commit
sfc-gh-zhwang Aug 22, 2023
ac9708d
commit
sfc-gh-zhwang Aug 22, 2023
7dce9e9
commit
sfc-gh-zhwang Aug 22, 2023
f45ca53
commit
sfc-gh-zhwang Aug 22, 2023
ddc28aa
commit
sfc-gh-zhwang Aug 22, 2023
73fe4a9
commit
sfc-gh-zhwang Aug 22, 2023
6916ed2
commit
sfc-gh-zhwang Aug 22, 2023
ce104bf
commit
sfc-gh-zhwang Aug 22, 2023
2c6f715
commit
sfc-gh-zhwang Aug 22, 2023
1678909
commit
sfc-gh-zhwang Aug 22, 2023
a1e86fd
commit
sfc-gh-zhwang Aug 22, 2023
b7fed9f
commit
sfc-gh-zhwang Aug 22, 2023
796138c
commit
sfc-gh-zhwang Aug 22, 2023
b5781d0
commit
sfc-gh-zhwang Aug 22, 2023
42722d5
commit
sfc-gh-zhwang Aug 22, 2023
e16d3a1
commit
sfc-gh-zhwang Aug 22, 2023
77eb018
commit
sfc-gh-zhwang Aug 22, 2023
2aa46d0
commit
sfc-gh-zhwang Aug 22, 2023
7262ffc
commit
sfc-gh-zhwang Aug 22, 2023
c1e13f3
commit
sfc-gh-zhwang Aug 22, 2023
622e7ca
commit
sfc-gh-zhwang Aug 22, 2023
82e7cce
commit
sfc-gh-zhwang Aug 22, 2023
3a0fb5c
commit
sfc-gh-zhwang Aug 22, 2023
4ee9f9f
commit
sfc-gh-zhwang Aug 22, 2023
ea911ba
commit
sfc-gh-zhwang Aug 22, 2023
7fa2edf
commit
sfc-gh-zhwang Aug 22, 2023
b7c5868
commit
sfc-gh-zhwang Aug 22, 2023
9e06972
commit
sfc-gh-zhwang Aug 22, 2023
31c95f1
commit
sfc-gh-zhwang Aug 22, 2023
d1b24af
commit
sfc-gh-zhwang Aug 22, 2023
b8638b1
commit
sfc-gh-zhwang Aug 22, 2023
2d8d350
commit
sfc-gh-zhwang Aug 22, 2023
0aa8de5
commit
sfc-gh-zhwang Aug 22, 2023
22fb3d3
commit
sfc-gh-zhwang Aug 22, 2023
da543ff
commit
sfc-gh-zhwang Aug 22, 2023
33e1dab
commit
sfc-gh-zhwang Aug 22, 2023
3d241e1
commit
sfc-gh-zhwang Aug 22, 2023
992f074
commit
sfc-gh-zhwang Aug 22, 2023
bb6c923
commit
sfc-gh-zhwang Aug 22, 2023
51c9578
commit
sfc-gh-zhwang Aug 22, 2023
8655809
commit
sfc-gh-zhwang Aug 22, 2023
6139b28
commit
sfc-gh-zhwang Aug 22, 2023
5abd726
commit
sfc-gh-zhwang Aug 22, 2023
827e5cb
commit
sfc-gh-zhwang Aug 22, 2023
c72d47b
commit
sfc-gh-zhwang Aug 22, 2023
111e405
commit
sfc-gh-zhwang Aug 22, 2023
9075221
commit
sfc-gh-zhwang Aug 23, 2023
e355935
commit
sfc-gh-zhwang Aug 23, 2023
a06be03
commit
sfc-gh-zhwang Aug 23, 2023
faef3b7
commit
sfc-gh-zhwang Aug 23, 2023
9b93cbe
commit
sfc-gh-zhwang Aug 23, 2023
058be11
commit
sfc-gh-zhwang Aug 23, 2023
b7c0ae6
commit
sfc-gh-zhwang Aug 23, 2023
755b6a8
commit
sfc-gh-zhwang Aug 23, 2023
a8c5310
commit
sfc-gh-zhwang Aug 23, 2023
d482d09
commit
sfc-gh-zhwang Aug 23, 2023
92f289c
commit
sfc-gh-zhwang Aug 23, 2023
060bc9b
commit
sfc-gh-zhwang Aug 23, 2023
8cf53aa
commit
sfc-gh-zhwang Aug 23, 2023
446bc39
commit
sfc-gh-zhwang Aug 23, 2023
7ccd7c7
commit
sfc-gh-zhwang Aug 23, 2023
77f1f47
commit
sfc-gh-zhwang Aug 23, 2023
83aea9b
commit
sfc-gh-zhwang Aug 23, 2023
ce22fbb
commit
sfc-gh-zhwang Aug 23, 2023
0b848bd
commit
sfc-gh-zhwang Aug 23, 2023
cdaf2b3
commit
sfc-gh-zhwang Aug 23, 2023
0dc6792
commit
sfc-gh-zhwang Aug 23, 2023
6571282
commit
sfc-gh-zhwang Aug 23, 2023
4253f06
commit
sfc-gh-zhwang Aug 23, 2023
9da2d6e
commit
sfc-gh-zhwang Aug 23, 2023
27d094f
commit
sfc-gh-zhwang Aug 23, 2023
fb44b13
commit
sfc-gh-zhwang Aug 23, 2023
9e629ce
commit
sfc-gh-zhwang Aug 23, 2023
0fcbabb
commit
sfc-gh-zhwang Aug 23, 2023
1de9077
commit
sfc-gh-zhwang Aug 23, 2023
735d2b0
commit
sfc-gh-zhwang Aug 23, 2023
b03e266
commit
sfc-gh-zhwang Aug 23, 2023
bb66d0c
commit
sfc-gh-zhwang Aug 23, 2023
c0d443e
commit
sfc-gh-zhwang Aug 23, 2023
db53c29
commit
sfc-gh-zhwang Aug 23, 2023
4cb84a8
commit
sfc-gh-zhwang Aug 23, 2023
4cd1183
commit
sfc-gh-zhwang Aug 23, 2023
c4f9955
commit
sfc-gh-zhwang Aug 23, 2023
2193d4c
commit
sfc-gh-zhwang Aug 23, 2023
5c6847a
commit
sfc-gh-zhwang Aug 23, 2023
af14c22
commit
sfc-gh-zhwang Aug 23, 2023
0de0b7a
commit
sfc-gh-zhwang Aug 23, 2023
9662064
commit
sfc-gh-zhwang Aug 23, 2023
123d0d9
commit
sfc-gh-zhwang Aug 23, 2023
46a7a2f
commit
sfc-gh-zhwang Aug 23, 2023
cdf9600
commit
sfc-gh-zhwang Aug 23, 2023
8a280d3
commit
sfc-gh-zhwang Aug 23, 2023
d9744d1
commit
sfc-gh-zhwang Aug 23, 2023
5cf599d
commit
sfc-gh-zhwang Aug 23, 2023
0f02dd9
commit
sfc-gh-zhwang Aug 23, 2023
158510f
commit
sfc-gh-zhwang Aug 23, 2023
cb9d056
commit
sfc-gh-zhwang Aug 23, 2023
301df67
commit
sfc-gh-zhwang Aug 23, 2023
045b486
commit
sfc-gh-zhwang Aug 23, 2023
8a51c2c
commit
sfc-gh-zhwang Aug 23, 2023
2e0c8c1
commit
sfc-gh-zhwang Aug 23, 2023
a7d2ba6
commit
sfc-gh-zhwang Aug 23, 2023
d24fef1
commit
sfc-gh-zhwang Aug 23, 2023
dd15172
commit
sfc-gh-zhwang Aug 23, 2023
0888d5d
commit
sfc-gh-zhwang Aug 23, 2023
ac7631a
commit
sfc-gh-zhwang Aug 23, 2023
9032f52
commit
sfc-gh-zhwang Aug 23, 2023
1f848b9
commit
sfc-gh-zhwang Aug 23, 2023
c2b5108
commit
sfc-gh-zhwang Aug 23, 2023
175f312
commit
sfc-gh-zhwang Aug 23, 2023
6409867
commit
sfc-gh-zhwang Aug 23, 2023
32485d2
commit
sfc-gh-zhwang Aug 23, 2023
4aa3bc3
commit
sfc-gh-zhwang Aug 23, 2023
601c721
commit
sfc-gh-zhwang Aug 23, 2023
30d9857
commit
sfc-gh-zhwang Aug 23, 2023
f0faf76
commit
sfc-gh-zhwang Aug 23, 2023
3175af0
commit
sfc-gh-zhwang Aug 23, 2023
b10e2f8
commit
sfc-gh-zhwang Aug 23, 2023
ed4d0e7
commit
sfc-gh-zhwang Aug 23, 2023
69b3c8f
commit
sfc-gh-zhwang Aug 23, 2023
582e440
commit
sfc-gh-zhwang Aug 23, 2023
1df8afa
commit
sfc-gh-zhwang Aug 23, 2023
22cb8ef
commit
sfc-gh-zhwang Aug 23, 2023
732c94a
commit
sfc-gh-zhwang Aug 23, 2023
edba9c1
commit
sfc-gh-zhwang Aug 24, 2023
4a9b871
commit
sfc-gh-zhwang Aug 24, 2023
cd12fed
commit
sfc-gh-zhwang Aug 24, 2023
b977a02
commit
sfc-gh-zhwang Aug 24, 2023
fc87201
commit
sfc-gh-zhwang Aug 24, 2023
3cc8b56
commit
sfc-gh-zhwang Aug 24, 2023
0d7f0de
commit
sfc-gh-zhwang Aug 24, 2023
6c2913e
commit
sfc-gh-zhwang Aug 24, 2023
5a28f0e
commit
sfc-gh-zhwang Aug 24, 2023
adf26d4
commit
sfc-gh-zhwang Aug 24, 2023
f6e3403
commit
sfc-gh-zhwang Aug 24, 2023
789605a
commit
sfc-gh-zhwang Aug 24, 2023
a906bcd
commit
sfc-gh-zhwang Aug 24, 2023
462bd5b
commit
sfc-gh-zhwang Aug 24, 2023
7f616d3
commit
sfc-gh-zhwang Aug 24, 2023
687f24d
commit
sfc-gh-zhwang Aug 24, 2023
9060add
commit
sfc-gh-zhwang Aug 24, 2023
d1c749a
commit
sfc-gh-zhwang Aug 24, 2023
7461e2a
commit
sfc-gh-zhwang Aug 24, 2023
a1d5548
commit
sfc-gh-zhwang Aug 24, 2023
8944364
commit
sfc-gh-zhwang Aug 24, 2023
5c1199d
commit
sfc-gh-zhwang Aug 24, 2023
8bc46a8
commit
sfc-gh-zhwang Aug 24, 2023
64caf0d
commit
sfc-gh-zhwang Aug 24, 2023
8caaf75
commit
sfc-gh-zhwang Aug 24, 2023
bb42351
commit
sfc-gh-zhwang Aug 24, 2023
7851116
commit
sfc-gh-zhwang Aug 24, 2023
dfa7ede
commit
sfc-gh-zhwang Aug 24, 2023
788f6ed
commit
sfc-gh-zhwang Aug 24, 2023
f7c03e7
commit
sfc-gh-zhwang Aug 24, 2023
5409b2f
commit
sfc-gh-zhwang Aug 24, 2023
be463a6
commit
sfc-gh-zhwang Aug 24, 2023
e0035a1
commit
sfc-gh-zhwang Aug 24, 2023
31176a7
commit
sfc-gh-zhwang Aug 24, 2023
0a04b43
commit
sfc-gh-zhwang Aug 24, 2023
1deb0d4
commit
sfc-gh-zhwang Aug 24, 2023
5ea044b
commit
sfc-gh-zhwang Aug 24, 2023
2d332b3
commit
sfc-gh-zhwang Aug 24, 2023
566f242
commit
sfc-gh-zhwang Aug 24, 2023
192fab6
commit
sfc-gh-zhwang Aug 24, 2023
0dd149e
commit
sfc-gh-zhwang Aug 24, 2023
a1dd054
commit
sfc-gh-zhwang Aug 25, 2023
cfce313
commit
sfc-gh-zhwang Aug 30, 2023
bc888ab
commit
sfc-gh-zhwang Aug 30, 2023
13ce7a0
commit
sfc-gh-zhwang Aug 30, 2023
3d24499
commit
sfc-gh-zhwang Aug 30, 2023
375c46a
commit
sfc-gh-zhwang Aug 30, 2023
11ef80a
commit
sfc-gh-zhwang Aug 30, 2023
69dfc1b
commit
sfc-gh-zhwang Aug 30, 2023
b830aec
commit
sfc-gh-zhwang Aug 30, 2023
7eaf572
commit
sfc-gh-zhwang Aug 30, 2023
20449c8
commit
sfc-gh-zhwang Aug 30, 2023
18631a6
commit
sfc-gh-zhwang Aug 30, 2023
289a67c
commit
sfc-gh-zhwang Aug 30, 2023
6327ac5
commit
sfc-gh-zhwang Aug 30, 2023
d7bfc81
commit
sfc-gh-zhwang Aug 30, 2023
228dfa0
commit
sfc-gh-zhwang Aug 30, 2023
6298bba
commit
sfc-gh-zhwang Aug 30, 2023
857bb32
commit
sfc-gh-zhwang Sep 5, 2023
1570400
commit
sfc-gh-zhwang Sep 5, 2023
31a1d05
commit
sfc-gh-zhwang Sep 5, 2023
67e8a47
commit
sfc-gh-zhwang Sep 5, 2023
990e4f9
commit
sfc-gh-zhwang Sep 5, 2023
c7dcc6c
commit
sfc-gh-zhwang Sep 5, 2023
6f369d1
commit
sfc-gh-zhwang Sep 5, 2023
4d72624
commit
sfc-gh-zhwang Sep 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,29 @@
"unordered_set": "cpp",
"future": "cpp",
"cfenv": "cpp",
"typeindex": "cpp"
"typeindex": "cpp",
"locale": "cpp",
"__mutex_base": "cpp",
"__config": "cpp",
"__bit_reference": "cpp",
"__bits": "cpp",
"__debug": "cpp",
"__errc": "cpp",
"__hash_table": "cpp",
"__locale": "cpp",
"__node_handle": "cpp",
"__split_buffer": "cpp",
"__threading_support": "cpp",
"__tree": "cpp",
"__tuple": "cpp",
"__verbose_abort": "cpp",
"bit": "cpp",
"ios": "cpp",
"stack": "cpp",
"variant": "cpp",
"__nullptr": "cpp",
"__string": "cpp",
"compare": "cpp",
"concepts": "cpp"
}
}
}
7 changes: 6 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
cmake_minimum_required(VERSION 3.8 FATAL_ERROR) # for PyTorch extensions, version should be greater than 3.13
project(FasterTransformer LANGUAGES CXX CUDA)

option(BUILD_MULTI_GPU "Enable multi GPU support" ON)
find_package(CUDA 10.2 REQUIRED)

if(${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "11")
Expand Down Expand Up @@ -328,6 +328,8 @@ add_library(transformer-shared SHARED
$<TARGET_OBJECTS:FfnLayer>
$<TARGET_OBJECTS:FusedAttentionLayer>
$<TARGET_OBJECTS:GptContextAttentionLayer>
$<TARGET_OBJECTS:LlamaContextAttentionLayer>
$<TARGET_OBJECTS:LlamaDecoderSelfAttentionLayer>
$<TARGET_OBJECTS:GptJ>
$<TARGET_OBJECTS:GptJContextDecoder>
$<TARGET_OBJECTS:GptJDecoder>
Expand Down Expand Up @@ -362,6 +364,8 @@ add_library(transformer-shared SHARED
$<TARGET_OBJECTS:T5EncoderTritonBackend>
$<TARGET_OBJECTS:TensorParallelDecoderCrossAttentionLayer>
$<TARGET_OBJECTS:TensorParallelDecoderSelfAttentionLayer>
$<TARGET_OBJECTS:TensorParallelLlamaDecoderSelfAttentionLayer>
$<TARGET_OBJECTS:TensorParallelLlamaContextAttentionLayer>
$<TARGET_OBJECTS:TensorParallelGeluFfnLayer>
$<TARGET_OBJECTS:TensorParallelSiluFfnLayer>
$<TARGET_OBJECTS:TensorParallelGptContextAttentionLayer>
Expand Down Expand Up @@ -393,6 +397,7 @@ add_library(transformer-shared SHARED
$<TARGET_OBJECTS:fpA_intB_gemm>
$<TARGET_OBJECTS:gen_relative_pos_bias>
$<TARGET_OBJECTS:gpt_kernels>
$<TARGET_OBJECTS:repeat_kv_kernels>
$<TARGET_OBJECTS:int8_gemm>
$<TARGET_OBJECTS:layernorm_int8_kernels>
$<TARGET_OBJECTS:layernorm_kernels>
Expand Down
75 changes: 57 additions & 18 deletions examples/cpp/llama/huggingface_llama_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,15 @@
import numpy as np
from pathlib import Path

import torch
import os
from transformers import LlamaForCausalLM

# using numpy extension: https://github.com/GreenWaves-Technologies/bfloat16
# install the library with `pip install bfloat16`
from bfloat16 import bfloat16
from transformers import LlamaForCausalLM, AutoConfig

def get_weight_data_type(data_type):
if data_type == "fp32":
return np.float32
elif data_type == "fp16":
return np.float16
elif data_type == "bf16":
return bfloat16
else:
assert False, f"Invalid weight data type {data_type}"

Expand Down Expand Up @@ -69,10 +64,30 @@ def split_and_convert(args):
assert(i_gpu_num % t_gpu_num == 0)

factor = (int)(i_gpu_num / t_gpu_num)

# load position_embedding from rank 0
# model = torch.load(ckpt_name)
model = LlamaForCausalLM.from_pretrained(args.in_file)
print(f'load model from {args.in_file}')
# model = LlamaForCausalLM.from_pretrained(args.in_file, device_map='auto')
config = AutoConfig.from_pretrained(args.in_file)
# num_layers = 3
# config.num_hidden_layers = num_layers
print(config)
state_dict = {}
for f in os.listdir(args.in_file):
if not f.endswith('.bin'):
continue
w = torch.load(os.path.join(args.in_file, f), map_location='cpu')
keys = list(w.keys())
for k in keys:
if 'model.layers.' not in k:
continue
l = int(k.split('.')[2])
if l < config.num_hidden_layers:
continue
del w[k]
state_dict.update(w)

model = LlamaForCausalLM.from_pretrained(None, config=config, state_dict=state_dict)
hf_config = vars(model.config)
print(f"hf_config: {hf_config}")

Expand All @@ -82,8 +97,9 @@ def split_and_convert(args):

hidden_size = hf_config["hidden_size"]
head_num = hf_config["num_attention_heads"]
kv_head_num = hf_config["num_key_value_heads"]
head_size = hidden_size // head_num
num_layers = hf_config["num_hidden_layers"]
# num_layers = hf_config["num_hidden_layers"]


np_weight_data_type = get_weight_data_type(args.weight_data_type)
Expand All @@ -94,6 +110,7 @@ def split_and_convert(args):
config['llama'] = {}
config['llama']['model_name'] = model_name
config['llama']["head_num"] = str(head_num)
config['llama']["kv_head_num"] = str(kv_head_num)
config['llama']["size_per_head"] = str(head_size)
config['llama']["inter_size"] = str(hf_config["intermediate_size"])
config['llama']["num_layer"] = str(num_layers)
Expand Down Expand Up @@ -127,14 +144,36 @@ def split_and_convert(args):
# first merge QKV into a single weight
# concat direct to FT shape: [hidden_size, 3, head_num, head_size]
# copied from huggingface_gptj_ckpt_convert.py
qkv_weights = np.stack([
param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.q_proj.weight']),
param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.k_proj.weight']),
param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.v_proj.weight']),
])
qkv_weights = np.transpose(qkv_weights, (2, 0, 1))
qkv_weights_base_name = f'model.layers.{l}.attention.query_key_value.weight'
split_and_convert_process(saved_dir, factor, qkv_weights_base_name, qkv_weights)
# qkv_weights = np.stack([
# param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.q_proj.weight']),
# param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.k_proj.weight']),
# param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.v_proj.weight']),
# ])
# qkv_weights = np.transpose(qkv_weights, (2, 0, 1))
q_proj = param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.q_proj.weight'])
k_proj = param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.k_proj.weight'])
v_proj = param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.v_proj.weight'])
q_proj = np.split(q_proj, factor, axis=0)
k_proj = np.split(k_proj, factor, axis=0)
v_proj = np.split(v_proj, factor, axis=0)
for j in range(factor):
qkv_weights = np.concatenate((q_proj[j], k_proj[j], v_proj[j]), axis=0)
print(qkv_weights.shape)
# qkv_weights = np.transpose(qkv_weights, (2, 0, 1))
qkv_weights = np.transpose(qkv_weights)
qkv_weights_base_name = f'model.layers.{l}.attention.query_key_value.weight'
saved_path = saved_dir + "/" + qkv_weights_base_name + ".%d.bin" % j
qkv_weights.tofile(saved_path)
# qkv_weights = np.concatenate((
# param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.q_proj.weight']),
# param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.k_proj.weight']),
# param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.v_proj.weight']),
# ), axis=0)
# print(qkv_weights.shape)
# # qkv_weights = np.transpose(qkv_weights, (2, 0, 1))
# qkv_weights = np.transpose(qkv_weights)
# qkv_weights_base_name = f'model.layers.{l}.attention.query_key_value.weight'
# split_and_convert_process(saved_dir, factor, qkv_weights_base_name, qkv_weights)

# attention dense
o_weight = param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.o_proj.weight']).T
Expand Down
Loading