Skip to content

Commit c0b6ed8

Browse files
committed
Added LLM models.
1 parent 6eead96 commit c0b6ed8

File tree

6 files changed

+179
-1
lines changed

6 files changed

+179
-1
lines changed

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ The "Scheduling Space Size Calculation" and "Optimal HW-tile Allocation Algorith
4949

5050
12: pnasnet
5151

52+
13: BERT-Large (one cell)
53+
54+
14: GPT2-XL prefill stage (one cell)
55+
56+
15: GPT2-XL decode stage (one cell)
57+
58+
- Note: For the LLM models, we provide their one-cell version due to the excessive length and identical cell structure of the network. To run the full network, see the comments in "nns/llm.cpp".
59+
5260
- batch: Workload batch size.
5361

5462
- x, y: Length of the x/y axis in the mesh.
@@ -72,6 +80,12 @@ The "Scheduling Space Size Calculation" and "Optimal HW-tile Allocation Algorith
7280

7381
- The current running method is not elegant, and we will improve it soon.
7482

83+
## Update History
84+
85+
2025/01/08 Added LLM models (BERT-Large, GPT2-XL prefill/decode).
86+
87+
2024/05/11 Initial version.
88+
7589
## Current Plans
7690

7791
Since current APIs of the main program mostly relies on cin, it is mostly inconvenient and difficult to start with.

STSchedule.pro

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ SOURCES += \
6161
src/nns/gnmt.cpp \
6262
src/nns/googlenet.cpp \
6363
src/nns/incep_resnet.cpp \
64+
src/nns/llm.cpp \
6465
src/nns/pnasnet.cpp \
6566
src/nns/resnet.cpp \
6667
src/nns/transformer.cpp \

include/nns/nns.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,16 @@ extern const Network lstm;
2727
extern const Network transformer;
2828
extern const Network transformer_cell;
2929

30+
// For LLM a single block is provided for each network.
31+
// See comments in "nns/llm.cpp" for more detail.
32+
33+
// extern const Network BERT;
34+
extern const Network BERT_block;
35+
// extern const Network GPT2_prefill;
36+
extern const Network GPT2_prefill_block;
37+
// extern const Network GPT2_decode;
38+
extern const Network GPT2_decode_block;
39+
3040
extern const Network PNASNet;
3141

3242
#endif // NNS_H

src/layer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ vol_t ConvLayer::Workload::ofm_size(len_t batch_size) const{
6666
}
6767

6868
void ConvLayer::Workload::update_op(){
69-
tot_op = C*K*R*S*H*W;
69+
tot_op = static_cast<access_t>(C)*K*R*S*H*W;
7070
}
7171

7272
access_t ConvLayer::Workload::calc_op(len_t batch_size) const{

src/main.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,18 @@ int main(int argc, char** argv){
171171
network = &PNASNet;
172172
net_name="pnas";
173173
break;
174+
case 13:
175+
network = &BERT_block;
176+
net_name="bert";
177+
break;
178+
case 14:
179+
network = &GPT2_prefill_block;
180+
net_name="gpt_prefill";
181+
break;
182+
case 15:
183+
network = &GPT2_decode_block;
184+
net_name="gpt_decode";
185+
break;
174186
default:
175187
assert(false);
176188
break;

src/nns/llm.cpp

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
#include "nns/nns.h"
2+
3+
#include <cassert>
4+
5+
6+
typedef TransposeLayer::dim Ldims;
7+
8+
static lid_t add_attention(
9+
Network& n, const std::string& name,
10+
len_t len, len_t numG, len_t gSize, len_t decode_len,
11+
lid_t prev
12+
){
13+
14+
lid_t Q, K, V, QK, QK_elt, QKV;
15+
Network::layer_set Ks;
16+
Q = n.add(NLAYER(name + "_Q", Conv, H=len, W=1, C=numG*gSize), {prev});
17+
18+
len_t kv_len;
19+
if(decode_len == 0){
20+
kv_len = len;
21+
for(len_t i=0; i<numG; ++i){
22+
n.add(NLAYER(name + "_K" + std::to_string(i), Conv, C=numG*gSize, K=gSize, H=len, W=1), {prev});
23+
K = n.add(NLAYER(name + "_Kt" + std::to_string(i), Transpose, K=len, H=gSize, W=1, order[Ldims::C]=Ldims::H, order[Ldims::H]=Ldims::C));
24+
Ks.push_back(K);
25+
}
26+
// K = n.add(NLAYER(name + "_K", PTP, K=numG*len, H=gSize, W=1), Ks);
27+
V = n.add(NLAYER(name + "_V", Conv, C=numG*gSize, H=len, W=1), {prev});
28+
}else{
29+
assert(len == 1);
30+
kv_len = len + decode_len;
31+
32+
InputData extK(name + "_Kext", fmap_shape(numG*decode_len, gSize, 1));
33+
InputData extV(name + "_Vext", fmap_shape(decode_len, numG*gSize, 1));
34+
35+
for(len_t i=0; i<numG; ++i){
36+
n.add(NLAYER(name + "_K" + std::to_string(i), Conv, C=numG*gSize, K=gSize, H=len, W=1), {prev});
37+
K = n.add(NLAYER(name + "_Kt" + std::to_string(i), Transpose, K=len, H=gSize, W=1, order[Ldims::C]=Ldims::H, order[Ldims::H]=Ldims::C));
38+
Ks.push_back(K);
39+
}
40+
K = n.add(NLAYER(name + "_K", PTP, K=numG*kv_len, H=gSize, W=1), Ks, 0, {extK});
41+
Ks = {K};
42+
V = n.add(NLAYER(name + "_V", Conv, C=numG*gSize, H=len, W=1), {prev});
43+
V = n.add(NLAYER(name + "_Vt1", Transpose, K=len, H=numG*gSize, W=1, order[Ldims::C]=Ldims::H, order[Ldims::H]=Ldims::C), {V});
44+
V = n.add(NLAYER(name + "_Vt2", Transpose, K=numG*gSize, H=kv_len, W=1, order[Ldims::C]=Ldims::H, order[Ldims::H]=Ldims::C), {V}, 0, {extV});
45+
}
46+
QK = n.add(NLAYER(name + "_QK", GroupConv, H=len, W=1, C=numG*gSize, K = numG*kv_len, G=numG), {Q}, 0, {}, Ks);
47+
QK_elt = n.add(NLAYER(name + "_QK_elt", PTP, K=numG*kv_len, H=len, W=1), {QK});
48+
QKV = n.add(NLAYER(name + "_QKV", GroupConv, H=len, W=1, C=numG*kv_len, K=numG*gSize, G=numG), {QK_elt}, 0, {}, {V});
49+
return n.add(NLAYER(name + "_FC", Conv, H=len, W=1, C=numG*gSize), {QKV});
50+
}
51+
52+
static lid_t add_trans_block(
53+
Network& n, const std::string& name,
54+
len_t len, len_t numG, len_t gSize, len_t ff_len, len_t decode_len,
55+
lid_t prev
56+
){
57+
58+
lid_t next_prev;
59+
next_prev = add_attention(n, name, len, numG, gSize, decode_len, prev);
60+
prev = n.add(NLAYER(name + "_elt1", Eltwise, K=numG*gSize, H=len, W=1, N=2), {prev, next_prev});
61+
n.add(NLAYER(name + "_feedfwd1", Conv, C=numG*gSize, K=ff_len, H=len, W=1));
62+
next_prev = n.add(NLAYER(name + "_feedfwd2", Conv, C=ff_len, K=numG*gSize, H=len, W=1));
63+
return n.add(NLAYER(name + "_elt2", Eltwise, K=numG*gSize, H=len, W=1, N=2), {prev, next_prev});
64+
}
65+
66+
static Network create_transformer(
67+
len_t numG, len_t gSize, len_t nBlock, bool is_prefill,
68+
len_t vocab_len = 1000, len_t len = 512, len_t ff_len = 0
69+
){
70+
// Default settings.
71+
if(ff_len == 0){
72+
ff_len = 4 * len;
73+
}
74+
75+
len_t decode_len = 0;
76+
if(!is_prefill){
77+
decode_len = len;
78+
len = 1;
79+
}
80+
81+
// Length of embedding
82+
len_t totG = numG * gSize;
83+
// Number of embedding
84+
len_t curH = len;
85+
86+
lid_t block_prev;
87+
Network::layer_set prevs;
88+
Network n;
89+
90+
InputData input_layer("input_layer", fmap_shape(totG, curH, 1));
91+
block_prev = n.add(NLAYER("word_embed", PTP, K=totG, H=curH, W=1), {}, 0, {input_layer});
92+
for(len_t i=1; i<=nBlock; ++i){
93+
block_prev = add_trans_block(n, "block"+std::to_string(i), len, numG, gSize, ff_len, decode_len, block_prev);
94+
}
95+
n.add(NLAYER("proj", Conv, C=totG, K=vocab_len, H=curH, W=1), {block_prev});
96+
return n;
97+
};
98+
99+
/*
100+
* Since the complete networks have too many layers,
101+
* and all blocks in the network are identical,
102+
* a single block is provided for each network.
103+
*
104+
* To run the full network, one can uncomment the networks below.
105+
* Also remember to uncomment in "nns/nns.h", add the network in "main.cpp",
106+
* and increase MAX_BITS_IN_BS in "bitset.h".
107+
*/
108+
109+
/*
110+
* BERT-Large
111+
*
112+
* numG = 16
113+
* gSize = 64
114+
* nBlock = 24
115+
* is_prefill = true
116+
*/
117+
// const Network BERT = create_transformer(16, 64, 24, true);
118+
const Network BERT_block = create_transformer(16, 64, 1, true);
119+
120+
/*
121+
* GPT2-XL at Prefill stage
122+
*
123+
* numG = 25
124+
* gSize = 64
125+
* nBlock = 48
126+
* is_prefill = true
127+
*/
128+
// const Network GPT2_prefill = create_transformer(25, 64, 48, true);
129+
const Network GPT2_prefill_block = create_transformer(25, 64, 1, true);
130+
131+
/*
132+
* GPT2-XL at Decode stage
133+
*
134+
* numG = 25
135+
* gSize = 64
136+
* nBlock = 48
137+
* is_prefill = false
138+
*/
139+
// const Network GPT2_decode = create_transformer(25, 64, 48, false);
140+
const Network GPT2_decode_block = create_transformer(25, 64, 1, false);
141+

0 commit comments

Comments
 (0)