diff --git a/configs/distill/cwd/cwd_fpn_yolox_s_yolox_s_300e_coco_example.py b/configs/distill/cwd/cwd_fpn_yolox_s_yolox_s_300e_coco_example.py new file mode 100644 index 000000000..5480d4be8 --- /dev/null +++ b/configs/distill/cwd/cwd_fpn_yolox_s_yolox_s_300e_coco_example.py @@ -0,0 +1,234 @@ +_base_ = [ + '../../_base_/schedules/mmdet/schedule_1x.py', + '../../_base_/mmdet_runtime.py' +] + +img_scale = (640, 640) # height, width + +student = dict( + type='mmdet.YOLOX', + input_size=img_scale, + random_size_range=(15, 25), + random_size_interval=10, + backbone=dict(type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5), + neck=dict( + type='YOLOXPAFPN', + in_channels=[128, 256, 512], + out_channels=128, + num_csp_blocks=1), + bbox_head=dict( + type='YOLOXHead', num_classes=80, in_channels=128, feat_channels=128), + train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)), + # In order to align the source code, the threshold of the val phase is + # 0.01, and the threshold of the test phase is 0.001. + test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65))) + +teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth' # noqa: E501 +teacher = dict( + init_cfg=dict(type='Pretrained', checkpoint=teacher_ckpt), + type='mmdet.YOLOX', + input_size=img_scale, + random_size_range=(15, 25), + random_size_interval=10, + backbone=dict(type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5), + neck=dict( + type='YOLOXPAFPN', + in_channels=[128, 256, 512], + out_channels=128, + num_csp_blocks=1), + bbox_head=dict( + type='YOLOXHead', num_classes=80, in_channels=128, feat_channels=128), + train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)), + # In order to align the source code, the threshold of the val phase is + # 0.01, and the threshold of the test phase is 0.001. + test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65))) + +algorithm = dict( + type='AlignMethodDistill', + architecture=dict(type='MMDetArchitecture', model=student), + with_student_loss=True, + with_teacher_loss=False, + distiller=dict( + type='SingleTeacherDistiller', + teacher=teacher, + teacher_trainable=False, + align_methods=[ + dict(method='YOLOX._preprocess', import_module='mmdet.models') + ], + components=[ + dict( + student_module='neck.out_convs.0.conv', + teacher_module='neck.out_convs.0.conv', + losses=[ + dict( + type='ChannelWiseDivergence', + name='loss_cwd_logits', + tau=1, + loss_weight=5) + ]), + dict( + student_module='neck.out_convs.1.conv', + teacher_module='neck.out_convs.1.conv', + losses=[ + dict( + type='ChannelWiseDivergence', + name='loss_cwd_logits', + tau=1, + loss_weight=5) + ]), + dict( + student_module='neck.out_convs.2.conv', + teacher_module='neck.out_convs.2.conv', + losses=[ + dict( + type='ChannelWiseDivergence', + name='loss_cwd_logits', + tau=1, + loss_weight=5) + ]) + ])) + +# dataset settings +data_root = 'data/coco/' +dataset_type = 'CocoDataset' + +train_pipeline = [ + dict(type='Mosaic', img_scale=img_scale, pad_val=114.0), + dict( + type='RandomAffine', + scaling_ratio_range=(0.1, 2), + border=(-img_scale[0] // 2, -img_scale[1] // 2)), + dict( + type='MixUp', + img_scale=img_scale, + ratio_range=(0.8, 1.6), + pad_val=114.0), + dict(type='YOLOXHSVRandomAug'), + dict(type='RandomFlip', flip_ratio=0.5), + # According to the official implementation, multi-scale + # training is not considered here but in the + # 'mmdet/models/detectors/yolox.py'. + dict(type='Resize', img_scale=img_scale, keep_ratio=True), + dict( + type='Pad', + pad_to_square=True, + # If the image is three-channel, the pad value needs + # to be set separately for each channel. + pad_val=dict(img=(114.0, 114.0, 114.0))), + dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) +] + +train_dataset = dict( + type='MultiImageMixDataset', + dataset=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + pipeline=[ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True) + ], + filter_empty_gt=False, + ), + pipeline=train_pipeline) + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=img_scale, + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict( + type='Pad', + pad_to_square=True, + pad_val=dict(img=(114.0, 114.0, 114.0))), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img']) + ]) +] + +data = dict( + samples_per_gpu=8, + workers_per_gpu=4, + persistent_workers=True, + train=train_dataset, + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline)) + +# optimizer +# default 8 gpu +optimizer = dict( + type='SGD', + lr=0.01, + momentum=0.9, + weight_decay=5e-4, + nesterov=True, + paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.)) +optimizer_config = dict(grad_clip=None) + +max_epochs = 300 +num_last_epochs = 15 +resume_from = None +interval = 10 + +# learning policy +lr_config = dict( + _delete_=True, + policy='YOLOX', + warmup='exp', + by_epoch=False, + warmup_by_epoch=True, + warmup_ratio=1, + warmup_iters=5, # 5 epoch + num_last_epochs=num_last_epochs, + min_lr_ratio=0.05) + +runner = dict(type='EpochBasedRunner', max_epochs=max_epochs) + +custom_hooks = [ + dict( + type='YOLOXModeSwitchHook', + num_last_epochs=num_last_epochs, + priority=48), + dict( + type='SyncNormHook', + num_last_epochs=num_last_epochs, + interval=interval, + priority=48), + dict( + type='ExpMomentumEMAHook', + resume_from=resume_from, + momentum=0.0001, + priority=49) +] +checkpoint_config = dict(interval=interval) +evaluation = dict( + save_best='auto', + # The evaluation interval is 'interval' when running epoch is + # less than ‘max_epochs - num_last_epochs’. + # The evaluation interval is 1 when running epoch is greater than + # or equal to ‘max_epochs - num_last_epochs’. + interval=interval, + dynamic_intervals=[(max_epochs - num_last_epochs, 1)], + metric='bbox') +log_config = dict(interval=50) + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# USER SHOULD NOT CHANGE ITS VALUES. +# base_batch_size = (8 GPUs) x (8 samples per GPU) +auto_scale_lr = dict(base_batch_size=64) + +find_unused_parameters = True diff --git a/mmrazor/models/distillers/base.py b/mmrazor/models/distillers/base.py index c62294889..33c6842ac 100644 --- a/mmrazor/models/distillers/base.py +++ b/mmrazor/models/distillers/base.py @@ -50,11 +50,9 @@ def _set_method(self, method): def __enter__(self): """Rewrite the function.""" self.method_impl = eval(self.method_exec_str) - if self.method_impl: self._set_method( - function_wrapper(self.ctx, self.method_impl, self.method_str, - self.align_mode)) + function_wrapper(self.ctx, self.method_impl, self.method_str)) def __exit__(self, exc_type, exc_value, traceback): """Restore the function."""