Skip to content

Commit e6924bb

Browse files
authored
add mindspore infer function patch (#1810)
1 parent d5ac1fe commit e6924bb

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

mindnlp/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
context.set_context(ascend_config={"precision_mode": "allow_mix_precision"})
4545

4646
from mindspore import jit as ms_jit
47+
from mindnlp import patch
4748
from mindnlp import transformers
4849
from mindnlp import dataset
4950
from mindnlp import evaluate

mindnlp/patch.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import numpy as np
2+
import mindspore
3+
from mindspore import Tensor
4+
5+
def infer_value_for_BroadcastTo(x, shape):
6+
"""Infer value for BroadcastTo op."""
7+
def none_in_tuple_or_list(x):
8+
return isinstance(x, (tuple, list)) and None in x
9+
if shape is None or none_in_tuple_or_list(shape) or x is None:
10+
return None
11+
12+
shape = list(shape)
13+
for idx, s in enumerate(shape):
14+
if s == -1:
15+
shape[idx] = x.shape[idx]
16+
17+
np_data = np.broadcast_to(x.asnumpy(), shape)
18+
return Tensor(np_data)
19+
20+
mindspore.ops.operations.manually_defined.ops_def.infer_value_for_BroadcastTo = infer_value_for_BroadcastTo

0 commit comments

Comments
 (0)