15
15
from typing import Callable
16
16
from typing import Union
17
17
18
- import paddle
19
- import sympy
20
18
from paddle import nn
21
19
22
20
from ppsci .autodiff import clear
23
- from ppsci .autodiff import hessian
24
- from ppsci .autodiff import jacobian
25
21
26
22
27
23
class ExpressionSolver (nn .Layer ):
@@ -38,127 +34,25 @@ class ExpressionSolver(nn.Layer):
38
34
>>> expr_solver = ExpressionSolver(("x", "y"), ("u", "v"), model)
39
35
"""
40
36
41
- def __init__ (self , input_keys , output_keys , model ):
37
+ def __init__ (self ):
42
38
super ().__init__ ()
43
- self .input_keys = input_keys
44
- self .output_keys = output_keys
45
- self .model = model
46
- self .expr_dict = {}
47
- self .output_dict = {}
48
-
49
- def solve_expr (self , expr : sympy .Basic ) -> Union [float , paddle .Tensor ]:
50
- """Evaluates the value of the expression recursively in the expression tree
51
- by post-order traversal.
52
-
53
- Args:
54
- expr (sympy.Basic): Expression.
55
-
56
- Returns:
57
- Union[float, paddle.Tensor]: Value of current expression `expr`.
58
- """
59
- # already computed in output_dict(including input data)
60
- if getattr (expr , "name" , None ) in self .output_dict :
61
- return self .output_dict [expr .name ]
62
-
63
- # compute output from model
64
- if isinstance (expr , sympy .Symbol ):
65
- if expr .name in self .model .output_keys :
66
- out_dict = self .model (self .output_dict )
67
- self .output_dict .update (out_dict )
68
- return self .output_dict [expr .name ]
69
- else :
70
- raise ValueError (f"varname { expr .name } not exist!" )
71
-
72
- # compute output from model
73
- elif isinstance (expr , sympy .Function ):
74
- out_dict = self .model (self .output_dict )
75
- self .output_dict .update (out_dict )
76
- return self .output_dict [expr .name ]
77
-
78
- # compute derivative
79
- elif isinstance (expr , sympy .Derivative ):
80
- ys = self .solve_expr (expr .args [0 ])
81
- ys_name = expr .args [0 ].name
82
- if ys_name not in self .output_dict :
83
- self .output_dict [ys_name ] = ys
84
- xs = self .solve_expr (expr .args [1 ][0 ])
85
- xs_name = expr .args [1 ][0 ].name
86
- if xs_name not in self .output_dict :
87
- self .output_dict [xs_name ] = xs
88
- order = expr .args [1 ][1 ]
89
- if order == 1 :
90
- der = jacobian (self .output_dict [ys_name ], self .output_dict [xs_name ])
91
- der_name = f"{ ys_name } __{ xs_name } "
92
- elif order == 2 :
93
- der = hessian (self .output_dict [ys_name ], self .output_dict [xs_name ])
94
- der_name = f"{ ys_name } __{ xs_name } __{ xs_name } "
95
- else :
96
- raise NotImplementedError (
97
- f"Expression { expr } has derivative order({ order } ) >=3, "
98
- f"which is not implemented yet"
99
- )
100
- if der_name not in self .output_dict :
101
- self .output_dict [der_name ] = der
102
- return der
103
-
104
- # return single python number directly for leaf node
105
- elif isinstance (expr , sympy .Number ):
106
- return float (expr )
107
-
108
- # compute sub-nodes value and merge by addition
109
- elif isinstance (expr , sympy .Add ):
110
- results = [self .solve_expr (arg ) for arg in expr .args ]
111
- out = results [0 ]
112
- for i in range (1 , len (results )):
113
- out = out + results [i ]
114
- return out
115
-
116
- # compute sub-nodes value and merge by multiplication
117
- elif isinstance (expr , sympy .Mul ):
118
- results = [self .solve_expr (arg ) for arg in expr .args ]
119
- out = results [0 ]
120
- for i in range (1 , len (results )):
121
- out = out * results [i ]
122
- return out
123
-
124
- # compute sub-nodes value and merge by power
125
- elif isinstance (expr , sympy .Pow ):
126
- results = [self .solve_expr (arg ) for arg in expr .args ]
127
- return results [0 ] ** results [1 ]
128
- else :
129
- raise ValueError (
130
- f"Expression { expr } of type({ type (expr )} ) can't be solved yet."
131
- )
132
-
133
- def forward (self , input_dict ):
134
- self .output_dict = input_dict
135
- if callable (next (iter (self .expr_dict .values ()))):
136
- model_output_dict = self .model (input_dict )
137
- self .output_dict .update (model_output_dict )
138
-
139
- for name , expr in self .expr_dict .items ():
140
- if isinstance (expr , sympy .Basic ):
141
- self .output_dict [name ] = self .solve_expr (expr )
142
- elif callable (expr ):
143
- self .output_dict [name ] = expr (self .output_dict )
39
+
40
+ def forward (self , expr_dict , input_dict , model ):
41
+ output_dict = {k : v for k , v in input_dict .items ()}
42
+
43
+ # model forward
44
+ if callable (next (iter (expr_dict .values ()))):
45
+ model_output_dict = model (input_dict )
46
+ output_dict .update (model_output_dict )
47
+
48
+ # equation forward
49
+ for name , expr in expr_dict .items ():
50
+ if callable (expr ):
51
+ output_dict [name ] = expr (output_dict )
144
52
else :
145
53
raise TypeError (f"expr type({ type (expr )} ) is invalid" )
146
54
147
55
# clear differentiation cache
148
56
clear ()
149
57
150
- return {k : self .output_dict [k ] for k in self .output_keys }
151
-
152
- def add_target_expr (self , expr : Callable , expr_name : str ):
153
- """Add an expression `expr` named `expr_name` to
154
-
155
- Args:
156
- expr (Callable): Callable function for computing an expression.
157
- expr_name (str): Name of expression.
158
- """
159
- self .expr_dict [expr_name ] = expr
160
-
161
- def __str__ (self ):
162
- return f"input: { self .input_keys } , output: { self .output_keys } \n " + "\n " .join (
163
- [f"{ name } = { expr } " for name , expr in self .expr_dict .items ()]
164
- )
58
+ return output_dict
0 commit comments