Skip to content

Commit dc565e8

Browse files
authored
fix sbert precision problem on mindnlp.sentence (#1873)
1 parent 9ec0bb5 commit dc565e8

File tree

14 files changed

+1697
-143
lines changed

14 files changed

+1697
-143
lines changed

mindnlp/peft/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from .peft_model import (
2525
PeftModel,
2626
PeftModelForCausalLM,
27-
# PeftModelForFeatureExtraction,
27+
PeftModelForFeatureExtraction,
2828
# PeftModelForQuestionAnswering,
2929
PeftModelForSeq2SeqLM,
3030
PeftModelForSequenceClassification,

mindnlp/peft/peft_model.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,3 +1095,108 @@ def _prefix_tuning_forward(
10951095

10961096
output = (logits,) + outputs[2:]
10971097
return ((loss,) + output) if loss is not None else output
1098+
1099+
1100+
class PeftModelForFeatureExtraction(PeftModel):
1101+
"""
1102+
Peft model for extracting features/embeddings from transformer models
1103+
1104+
Args:
1105+
model ([`~transformers.PreTrainedModel`]): Base transformer model.
1106+
peft_config ([`PeftConfig`]): Peft config.
1107+
adapter_name (`str`, *optional*): The name of the adapter, defaults to `"default"`.
1108+
autocast_adapter_dtype (`bool`, *optional*):
1109+
Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights
1110+
using float16 and bfloat16 to float32, as this is typically required for stable training, and only affect
1111+
select PEFT tuners.
1112+
1113+
**Attributes**:
1114+
- **config** ([`~transformers.PretrainedConfig`]) -- The configuration object of the base model.
1115+
1116+
Example:
1117+
1118+
```py
1119+
>>> from transformers import AutoModel
1120+
>>> from peft import PeftModelForFeatureExtraction, get_peft_config
1121+
1122+
>>> config = {
1123+
... "peft_type": "LORA",
1124+
... "task_type": "FEATURE_EXTRACTION",
1125+
... "inference_mode": False,
1126+
... "r": 16,
1127+
... "target_modules": ["query", "value"],
1128+
... "lora_alpha": 32,
1129+
... "lora_dropout": 0.05,
1130+
... "fan_in_fan_out": False,
1131+
... "bias": "none",
1132+
... }
1133+
>>> peft_config = get_peft_config(config)
1134+
>>> model = AutoModel.from_pretrained("bert-base-cased")
1135+
>>> peft_model = PeftModelForFeatureExtraction(model, peft_config)
1136+
>>> peft_model.print_trainable_parameters()
1137+
```
1138+
"""
1139+
1140+
def __init__(self, model: nn.Module, peft_config: PeftConfig, adapter_name: str = "default", **kwargs):
1141+
super().__init__(model, peft_config, adapter_name, **kwargs)
1142+
1143+
def forward(
1144+
self,
1145+
input_ids=None,
1146+
attention_mask=None,
1147+
inputs_embeds=None,
1148+
output_attentions=None,
1149+
output_hidden_states=None,
1150+
return_dict=None,
1151+
task_ids=None,
1152+
**kwargs,
1153+
):
1154+
peft_config = self.active_peft_config
1155+
if not peft_config.is_prompt_learning:
1156+
if peft_config.peft_type == PeftType.POLY:
1157+
kwargs["task_ids"] = task_ids
1158+
1159+
with self._enable_peft_forward_hooks(**kwargs):
1160+
kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
1161+
return self.base_model(
1162+
input_ids=input_ids,
1163+
attention_mask=attention_mask,
1164+
inputs_embeds=inputs_embeds,
1165+
output_attentions=output_attentions,
1166+
output_hidden_states=output_hidden_states,
1167+
return_dict=return_dict,
1168+
**kwargs,
1169+
)
1170+
1171+
batch_size = _get_batch_size(input_ids, inputs_embeds)
1172+
if attention_mask is not None:
1173+
# concat prompt attention mask
1174+
prefix_attention_mask = ops.ones(batch_size, peft_config.num_virtual_tokens)
1175+
attention_mask = ops.cat((prefix_attention_mask, attention_mask), dim=1)
1176+
1177+
if kwargs.get("position_ids", None) is not None:
1178+
warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
1179+
kwargs["position_ids"] = None
1180+
if kwargs.get("token_type_ids", None) is not None:
1181+
warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids")
1182+
kwargs["token_type_ids"] = None
1183+
kwargs.update(
1184+
{
1185+
"attention_mask": attention_mask,
1186+
"output_attentions": output_attentions,
1187+
"output_hidden_states": output_hidden_states,
1188+
"return_dict": return_dict,
1189+
}
1190+
)
1191+
1192+
if peft_config.peft_type == PeftType.PREFIX_TUNING:
1193+
# overwrite past_kv in kwargs
1194+
kwargs["past_key_values"] = self.get_prompt(batch_size)
1195+
return self.base_model(input_ids=input_ids, **kwargs)
1196+
else:
1197+
if inputs_embeds is None:
1198+
inputs_embeds = self.word_embeddings(input_ids)
1199+
prompts = self.get_prompt(batch_size=batch_size)
1200+
prompts = prompts.to(inputs_embeds.dtype)
1201+
inputs_embeds = ops.cat((prompts, inputs_embeds), dim=1)
1202+
return self.base_model(inputs_embeds=inputs_embeds, **kwargs)

mindnlp/sentence/models/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717
from .transformer import Transformer
1818
from .pooling import Pooling
19+
from .normalize import Normalize
1920

2021
__all__ = [
21-
"transformer",
22-
"pooling",
22+
"Transformer",
23+
"Pooling",
24+
"Normalize",
2325
]

mindnlp/sentence/models/normalize.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""normalize model"""
2+
from __future__ import annotations
3+
4+
from mindspore import Tensor
5+
from mindnlp.core.nn import functional as F
6+
from mindnlp.core import nn
7+
8+
9+
class Normalize(nn.Module):
10+
"""This layer normalizes embeddings to unit length"""
11+
12+
def forward(self, features: dict[str, Tensor]) -> dict[str, Tensor]:
13+
features.update({"sentence_embedding": F.normalize(features["sentence_embedding"], p=2, dim=1)})
14+
return features
15+
16+
def save(self, output_path) -> None:
17+
pass
18+
19+
@staticmethod
20+
def load(input_path) -> Normalize:
21+
return Normalize()

0 commit comments

Comments
 (0)