@@ -29,17 +29,21 @@ def fn_aux(*args):
29
29
else :
30
30
fn_ = fn
31
31
32
- def value_and_grad_f (* args ):
32
+ def value_and_grad_f (* args , ** kwargs ):
33
33
_pynative_executor .set_grad_flag (True )
34
- _pynative_executor .new_graph (fn , * args )
35
- values = fn_ (* args )
36
- _pynative_executor .end_graph (fn , values , * args )
34
+ _pynative_executor .new_graph (fn , * args , ** kwargs )
35
+ values = fn_ (* args , ** kwargs )
36
+ _pynative_executor .end_graph (fn , values , * args , ** kwargs )
37
+
38
+ run_args = args
39
+ if kwargs :
40
+ run_args = args + tuple (kwargs .values ())
37
41
38
42
if GENERATOR_SEED :
39
- grads = _pynative_executor .grad (fn_ , grad_ , params_or_argnums , None , * args )
43
+ grads = _pynative_executor .grad (fn_ , grad_ , params_or_argnums , None , * run_args )
40
44
# grads = grad_(fn_, params)(*args, *params)
41
45
else :
42
- _pynative_executor .grad (fn_ , grad_ , params_or_argnums , None , * args )
46
+ _pynative_executor .grad (fn_ , grad_ , params_or_argnums , None , * run_args )
43
47
grads = _pynative_executor () # pylint: disable=not-callable
44
48
grads = tuple (mindspore .Tensor (grad ) for grad in grads )
45
49
if attach_grads :
0 commit comments