Skip to content

Commit 825d93c

Browse files
authored
fix ops.max (#1993)
1 parent 82379a1 commit 825d93c

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

mindnlp/core/ops/reduction.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,20 @@ def max(*args, **kwargs):
5252
if use_pyboost() and has_max:
5353
return mindspore.mint.max(*args, **kwargs)
5454

55-
input = kwargs.get('input', None) or args[0]
56-
dim = kwargs.get('dim', None) or args[1]
57-
keepdim = kwargs.get('keepdim', False) or args[2]
55+
input = kwargs.get('input', None)
56+
dim = kwargs.get('dim', None)
57+
keepdim = kwargs.get('keepdim', False)
58+
if len(args) == 1:
59+
input = args[0]
60+
elif len(args) == 2:
61+
input = args[0]
62+
dim = args[1]
63+
elif len(args) == 3:
64+
input = args[0]
65+
dim = args[1]
66+
keepdim = args[2]
67+
else:
68+
raise RuntimeError(f'need 3 inputs but got {len(args)}')
5869
out = ops.max(input, dim, keepdim)
5970
if dim is None:
6071
return out[0]

0 commit comments

Comments
 (0)