Skip to content

Commit 20a3564

Browse files
update newest static code for euler beam
1 parent 741e200 commit 20a3564

File tree

2 files changed

+49
-31
lines changed

2 files changed

+49
-31
lines changed

ppsci/solver/train.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,12 @@ def train_epoch_func(solver, epoch_id: int, log_freq: int):
3838
reader_cost = 0
3939
batch_cost = 0
4040
reader_tic = time.perf_counter()
41+
42+
input_dict_list = []
43+
label_dict_list = []
44+
weight_dict_list = []
4145
for _, _constraint in solver.constraint.items():
4246
input_dict, label_dict, weight_dict = next(_constraint.data_iter)
43-
4447
# profile code below
4548
# profiler.add_profiler_step(solver.cfg["profiler_options"])
4649
if iter_id == 5:
@@ -50,20 +53,34 @@ def train_epoch_func(solver, epoch_id: int, log_freq: int):
5053
reader_cost += time.perf_counter() - reader_tic
5154
total_batch_size.append(next(iter(input_dict.values())).shape[0])
5255

53-
for v in input_dict.values():
54-
v.stop_gradient = False
56+
# gather each constraint's input, label, weight to a list
57+
input_dict_list.append(input_dict)
58+
label_dict_list.append(label_dict)
59+
weight_dict_list.append(weight_dict)
5560

56-
# forward for every constraint
57-
with solver.autocast_context_manager():
58-
output_dict = solver.expr_helper(
59-
_constraint.output_expr, input_dict, solver.model
60-
)
61-
constraint_loss = _constraint.loss(output_dict, label_dict, weight_dict)
62-
total_loss += constraint_loss
61+
reader_tic = time.perf_counter()
6362

64-
loss_dict[_constraint.name] += float(constraint_loss)
63+
for x in input_dict_list:
64+
for v in x.values():
65+
v.stop_gradient = False
6566

66-
reader_tic = time.perf_counter()
67+
# forward for every constraint, including model and equation expression
68+
with solver.autocast_context_manager():
69+
output_dict_list = solver.expr_helper(
70+
[_constraint.output_expr for _constraint in solver.constraint.values()],
71+
input_dict_list,
72+
solver.model,
73+
)
74+
75+
# compute loss for each constraint according to its' own output, label and weight
76+
for i, (_, _constraint) in enumerate(solver.constraint.items()):
77+
constraint_loss = _constraint.loss(
78+
output_dict_list[i],
79+
label_dict_list[i],
80+
weight_dict_list[i],
81+
)
82+
total_loss += constraint_loss
83+
loss_dict[_constraint.name] += float(constraint_loss)
6784

6885
if solver.update_freq > 1:
6986
total_loss = total_loss / solver.update_freq

ppsci/utils/expression.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,22 +36,23 @@ def __init__(self):
3636
super().__init__()
3737

3838
@jit.to_static
39-
def forward(self, expr_dict, input_dict, model):
40-
output_dict = {k: v for k, v in input_dict.items()}
41-
42-
# model forward
43-
if callable(next(iter(expr_dict.values()))):
44-
model_output_dict = model(input_dict)
45-
output_dict.update(model_output_dict)
46-
47-
# equation forward
48-
for name, expr in expr_dict.items():
49-
if callable(expr):
50-
output_dict[name] = expr(output_dict)
51-
else:
52-
raise TypeError(f"expr type({type(expr)}) is invalid")
53-
54-
# clear differentiation cache
55-
clear()
56-
57-
return output_dict
39+
def forward(self, expr_dict_list, input_dict_list, model):
40+
output_dict_list = []
41+
for i, expr_dict in enumerate(expr_dict_list):
42+
# model forward
43+
if callable(next(iter(expr_dict.values()))):
44+
output_dict = model(input_dict_list[i])
45+
46+
# equation forward
47+
for name, expr in expr_dict.items():
48+
if callable(expr):
49+
output_dict[name] = expr({**output_dict, **input_dict_list[i]})
50+
else:
51+
raise TypeError(f"expr type({type(expr)}) is invalid")
52+
53+
output_dict_list.append(output_dict)
54+
55+
# clear differentiation cache
56+
clear()
57+
58+
return output_dict_list

0 commit comments

Comments
 (0)