File tree Expand file tree Collapse file tree 2 files changed +2
-8
lines changed Expand file tree Collapse file tree 2 files changed +2
-8
lines changed Original file line number Diff line number Diff line change @@ -472,10 +472,7 @@ def predict(
472
472
batch_input_dict [key ].stop_gradient = False
473
473
474
474
# forward
475
- if self .use_amp :
476
- with amp .auto_cast (level = self .amp_level ):
477
- batch_output_dict = self .model (batch_input_dict )
478
- else :
475
+ with self ._autocast_context_manager ():
479
476
batch_output_dict = self .model (batch_input_dict )
480
477
481
478
# collect batch data
Original file line number Diff line number Diff line change @@ -65,10 +65,7 @@ def visualize_func(solver, epoch_id: int):
65
65
evaluator .add_target_expr (output_expr , output_key )
66
66
67
67
# forward
68
- if solver .use_amp :
69
- with amp .auto_cast (level = solver .amp_level ):
70
- batch_output_dict = evaluator (batch_input_dict )
71
- else :
68
+ with solver ._autocast_context_manager ():
72
69
batch_output_dict = evaluator (batch_input_dict )
73
70
74
71
# collect batch data
You can’t perform that action at this time.
0 commit comments