Skip to content

Commit 6285685

Browse files
authored
fix:解决minicpm未注册问题 (#1821)
1 parent 0e60754 commit 6285685

File tree

2 files changed

+80
-68
lines changed

2 files changed

+80
-68
lines changed

mindnlp/transformers/models/auto/modeling_auto.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@
161161
("mega", "MegaModel"),
162162
("megatron-bert", "MegatronBertModel"),
163163
("mgp-str", "MgpstrForSceneTextRecognition"),
164+
('minicpm', 'MiniCPMModel'),
164165
("mistral", "MistralModel"),
165166
("mixtral", "MixtralModel"),
166167
("mobilebert", "MobileBertModel"),
@@ -318,6 +319,7 @@
318319
("mamba", "MambaForCausalLM"),
319320
("mega", "MegaForMaskedLM"),
320321
("megatron-bert", "MegatronBertForPreTraining"),
322+
('minicpm', 'MiniCPMForCausalLM'),
321323
("mllama", "MllamaForConditionalGeneration"),
322324
("mobilebert", "MobileBertForPreTraining"),
323325
("mpnet", "MPNetForMaskedLM"),
@@ -404,6 +406,7 @@
404406
("marian", "MarianMTModel"),
405407
("mega", "MegaForMaskedLM"),
406408
("megatron-bert", "MegatronBertForCausalLM"),
409+
('minicpm', 'MiniCPMForCausalLM'),
407410
("mobilebert", "MobileBertForMaskedLM"),
408411
("mpnet", "MPNetForMaskedLM"),
409412
("mpt", "MptForCausalLM"),
@@ -488,6 +491,7 @@
488491
("mbart", "MBartForCausalLM"),
489492
("mega", "MegaForCausalLM"),
490493
("megatron-bert", "MegatronBertForCausalLM"),
494+
('minicpm', 'MiniCPMForCausalLM'),
491495
("mistral", "MistralForCausalLM"),
492496
("mixtral", "MixtralForCausalLM"),
493497
("mllama", "MllamaForCausalLM"),
@@ -937,6 +941,7 @@
937941
("mbart", "MBartForSequenceClassification"),
938942
("mega", "MegaForSequenceClassification"),
939943
("megatron-bert", "MegatronBertForSequenceClassification"),
944+
('minicpm', 'MiniCPMForSequenceClassification'),
940945
("mistral", "MistralForSequenceClassification"),
941946
("mixtral", "MixtralForSequenceClassification"),
942947
("mobilebert", "MobileBertForSequenceClassification"),

mindnlp/transformers/models/minicpm/modeling_minicpm.py

Lines changed: 75 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,17 @@
4646

4747
_CONFIG_FOR_DOC = "MiniCPMConfig"
4848

49+
4950
def rms_layernorm(hidden: mindspore.Tensor, weight: mindspore.Tensor, eps: float):
5051
"""
5152
Args:
5253
hidden (mindspore.Tensor): The input tensor to be normalized.
5354
weight (mindspore.Tensor): The weight tensor applied to the normalized input.
5455
eps (float): A small value added to the variance to avoid division by zero.
55-
56+
5657
Returns:
5758
None: This function does not return a value. It operates in place on the 'hidden' tensor.
58-
59+
5960
Raises:
6061
ValueError: If the 'hidden' tensor or 'weight' tensor is not of type mindspore.Tensor.
6162
TypeError: If the 'eps' parameter is not of type float.
@@ -67,11 +68,10 @@ def rms_layernorm(hidden: mindspore.Tensor, weight: mindspore.Tensor, eps: float
6768

6869

6970
class MiniCPMRMSNorm(nn.Module):
70-
7171
"""
72-
MiniCPMRMSNorm is a custom layer normalization module designed to mimic the functionality of T5LayerNorm.
72+
MiniCPMRMSNorm is a custom layer normalization module designed to mimic the functionality of T5LayerNorm.
7373
It performs RMS-based layer normalization on the input hidden states using the provided weight and epsilon value.
74-
74+
7575
Parameters:
7676
hidden_size (int): The size of the hidden states being normalized.
7777
eps (float, optional): A small value added to the variance to prevent division by zero. Default is 1e-06.
@@ -87,6 +87,7 @@ class MiniCPMRMSNorm(nn.Module):
8787
__init__: Initializes the MiniCPMRMSNorm instance with the given hidden size and epsilon.
8888
forward: Applies RMS-based layer normalization on the input hidden states using the weight and epsilon.
8989
"""
90+
9091
def __init__(self, hidden_size, eps=1e-6):
9192
"""
9293
MiniCPMRMSNorm is equivalent to T5LayerNorm
@@ -117,7 +118,6 @@ def forward(self, hidden_states):
117118

118119

119120
class MiniCPMRotaryEmbedding(nn.Module):
120-
121121
"""
122122
MiniCPMRotaryEmbedding is a class that represents a rotary positional embedding layer for neural networks.
123123
It inherits from nn.Module and provides methods for initializing the embedding layer, setting cosine and sine cache,
@@ -128,6 +128,7 @@ class MiniCPMRotaryEmbedding(nn.Module):
128128
cosine and sine values for positional embeddings.
129129
The forward method generates the positional embeddings based on the input data and the specified sequence length.
130130
"""
131+
131132
def __init__(self, dim, max_position_embeddings=2048, base=10000):
132133
"""
133134
Initializes a new instance of the MiniCPMRotaryEmbedding class.
@@ -212,6 +213,7 @@ def forward(self, x, seq_len=None):
212213

213214
class MiniCPMLinearScalingRotaryEmbedding(MiniCPMRotaryEmbedding):
214215
"""MiniCPMRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
216+
215217
def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0):
216218
"""
217219
Initializes an instance of MiniCPMLinearScalingRotaryEmbedding.
@@ -260,6 +262,7 @@ def _set_cos_sin_cache(self, seq_len, dtype):
260262

261263
class MiniCPMDynamicNTKScalingRotaryEmbedding(MiniCPMRotaryEmbedding):
262264
"""MiniCPMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
265+
263266
def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0):
264267
"""
265268
Initializes a new instance of the MiniCPMDynamicNTKScalingRotaryEmbedding class.
@@ -302,7 +305,7 @@ def _set_cos_sin_cache(self, seq_len, dtype):
302305

303306
if seq_len > self.max_position_embeddings:
304307
base = self.base * (
305-
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
308+
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
306309
) ** (self.dim / (self.dim - 2))
307310
inv_freq = 1.0 / (base ** (ops.arange(0, self.dim, 2).float() / self.dim))
308311
self.inv_freq = inv_freq
@@ -316,6 +319,7 @@ def _set_cos_sin_cache(self, seq_len, dtype):
316319
self.cos_cached = emb.cos().to(dtype)
317320
self.sin_cached = emb.sin().to(dtype)
318321

322+
319323
def rotate_half(x):
320324
"""Rotates half the hidden dims of the input."""
321325
# x1 = x[..., : x.shape[-1] // 2]
@@ -358,8 +362,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
358362
k_embed = (k_fp32 * cos) + (rotate_half(k_fp32) * sin)
359363
return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype)
360364

361-
class MiniCPMMLP(nn.Module):
362365

366+
class MiniCPMMLP(nn.Module):
363367
"""
364368
MiniCPMMLP is a neural network model that implements a specific variant of a Multi-Layer Perceptron (MLP)
365369
architecture for deep learning tasks.
@@ -385,6 +389,7 @@ class MiniCPMMLP(nn.Module):
385389
Returns:
386390
down_proj: The output tensor resulting from the forward pass computation of the MiniCPMMLP model.
387391
"""
392+
388393
def __init__(self, config):
389394
"""
390395
Initializes a MiniCPMMLP object with the provided configuration.
@@ -458,6 +463,7 @@ def repeat_kv(hidden_states: mindspore.Tensor, n_rep: int) -> mindspore.Tensor:
458463

459464
class MiniCPMAttention(nn.Module):
460465
"""Multi-headed attention from 'Attention Is All You Need' paper"""
466+
461467
def __init__(self, config: MiniCPMConfig, layer_idx: Optional[int] = None):
462468
"""
463469
Initializes an instance of the MiniCPMAttention class.
@@ -594,14 +600,14 @@ def _shape(self, tensor: mindspore.Tensor, seq_len: int, bsz: int):
594600
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).swapaxes(1, 2)
595601

596602
def forward(
597-
self,
598-
hidden_states: mindspore.Tensor,
599-
attention_mask: Optional[mindspore.Tensor] = None,
600-
position_ids: Optional[mindspore.Tensor] = None,
601-
past_key_value: Optional[Cache] = None,
602-
output_attentions: bool = False,
603-
use_cache: bool = False,
604-
**kwargs,
603+
self,
604+
hidden_states: mindspore.Tensor,
605+
attention_mask: Optional[mindspore.Tensor] = None,
606+
position_ids: Optional[mindspore.Tensor] = None,
607+
past_key_value: Optional[Cache] = None,
608+
output_attentions: bool = False,
609+
use_cache: bool = False,
610+
**kwargs,
605611
) -> Tuple[mindspore.Tensor, Optional[mindspore.Tensor], Optional[Tuple[mindspore.Tensor]]]:
606612
'''
607613
This method forwards the MiniCPMAttention layer.
@@ -730,7 +736,6 @@ def forward(
730736

731737

732738
class MiniCPMDecoderLayer(nn.Module):
733-
734739
"""
735740
MiniCPMDecoderLayer represents a single layer of the MiniCPM (Minimalist Conditional Pretrained Model) decoder.
736741
This class is responsible for processing input hidden states through self-attention mechanism and MLP
@@ -767,6 +772,7 @@ class MiniCPMDecoderLayer(nn.Module):
767772
If 'padding_mask' is passed as a keyword argument in kwargs, a deprecation warning will be issued.
768773
It is recommended to use 'attention_mask' instead.
769774
"""
775+
770776
def __init__(self, config: MiniCPMConfig, layer_idx: int):
771777
"""
772778
Initializes a new instance of MiniCPMDecoderLayer.
@@ -796,14 +802,14 @@ def __init__(self, config: MiniCPMConfig, layer_idx: int):
796802
self.num_hidden_layers = config.num_hidden_layers
797803

798804
def forward(
799-
self,
800-
hidden_states: mindspore.Tensor,
801-
attention_mask: Optional[mindspore.Tensor] = None,
802-
position_ids: Optional[mindspore.Tensor] = None,
803-
past_key_value: Optional[Tuple[mindspore.Tensor]] = None,
804-
output_attentions: Optional[bool] = False,
805-
use_cache: Optional[bool] = False,
806-
**kwargs,
805+
self,
806+
hidden_states: mindspore.Tensor,
807+
attention_mask: Optional[mindspore.Tensor] = None,
808+
position_ids: Optional[mindspore.Tensor] = None,
809+
past_key_value: Optional[Tuple[mindspore.Tensor]] = None,
810+
output_attentions: Optional[bool] = False,
811+
use_cache: Optional[bool] = False,
812+
**kwargs,
807813
) -> Tuple[mindspore.Tensor, Optional[Tuple[mindspore.Tensor, mindspore.Tensor]]]:
808814
"""
809815
Args:
@@ -858,7 +864,6 @@ def forward(
858864

859865

860866
class MiniCPMPreTrainedModel(PreTrainedModel):
861-
862867
"""
863868
Represents a pre-trained mini version of CPM (Code-PM) model for various NLP tasks.
864869
This class inherits from PreTrainedModel and provides functionality to initialize weights for different types
@@ -916,6 +921,7 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
916921
Args:
917922
config: MiniCPMConfig
918923
"""
924+
919925
def __init__(self, config: MiniCPMConfig):
920926
"""
921927
Initializes a MiniCPMModel instance with the provided configuration.
@@ -995,16 +1001,16 @@ def set_input_embeddings(self, new_embeddings):
9951001
self.embed_tokens = new_embeddings
9961002

9971003
def forward(
998-
self,
999-
input_ids: mindspore.Tensor = None,
1000-
attention_mask: Optional[mindspore.Tensor] = None,
1001-
position_ids: Optional[mindspore.Tensor] = None,
1002-
past_key_values: Optional[List[mindspore.Tensor]] = None,
1003-
inputs_embeds: Optional[mindspore.Tensor] = None,
1004-
use_cache: Optional[bool] = None,
1005-
output_attentions: Optional[bool] = None,
1006-
output_hidden_states: Optional[bool] = None,
1007-
return_dict: Optional[bool] = None,
1004+
self,
1005+
input_ids: mindspore.Tensor = None,
1006+
attention_mask: Optional[mindspore.Tensor] = None,
1007+
position_ids: Optional[mindspore.Tensor] = None,
1008+
past_key_values: Optional[List[mindspore.Tensor]] = None,
1009+
inputs_embeds: Optional[mindspore.Tensor] = None,
1010+
use_cache: Optional[bool] = None,
1011+
output_attentions: Optional[bool] = None,
1012+
output_hidden_states: Optional[bool] = None,
1013+
return_dict: Optional[bool] = None,
10081014
) -> Union[Tuple, BaseModelOutputWithPast]:
10091015
"""
10101016
Constructs the MiniCPMModel.
@@ -1299,17 +1305,17 @@ def get_decoder(self):
12991305
return self.model
13001306

13011307
def forward(
1302-
self,
1303-
input_ids: mindspore.Tensor = None,
1304-
attention_mask: Optional[mindspore.Tensor] = None,
1305-
position_ids: Optional[mindspore.Tensor] = None,
1306-
past_key_values: Optional[List[mindspore.Tensor]] = None,
1307-
inputs_embeds: Optional[mindspore.Tensor] = None,
1308-
labels: Optional[mindspore.Tensor] = None,
1309-
use_cache: Optional[bool] = None,
1310-
output_attentions: Optional[bool] = None,
1311-
output_hidden_states: Optional[bool] = None,
1312-
return_dict: Optional[bool] = None,
1308+
self,
1309+
input_ids: mindspore.Tensor = None,
1310+
attention_mask: Optional[mindspore.Tensor] = None,
1311+
position_ids: Optional[mindspore.Tensor] = None,
1312+
past_key_values: Optional[List[mindspore.Tensor]] = None,
1313+
inputs_embeds: Optional[mindspore.Tensor] = None,
1314+
labels: Optional[mindspore.Tensor] = None,
1315+
use_cache: Optional[bool] = None,
1316+
output_attentions: Optional[bool] = None,
1317+
output_hidden_states: Optional[bool] = None,
1318+
return_dict: Optional[bool] = None,
13131319
) -> Union[Tuple, CausalLMOutputWithPast]:
13141320
r"""
13151321
Args:
@@ -1389,7 +1395,7 @@ def forward(
13891395
)
13901396

13911397
def prepare_inputs_for_generation(
1392-
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1398+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
13931399
):
13941400
"""
13951401
Prepare inputs for generation.
@@ -1428,7 +1434,7 @@ def prepare_inputs_for_generation(
14281434
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
14291435
# input)
14301436
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1431-
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1437+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
14321438
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
14331439
# input_ids based on the past_length.
14341440
elif past_length < input_ids.shape[1]:
@@ -1437,19 +1443,19 @@ def prepare_inputs_for_generation(
14371443

14381444
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
14391445
if (
1440-
max_cache_length is not None
1441-
and attention_mask is not None
1442-
and cache_length + input_ids.shape[1] > max_cache_length
1446+
max_cache_length is not None
1447+
and attention_mask is not None
1448+
and cache_length + input_ids.shape[1] > max_cache_length
14431449
):
14441450
attention_mask = attention_mask[:, -max_cache_length:]
14451451

14461452
position_ids = kwargs.get("position_ids", None)
14471453
if attention_mask is not None and position_ids is None:
14481454
# create position_ids on the fly for batch generation
1449-
position_ids = attention_mask.long().cumsum(-1) - 1
1455+
position_ids = attention_mask.long().int().cumsum(-1) - 1
14501456
position_ids = position_ids.masked_fill(attention_mask == 0, 1)
14511457
if past_key_values:
1452-
position_ids = position_ids[:, -input_ids.shape[1] :]
1458+
position_ids = position_ids[:, -input_ids.shape[1]:]
14531459

14541460
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
14551461
if inputs_embeds is not None and past_key_values is None:
@@ -1524,10 +1530,10 @@ def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "u
15241530
history = []
15251531
if logits_processor:
15261532
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1527-
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
1533+
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
15281534
else:
15291535
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1530-
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
1536+
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
15311537

15321538
history.append({"role": role, "content": query})
15331539
history_str = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=False)
@@ -1544,7 +1550,6 @@ def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "u
15441550

15451551

15461552
class MiniCPMForSequenceClassification(MiniCPMPreTrainedModel):
1547-
15481553
"""
15491554
MiniCPMForSequenceClassification is a Python class that represents a fine-tuning model for sequence classification
15501555
tasks based on the MiniCPM architecture. It inherits from the MiniCPMPreTrainedModel class and provides methods for
@@ -1584,6 +1589,7 @@ class MiniCPMForSequenceClassification(MiniCPMPreTrainedModel):
15841589
This class inherits from MiniCPMPreTrainedModel and extends its functionality to support sequence
15851590
classification tasks.
15861591
"""
1592+
15871593
def __init__(self, config):
15881594
"""
15891595
Initializes a new instance of the MiniCPMForSequenceClassification class.
@@ -1640,17 +1646,17 @@ def set_input_embeddings(self, new_embeddings):
16401646
self.model.embed_tokens = new_embeddings
16411647

16421648
def forward(
1643-
self,
1644-
input_ids: mindspore.Tensor = None,
1645-
attention_mask: Optional[mindspore.Tensor] = None,
1646-
position_ids: Optional[mindspore.Tensor] = None,
1647-
past_key_values: Optional[List[mindspore.Tensor]] = None,
1648-
inputs_embeds: Optional[mindspore.Tensor] = None,
1649-
labels: Optional[mindspore.Tensor] = None,
1650-
use_cache: Optional[bool] = None,
1651-
output_attentions: Optional[bool] = None,
1652-
output_hidden_states: Optional[bool] = None,
1653-
return_dict: Optional[bool] = None,
1649+
self,
1650+
input_ids: mindspore.Tensor = None,
1651+
attention_mask: Optional[mindspore.Tensor] = None,
1652+
position_ids: Optional[mindspore.Tensor] = None,
1653+
past_key_values: Optional[List[mindspore.Tensor]] = None,
1654+
inputs_embeds: Optional[mindspore.Tensor] = None,
1655+
labels: Optional[mindspore.Tensor] = None,
1656+
use_cache: Optional[bool] = None,
1657+
output_attentions: Optional[bool] = None,
1658+
output_hidden_states: Optional[bool] = None,
1659+
return_dict: Optional[bool] = None,
16541660
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
16551661
r"""
16561662
Args:
@@ -1723,6 +1729,7 @@ def forward(
17231729
attentions=transformer_outputs.attentions,
17241730
)
17251731

1732+
17261733
__all__ = [
17271734
'MiniCPMModel',
17281735
'MiniCPMPreTrainedModel',

0 commit comments

Comments
 (0)