Skip to content

Commit ef64a3b

Browse files
authored
fix optimizer args as same dtype on GPU (#1842)
1 parent a4b0072 commit ef64a3b

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

mindnlp/core/ops/optim.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,38 @@
11
"""optim op"""
2+
import mindspore
23
from mindspore import ops
34
from mindspore.ops._primitive_cache import _get_cache_prim
45

6+
DEVICE_TARGET = mindspore.get_context('device_target')
7+
58
_adadelta = ops.ApplyAdadelta()
69
def raw_adadelta(param, square_avg, acc_delta, lr, rho, eps, grad):
710
return _adadelta(param, square_avg, acc_delta, lr, rho, eps, grad)
811

912
_adam = ops.Adam()
1013
def raw_adam(param, exp_avg, exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
1114
# var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad
15+
if DEVICE_TARGET == 'GPU' and param.dtype != mindspore.float32:
16+
beta1_power, beta2_power, lr, beta1, beta2, epsilon = mindspore.tensor(beta1_power, dtype=param.dtype), \
17+
mindspore.tensor(beta2_power, dtype=param.dtype), \
18+
mindspore.tensor(lr, dtype=param.dtype), \
19+
mindspore.tensor(beta1, dtype=param.dtype), \
20+
mindspore.tensor(beta2, dtype=param.dtype), \
21+
mindspore.tensor(epsilon, dtype=param.dtype)
1222
return _adam(param, exp_avg, exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
1323

1424
_adam_amsgrad = ops.ApplyAdamWithAmsgradV2()
1525
def raw_adam_amsgrad(param, exp_avg, exp_avg_sq, max_exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
1626
# var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad
27+
28+
if DEVICE_TARGET == 'GPU' and param.dtype != mindspore.float32:
29+
beta1_power, beta2_power, lr, beta1, beta2, epsilon = mindspore.tensor(beta1_power, dtype=param.dtype), \
30+
mindspore.tensor(beta2_power, dtype=param.dtype), \
31+
mindspore.tensor(lr, dtype=param.dtype), \
32+
mindspore.tensor(beta1, dtype=param.dtype), \
33+
mindspore.tensor(beta2, dtype=param.dtype), \
34+
mindspore.tensor(epsilon, dtype=param.dtype)
35+
1736
return _adam_amsgrad(param, exp_avg, exp_avg_sq, max_exp_avg_sq,
1837
beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
1938

0 commit comments

Comments
 (0)