We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9ea7a99 commit ab20e2eCopy full SHA for ab20e2e
mindnlp/transformers/models/llama/modeling_llama.py
@@ -849,7 +849,7 @@ def forward(
849
850
hidden_states = outputs[0]
851
if self.config.pretraining_tp > 1:
852
- lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
+ lm_head_slices = ops.split(self.lm_head.weight,self.vocab_size // self.config.pretraining_tp, dim=0)
853
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
854
logits = ops.cat(logits, dim=-1)
855
else:
0 commit comments