Skip to content

Commit b035948

Browse files
authored
fix o class ut (#2079)
1 parent 2e1809c commit b035948

File tree

4 files changed

+13
-6
lines changed

4 files changed

+13
-6
lines changed

mindnlp/core/nn/functional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def _nllloss_nd(input, target, weight=None, ingore_index=-100, reduction='mean')
350350
raise ValueError(f"input bacth_size should be equal to target batch_size, but got {input.shape[0]} and "
351351
f"{target.shape[0]}")
352352
if input_dim == 1 or input_dim == 2:
353-
return nllloss_impl(input, target, weight, reduction, ingore_index)[0]
353+
return nllloss_impl(input.float(), target, weight.float(), reduction, ingore_index)[0]
354354
if input_dim == 4:
355355
return nllloss_2d_op(input, target, weight, reduction, ingore_index)[0]
356356
# input_dim==3 or input_dim>4
@@ -1490,7 +1490,7 @@ def pixel_shuffle(input, upscale_factor):
14901490
return ops.pixel_shuffle(input, upscale_factor)
14911491

14921492
def pixel_unshuffle(input, downscale_factor):
1493-
return ops.pixel_shuffle(input, downscale_factor)
1493+
return ops.pixel_unshuffle(input, downscale_factor)
14941494

14951495
def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=False):
14961496
if use_pyboost():

mindnlp/core/ops/comparison.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88
from ._inner import call_ms_func
99

10-
sort_out = namedtuple('stor_out', ['sorted', 'indices'])
10+
sort_out = namedtuple('sort_out', ['sorted', 'indices'])
11+
topk_out = namedtuple('topk_out', ['values', 'indices'])
1112
# allclose
1213
has_allclose = hasattr(mindspore.mint, 'allclose')
1314
def allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
@@ -173,15 +174,18 @@ def not_equal(input, other):
173174
def sort(input, *, dim=-1, descending=False, stable=False):
174175
if use_pyboost() and has_sort:
175176
out = mindspore.mint.sort(input, dim=dim, descending=descending, stable=stable)
176-
out = ops.sort(input, dim, descending)
177+
else:
178+
out = ops.sort(input, dim, descending)
177179
return sort_out(sorted=out[0], indices=out[1])
178180

179181
# topk
180182
has_topk = hasattr(mindspore.mint, 'topk')
181183
def topk(input, k, dim=-1, largest=True, sorted=True):
182184
if use_pyboost() and has_topk:
183-
return mindspore.mint.topk(input, k, dim, largest, sorted)
184-
return ops.topk(input, k, dim, largest, sorted)
185+
out = mindspore.mint.topk(input, k, dim, largest, sorted)
186+
else:
187+
out = ops.topk(input, k, dim, largest, sorted)
188+
return topk_out(values=out[0], indices=out[1])
185189

186190
# msort
187191
def msort(input):

mindnlp/core/serialization.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,7 @@ def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
903903

904904

905905
dtype_map = {
906+
"DoubleStorage": np.float64,
906907
"HalfStorage": np.float16,
907908
"FloatStorage": np.float32,
908909
'BFloat16Storage': bfloat16,
@@ -912,6 +913,7 @@ def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
912913
}
913914

914915
storage_map = {
916+
mindspore.float64: "DoubleStorage",
915917
mindspore.float16: "HalfStorage",
916918
mindspore.float32: "FloatStorage",
917919
mindspore.bfloat16: 'BFloat16Storage',

mindnlp/core/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def __init__(self, type=None, index=None):
3030
_target = type.type
3131
_id = type.index
3232
else:
33+
print(type)
3334
raise TypeError("core.device(): `type` must be type of 'str' or 'core.device'.")
3435
else:
3536
raise ValueError("core.device(): `type` can not be None")

0 commit comments

Comments
 (0)