Skip to content

Commit a72d54c

Browse files
authored
fix transformers p class ut (#2082)
1 parent a1bdf39 commit a72d54c

File tree

10 files changed

+149
-13
lines changed

10 files changed

+149
-13
lines changed

mindnlp/core/_tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,9 @@ def __contains__(self, item):
655655
Tensor.scatter_reduce_ = ops.inplace_scatter_reduce
656656
StubTensor.scatter_reduce_ = ops.inplace_scatter_reduce
657657

658+
Tensor.exponential_ = ops.inplace_exponential
659+
StubTensor.exponential_ = ops.inplace_exponential
660+
658661
def _rebuild_from_type_v2(func, new_type, args, state):
659662
ret = func(*args)
660663
return ret

mindnlp/core/backends/cuda/__init__.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import contextlib
2+
from typing_extensions import deprecated
3+
14
class cuBLASModule:
25
# def __getattr__(self, name):
36
# if name == "allow_tf32":
@@ -26,4 +29,46 @@ class cuBLASModule:
2629
# raise AttributeError("Unknown attribute " + name)
2730
pass
2831

29-
matmul = cuBLASModule()
32+
matmul = cuBLASModule()
33+
34+
@contextlib.contextmanager
35+
@deprecated(
36+
(
37+
"`torch.backends.cuda.sdp_kernel()` is deprecated. "
38+
"In the future, this context manager will be removed. "
39+
"Please see `torch.nn.attention.sdpa_kernel()` for the new context manager, "
40+
"with updated signature."
41+
),
42+
category=FutureWarning,
43+
)
44+
def sdp_kernel(
45+
enable_flash: bool = True,
46+
enable_math: bool = True,
47+
enable_mem_efficient: bool = True,
48+
enable_cudnn: bool = True,
49+
):
50+
r"""
51+
.. warning:: This flag is beta and subject to change.
52+
53+
This context manager can be used to temporarily enable or disable any of the three backends for scaled dot product attention.
54+
Upon exiting the context manager, the previous state of the flags will be restored.
55+
"""
56+
# from torch.nn.attention import sdpa_kernel
57+
58+
# backend_list = []
59+
# if enable_flash:
60+
# backend_list.append(SDPBackend.FLASH_ATTENTION)
61+
# if enable_mem_efficient:
62+
# backend_list.append(SDPBackend.EFFICIENT_ATTENTION)
63+
# if enable_math:
64+
# backend_list.append(SDPBackend.MATH)
65+
# if enable_cudnn:
66+
# backend_list.append(SDPBackend.CUDNN_ATTENTION)
67+
68+
# with sdpa_kernel(backend_list) as context:
69+
# try:
70+
# yield context
71+
# finally:
72+
# pass
73+
74+
pass

mindnlp/core/compiler/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,6 @@ def staging_specialize(*args, **kwargs):
1010

1111
if fn is not None:
1212
return wrap_func(fn)
13-
return wrap_func
13+
return wrap_func
14+
15+
def reset(): pass

mindnlp/core/nn/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from . import parametrizations
33
from .weight_norm import *
44
from .clip_grad import *
5+
from .init import skip_init

mindnlp/core/nn/utils/init.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# mypy: allow-untyped-defs
2+
import inspect
3+
4+
from mindnlp import core
5+
6+
def skip_init(module_cls, *args, **kwargs):
7+
r"""
8+
Given a module class object and args / kwargs, instantiate the module without initializing parameters / buffers.
9+
10+
This can be useful if initialization is slow or if custom initialization will
11+
be performed, making the default initialization unnecessary. There are some caveats to this, due to
12+
the way this function is implemented:
13+
14+
1. The module must accept a `device` arg in its constructor that is passed to any parameters
15+
or buffers created during construction.
16+
17+
2. The module must not perform any computation on parameters in its constructor except
18+
initialization (i.e. functions from :mod:`torch.nn.init`).
19+
20+
If these conditions are satisfied, the module can be instantiated with parameter / buffer values
21+
uninitialized, as if having been created using :func:`torch.empty`.
22+
23+
Args:
24+
module_cls: Class object; should be a subclass of :class:`torch.nn.Module`
25+
args: args to pass to the module's constructor
26+
kwargs: kwargs to pass to the module's constructor
27+
28+
Returns:
29+
Instantiated module with uninitialized parameters / buffers
30+
31+
Example::
32+
33+
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
34+
>>> import torch
35+
>>> m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1)
36+
>>> m.weight
37+
Parameter containing:
38+
tensor([[0.0000e+00, 1.5846e+29, 7.8307e+00, 2.5250e-29, 1.1210e-44]],
39+
requires_grad=True)
40+
>>> m2 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=6, out_features=1)
41+
>>> m2.weight
42+
Parameter containing:
43+
tensor([[-1.4677e+24, 4.5915e-41, 1.4013e-45, 0.0000e+00, -1.4677e+24,
44+
4.5915e-41]], requires_grad=True)
45+
46+
"""
47+
if not issubclass(module_cls, core.nn.Module):
48+
raise RuntimeError(f"Expected a Module; got {module_cls}")
49+
if "device" not in inspect.signature(module_cls).parameters:
50+
raise RuntimeError("Module must support a 'device' arg to skip initialization")
51+
52+
final_device = kwargs.pop("device", "cpu")
53+
kwargs["device"] = "meta"
54+
return module_cls(*args, **kwargs).to_empty(device=final_device)

mindnlp/core/ops/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ def concat(tensors, dim=0, *, out=None, **kwargs):
3636
return cat(tensors, dim, out=out, **kwargs)
3737

3838
# concatenate
39-
def concatenate(tensors, dim=0, out=None):
40-
return cat(tensors, dim, out=out)
39+
def concatenate(tensors, dim=0, out=None, **kwargs):
40+
return cat(tensors, dim, out=out, **kwargs)
4141

4242
# conj
4343
def conj(input):

mindnlp/core/ops/creation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ def arange(start=0, end=None, step=1, *, dtype=None, device=None):
104104
if ON_ORANGE_PI and dtype in (None, mindspore.int64):
105105
dtype = mindspore.int32
106106
if use_pyboost() and has_arange:
107-
start = start.item() if isinstance(start, mindspore.Tensor) else start
108-
end = end.item() if isinstance(end, mindspore.Tensor) else end
109-
step = step.item() if isinstance(step, mindspore.Tensor) else step
107+
start = start.item() if isinstance(start, (mindspore.Tensor, np.integer)) else start
108+
end = end.item() if isinstance(end, (mindspore.Tensor, np.integer)) else end
109+
step = step.item() if isinstance(step, (mindspore.Tensor, np.integer)) else step
110110
return mindspore.mint.arange(start, end, step, dtype=dtype)
111111

112112
start = mindspore.Tensor(start) if not isinstance(start, mindspore.Tensor) else start

mindnlp/core/ops/inplace.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,26 @@ def inplace_scatter_reduce(input, dim, index, src, reduce, *, include_self=True)
137137
reduce = "add"
138138
return inplace_scatter_src_reduce_op(input, dim, index, src, reduce)
139139

140+
def inplace_exponential(tensor, lambd=1.0):
141+
"""
142+
原地操作的指数分布采样 (类似Tensor.exponential_)
143+
:param tensor: 要填充的目标张量
144+
:param lambd: 率参数 (λ > 0)
145+
:return: 修改后的张量 (原张量被覆盖)
146+
"""
147+
assert lambd > 0, "lambd 必须大于0"
148+
149+
# 生成与目标张量形状相同的均匀分布随机数
150+
u = core.rand_like(tensor)
151+
152+
# 数值保护
153+
u = u.clamp(min=core.finfo(u.dtype).eps, max=1.0)
154+
155+
# 逆变换法赋值
156+
tensor.data = -core.log(1 - u) / lambd
157+
158+
return tensor
159+
140160
__all__ = [
141161
'inplace_copy',
142162
'inplace_zero',
@@ -152,5 +172,6 @@ def inplace_scatter_reduce(input, dim, index, src, reduce, *, include_self=True)
152172
'inplace_fill_diagonal',
153173
'inplace_triu',
154174
'inplace_round',
155-
'inplace_scatter_reduce'
175+
'inplace_scatter_reduce',
176+
'inplace_exponential'
156177
]

mindnlp/core/ops/other.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def manual_expand(tensor, shape):
6464

6565

6666
def broadcast_to(input, *shape):
67-
if isinstance(shape[0], tuple):
67+
if isinstance(shape[0], (list, tuple)):
6868
shape = shape[0]
6969
if ON_ORANGE_PI and not use_pyboost():
7070
# return input.expand(mindspore.tensor(shape))

mindnlp/core/ops/random.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,16 @@ def multinomial(input, num_samples, replacement=False, *, generator=None):
4747

4848
vals = div(log(random_uniform), input + 1e-10)
4949
_, samples = topk(vals, num_samples)
50-
50+
5151
return samples.astype(mindspore.int64)
5252

5353
# normal
5454
has_normal = hasattr(mindspore.mint, 'normal')
5555
def normal(mean=0.0, std=1.0, size=None, *, generator=None, out=None):
5656
if use_pyboost() and has_normal:
57-
return call_ms_func(mindspore.mint.normal, float(mean), float(std), size, generator, out=out)
57+
mean = float(mean) if isinstance(mean, int) else mean
58+
mean = float(std) if isinstance(std, int) else std
59+
return call_ms_func(mindspore.mint.normal, mean, std, size, generator, out=out)
5860
if size is None:
5961
if isinstance(mean, mindspore.Tensor):
6062
size = mean.shape
@@ -90,8 +92,11 @@ def rand_like(input, *, dtype=None):
9092
has_randint = hasattr(mindspore.mint, 'randint')
9193
def randint(*args, **kwargs):
9294
device = kwargs.pop('device', None)
95+
low = kwargs.pop('low', None)
9396
high = kwargs.pop('high', None)
9497
size = kwargs.pop('size', None)
98+
if low is not None:
99+
args += (low,)
95100
if high is not None:
96101
args += (high,)
97102

@@ -112,11 +117,16 @@ def randint_like(*args, **kwargs):
112117
has_randn = hasattr(mindspore.mint, 'randn')
113118
def randn(*size, generator=None, dtype=None, **kwargs):
114119
size = kwargs.pop('size', size)
120+
new_size = ()
121+
for s in size:
122+
if isinstance(s, np.integer):
123+
s = s.item()
124+
new_size += (s,)
115125
if dtype is None:
116126
dtype = get_default_dtype()
117127
if use_pyboost() and has_randn:
118-
return mindspore.mint.randn(*size, generator=generator, dtype=dtype)
119-
return ops.randn(*size, dtype=dtype)
128+
return mindspore.mint.randn(*new_size, generator=generator, dtype=dtype)
129+
return ops.randn(*new_size, dtype=dtype)
120130

121131
# randn_like
122132
has_randn_like = hasattr(mindspore.mint, 'randn_like')

0 commit comments

Comments
 (0)