Skip to content

Commit 756344b

Browse files
authored
value_and_grad support kwargs (#1835)
1 parent 5ce1a83 commit 756344b

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

mindnlp/core/autograd/function.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,21 @@ def fn_aux(*args):
2929
else:
3030
fn_ = fn
3131

32-
def value_and_grad_f(*args):
32+
def value_and_grad_f(*args, **kwargs):
3333
_pynative_executor.set_grad_flag(True)
34-
_pynative_executor.new_graph(fn, *args)
35-
values = fn_(*args)
36-
_pynative_executor.end_graph(fn, values, *args)
34+
_pynative_executor.new_graph(fn, *args, **kwargs)
35+
values = fn_(*args, **kwargs)
36+
_pynative_executor.end_graph(fn, values, *args, **kwargs)
37+
38+
run_args = args
39+
if kwargs:
40+
run_args = args + tuple(kwargs.values())
3741

3842
if GENERATOR_SEED:
39-
grads = _pynative_executor.grad(fn_, grad_, params_or_argnums, None, *args)
43+
grads = _pynative_executor.grad(fn_, grad_, params_or_argnums, None, *run_args)
4044
# grads = grad_(fn_, params)(*args, *params)
4145
else:
42-
_pynative_executor.grad(fn_, grad_, params_or_argnums, None, *args)
46+
_pynative_executor.grad(fn_, grad_, params_or_argnums, None, *run_args)
4347
grads = _pynative_executor() # pylint: disable=not-callable
4448
grads = tuple(mindspore.Tensor(grad) for grad in grads)
4549
if attach_grads:

0 commit comments

Comments
 (0)