|
33 | 33 |
|
34 | 34 | from contextlib import closing, contextmanager
|
35 | 35 | from enum import Enum
|
36 |
| -from typing import Dict, Union, Optional, Any, OrderedDict |
| 36 | +from typing import Dict, Union, Optional, Any, OrderedDict, Tuple, List |
37 | 37 | from functools import reduce
|
38 | 38 | from dataclasses import dataclass
|
39 | 39 |
|
|
46 | 46 | import safetensors.numpy
|
47 | 47 | from safetensors import deserialize
|
48 | 48 |
|
| 49 | +from mindnlp.core import nn |
49 | 50 | from mindnlp.core.nn import Parameter
|
50 | 51 | from mindnlp.configs import SUPPORT_BF16
|
51 | 52 | from .nn import Module
|
@@ -1548,3 +1549,83 @@ def load_checkpoint(ckpt_file_name):
|
1548 | 1549 | "'filter_prefix' or 'specify_prefix' are set correctly.")
|
1549 | 1550 |
|
1550 | 1551 | return parameter_dict
|
| 1552 | + |
| 1553 | + |
| 1554 | +def save_model( |
| 1555 | + model: nn.Module, filename: str, metadata: Optional[Dict[str, str]] = None, force_contiguous: bool = True |
| 1556 | +): |
| 1557 | + """ |
| 1558 | + Saves a given torch model to specified filename. |
| 1559 | + This method exists specifically to avoid tensor sharing issues which are |
| 1560 | + not allowed in `safetensors`. [More information on tensor sharing](../torch_shared_tensors) |
| 1561 | +
|
| 1562 | + Args: |
| 1563 | + model (`nn.Module`): |
| 1564 | + The model to save on disk. |
| 1565 | + filename (`str`): |
| 1566 | + The filename location to save the file |
| 1567 | + metadata (`Dict[str, str]`, *optional*): |
| 1568 | + Extra information to save along with the file. |
| 1569 | + Some metadata will be added for each dropped tensors. |
| 1570 | + This information will not be enough to recover the entire |
| 1571 | + shared structure but might help understanding things |
| 1572 | + force_contiguous (`boolean`, *optional*, defaults to True): |
| 1573 | + Forcing the state_dict to be saved as contiguous tensors. |
| 1574 | + This has no effect on the correctness of the model, but it |
| 1575 | + could potentially change performance if the layout of the tensor |
| 1576 | + was chosen specifically for that reason. |
| 1577 | + """ |
| 1578 | + state_dict = model.state_dict() |
| 1579 | + |
| 1580 | + if force_contiguous: |
| 1581 | + state_dict = {k: v.contiguous() for k, v in state_dict.items()} |
| 1582 | + try: |
| 1583 | + safe_save_file(state_dict, filename, metadata=metadata) |
| 1584 | + except ValueError as e: |
| 1585 | + msg = str(e) |
| 1586 | + msg += " Or use save_model(..., force_contiguous=True), read the docs for potential caveats." |
| 1587 | + raise ValueError(msg) |
| 1588 | + |
| 1589 | + |
| 1590 | +def load_model( |
| 1591 | + model: nn.Module, filename: Union[str, os.PathLike], strict: bool = True |
| 1592 | +) -> Tuple[List[str], List[str]]: |
| 1593 | + """ |
| 1594 | + Loads a given filename onto a torch model. |
| 1595 | + This method exists specifically to avoid tensor sharing issues which are |
| 1596 | + not allowed in `safetensors`. [More information on tensor sharing](../torch_shared_tensors) |
| 1597 | +
|
| 1598 | + Args: |
| 1599 | + model (`nn.Module`): |
| 1600 | + The model to load onto. |
| 1601 | + filename (`str`, or `os.PathLike`): |
| 1602 | + The filename location to load the file from. |
| 1603 | + strict (`bool`, *optional*, defaults to True): |
| 1604 | + Whether to fail if you're missing keys or having unexpected ones. |
| 1605 | + When false, the function simply returns missing and unexpected names. |
| 1606 | + device (`Union[str, int]`, *optional*, defaults to `cpu`): |
| 1607 | + The device where the tensors need to be located after load. |
| 1608 | + available options are all regular torch device locations. |
| 1609 | +
|
| 1610 | + Returns: |
| 1611 | + `(missing, unexpected): (List[str], List[str])` |
| 1612 | + `missing` are names in the model which were not modified during loading |
| 1613 | + `unexpected` are names that are on the file, but weren't used during |
| 1614 | + the load. |
| 1615 | + """ |
| 1616 | + state_dict = safe_load_file(filename) |
| 1617 | + model_state_dict = model.state_dict() |
| 1618 | + |
| 1619 | + missing, unexpected = model.load_state_dict(state_dict, strict=False) |
| 1620 | + missing = set(missing) |
| 1621 | + |
| 1622 | + if strict and (missing or unexpected): |
| 1623 | + missing_keys = ", ".join([f'"{k}"' for k in sorted(missing)]) |
| 1624 | + unexpected_keys = ", ".join([f'"{k}"' for k in sorted(unexpected)]) |
| 1625 | + error = f"Error(s) in loading state_dict for {model.__class__.__name__}:" |
| 1626 | + if missing: |
| 1627 | + error += f"\n Missing key(s) in state_dict: {missing_keys}" |
| 1628 | + if unexpected: |
| 1629 | + error += f"\n Unexpected key(s) in state_dict: {unexpected_keys}" |
| 1630 | + raise RuntimeError(error) |
| 1631 | + return missing, unexpected |
0 commit comments