Skip to content

Commit fdf8a47

Browse files
authored
fix roll on CPU (#1840)
1 parent 0a936bf commit fdf8a47

File tree

3 files changed

+8
-3
lines changed

3 files changed

+8
-3
lines changed

mindnlp/accelerate/big_modeling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,13 @@ def register_empty_parameter(module, name, param):
7373
if param is not None:
7474
kwargs = module._parameters[name].__dict__
7575
kwargs["requires_grad"] = param.requires_grad
76-
module._parameters[name].assign_value(Tensor_(shape=(), dtype=module._parameters[name].dtype))
76+
module._parameters[name].assign_value(Tensor_(shape=module._parameters[name].shape, dtype=module._parameters[name].dtype))
7777
module._parameters[name].meta = True
7878

7979
def register_empty_buffer(module, name, buffer, persistent=True):
8080
old_register_buffer(module, name, buffer, persistent=persistent)
8181
if buffer is not None:
82-
module._buffers[name].assign_value(Tensor_(shape=(), dtype=module._buffers[name].dtype))
82+
module._buffers[name].assign_value(Tensor_(shape=module._parameters[name].shape, dtype=module._buffers[name].dtype))
8383
module._buffers[name].meta = True
8484

8585
try:

mindnlp/core/ops/other.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,9 +619,12 @@ def repeat_interleave(input, repeats, dim=None):
619619
return input.repeat(repeats, dim)
620620

621621
# roll
622+
DEVICE_TARGET = mindspore.get_context('device_target')
622623
def roll(input, shifts, dims=None):
623624
if use_pyboost():
624625
return mindspore.mint.roll(input, shifts, dims)
626+
if DEVICE_TARGET == 'CPU':
627+
return mindspore.numpy.roll(input, shifts, dims)
625628
return ops.roll(input, shifts, dims)
626629

627630
# searchsorted

mindnlp/patch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
import mindspore
44
from mindspore import Tensor
5+
from mindnlp.configs import GENERATOR_SEED
56

67
def infer_value_for_BroadcastTo(x, shape):
78
"""Infer value for BroadcastTo op."""
@@ -18,4 +19,5 @@ def none_in_tuple_or_list(x):
1819
np_data = np.broadcast_to(x.asnumpy(), shape)
1920
return Tensor(np_data)
2021

21-
mindspore.ops.operations.manually_defined.ops_def.infer_value_for_BroadcastTo = infer_value_for_BroadcastTo
22+
if GENERATOR_SEED:
23+
mindspore.ops.operations.manually_defined.ops_def.infer_value_for_BroadcastTo = infer_value_for_BroadcastTo

0 commit comments

Comments
 (0)