Skip to content

Commit 2ac2d08

Browse files
authored
Fix(ckpt): fix llama2 loading function (#276)
1 parent db97782 commit 2ac2d08

File tree

1 file changed

+0
-6
lines changed

1 file changed

+0
-6
lines changed

internlm/checkpoint/load_funcs.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,6 @@ def load_hf_llama_pretrained_weights(folder, model):
148148
if f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq" in states:
149149
states.pop(f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq")
150150

151-
if gpc.config.model_type in ("LLAMA2",):
152-
w2 = states.pop(f"layers.{i}.feed_forward.w2.weight")
153-
w3 = states.pop(f"layers.{i}.feed_forward.w3.weight")
154-
states[f"layers.{i}.feed_forward.w2.weight"] = w3
155-
states[f"layers.{i}.feed_forward.w3.weight"] = w2
156-
157151
for name in list(states.keys()):
158152
if name.startswith(f"layers.{i}"):
159153
current_states[name.replace(f".{i}.", f".{idx}.")] = states.pop(name)

0 commit comments

Comments
 (0)