Skip to content

Commit cceef44

Browse files
Merge branch 'develop' into support_more_fptype
2 parents 6e5a58c + 99a2a80 commit cceef44

File tree

17 files changed

+48
-52
lines changed

17 files changed

+48
-52
lines changed

paddlescience/network/grad_norm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def get_grad_norm_loss(self, losses):
8484
norms = []
8585
for i in range(losses.shape[0]):
8686
grad = paddle.autograd.grad(losses[i], W, retain_graph=True)
87-
norms.append(paddle.norm(self.loss_weights[i] * grad[0], p=2))
88-
norms = paddle.concat(norms)
87+
norms.append(paddle.norm(self.loss_weights[i] * grad[0], p=2).reshape([]))
88+
norms = paddle.stack(norms)
8989

9090
# calculate the inverse train rate
9191
loss_ratio = losses.numpy() / self.initial_losses

ppsci/arch/activation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from typing import Callable
1616

1717
import paddle
18-
import paddle.nn as nn
1918
import paddle.nn.functional as F
19+
from paddle import nn
2020

2121
act_func_dict = {
2222
"elu": F.elu,

ppsci/arch/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import numpy as np
1919
import paddle
20-
import paddle.nn as nn
20+
from paddle import nn
2121

2222
from ppsci.utils import logger
2323

ppsci/arch/embedding_koopman.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import numpy as np
2323
import paddle
24-
import paddle.nn as nn
24+
from paddle import nn
2525
from paddle.nn.initializer import Constant
2626
from paddle.nn.initializer import Uniform
2727

ppsci/arch/mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from typing import Tuple
1616
from typing import Union
1717

18-
import paddle.nn as nn
18+
from paddle import nn
1919

2020
from ppsci.arch import activation as act_mod
2121
from ppsci.arch import base

ppsci/data/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
from functools import partial
1818

1919
import numpy as np
20-
import paddle.device as device
2120
import paddle.distributed as dist
22-
import paddle.io as io
21+
from paddle import device
22+
from paddle import io
2323

2424
from ppsci.data import dataloader
2525
from ppsci.data import dataset

ppsci/equation/pde/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from typing import Tuple
1919

2020
import paddle
21-
import paddle.nn as nn
2221
import sympy
22+
from paddle import nn
2323

2424

2525
class PDE:

ppsci/loss/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import paddle.nn as nn
15+
from paddle import nn
1616

1717

1818
class Loss(nn.Layer):

ppsci/metric/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import paddle.nn as nn
15+
from paddle import nn
1616

1717

1818
class Metric(nn.Layer):

ppsci/solver/eval.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
import time
1616

1717
import paddle
18-
import paddle.amp as amp
19-
import paddle.io as io
18+
from paddle import io
2019

2120
from ppsci.solver import printer
2221
from ppsci.utils import expression
@@ -67,14 +66,7 @@ def eval_func(solver, epoch_id: int, log_freq: int) -> float:
6766
evaluator.add_target_expr(output_formula, output_name)
6867

6968
# forward
70-
if solver.use_amp:
71-
with amp.auto_cast(level=solver.amp_level):
72-
output_dict = evaluator(input_dict)
73-
validator_loss = _validator.loss(
74-
output_dict, label_dict, weight_dict
75-
)
76-
loss_dict[f"loss({_validator.name})"] = float(validator_loss)
77-
else:
69+
with solver._autocast_context_manager():
7870
output_dict = evaluator(input_dict)
7971
validator_loss = _validator.loss(output_dict, label_dict, weight_dict)
8072
loss_dict[f"loss({_validator.name})"] = float(validator_loss)

0 commit comments

Comments
 (0)