Skip to content

Commit d859a1b

Browse files
authored
fix TensorPy for mindspore 2.5 (#1961)
1 parent ab20e2e commit d859a1b

File tree

4 files changed

+22
-4
lines changed

4 files changed

+22
-4
lines changed

mindnlp/accelerate/big_modeling.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
"""big modeling"""
22
from contextlib import contextmanager
3-
from mindspore._c_expression import Tensor as Tensor_ # pylint: disable=no-name-in-module
3+
try:
4+
from mindspore._c_expression import TensorPy as Tensor_ # pylint: disable=no-name-in-module
5+
except:
6+
from mindspore._c_expression import Tensor as Tensor_ # pylint: disable=no-name-in-module
7+
48
from mindnlp.utils.testing_utils import parse_flag_from_env
59
from mindnlp.core import nn
610

mindnlp/core/ops/creation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
"""creation ops"""
22
import numpy as np
33
import mindspore
4-
from mindspore._c_expression import Tensor as CTensor # pylint: disable=no-name-in-module, import-error
4+
try:
5+
from mindspore._c_expression import TensorPy as CTensor # pylint: disable=no-name-in-module
6+
except:
7+
from mindspore._c_expression import Tensor as CTensor # pylint: disable=no-name-in-module
8+
9+
510
from mindspore import ops
611
from mindspore.ops._primitive_cache import _get_cache_prim
712
from mindnlp.configs import use_pyboost, ON_ORANGE_PI

mindnlp/parallel/comm_func.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
"""communication functional api."""
22
from mindspore import ops, Tensor
3-
from mindspore._c_expression import Tensor as Tensor_ # pylint: disable=no-name-in-module
43
from mindspore.ops.operations._inner_ops import Send, Receive
54
from mindspore.communication import GlobalComm, get_group_rank_from_world_rank
65
from mindspore.ops._primitive_cache import _get_cache_prim
6+
try:
7+
from mindspore._c_expression import TensorPy as Tensor_ # pylint: disable=no-name-in-module
8+
except:
9+
from mindspore._c_expression import Tensor as Tensor_ # pylint: disable=no-name-in-module
10+
711

812
def isend(tensor, dst=0, group=GlobalComm.WORLD_COMM_GROUP, tag=0):
913
"""

mindnlp/transformers/models/jamba/modeling_jamba.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@
2323
from typing import Any, Dict, List, Optional, Tuple, Union
2424

2525
import mindspore
26-
from mindspore._c_expression import Tensor as RawTensor # pylint: disable=no-name-in-module
26+
try:
27+
from mindspore._c_expression import TensorPy as RawTensor # pylint: disable=no-name-in-module
28+
except:
29+
from mindspore._c_expression import Tensor as RawTensor # pylint: disable=no-name-in-module
30+
31+
2732
import mindnlp.core.nn.functional as F
2833
from mindnlp.core import nn, ops
2934
from mindnlp.core.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

0 commit comments

Comments
 (0)