diff --git a/mindnlp/core/ops/reduction.py b/mindnlp/core/ops/reduction.py index 6aaa563c7..86b7558bf 100644 --- a/mindnlp/core/ops/reduction.py +++ b/mindnlp/core/ops/reduction.py @@ -52,9 +52,20 @@ def max(*args, **kwargs): if use_pyboost() and has_max: return mindspore.mint.max(*args, **kwargs) - input = kwargs.get('input', None) or args[0] - dim = kwargs.get('dim', None) or args[1] - keepdim = kwargs.get('keepdim', False) or args[2] + input = kwargs.get('input', None) + dim = kwargs.get('dim', None) + keepdim = kwargs.get('keepdim', False) + if len(args) == 1: + input = args[0] + elif len(args) == 2: + input = args[0] + dim = args[1] + elif len(args) == 3: + input = args[0] + dim = args[1] + keepdim = args[2] + else: + raise RuntimeError(f'need 3 inputs but got {len(args)}') out = ops.max(input, dim, keepdim) if dim is None: return out[0]