File tree Expand file tree Collapse file tree 3 files changed +3
-3
lines changed Expand file tree Collapse file tree 3 files changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -100,7 +100,7 @@ def on_train_end(self, run_context):
100
100
101
101
def get_segment_train_callback (args , steps_per_epoch , rank_id ):
102
102
callbacks = [TimeMonitor (data_size = steps_per_epoch ), LossMonitor ()]
103
- if rank_id == 0 :
103
+ if rank_id == 0 or rank_id is None :
104
104
ckpt_config = CheckpointConfig (
105
105
save_checkpoint_steps = args .save_steps ,
106
106
keep_checkpoint_max = args .keep_checkpoint_max ,
Original file line number Diff line number Diff line change @@ -260,7 +260,7 @@ class DeepLabInferNetwork(nn.Cell):
260
260
"""
261
261
262
262
def __init__ (self , network , input_format = "NCHW" ):
263
- super (DeepLabInferNetwork , self ).__init__ ()
263
+ super (DeepLabInferNetwork , self ).__init__ (auto_prefix = False )
264
264
self .network = network
265
265
self .softmax = nn .Softmax (axis = 1 )
266
266
self .format = input_format
Original file line number Diff line number Diff line change @@ -138,7 +138,7 @@ def train(args):
138
138
callbacks = get_segment_train_callback (args , steps_per_epoch , rank_id )
139
139
140
140
# eval when train
141
- if args .eval_while_train and rank_id == 0 :
141
+ if args .eval_while_train and ( rank_id == 0 or rank_id is None ) :
142
142
eval_model = DeepLabInferNetwork (deeplab , input_format = args .input_format )
143
143
eval_dataset = create_segment_dataset (
144
144
name = args .dataset ,
You can’t perform that action at this time.
0 commit comments