-
Notifications
You must be signed in to change notification settings - Fork 5
Open
Description
개요
- MMsegmentation에서 다른 repository에서 release된 VIT계열의 pretrained model을 사용하기 위해서는 변환 해주어야 합니다.
- mmsegmentation/tools/model_converters 를 통해 mmsegmentation 형태로 변환할 수 있습니다.
Example
-
mmsegmentation/tools/model_converters 에서 사용하려는 모델에 맞는 ‘model’2mmseg.py를 사용하시면 됩니다.
-
아래와 같이 .pth 파일을 변환하시면 됩니다.
# python tools/model_converters/사용할모델2mmseg.py 'pth파일 경로' '저장 경로'
python tools/model_converters/beit2mmseg.py \
https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth \
pretrain/beit_base_patch16_224_pt22k_ft22kto1k.pth
BEIT
_base_ = [
'../_base_/datasets/custom.py',
'../_base_/default_runtime.py',
'../_base_/schedules/schedule.py'
]
norm_cfg = dict(type='BN', requires_grad=True)
model = dict(
type='EncoderDecoder',
backbone=dict(
type='BEiT',
# .pth file
pretrained = '/opt/ml/input/mmsegmentation/pretrained/beit_base_patch16_224_pt22k_ft22kto1k.pth',
img_size=512,
patch_size=16,
in_channels=3,
embed_dims=768,
num_layers=12,
num_heads=12,
mlp_ratio=4,
out_indices=(2, 5, 8, 11),
qv_bias=True,
attn_drop_rate=0.0,
drop_path_rate=0.0,
norm_cfg=dict(type='LN', eps=1e-06), # 이부분 BN으로하면 ValueError 발생
act_cfg=dict(type='GELU'),
patch_norm=False,
final_norm=False,
norm_eval=False,
init_values=0.1,
),
neck=dict(
type='MultiLevelNeck',
in_channels=[768, 768, 768, 768],
out_channels=768,
scales=[4, 2, 1, 0.5]),
decode_head=dict(
type='UPerHead',
in_channels=[768, 768, 768, 768],
in_index=[0, 1, 2, 3],
pool_scales=(1, 2, 3, 6),
channels=768,
dropout_ratio=0.1,
num_classes=11,
norm_cfg=dict(type='BN', requires_grad=True),
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0,avg_non_ignore=True)),
auxiliary_head=dict(
type='FCNHead',
in_channels=768,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=11,
norm_cfg=dict(type='BN', requires_grad=True),
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4,avg_non_ignore=True)),
train_cfg=dict(),
# test_cfg = dict(mode='slide', crop_size=(512,512), stride=(341, 341)))
test_cfg=dict(mode='whole'))
Metadata
Metadata
Assignees
Labels
No labels