Skip to content

Commit c744121

Browse files
ChongWei905ChongWei905
andauthored
fix: fix load checkpoint failure for deeplabv3 (#790)
Co-authored-by: ChongWei905 <weichong4@huawei.com>
1 parent a6d45f3 commit c744121

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

examples/seg/deeplabv3/callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def on_train_end(self, run_context):
100100

101101
def get_segment_train_callback(args, steps_per_epoch, rank_id):
102102
callbacks = [TimeMonitor(data_size=steps_per_epoch), LossMonitor()]
103-
if rank_id == 0:
103+
if rank_id == 0 or rank_id is None:
104104
ckpt_config = CheckpointConfig(
105105
save_checkpoint_steps=args.save_steps,
106106
keep_checkpoint_max=args.keep_checkpoint_max,

examples/seg/deeplabv3/deeplabv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ class DeepLabInferNetwork(nn.Cell):
260260
"""
261261

262262
def __init__(self, network, input_format="NCHW"):
263-
super(DeepLabInferNetwork, self).__init__()
263+
super(DeepLabInferNetwork, self).__init__(auto_prefix=False)
264264
self.network = network
265265
self.softmax = nn.Softmax(axis=1)
266266
self.format = input_format

examples/seg/deeplabv3/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def train(args):
138138
callbacks = get_segment_train_callback(args, steps_per_epoch, rank_id)
139139

140140
# eval when train
141-
if args.eval_while_train and rank_id == 0:
141+
if args.eval_while_train and (rank_id == 0 or rank_id is None):
142142
eval_model = DeepLabInferNetwork(deeplab, input_format=args.input_format)
143143
eval_dataset = create_segment_dataset(
144144
name=args.dataset,

0 commit comments

Comments
 (0)