Skip to content

Init weights #46

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/mineral-extract-sites-detection/config_trne.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ train_model.py:
detectron2_config_file: ../../detectron2_config_dqry.yaml # path relative to the working_folder
model_weights:
model_zoo_checkpoint_url: COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml

init_model_weights: False

# Object detection with the optimised trained model
make_detections.py:
working_directory: ./output/output_trne
Expand Down
3 changes: 2 additions & 1 deletion examples/road-surface-classification/config_rs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ train_model.py:
detectron2_config_file: ../detectron2_config_3bands.yaml # path relative to the working_folder
model_weights:
model_zoo_checkpoint_url: COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml

init_model_weights: False

make_detections.py:
working_directory: outputs_RS
log_subfolder: logs
Expand Down
3 changes: 2 additions & 1 deletion examples/swimming-pool-detection/GE/config_GE.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ train_model.py:
detectron2_config_file: '../detectron2_config_GE.yaml' # path relative to the working_folder
model_weights:
model_zoo_checkpoint_url: "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml"

init_model_weights: False

make_detections.py:
working_directory: output_GE
log_subfolder: logs
Expand Down
5 changes: 3 additions & 2 deletions examples/swimming-pool-detection/NE/config_NE.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ train_model.py:
tst: COCO_tst.json
detectron2_config_file: '../detectron2_config_NE.yaml' # path relative to the working_folder
model_weights:
model_zoo_checkpoint_url: "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml"

model_zoo_checkpoint_url: "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml"
init_model_weights: False

make_detections.py:
working_directory: output_NE
log_subfolder: logs
Expand Down
22 changes: 19 additions & 3 deletions scripts/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,12 @@ def main(cfg_file_path):
# ---- parse config file

DEBUG = cfg['debug_mode'] if 'debug_mode' in cfg.keys() else False

if 'model_zoo_checkpoint_url' in cfg['model_weights'].keys():
MODEL_ZOO_CHECKPOINT_URL = cfg['model_weights']['model_zoo_checkpoint_url']
else:
MODEL_ZOO_CHECKPOINT_URL = None
INIT_MODEL_WEIGHTS = cfg['model_weights']['init_model_weights']

# TODO: allow resuming from previous training
# if 'pth_file' in cfg['model_weights'].keys():
Expand Down Expand Up @@ -114,11 +115,15 @@ def main(cfg_file_path):
# cf. https://detectron2.readthedocs.io/modules/config.html#config-references
cfg = get_cfg()
cfg.merge_from_file(DETECTRON2_CFG_FILE)
## Get the config file of parameters to a execute a given task with a deep learning framework. Can be used to generate the default parameter value
# cfg.merge_from_file(model_zoo.get_config_file(MODEL_ZOO_CHECKPOINT_URL))
# print(cfg)
# sys.exit(1)
Comment on lines +118 to +121
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These lines should be deleted or implemented.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this one can be deleted.

cfg.OUTPUT_DIR = LOG_SUBDIR

num_classes = get_number_of_classes(COCO_FILES_DICT)

cfg.MODEL.ROI_HEADS.NUM_CLASSES=num_classes
cfg.MODEL.ROI_HEADS.NUM_CLASSES = num_classes

if DEBUG:
logger.warning('Setting a configuration for DEBUG only.')
Expand All @@ -127,8 +132,19 @@ def main(cfg_file_path):
cfg.SOLVER.MAX_ITER = 500

# ---- do training
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(MODEL_ZOO_CHECKPOINT_URL)
if INIT_MODEL_WEIGHTS:
# A common error that might occur is that the lr is too high: https://trello.com/c/BY4HtY9h#comment-673616d755b046983033e24f
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it relevant to keep a trello ref here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added this comment for you to know, but if you want to publish it, it should be deleted.

cfg.MODEL.WEIGHTS = ""
logger.info("The weights of the pre-trained model are reinitialize. Training from scratch.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"reinitialized" with d

else:
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(MODEL_ZOO_CHECKPOINT_URL)
logger.info("The weights of the pre-trained model are used. Fine-tune the model.")
trainer = CocoTrainer(cfg)
## Visualize model parameters
# print(len(list(trainer.model.parameters())))
# for p in trainer.model.parameters():
# print(p)
# sys.exit(0)
Comment on lines +143 to +147
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These lines should be deleted or implemented.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be this one can be implemented as an option.

trainer.resume_or_load(resume=False)
trainer.train()
TRAINED_MODEL_PTH_FILE = os.path.join(LOG_SUBDIR, 'model_final.pth')
Expand Down
Loading