Skip to content

Commit ab20e2e

Browse files
fix model llama for split function in line 852 (#1941)
1 parent 9ea7a99 commit ab20e2e

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

mindnlp/transformers/models/llama/modeling_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -849,7 +849,7 @@ def forward(
849849

850850
hidden_states = outputs[0]
851851
if self.config.pretraining_tp > 1:
852-
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
852+
lm_head_slices = ops.split(self.lm_head.weight,self.vocab_size // self.config.pretraining_tp, dim=0)
853853
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
854854
logits = ops.cat(logits, dim=-1)
855855
else:

0 commit comments

Comments
 (0)