Skip to content

Commit 635e300

Browse files
simplify expression_solver for to_static
1 parent 407551f commit 635e300

File tree

5 files changed

+27
-140
lines changed

5 files changed

+27
-140
lines changed

ppsci/solver/eval.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,12 @@ def eval_by_dataset(solver, epoch_id: int, log_freq: int) -> float:
5959

6060
for v in input_dict.values():
6161
v.stop_gradient = False
62-
evaluator = expression.ExpressionSolver(
63-
_validator.input_keys, _validator.output_keys, solver.model
64-
)
65-
for output_name, output_formula in _validator.output_expr.items():
66-
evaluator.add_target_expr(output_formula, output_name)
6762

6863
# forward
6964
with solver.autocast_context_manager(), solver.no_grad_context_manager():
70-
output_dict = evaluator(input_dict)
65+
output_dict = solver.expr_helper(
66+
_validator.output_expr, input_dict, solver.model
67+
)
7168
validator_loss = _validator.loss(output_dict, label_dict, weight_dict)
7269

7370
loss_dict[f"loss({_validator.name})"] = float(validator_loss)

ppsci/solver/solver.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
import ppsci
3838
from ppsci.utils import config
39+
from ppsci.utils import expression
3940
from ppsci.utils import logger
4041
from ppsci.utils import misc
4142
from ppsci.utils import save_load
@@ -246,6 +247,8 @@ def __init__(
246247
if logger._logger is not None:
247248
logger.info(f"Using paddlepaddle {paddle_version} on device {self.device}")
248249

250+
self.expr_helper = expression.ExpressionSolver()
251+
249252
@staticmethod
250253
def from_config(cfg: Dict[str, Any]) -> Solver:
251254
"""Initialize solver from given config.

ppsci/solver/train.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,12 @@ def train_epoch_func(solver, epoch_id: int, log_freq: int):
5252

5353
for v in input_dict.values():
5454
v.stop_gradient = False
55-
evaluator = expression.ExpressionSolver(
56-
_constraint.input_keys, _constraint.output_keys, solver.model
57-
)
58-
for output_name, output_formula in _constraint.output_expr.items():
59-
evaluator.add_target_expr(output_formula, output_name)
6055

6156
# forward for every constraint
6257
with solver.autocast_context_manager():
63-
output_dict = evaluator(input_dict)
58+
output_dict = solver.expr_helper(
59+
_constraint.output_expr, input_dict, solver.model
60+
)
6461
constraint_loss = _constraint.loss(output_dict, label_dict, weight_dict)
6562
total_loss += constraint_loss
6663

ppsci/solver/visu.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,11 @@ def visualize_func(solver, epoch_id: int):
5353
batch_input_dict[key] = input_dict[key][st:ed]
5454
batch_input_dict[key].stop_gradient = False
5555

56-
evaluator = expression.ExpressionSolver(
57-
_visualizer.input_keys, _visualizer.output_keys, solver.model
58-
)
59-
for output_key, output_expr in _visualizer.output_expr.items():
60-
evaluator.add_target_expr(output_expr, output_key)
61-
6256
# forward
6357
with solver.autocast_context_manager():
64-
batch_output_dict = evaluator(batch_input_dict)
58+
batch_output_dict = solver.expr_helper(
59+
_visualizer.output_expr, batch_input_dict, solver.model
60+
)
6561

6662
# collect batch data
6763
for key, batch_input in batch_input_dict.items():

ppsci/utils/expression.py

Lines changed: 15 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,9 @@
1515
from typing import Callable
1616
from typing import Union
1717

18-
import paddle
19-
import sympy
2018
from paddle import nn
2119

2220
from ppsci.autodiff import clear
23-
from ppsci.autodiff import hessian
24-
from ppsci.autodiff import jacobian
2521

2622

2723
class ExpressionSolver(nn.Layer):
@@ -38,127 +34,25 @@ class ExpressionSolver(nn.Layer):
3834
>>> expr_solver = ExpressionSolver(("x", "y"), ("u", "v"), model)
3935
"""
4036

41-
def __init__(self, input_keys, output_keys, model):
37+
def __init__(self):
4238
super().__init__()
43-
self.input_keys = input_keys
44-
self.output_keys = output_keys
45-
self.model = model
46-
self.expr_dict = {}
47-
self.output_dict = {}
48-
49-
def solve_expr(self, expr: sympy.Basic) -> Union[float, paddle.Tensor]:
50-
"""Evaluates the value of the expression recursively in the expression tree
51-
by post-order traversal.
52-
53-
Args:
54-
expr (sympy.Basic): Expression.
55-
56-
Returns:
57-
Union[float, paddle.Tensor]: Value of current expression `expr`.
58-
"""
59-
# already computed in output_dict(including input data)
60-
if getattr(expr, "name", None) in self.output_dict:
61-
return self.output_dict[expr.name]
62-
63-
# compute output from model
64-
if isinstance(expr, sympy.Symbol):
65-
if expr.name in self.model.output_keys:
66-
out_dict = self.model(self.output_dict)
67-
self.output_dict.update(out_dict)
68-
return self.output_dict[expr.name]
69-
else:
70-
raise ValueError(f"varname {expr.name} not exist!")
71-
72-
# compute output from model
73-
elif isinstance(expr, sympy.Function):
74-
out_dict = self.model(self.output_dict)
75-
self.output_dict.update(out_dict)
76-
return self.output_dict[expr.name]
77-
78-
# compute derivative
79-
elif isinstance(expr, sympy.Derivative):
80-
ys = self.solve_expr(expr.args[0])
81-
ys_name = expr.args[0].name
82-
if ys_name not in self.output_dict:
83-
self.output_dict[ys_name] = ys
84-
xs = self.solve_expr(expr.args[1][0])
85-
xs_name = expr.args[1][0].name
86-
if xs_name not in self.output_dict:
87-
self.output_dict[xs_name] = xs
88-
order = expr.args[1][1]
89-
if order == 1:
90-
der = jacobian(self.output_dict[ys_name], self.output_dict[xs_name])
91-
der_name = f"{ys_name}__{xs_name}"
92-
elif order == 2:
93-
der = hessian(self.output_dict[ys_name], self.output_dict[xs_name])
94-
der_name = f"{ys_name}__{xs_name}__{xs_name}"
95-
else:
96-
raise NotImplementedError(
97-
f"Expression {expr} has derivative order({order}) >=3, "
98-
f"which is not implemented yet"
99-
)
100-
if der_name not in self.output_dict:
101-
self.output_dict[der_name] = der
102-
return der
103-
104-
# return single python number directly for leaf node
105-
elif isinstance(expr, sympy.Number):
106-
return float(expr)
107-
108-
# compute sub-nodes value and merge by addition
109-
elif isinstance(expr, sympy.Add):
110-
results = [self.solve_expr(arg) for arg in expr.args]
111-
out = results[0]
112-
for i in range(1, len(results)):
113-
out = out + results[i]
114-
return out
115-
116-
# compute sub-nodes value and merge by multiplication
117-
elif isinstance(expr, sympy.Mul):
118-
results = [self.solve_expr(arg) for arg in expr.args]
119-
out = results[0]
120-
for i in range(1, len(results)):
121-
out = out * results[i]
122-
return out
123-
124-
# compute sub-nodes value and merge by power
125-
elif isinstance(expr, sympy.Pow):
126-
results = [self.solve_expr(arg) for arg in expr.args]
127-
return results[0] ** results[1]
128-
else:
129-
raise ValueError(
130-
f"Expression {expr} of type({type(expr)}) can't be solved yet."
131-
)
132-
133-
def forward(self, input_dict):
134-
self.output_dict = input_dict
135-
if callable(next(iter(self.expr_dict.values()))):
136-
model_output_dict = self.model(input_dict)
137-
self.output_dict.update(model_output_dict)
138-
139-
for name, expr in self.expr_dict.items():
140-
if isinstance(expr, sympy.Basic):
141-
self.output_dict[name] = self.solve_expr(expr)
142-
elif callable(expr):
143-
self.output_dict[name] = expr(self.output_dict)
39+
40+
def forward(self, expr_dict, input_dict, model):
41+
output_dict = {k: v for k, v in input_dict.items()}
42+
43+
# model forward
44+
if callable(next(iter(expr_dict.values()))):
45+
model_output_dict = model(input_dict)
46+
output_dict.update(model_output_dict)
47+
48+
# equation forward
49+
for name, expr in expr_dict.items():
50+
if callable(expr):
51+
output_dict[name] = expr(output_dict)
14452
else:
14553
raise TypeError(f"expr type({type(expr)}) is invalid")
14654

14755
# clear differentiation cache
14856
clear()
14957

150-
return {k: self.output_dict[k] for k in self.output_keys}
151-
152-
def add_target_expr(self, expr: Callable, expr_name: str):
153-
"""Add an expression `expr` named `expr_name` to
154-
155-
Args:
156-
expr (Callable): Callable function for computing an expression.
157-
expr_name (str): Name of expression.
158-
"""
159-
self.expr_dict[expr_name] = expr
160-
161-
def __str__(self):
162-
return f"input: {self.input_keys}, output: {self.output_keys}\n" + "\n".join(
163-
[f"{name} = {expr}" for name, expr in self.expr_dict.items()]
164-
)
58+
return output_dict

0 commit comments

Comments
 (0)