Skip to content

Commit 898797a

Browse files
authored
load safetensor back to numpy (#1986)
1 parent fd5fa73 commit 898797a

File tree

2 files changed

+9
-13
lines changed

2 files changed

+9
-13
lines changed

.github/workflows/ci_pipeline.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ jobs:
8686
needs: pylint-check
8787
strategy:
8888
matrix:
89-
ms_version: ['2.2.14', '2.3.1']
89+
ms_version: ['2.2.14', '2.3.1', '2.4.10', '2.5.0']
9090
runs-on: ubuntu-latest
9191
steps:
9292
- uses: actions/checkout@v3

mindnlp/core/serialization.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1410,18 +1410,14 @@ def convert(info: dict[str, Any]):
14101410
assert end - begin == math.prod(shape) * np.dtype(numpy_dtype).itemsize
14111411
buf = byte_buf[begin:end]
14121412

1413-
try:
1414-
if info['dtype'] == 'BF16' and not SUPPORT_BF16:
1415-
raise ValueError('not support bfloat16.')
1416-
out = Tensor.convert_bytes_to_tensor(buf, tuple(shape), ms_dtype)
1417-
except:
1418-
array = np.frombuffer(buf, dtype=numpy_dtype).reshape(shape)
1419-
1420-
if array.dtype == bfloat16 and not SUPPORT_BF16:
1421-
logger.warning_once("MindSpore do not support bfloat16 dtype, we will automaticlly convert to float16")
1422-
array = array.astype(np.float16)
1423-
array = array.astype(array.dtype)
1424-
out = Tensor.from_numpy(array)
1413+
1414+
array = np.frombuffer(buf, dtype=numpy_dtype).reshape(shape)
1415+
1416+
if array.dtype == bfloat16 and not SUPPORT_BF16:
1417+
logger.warning_once("MindSpore do not support bfloat16 dtype, we will automaticlly convert to float16")
1418+
array = array.astype(np.float16)
1419+
array = array.astype(array.dtype)
1420+
out = Tensor.from_numpy(array)
14251421
return out
14261422

14271423
with open(filename, "rb") as fp:

0 commit comments

Comments
 (0)