Skip to content

Commit 2d31516

Browse files
committed
feat: Add support for distinguishing recurrent vs non-recurrent layers in hparams
Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 41f5d54 commit 2d31516

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

src/llama-hparams.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,10 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
6565
return n_embd_head_v * n_head_kv;
6666
}
6767

68-
uint32_t llama_hparams::n_embd_k_s() const {
68+
uint32_t llama_hparams::n_embd_k_s(uint32_t il) const {
69+
if (!recurrent_layer(il)) {
70+
return 0;
71+
}
6972
if (wkv_head_size != 0) {
7073
// for RWKV models
7174
return token_shift_count * n_embd;
@@ -76,7 +79,10 @@ uint32_t llama_hparams::n_embd_k_s() const {
7679
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
7780
}
7881

79-
uint32_t llama_hparams::n_embd_v_s() const {
82+
uint32_t llama_hparams::n_embd_v_s(uint32_t il) const {
83+
if (!recurrent_layer(il)) {
84+
return 0;
85+
}
8086
if (wkv_head_size != 0) {
8187
// corresponds to RWKV's wkv_states size
8288
return n_embd * wkv_head_size;
@@ -86,6 +92,10 @@ uint32_t llama_hparams::n_embd_v_s() const {
8692
return ssm_d_state * ssm_d_inner;
8793
}
8894

95+
bool llama_hparams::recurrent_layer(uint32_t il) const {
96+
return recurrent_layer_arr[il];
97+
}
98+
8999
bool llama_hparams::is_swa(uint32_t il) const {
90100
if (il < n_layer) {
91101
return swa_layers[il];

src/llama-hparams.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,9 @@ struct llama_hparams {
115115
uint32_t ssm_d_state = 0;
116116
uint32_t ssm_dt_rank = 0;
117117

118+
// for hybrid state space models
119+
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
120+
118121
bool ssm_dt_b_c_rms = false;
119122

120123
float f_clamp_kqv = 0.0f;
@@ -181,10 +184,13 @@ struct llama_hparams {
181184

182185
// dimension of the rolling state embeddings
183186
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
184-
uint32_t n_embd_k_s() const;
187+
uint32_t n_embd_k_s(uint32_t il = 0) const;
185188

186189
// dimension of the recurrent state embeddings
187-
uint32_t n_embd_v_s() const;
190+
uint32_t n_embd_v_s(uint32_t il = 0) const;
191+
192+
// whether or not the given layer is recurrent (for hybrid models)
193+
bool recurrent_layer(uint32_t il) const;
188194

189195
bool is_swa(uint32_t il) const;
190196
};

0 commit comments

Comments
 (0)