Skip to content

Commit eb22b06

Browse files
fix logger
1 parent 7aacf09 commit eb22b06

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

ppsci/solver/solver.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,10 @@ def __init__(
219219
if isinstance(loaded_metric, dict):
220220
self.best_metric.update(loaded_metric)
221221

222+
# init logger without FileHandler if not initialized before
223+
if logger._logger is None:
224+
logger.init_logger("ppsci", None)
225+
222226
# choosing an appropriate training function for different optimizers
223227
if isinstance(self.optimizer, optim.LBFGS):
224228
self.train_epoch_func = ppsci.solver.train.train_LBFGS_epoch_func
@@ -252,8 +256,7 @@ def __init__(
252256
if version.Version(paddle.__version__) != version.Version("0.0.0")
253257
else f"develop({paddle.version.commit[:7]})"
254258
)
255-
if logger._logger is not None:
256-
logger.info(f"Using paddlepaddle {paddle_version} on device {self.device}")
259+
logger.info(f"Using paddlepaddle {paddle_version} on device {self.device}")
257260

258261
self.forward_helper = expression.ExpressionSolver()
259262

ppsci/utils/expression.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,11 @@ def train_forward(
8585
def eval_forward(
8686
self,
8787
expr_dict: Dict[str, Callable],
88-
input_dict: Dict[str, Callable],
88+
input_dict: Dict[str, paddle.Tensor],
8989
model: nn.Layer,
9090
validator: "validate.Validator",
91-
label_dict: Dict[str, Callable],
92-
weight_dict: Dict[str, Callable],
91+
label_dict: Dict[str, paddle.Tensor],
92+
weight_dict: Dict[str, paddle.Tensor],
9393
):
9494
# model forward
9595
if callable(next(iter(expr_dict.values()))):
@@ -118,7 +118,7 @@ def eval_forward(
118118
def visu_forward(
119119
self,
120120
expr_dict: Dict[str, Callable],
121-
input_dict: Dict[str, Callable],
121+
input_dict: Dict[str, paddle.Tensor],
122122
model: nn.Layer,
123123
):
124124
# model forward

0 commit comments

Comments
 (0)