Skip to content

Commit 16f4797

Browse files
add _autocast_context_manager to visu.py and solver.predict
1 parent 4dcb801 commit 16f4797

File tree

2 files changed

+2
-8
lines changed

2 files changed

+2
-8
lines changed

ppsci/solver/solver.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -472,10 +472,7 @@ def predict(
472472
batch_input_dict[key].stop_gradient = False
473473

474474
# forward
475-
if self.use_amp:
476-
with amp.auto_cast(level=self.amp_level):
477-
batch_output_dict = self.model(batch_input_dict)
478-
else:
475+
with self._autocast_context_manager():
479476
batch_output_dict = self.model(batch_input_dict)
480477

481478
# collect batch data

ppsci/solver/visu.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,7 @@ def visualize_func(solver, epoch_id: int):
6565
evaluator.add_target_expr(output_expr, output_key)
6666

6767
# forward
68-
if solver.use_amp:
69-
with amp.auto_cast(level=solver.amp_level):
70-
batch_output_dict = evaluator(batch_input_dict)
71-
else:
68+
with solver._autocast_context_manager():
7269
batch_output_dict = evaluator(batch_input_dict)
7370

7471
# collect batch data

0 commit comments

Comments
 (0)