Skip to content

Commit e7b26a4

Browse files
authored
perf: optimize performance for pynative (#730)
add `jit` decorator to optimizer
1 parent 1b305f1 commit e7b26a4

File tree

4 files changed

+24
-0
lines changed

4 files changed

+24
-0
lines changed

mindcv/optim/adamw.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
from mindspore.nn.optim import Optimizer
1010
from mindspore.nn.optim.optimizer import opt_init_args_register
1111

12+
try:
13+
from mindspore import jit
14+
except ImportError:
15+
from mindspore import ms_function as jit
16+
1217

1318
def _check_param_value(beta1, beta2, eps, prim_name):
1419
"""Check the type of inputs."""
@@ -154,6 +159,7 @@ def __init__(
154159
self.reciprocal_scale = Tensor(1.0 / loss_scale, ms.float32)
155160
self.clip = clip
156161

162+
@jit
157163
def construct(self, gradients):
158164
lr = self.get_lr()
159165
gradients = scale_grad(gradients, self.reciprocal_scale)

mindcv/optim/adan.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
from mindspore.common.tensor import Tensor
66
from mindspore.nn.optim.optimizer import Optimizer, opt_init_args_register
77

8+
try:
9+
from mindspore import jit
10+
except ImportError:
11+
from mindspore import ms_function as jit
12+
813
_adan_opt = ops.MultitypeFuncGraph("adan_opt")
914

1015

@@ -144,6 +149,7 @@ def __init__(
144149

145150
self.weight_decay = Tensor(weight_decay, mstype.float32)
146151

152+
@jit
147153
def construct(self, gradients):
148154
params = self._parameters
149155
moment1 = self.moment1

mindcv/optim/lion.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
from mindspore.nn.optim import Optimizer
99
from mindspore.nn.optim.optimizer import opt_init_args_register
1010

11+
try:
12+
from mindspore import jit
13+
except ImportError:
14+
from mindspore import ms_function as jit
15+
1116

1217
def _check_param_value(beta1, beta2, prim_name):
1318
"""Check the type of inputs."""
@@ -142,6 +147,7 @@ def __init__(
142147
self.reciprocal_scale = Tensor(1.0 / loss_scale, ms.float32)
143148
self.clip = clip
144149

150+
@jit
145151
def construct(self, gradients):
146152
lr = self.get_lr()
147153
gradients = scale_grad(gradients, self.reciprocal_scale)

mindcv/optim/nadam.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
from mindspore.nn.optim import Optimizer
1010
from mindspore.nn.optim.optimizer import opt_init_args_register
1111

12+
try:
13+
from mindspore import jit
14+
except ImportError:
15+
from mindspore import ms_function as jit
16+
1217

1318
def _check_param_value(beta1, beta2, eps, prim_name):
1419
"""Check the type of inputs."""
@@ -48,6 +53,7 @@ def __init__(
4853
self.mu_schedule = Parameter(initializer(1, [1], ms.float32), name="mu_schedule")
4954
self.beta2_power = Parameter(initializer(1, [1], ms.float32), name="beta2_power")
5055

56+
@jit
5157
def construct(self, gradients):
5258
lr = self.get_lr()
5359
params = self.parameters

0 commit comments

Comments
 (0)