Skip to content

Commit 54161e2

Browse files
authored
fix sbert normalize_embeddings (#1879)
1 parent 9d7f17a commit 54161e2

File tree

5 files changed

+186
-2
lines changed

5 files changed

+186
-2
lines changed

mindnlp/core/serialization.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
from contextlib import closing, contextmanager
3535
from enum import Enum
36-
from typing import Dict, Union, Optional, Any, OrderedDict
36+
from typing import Dict, Union, Optional, Any, OrderedDict, Tuple, List
3737
from functools import reduce
3838
from dataclasses import dataclass
3939

@@ -46,6 +46,7 @@
4646
import safetensors.numpy
4747
from safetensors import deserialize
4848

49+
from mindnlp.core import nn
4950
from mindnlp.core.nn import Parameter
5051
from mindnlp.configs import SUPPORT_BF16
5152
from .nn import Module
@@ -1548,3 +1549,83 @@ def load_checkpoint(ckpt_file_name):
15481549
"'filter_prefix' or 'specify_prefix' are set correctly.")
15491550

15501551
return parameter_dict
1552+
1553+
1554+
def save_model(
1555+
model: nn.Module, filename: str, metadata: Optional[Dict[str, str]] = None, force_contiguous: bool = True
1556+
):
1557+
"""
1558+
Saves a given torch model to specified filename.
1559+
This method exists specifically to avoid tensor sharing issues which are
1560+
not allowed in `safetensors`. [More information on tensor sharing](../torch_shared_tensors)
1561+
1562+
Args:
1563+
model (`nn.Module`):
1564+
The model to save on disk.
1565+
filename (`str`):
1566+
The filename location to save the file
1567+
metadata (`Dict[str, str]`, *optional*):
1568+
Extra information to save along with the file.
1569+
Some metadata will be added for each dropped tensors.
1570+
This information will not be enough to recover the entire
1571+
shared structure but might help understanding things
1572+
force_contiguous (`boolean`, *optional*, defaults to True):
1573+
Forcing the state_dict to be saved as contiguous tensors.
1574+
This has no effect on the correctness of the model, but it
1575+
could potentially change performance if the layout of the tensor
1576+
was chosen specifically for that reason.
1577+
"""
1578+
state_dict = model.state_dict()
1579+
1580+
if force_contiguous:
1581+
state_dict = {k: v.contiguous() for k, v in state_dict.items()}
1582+
try:
1583+
safe_save_file(state_dict, filename, metadata=metadata)
1584+
except ValueError as e:
1585+
msg = str(e)
1586+
msg += " Or use save_model(..., force_contiguous=True), read the docs for potential caveats."
1587+
raise ValueError(msg)
1588+
1589+
1590+
def load_model(
1591+
model: nn.Module, filename: Union[str, os.PathLike], strict: bool = True
1592+
) -> Tuple[List[str], List[str]]:
1593+
"""
1594+
Loads a given filename onto a torch model.
1595+
This method exists specifically to avoid tensor sharing issues which are
1596+
not allowed in `safetensors`. [More information on tensor sharing](../torch_shared_tensors)
1597+
1598+
Args:
1599+
model (`nn.Module`):
1600+
The model to load onto.
1601+
filename (`str`, or `os.PathLike`):
1602+
The filename location to load the file from.
1603+
strict (`bool`, *optional*, defaults to True):
1604+
Whether to fail if you're missing keys or having unexpected ones.
1605+
When false, the function simply returns missing and unexpected names.
1606+
device (`Union[str, int]`, *optional*, defaults to `cpu`):
1607+
The device where the tensors need to be located after load.
1608+
available options are all regular torch device locations.
1609+
1610+
Returns:
1611+
`(missing, unexpected): (List[str], List[str])`
1612+
`missing` are names in the model which were not modified during loading
1613+
`unexpected` are names that are on the file, but weren't used during
1614+
the load.
1615+
"""
1616+
state_dict = safe_load_file(filename)
1617+
model_state_dict = model.state_dict()
1618+
1619+
missing, unexpected = model.load_state_dict(state_dict, strict=False)
1620+
missing = set(missing)
1621+
1622+
if strict and (missing or unexpected):
1623+
missing_keys = ", ".join([f'"{k}"' for k in sorted(missing)])
1624+
unexpected_keys = ", ".join([f'"{k}"' for k in sorted(unexpected)])
1625+
error = f"Error(s) in loading state_dict for {model.__class__.__name__}:"
1626+
if missing:
1627+
error += f"\n Missing key(s) in state_dict: {missing_keys}"
1628+
if unexpected:
1629+
error += f"\n Unexpected key(s) in state_dict: {unexpected_keys}"
1630+
raise RuntimeError(error)
1631+
return missing, unexpected

mindnlp/sentence/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
from .transformer import Transformer
1818
from .pooling import Pooling
1919
from .normalize import Normalize
20+
from .dense import Dense
2021

2122
__all__ = [
2223
"Transformer",
2324
"Pooling",
2425
"Normalize",
26+
"Dense"
2527
]

mindnlp/sentence/models/dense.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
"""dense model"""
2+
from __future__ import annotations
3+
4+
import json
5+
import os
6+
7+
8+
from mindspore import Tensor
9+
from mindnlp.core import nn
10+
from mindnlp.core.serialization import load_model as load_safetensors_model, save, load
11+
from mindnlp.core.serialization import save_model as save_safetensors_model
12+
13+
from ..util import fullname, import_from_string
14+
15+
16+
class Dense(nn.Module):
17+
"""
18+
Feed-forward function with activation function.
19+
20+
This layer takes a fixed-sized sentence embedding and passes it through a feed-forward layer. Can be used to generate deep averaging networks (DAN).
21+
22+
Args:
23+
in_features: Size of the input dimension
24+
out_features: Output size
25+
bias: Add a bias vector
26+
activation_function: Pytorch activation function applied on
27+
output
28+
init_weight: Initial value for the matrix of the linear layer
29+
init_bias: Initial value for the bias of the linear layer
30+
"""
31+
32+
def __init__(
33+
self,
34+
in_features: int,
35+
out_features: int,
36+
bias: bool = True,
37+
activation_function=nn.Tanh(),
38+
init_weight: Tensor = None,
39+
init_bias: Tensor = None,
40+
):
41+
super().__init__()
42+
self.in_features = in_features
43+
self.out_features = out_features
44+
self.bias = bias
45+
self.activation_function = activation_function
46+
self.linear = nn.Linear(in_features, out_features, bias=bias)
47+
48+
if init_weight is not None:
49+
self.linear.weight = nn.Parameter(init_weight)
50+
51+
if init_bias is not None:
52+
self.linear.bias = nn.Parameter(init_bias)
53+
54+
def forward(self, features: dict[str, Tensor]):
55+
features.update({"sentence_embedding": self.activation_function(self.linear(features["sentence_embedding"]))})
56+
return features
57+
58+
def get_sentence_embedding_dimension(self) -> int:
59+
return self.out_features
60+
61+
def get_config_dict(self):
62+
return {
63+
"in_features": self.in_features,
64+
"out_features": self.out_features,
65+
"bias": self.bias,
66+
"activation_function": fullname(self.activation_function),
67+
}
68+
69+
def save(self, output_path, safe_serialization: bool = True) -> None:
70+
with open(os.path.join(output_path, "config.json"), "w") as fOut:
71+
json.dump(self.get_config_dict(), fOut)
72+
73+
if safe_serialization:
74+
save_safetensors_model(self, os.path.join(output_path, "model.safetensors"))
75+
else:
76+
save(self.state_dict(), os.path.join(output_path, "pytorch_model.bin"))
77+
78+
def __repr__(self):
79+
return f"Dense({self.get_config_dict()})"
80+
81+
@staticmethod
82+
def load(input_path):
83+
with open(os.path.join(input_path, "config.json")) as fIn:
84+
config = json.load(fIn)
85+
86+
config["activation_function"] = import_from_string(config["activation_function"])()
87+
model = Dense(**config)
88+
if os.path.exists(os.path.join(input_path, "model.safetensors")):
89+
load_safetensors_model(model, os.path.join(input_path, "model.safetensors"))
90+
else:
91+
model.load_state_dict(
92+
load(
93+
os.path.join(input_path, "pytorch_model.bin"), weights_only=True
94+
)
95+
)
96+
return model

mindnlp/sentence/sentence_transformer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,6 @@ def _load_module_class_from_ref(
199199
# If the class is from sentence_transformers, we can directly import it,
200200
# otherwise, we try to import it dynamically, and if that fails, we fall back to the default import
201201
if class_ref.startswith("sentence_transformers."):
202-
class_ref = class_ref.replace('sentence_transformers', 'mindnlp.sentence')
203202
return import_from_string(class_ref)
204203

205204
return import_from_string(class_ref)

mindnlp/sentence/util.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,12 @@ def import_from_string(dotted_path: str) -> type:
551551
>>> import_from_string('sentence_transformers.losses.MultipleNegativesRankingLoss')
552552
<class 'sentence_transformers.losses.MultipleNegativesRankingLoss.MultipleNegativesRankingLoss'>
553553
"""
554+
if 'sentence_transformers' in dotted_path:
555+
dotted_path = dotted_path.replace('sentence_transformers', 'mindnlp.sentence')
556+
557+
if 'torch.nn' in dotted_path:
558+
dotted_path = dotted_path.replace('torch.nn', 'mindnlp.core.nn')
559+
554560
try:
555561
module_path, class_name = dotted_path.rsplit(".", 1)
556562
except ValueError:

0 commit comments

Comments
 (0)