Skip to content

Commit c1b22f8

Browse files
authored
fix diffuers pipelines c class ut (#2084)
1 parent afa0b61 commit c1b22f8

File tree

4 files changed

+20
-4
lines changed

4 files changed

+20
-4
lines changed

mindnlp/core/_tensor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def __reduce_ex__(self, protocol):
126126
StubTensor.__reduce_ex__ = __reduce_ex__
127127

128128
def to_(self, *args, **kwargs):
129-
dtype_to = None
129+
dtype_to = kwargs.get("dtype", None)
130130
if len(args) == 1:
131131
if isinstance(args[0], Type):
132132
dtype_to = args[0]
@@ -675,6 +675,10 @@ def __contains__(self, item):
675675
Tensor.sub_ = ops.inplace_sub
676676
StubTensor.sub_ = ops.inplace_sub
677677

678+
Tensor.roll = ops.roll
679+
StubTensor.roll = ops.roll
680+
681+
678682
def _rebuild_from_type_v2(func, new_type, args, state):
679683
ret = func(*args)
680684
return ret

mindnlp/core/fft/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,10 @@ def ifftn(input, s=None, dim=None, norm=None, *, out=None):
4444
def ifftshift(input, dim=None):
4545
return ops.ifftshift(input, dim)
4646

47+
def fft2(input, s=None, dim=(-2, -1), norm=None):
48+
return ops.fft2(input, s, dim, norm)
49+
50+
def ifft2(input, s=None, dim=(-2, -1), norm=None):
51+
return ops.ifft2(input, s, dim, norm)
52+
4753
__all__ = ['fft', 'fftn', 'irfft', 'rfft']

mindnlp/core/ops/array.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,8 @@ def squeeze(input, *dim, **kwargs):
273273

274274
# stack
275275
has_stack = hasattr(mindspore.mint, 'stack')
276-
def stack(tensors, dim=0, *, out=None):
276+
def stack(tensors, dim=0, *, out=None, **kwargs):
277+
dim = kwargs.pop('axis', dim)
277278
if use_pyboost() and has_stack:
278279
return call_ms_func(mindspore.mint.stack, tensors, dim, out=out)
279280
return call_ms_func(ops.stack, tensors, dim, out=out)

mindnlp/core/ops/creation.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,13 @@ def linspace(start, end, steps, *, dtype=None, **kwargs):
139139
return ops.linspace(start, end, steps).to(dtype)
140140

141141
# logspace
142-
def logspace(start, end, steps, base=10.0, *, dtype=None):
143-
return ops.logspace(start, end, steps, base, dtype=dtype)
142+
has_logspace = hasattr(mindspore.mint, 'logspace')
143+
def logspace(start, end, steps, base=10.0, *, dtype=None, **kwargs):
144+
if dtype is None:
145+
dtype = get_default_dtype()
146+
if use_pyboost() and has_logspace:
147+
return mindspore.mint.logspace(start, end, steps, base, dtype=dtype)
148+
return ops.logspace(float(start), float(end), steps, int(base), dtype=dtype)
144149

145150
# eye
146151
has_eye = hasattr(mindspore.mint, 'eye')

0 commit comments

Comments
 (0)