Skip to content

Commit fa08659

Browse files
update runable code(wrap subloss computation in static
1 parent 20a3564 commit fa08659

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

ppsci/solver/train.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,19 +66,18 @@ def train_epoch_func(solver, epoch_id: int, log_freq: int):
6666

6767
# forward for every constraint, including model and equation expression
6868
with solver.autocast_context_manager():
69-
output_dict_list = solver.expr_helper(
69+
constraint_losses = solver.expr_helper(
7070
[_constraint.output_expr for _constraint in solver.constraint.values()],
7171
input_dict_list,
7272
solver.model,
73+
solver.constraint,
74+
label_dict_list,
75+
weight_dict_list,
7376
)
7477

7578
# compute loss for each constraint according to its' own output, label and weight
7679
for i, (_, _constraint) in enumerate(solver.constraint.items()):
77-
constraint_loss = _constraint.loss(
78-
output_dict_list[i],
79-
label_dict_list[i],
80-
weight_dict_list[i],
81-
)
80+
constraint_loss = constraint_losses[i]
8281
total_loss += constraint_loss
8382
loss_dict[_constraint.name] += float(constraint_loss)
8483

ppsci/utils/expression.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,15 @@ def __init__(self):
3636
super().__init__()
3737

3838
@jit.to_static
39-
def forward(self, expr_dict_list, input_dict_list, model):
39+
def forward(
40+
self,
41+
expr_dict_list,
42+
input_dict_list,
43+
model,
44+
constraint,
45+
label_dict_list,
46+
weight_dict_list,
47+
):
4048
output_dict_list = []
4149
for i, expr_dict in enumerate(expr_dict_list):
4250
# model forward
@@ -55,4 +63,13 @@ def forward(self, expr_dict_list, input_dict_list, model):
5563
# clear differentiation cache
5664
clear()
5765

58-
return output_dict_list
66+
# compute loss for each constraint according to its' own output, label and weight
67+
constraint_losses = []
68+
for i, (_, _constraint) in enumerate(constraint.items()):
69+
constraint_loss = _constraint.loss(
70+
output_dict_list[i],
71+
label_dict_list[i],
72+
weight_dict_list[i],
73+
)
74+
constraint_losses.append(constraint_loss)
75+
return constraint_losses

0 commit comments

Comments
 (0)