Skip to content

Commit bcce96b

Browse files
convert.py : fix baichuan7B support (#2870)
* [Fix]: convert.py support baichuan7B * convert.py : fix trailing whitespaces --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 74e0cae commit bcce96b

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

convert.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def to_ggml(self) -> 'UnquantizedTensor':
469469

470470
def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor':
471471
r = self.ndarray.shape[0] // 3
472-
return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head))
472+
return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head, n_head))
473473

474474
def part(self, n_part: int) -> 'UnquantizedTensor':
475475
r = self.ndarray.shape[0] // 3
@@ -952,9 +952,10 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
952952
#tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
953953
elif f"model.layers.{i}.self_attn.W_pack.weight" in model:
954954
print(f"Unpacking and permuting layer {i}")
955-
tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head, params.n_head)
956-
tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head, params.n_head_kv)
955+
tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head)
956+
tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head)
957957
tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = part_lazy (model[f"model.layers.{i}.self_attn.W_pack.weight"], 2)
958+
del tmp[f"model.layers.{i}.self_attn.W_pack.weight"]
958959
else:
959960
break
960961

0 commit comments

Comments
 (0)