Skip to content

Commit 57949c0

Browse files
refine expression code and support train/eval/visu for euler_beam in jit.to_static
1 parent 21e5ec1 commit 57949c0

File tree

9 files changed

+136
-62
lines changed

9 files changed

+136
-62
lines changed

examples/darcy/darcy2d.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,12 @@
1313
# limitations under the License.
1414

1515
import numpy as np
16-
from paddle import fluid
1716

1817
import ppsci
1918
from ppsci.utils import config
2019
from ppsci.utils import logger
2120

2221
if __name__ == "__main__":
23-
fluid.core._set_prim_all_enabled(True)
24-
2522
args = config.parse_args()
2623
# set random seed for reproducibility
2724
ppsci.utils.misc.set_random_seed(42)

examples/euler_beam/euler_beam.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,12 @@ def u_solution_func(out):
126126
epochs=EPOCHS,
127127
iters_per_epoch=ITERS_PER_EPOCH,
128128
eval_during_train=True,
129-
eval_freq=1000,
129+
eval_freq=10,
130130
equation=equation,
131131
geom=geom,
132132
validator=validator,
133133
visualizer=visualizer,
134+
to_static=args.to_static,
134135
)
135136
# train model
136137
solver.train()

ppsci/solver/eval.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from paddle import io
1919

2020
from ppsci.solver import printer
21-
from ppsci.utils import expression
2221
from ppsci.utils import misc
2322
from ppsci.utils import profiler
2423

@@ -62,10 +61,14 @@ def _eval_by_dataset(solver, epoch_id: int, log_freq: int) -> float:
6261

6362
# forward
6463
with solver.autocast_context_manager(), solver.no_grad_context_manager():
65-
output_dict = solver.expr_helper(
66-
_validator.output_expr, input_dict, solver.model
64+
output_dict, validator_loss = solver.forward_helper.eval_forward(
65+
_validator.output_expr,
66+
input_dict,
67+
solver.model,
68+
_validator,
69+
label_dict,
70+
weight_dict,
6771
)
68-
validator_loss = _validator.loss(output_dict, label_dict, weight_dict)
6972

7073
loss_dict[f"loss({_validator.name})"] = float(validator_loss)
7174

@@ -179,16 +182,16 @@ def _eval_by_batch(solver, epoch_id: int, log_freq: int) -> float:
179182

180183
for v in input_dict.values():
181184
v.stop_gradient = False
182-
evaluator = expression.ExpressionSolver(
183-
_validator.input_keys, _validator.output_keys, solver.model
184-
)
185-
for output_name, output_formula in _validator.output_expr.items():
186-
evaluator.add_target_expr(output_formula, output_name)
187-
188185
# forward
189186
with solver.autocast_context_manager(), solver.no_grad_context_manager():
190-
output_dict = evaluator(input_dict)
191-
validator_loss = _validator.loss(output_dict, label_dict, weight_dict)
187+
output_dict, validator_loss = solver.forward_helper.eval_forward(
188+
_validator.output_expr,
189+
input_dict,
190+
solver.model,
191+
_validator,
192+
label_dict,
193+
weight_dict,
194+
)
192195

193196
loss_dict[f"loss({_validator.name})"] = float(validator_loss)
194197

ppsci/solver/solver.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ class Solver:
7474
compute_metric_by_batch (bool, optional): Whether calculate metrics after each batch during evaluate. Defaults to False.
7575
eval_with_no_grad (bool, optional): Whether set `stop_gradient=True` for every Tensor if no differentiation
7676
involved during computation, generally for save GPU memory and accelerate computing. Defaults to False.
77+
to_static (bool, optional): Whether enable to_static for forward pass. Defaults to False.
7778
7879
Examples:
7980
>>> import ppsci
@@ -129,6 +130,7 @@ def __init__(
129130
checkpoint_path: Optional[str] = None,
130131
compute_metric_by_batch: bool = False,
131132
eval_with_no_grad: bool = False,
133+
to_static: bool = False,
132134
):
133135
# set model
134136
self.model = model
@@ -253,7 +255,12 @@ def __init__(
253255
if logger._logger is not None:
254256
logger.info(f"Using paddlepaddle {paddle_version} on device {self.device}")
255257

256-
self.expr_helper = expression.ExpressionSolver()
258+
self.forward_helper = expression.ExpressionSolver()
259+
260+
# whether enable static for forward pass, default to Fals
261+
if to_static:
262+
jit.enable_to_static(to_static)
263+
logger.info("Enable to_static for forward computation.")
257264

258265
@staticmethod
259266
def from_config(cfg: Dict[str, Any]) -> Solver:

ppsci/solver/train.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ def train_epoch_func(solver, epoch_id: int, log_freq: int):
3939
batch_cost = 0
4040
reader_tic = time.perf_counter()
4141

42-
input_dict_list = []
43-
label_dict_list = []
44-
weight_dict_list = []
42+
input_dicts = []
43+
label_dicts = []
44+
weight_dicts = []
4545
for _, _constraint in solver.constraint.items():
4646
input_dict, label_dict, weight_dict = next(_constraint.data_iter)
4747
# profile code below
@@ -51,35 +51,31 @@ def train_epoch_func(solver, epoch_id: int, log_freq: int):
5151
for key in solver.train_time_info:
5252
solver.train_time_info[key].reset()
5353
reader_cost += time.perf_counter() - reader_tic
54-
total_batch_size += next(iter(input_dict.values())).shape[0]
54+
for v in input_dict.values():
55+
v.stop_gradient = False
5556

5657
# gather each constraint's input, label, weight to a list
57-
input_dict_list.append(input_dict)
58-
label_dict_list.append(label_dict)
59-
weight_dict_list.append(weight_dict)
60-
58+
input_dicts.append(input_dict)
59+
label_dicts.append(label_dict)
60+
weight_dicts.append(weight_dict)
61+
total_batch_size += next(iter(input_dict.values())).shape[0]
6162
reader_tic = time.perf_counter()
6263

63-
for x in input_dict_list:
64-
for v in x.values():
65-
v.stop_gradient = False
66-
6764
# forward for every constraint, including model and equation expression
6865
with solver.autocast_context_manager():
69-
constraint_losses = solver.expr_helper(
66+
constraint_losses = solver.forward_helper.train_forward(
7067
[_constraint.output_expr for _constraint in solver.constraint.values()],
71-
input_dict_list,
68+
input_dicts,
7269
solver.model,
7370
solver.constraint,
74-
label_dict_list,
75-
weight_dict_list,
71+
label_dicts,
72+
weight_dicts,
7673
)
7774

7875
# compute loss for each constraint according to its' own output, label and weight
79-
for i, (_, _constraint) in enumerate(solver.constraint.items()):
80-
constraint_loss = constraint_losses[i]
81-
total_loss += constraint_loss
82-
loss_dict[_constraint.name] += float(constraint_loss)
76+
for i, _constraint in enumerate(solver.constraint.values()):
77+
total_loss += constraint_losses[i]
78+
loss_dict[_constraint.name] += float(constraint_losses[i])
8379

8480
if solver.update_freq > 1:
8581
total_loss = total_loss / solver.update_freq

ppsci/solver/visu.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import paddle
1919

20-
from ppsci.utils import expression
2120
from ppsci.utils import misc
2221

2322

@@ -54,7 +53,7 @@ def visualize_func(solver, epoch_id: int):
5453

5554
# forward
5655
with solver.no_grad_context_manager():
57-
batch_output_dict = solver.expr_helper(
56+
batch_output_dict = solver.forward_helper.visu_forward(
5857
_visualizer.output_expr, batch_input_dict, solver.model
5958
)
6059

ppsci/utils/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,11 @@ def parse_args():
179179
parser = argparse.ArgumentParser("paddlescience running script")
180180
parser.add_argument("-e", "--epochs", type=int, help="training epochs")
181181
parser.add_argument("-o", "--output_dir", type=str, help="output directory")
182+
parser.add_argument(
183+
"--to_static",
184+
action="store_true",
185+
help="whether enable to_static for forward computation",
186+
)
182187

183188
args = parser.parse_args()
184189
return args

ppsci/utils/expression.py

Lines changed: 86 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,64 +12,128 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import TYPE_CHECKING
16+
from typing import Callable
17+
from typing import Dict
18+
from typing import Tuple
19+
20+
import paddle
1521
from paddle import jit
1622
from paddle import nn
1723

24+
if TYPE_CHECKING:
25+
from ppsci import constraint
26+
from ppsci import validate
27+
1828
from ppsci.autodiff import clear
1929

2030

2131
class ExpressionSolver(nn.Layer):
2232
"""Expression Solver
2333
24-
Args:
25-
input_keys (Dict[str]):Names of input keys.
26-
output_keys (Dict[str]):Names of output keys.
27-
model (nn.Layer): Model to get output variables from input variables.
28-
2934
Examples:
3035
>>> import ppsci
3136
>>> model = ppsci.arch.MLP(("x", "y"), ("u", "v"), 5, 128)
32-
>>> expr_solver = ExpressionSolver(("x", "y"), ("u", "v"), model)
37+
>>> expr_solver = ExpressionSolver()
3338
"""
3439

3540
def __init__(self):
3641
super().__init__()
3742

3843
@jit.to_static
39-
def forward(
44+
def train_forward(
4045
self,
41-
expr_dict_list,
42-
input_dict_list,
43-
model,
44-
constraint,
45-
label_dict_list,
46-
weight_dict_list,
46+
expr_dicts: Tuple[Dict[str, Callable], ...],
47+
input_dicts: Tuple[Dict[str, paddle.Tensor], ...],
48+
model: nn.Layer,
49+
constraint: Dict[str, "constraint.Constraint"],
50+
label_dicts: Tuple[Dict[str, paddle.Tensor], ...],
51+
weight_dicts: Tuple[Dict[str, paddle.Tensor], ...],
4752
):
48-
output_dict_list = []
49-
for i, expr_dict in enumerate(expr_dict_list):
53+
output_dicts = []
54+
for i, expr_dict in enumerate(expr_dicts):
5055
# model forward
5156
if callable(next(iter(expr_dict.values()))):
52-
output_dict = model(input_dict_list[i])
57+
output_dict = model(input_dicts[i])
5358

5459
# equation forward
5560
for name, expr in expr_dict.items():
61+
if name not in label_dicts[i]:
62+
continue
5663
if callable(expr):
57-
output_dict[name] = expr({**output_dict, **input_dict_list[i]})
64+
output_dict[name] = expr({**output_dict, **input_dicts[i]})
5865
else:
5966
raise TypeError(f"expr type({type(expr)}) is invalid")
6067

61-
output_dict_list.append(output_dict)
68+
output_dicts.append(output_dict)
6269

6370
# clear differentiation cache
6471
clear()
6572

6673
# compute loss for each constraint according to its' own output, label and weight
6774
constraint_losses = []
68-
for i, (_, _constraint) in enumerate(constraint.items()):
75+
for i, _constraint in enumerate(constraint.values()):
6976
constraint_loss = _constraint.loss(
70-
output_dict_list[i],
71-
label_dict_list[i],
72-
weight_dict_list[i],
77+
output_dicts[i],
78+
label_dicts[i],
79+
weight_dicts[i],
7380
)
7481
constraint_losses.append(constraint_loss)
7582
return constraint_losses
83+
84+
@jit.to_static
85+
def eval_forward(
86+
self,
87+
expr_dict: Dict[str, Callable],
88+
input_dict: Dict[str, Callable],
89+
model: nn.Layer,
90+
validator: "validate.Validator",
91+
label_dict: Dict[str, Callable],
92+
weight_dict: Dict[str, Callable],
93+
):
94+
# model forward
95+
if callable(next(iter(expr_dict.values()))):
96+
output_dict = model(input_dict)
97+
98+
# equation forward
99+
for name, expr in expr_dict.items():
100+
if name not in label_dict:
101+
continue
102+
if callable(expr):
103+
output_dict[name] = expr({**output_dict, **input_dict})
104+
else:
105+
raise TypeError(f"expr type({type(expr)}) is invalid")
106+
107+
# clear differentiation cache
108+
clear()
109+
110+
# compute loss for each validator according to its' own output, label and weight
111+
validator_loss = validator.loss(
112+
output_dict,
113+
label_dict,
114+
weight_dict,
115+
)
116+
return output_dict, validator_loss
117+
118+
def visu_forward(
119+
self,
120+
expr_dict: Dict[str, Callable],
121+
input_dict: Dict[str, Callable],
122+
model: nn.Layer,
123+
):
124+
# model forward
125+
if callable(next(iter(expr_dict.values()))):
126+
output_dict = model(input_dict)
127+
128+
# equation forward
129+
for name, expr in expr_dict.items():
130+
if callable(expr):
131+
output_dict[name] = expr({**output_dict, **input_dict})
132+
else:
133+
raise TypeError(f"expr type({type(expr)}) is invalid")
134+
135+
# clear differentiation cache
136+
clear()
137+
138+
# compute loss for each validator according to its' own output, label and weight
139+
return output_dict

ppsci/utils/misc.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,12 @@ def function_with_eval_state(self, *args, **kwargs):
255255
self.model.eval()
256256

257257
# run func in eval mode
258-
func(self, *args, **kwargs)
258+
result = func(self, *args, **kwargs)
259259

260260
# restore state
261261
if train_state:
262262
self.model.train()
263263

264+
return result
265+
264266
return function_with_eval_state

0 commit comments

Comments
 (0)