Skip to content

Commit 5ad1f19

Browse files
Merge pull request #325 from HydrogenSulfate/okok
support @jit.to_static for euler_beam example
2 parents 78d1612 + a3438c9 commit 5ad1f19

File tree

9 files changed

+228
-247
lines changed

9 files changed

+228
-247
lines changed

examples/euler_beam/euler_beam.py

Lines changed: 17 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import numpy as np
15+
from paddle import fluid
1616

1717
import ppsci
1818
from ppsci.autodiff import hessian
@@ -22,6 +22,9 @@
2222

2323
if __name__ == "__main__":
2424
args = config.parse_args()
25+
# enable computation for fourth-order differentiation of matmul
26+
fluid.core.set_prim_eager_enabled(True)
27+
fluid.core._set_prim_all_enabled(True)
2528
# set random seed for reproducibility
2629
ppsci.utils.misc.set_random_seed(42)
2730
# set training hyper-parameters
@@ -56,49 +59,24 @@
5659
random="Hammersley",
5760
name="EQ",
5861
)
59-
bc1 = ppsci.constraint.BoundaryConstraint(
60-
{"u0": lambda d: d["u"]},
61-
{"u0": 0},
62-
geom["interval"],
63-
{**dataloader_cfg, "batch_size": 1},
64-
ppsci.loss.MSELoss("sum"),
65-
criteria=lambda x: np.isclose(x, 0.0),
66-
name="BC1",
67-
)
68-
bc2 = ppsci.constraint.BoundaryConstraint(
69-
{"u__x": lambda d: jacobian(d["u"], d["x"])},
70-
{"u__x": 0},
71-
geom["interval"],
72-
{**dataloader_cfg, "batch_size": 1},
73-
ppsci.loss.MSELoss("sum"),
74-
criteria=lambda x: np.isclose(x, 0.0),
75-
name="BC2",
76-
)
77-
bc3 = ppsci.constraint.BoundaryConstraint(
78-
{"u__x__x": lambda d: hessian(d["u"], d["x"])},
79-
{"u__x__x": 0},
80-
geom["interval"],
81-
{**dataloader_cfg, "batch_size": 1},
82-
ppsci.loss.MSELoss("sum"),
83-
criteria=lambda x: np.isclose(x, 1.0),
84-
name="BC3",
85-
)
86-
bc4 = ppsci.constraint.BoundaryConstraint(
87-
{"u__x__x__x": lambda d: jacobian(hessian(d["u"], d["x"]), d["x"])},
88-
{"u__x__x__x": 0},
62+
bc = ppsci.constraint.BoundaryConstraint(
63+
{
64+
"u0": lambda d: d["u"][0:1],
65+
"u__x": lambda d: jacobian(d["u"], d["x"])[1:2],
66+
"u__x__x": lambda d: hessian(d["u"], d["x"])[2:3],
67+
"u__x__x__x": lambda d: jacobian(hessian(d["u"], d["x"]), d["x"])[3:4],
68+
},
69+
{"u0": 0, "u__x": 0, "u__x__x": 0, "u__x__x__x": 0},
8970
geom["interval"],
90-
{**dataloader_cfg, "batch_size": 1},
71+
{**dataloader_cfg, "batch_size": 4},
9172
ppsci.loss.MSELoss("sum"),
92-
criteria=lambda x: np.isclose(x, 1.0),
93-
name="BC4",
73+
evenly=True,
74+
name="BC",
9475
)
9576
# wrap constraints together
9677
constraint = {
9778
pde_constraint.name: pde_constraint,
98-
bc1.name: bc1,
99-
bc2.name: bc2,
100-
bc3.name: bc3,
101-
bc4.name: bc4,
79+
bc.name: bc,
10280
}
10381

10482
# set optimizer
@@ -153,6 +131,7 @@ def u_solution_func(out):
153131
geom=geom,
154132
validator=validator,
155133
visualizer=visualizer,
134+
to_static=args.to_static,
156135
)
157136
# train model
158137
solver.train()

ppsci/solver/eval.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from paddle import io
1919

2020
from ppsci.solver import printer
21-
from ppsci.utils import expression
2221
from ppsci.utils import misc
2322
from ppsci.utils import profiler
2423

@@ -59,17 +58,17 @@ def _eval_by_dataset(solver, epoch_id: int, log_freq: int) -> float:
5958

6059
for v in input_dict.values():
6160
v.stop_gradient = False
62-
evaluator = expression.ExpressionSolver(
63-
_validator.input_keys, _validator.output_keys, solver.model
64-
)
65-
for output_name, output_formula in _validator.output_expr.items():
66-
if output_name in label_dict:
67-
evaluator.add_target_expr(output_formula, output_name)
6861

6962
# forward
7063
with solver.autocast_context_manager(), solver.no_grad_context_manager():
71-
output_dict = evaluator(input_dict)
72-
validator_loss = _validator.loss(output_dict, label_dict, weight_dict)
64+
output_dict, validator_loss = solver.forward_helper.eval_forward(
65+
_validator.output_expr,
66+
input_dict,
67+
solver.model,
68+
_validator,
69+
label_dict,
70+
weight_dict,
71+
)
7372

7473
loss_dict[f"loss({_validator.name})"] = float(validator_loss)
7574

@@ -183,16 +182,16 @@ def _eval_by_batch(solver, epoch_id: int, log_freq: int) -> float:
183182

184183
for v in input_dict.values():
185184
v.stop_gradient = False
186-
evaluator = expression.ExpressionSolver(
187-
_validator.input_keys, _validator.output_keys, solver.model
188-
)
189-
for output_name, output_formula in _validator.output_expr.items():
190-
evaluator.add_target_expr(output_formula, output_name)
191-
192185
# forward
193186
with solver.autocast_context_manager(), solver.no_grad_context_manager():
194-
output_dict = evaluator(input_dict)
195-
validator_loss = _validator.loss(output_dict, label_dict, weight_dict)
187+
output_dict, validator_loss = solver.forward_helper.eval_forward(
188+
_validator.output_expr,
189+
input_dict,
190+
solver.model,
191+
_validator,
192+
label_dict,
193+
weight_dict,
194+
)
196195

197196
loss_dict[f"loss({_validator.name})"] = float(validator_loss)
198197

ppsci/solver/solver.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
import ppsci
3939
from ppsci.utils import config
40+
from ppsci.utils import expression
4041
from ppsci.utils import logger
4142
from ppsci.utils import misc
4243
from ppsci.utils import save_load
@@ -73,6 +74,7 @@ class Solver:
7374
compute_metric_by_batch (bool, optional): Whether calculate metrics after each batch during evaluate. Defaults to False.
7475
eval_with_no_grad (bool, optional): Whether set `stop_gradient=True` for every Tensor if no differentiation
7576
involved during computation, generally for save GPU memory and accelerate computing. Defaults to False.
77+
to_static (bool, optional): Whether enable to_static for forward pass. Defaults to False.
7678
7779
Examples:
7880
>>> import ppsci
@@ -97,7 +99,7 @@ class Solver:
9799
... "./output",
98100
... opt,
99101
... None,
100-
... )
102+
... ) # doctest: +SKIP
101103
"""
102104

103105
def __init__(
@@ -128,6 +130,7 @@ def __init__(
128130
checkpoint_path: Optional[str] = None,
129131
compute_metric_by_batch: bool = False,
130132
eval_with_no_grad: bool = False,
133+
to_static: bool = False,
131134
):
132135
# set model
133136
self.model = model
@@ -216,6 +219,10 @@ def __init__(
216219
if isinstance(loaded_metric, dict):
217220
self.best_metric.update(loaded_metric)
218221

222+
# init logger without FileHandler if not initialized before
223+
if logger._logger is None:
224+
logger.init_logger("ppsci", None)
225+
219226
# choosing an appropriate training function for different optimizers
220227
if isinstance(self.optimizer, optim.LBFGS):
221228
self.train_epoch_func = ppsci.solver.train.train_LBFGS_epoch_func
@@ -249,8 +256,13 @@ def __init__(
249256
if version.Version(paddle.__version__) != version.Version("0.0.0")
250257
else f"develop({paddle.version.commit[:7]})"
251258
)
252-
if logger._logger is not None:
253-
logger.info(f"Using paddlepaddle {paddle_version} on device {self.device}")
259+
logger.info(f"Using paddlepaddle {paddle_version} on device {self.device}")
260+
261+
self.forward_helper = expression.ExpressionSolver()
262+
263+
# whether enable static for forward pass, default to Fals
264+
jit.enable_to_static(to_static)
265+
logger.info(f"Set to_static={to_static} for forward computation.")
254266

255267
@staticmethod
256268
def from_config(cfg: Dict[str, Any]) -> Solver:

ppsci/solver/train.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,36 +38,44 @@ def train_epoch_func(solver, epoch_id: int, log_freq: int):
3838
reader_cost = 0
3939
batch_cost = 0
4040
reader_tic = time.perf_counter()
41+
42+
input_dicts = []
43+
label_dicts = []
44+
weight_dicts = []
4145
for _, _constraint in solver.constraint.items():
4246
input_dict, label_dict, weight_dict = next(_constraint.data_iter)
43-
4447
# profile code below
4548
# profiler.add_profiler_step(solver.cfg["profiler_options"])
4649
if iter_id == 5:
4750
# 5 step for warmup
4851
for key in solver.train_time_info:
4952
solver.train_time_info[key].reset()
5053
reader_cost += time.perf_counter() - reader_tic
51-
total_batch_size += next(iter(input_dict.values())).shape[0]
52-
5354
for v in input_dict.values():
5455
v.stop_gradient = False
55-
evaluator = expression.ExpressionSolver(
56-
_constraint.input_keys, _constraint.output_keys, solver.model
57-
)
58-
for output_name, output_formula in _constraint.output_expr.items():
59-
if output_name in label_dict:
60-
evaluator.add_target_expr(output_formula, output_name)
6156

62-
# forward for every constraint
63-
with solver.autocast_context_manager():
64-
output_dict = evaluator(input_dict)
65-
constraint_loss = _constraint.loss(output_dict, label_dict, weight_dict)
66-
total_loss += constraint_loss
57+
# gather each constraint's input, label, weight to a list
58+
input_dicts.append(input_dict)
59+
label_dicts.append(label_dict)
60+
weight_dicts.append(weight_dict)
61+
total_batch_size += next(iter(input_dict.values())).shape[0]
62+
reader_tic = time.perf_counter()
6763

68-
loss_dict[_constraint.name] = float(constraint_loss)
64+
# forward for every constraint, including model and equation expression
65+
with solver.autocast_context_manager():
66+
constraint_losses = solver.forward_helper.train_forward(
67+
[_constraint.output_expr for _constraint in solver.constraint.values()],
68+
input_dicts,
69+
solver.model,
70+
solver.constraint,
71+
label_dicts,
72+
weight_dicts,
73+
)
6974

70-
reader_tic = time.perf_counter()
75+
# compute loss for each constraint according to its' own output, label and weight
76+
for i, _constraint in enumerate(solver.constraint.values()):
77+
total_loss += constraint_losses[i]
78+
loss_dict[_constraint.name] += float(constraint_losses[i])
7179

7280
if solver.update_freq > 1:
7381
total_loss = total_loss / solver.update_freq

ppsci/solver/visu.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import paddle
1919

20-
from ppsci.utils import expression
2120
from ppsci.utils import misc
2221

2322

@@ -52,15 +51,11 @@ def visualize_func(solver, epoch_id: int):
5251
batch_input_dict[key] = input_dict[key][st:ed]
5352
batch_input_dict[key].stop_gradient = False
5453

55-
evaluator = expression.ExpressionSolver(
56-
_visualizer.input_keys, _visualizer.output_keys, solver.model
57-
)
58-
for output_key, output_expr in _visualizer.output_expr.items():
59-
evaluator.add_target_expr(output_expr, output_key)
60-
6154
# forward
62-
with solver.autocast_context_manager(), solver.no_grad_context_manager():
63-
batch_output_dict = evaluator(batch_input_dict)
55+
with solver.no_grad_context_manager():
56+
batch_output_dict = solver.forward_helper.visu_forward(
57+
_visualizer.output_expr, batch_input_dict, solver.model
58+
)
6459

6560
# collect batch data
6661
for key, batch_input in batch_input_dict.items():

0 commit comments

Comments
 (0)