diff --git a/mindtorch/_apis/gpu.py b/mindtorch/_apis/gpu.py index f940d011b..8c11c4d8f 100644 --- a/mindtorch/_apis/gpu.py +++ b/mindtorch/_apis/gpu.py @@ -277,7 +277,9 @@ def select(condition, x, y): if isinstance(y, numbers.Number) or y.ndim == 0: y = fill_scalar(condition.shape, y, None) - return legacy.select(condition, x, y) + # return legacy.select(condition, x, y) + return add(mul(condition, x), mul(sub(1, condition), y)) + def round(input, decimals): return legacy.round(input, decimals) diff --git a/mindtorch/_apis/meta.py b/mindtorch/_apis/meta.py index 831d22e0b..cff8df0f4 100644 --- a/mindtorch/_apis/meta.py +++ b/mindtorch/_apis/meta.py @@ -246,8 +246,8 @@ def sqrt(input): __all__.append('sqrt') -def normal_float_float(mean, std, size, geneartor): - out = Tensor_(shape=size, dtype=mindtorch.float32) +def normal_float_float(mean, std, size, dtype, geneartor): + out = Tensor_(shape=size, dtype=dtype) return mindtorch.Tensor(out) @@ -371,9 +371,24 @@ def greater_equal(input, other): __all__.append('greater_equal') +def greater(input, other): + if isinstance(input, mindtorch.Tensor): + return input + return other + +def less(input, other): + if isinstance(input, mindtorch.Tensor): + return input + return other + def inplace_zero(input): return input def clone(input): return input +def select(condition, input, other): + return input + +def logical_not(input): + return input \ No newline at end of file diff --git a/tests/test_module/custom_modeling.py b/tests/test_module/custom_modeling.py index 43b8cc9f3..666cd21cf 100644 --- a/tests/test_module/custom_modeling.py +++ b/tests/test_module/custom_modeling.py @@ -1,4 +1,4 @@ -from mindnlp.core import nn +from mindtorch import nn from mindnlp.transformers import PreTrainedModel