Skip to content

Commit 162922d

Browse files
update to_static code for darcy
1 parent 635e300 commit 162922d

File tree

3 files changed

+8
-3
lines changed

3 files changed

+8
-3
lines changed

examples/darcy/darcy2d.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@
1313
# limitations under the License.
1414

1515
import numpy as np
16+
from paddle import fluid
1617

1718
import ppsci
1819
from ppsci.utils import config
1920
from ppsci.utils import logger
2021

2122
if __name__ == "__main__":
23+
fluid.core._set_prim_all_enabled(True)
24+
2225
args = config.parse_args()
2326
# set random seed for reproducibility
2427
ppsci.utils.misc.set_random_seed(42)

ppsci/arch/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ def split_to_dict(
7878
Returns:
7979
Dict[str, paddle.Tensor]: Dict contains tensor.
8080
"""
81+
# TODO: num_or_sections must > 1 in static, but 1 is allowed in dygraph.
82+
if len(keys) == 1:
83+
return {key: data_tensor for i, key in enumerate(keys)}
8184
data = paddle.split(data_tensor, len(keys), axis=axis)
8285
return {key: data[i] for i, key in enumerate(keys)}
8386

ppsci/utils/expression.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Callable
16-
from typing import Union
17-
15+
from paddle import jit
1816
from paddle import nn
1917

2018
from ppsci.autodiff import clear
@@ -37,6 +35,7 @@ class ExpressionSolver(nn.Layer):
3735
def __init__(self):
3836
super().__init__()
3937

38+
@jit.to_static
4039
def forward(self, expr_dict, input_dict, model):
4140
output_dict = {k: v for k, v in input_dict.items()}
4241

0 commit comments

Comments
 (0)