|
1 | 1 | """optim op"""
|
| 2 | +import mindspore |
2 | 3 | from mindspore import ops
|
3 | 4 | from mindspore.ops._primitive_cache import _get_cache_prim
|
4 | 5 |
|
| 6 | +DEVICE_TARGET = mindspore.get_context('device_target') |
| 7 | + |
5 | 8 | _adadelta = ops.ApplyAdadelta()
|
6 | 9 | def raw_adadelta(param, square_avg, acc_delta, lr, rho, eps, grad):
|
7 | 10 | return _adadelta(param, square_avg, acc_delta, lr, rho, eps, grad)
|
8 | 11 |
|
9 | 12 | _adam = ops.Adam()
|
10 | 13 | def raw_adam(param, exp_avg, exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
|
11 | 14 | # var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad
|
| 15 | + if DEVICE_TARGET == 'GPU' and param.dtype != mindspore.float32: |
| 16 | + beta1_power, beta2_power, lr, beta1, beta2, epsilon = mindspore.tensor(beta1_power, dtype=param.dtype), \ |
| 17 | + mindspore.tensor(beta2_power, dtype=param.dtype), \ |
| 18 | + mindspore.tensor(lr, dtype=param.dtype), \ |
| 19 | + mindspore.tensor(beta1, dtype=param.dtype), \ |
| 20 | + mindspore.tensor(beta2, dtype=param.dtype), \ |
| 21 | + mindspore.tensor(epsilon, dtype=param.dtype) |
12 | 22 | return _adam(param, exp_avg, exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
|
13 | 23 |
|
14 | 24 | _adam_amsgrad = ops.ApplyAdamWithAmsgradV2()
|
15 | 25 | def raw_adam_amsgrad(param, exp_avg, exp_avg_sq, max_exp_avg_sq, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
|
16 | 26 | # var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad
|
| 27 | + |
| 28 | + if DEVICE_TARGET == 'GPU' and param.dtype != mindspore.float32: |
| 29 | + beta1_power, beta2_power, lr, beta1, beta2, epsilon = mindspore.tensor(beta1_power, dtype=param.dtype), \ |
| 30 | + mindspore.tensor(beta2_power, dtype=param.dtype), \ |
| 31 | + mindspore.tensor(lr, dtype=param.dtype), \ |
| 32 | + mindspore.tensor(beta1, dtype=param.dtype), \ |
| 33 | + mindspore.tensor(beta2, dtype=param.dtype), \ |
| 34 | + mindspore.tensor(epsilon, dtype=param.dtype) |
| 35 | + |
17 | 36 | return _adam_amsgrad(param, exp_avg, exp_avg_sq, max_exp_avg_sq,
|
18 | 37 | beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
|
19 | 38 |
|
|
0 commit comments