Skip to content

VIT 계열의 pretrained 사용 #8

@ysw2946

Description

@ysw2946

개요

  • MMsegmentation에서 다른 repository에서 release된 VIT계열의 pretrained model을 사용하기 위해서는 변환 해주어야 합니다.
  • mmsegmentation/tools/model_converters 를 통해 mmsegmentation 형태로 변환할 수 있습니다.

Example

  • mmsegmentation/tools/model_converters 에서 사용하려는 모델에 맞는 ‘model’2mmseg.py를 사용하시면 됩니다.

  • 아래와 같이 .pth 파일을 변환하시면 됩니다.

# python tools/model_converters/사용할모델2mmseg.py 'pth파일 경로' '저장 경로'
python tools/model_converters/beit2mmseg.py \ 
https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth \ 
pretrain/beit_base_patch16_224_pt22k_ft22kto1k.pth

BEIT

_base_ = [
    '../_base_/datasets/custom.py',
    '../_base_/default_runtime.py',
    '../_base_/schedules/schedule.py'
]

norm_cfg = dict(type='BN', requires_grad=True)

model = dict(
    type='EncoderDecoder',
    backbone=dict(
        type='BEiT',
        # .pth file
        pretrained = '/opt/ml/input/mmsegmentation/pretrained/beit_base_patch16_224_pt22k_ft22kto1k.pth',
        img_size=512,
        patch_size=16,
        in_channels=3,
        embed_dims=768,
        num_layers=12,
        num_heads=12,
        mlp_ratio=4,
        out_indices=(2, 5, 8, 11),
        qv_bias=True,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_cfg=dict(type='LN', eps=1e-06), # 이부분 BN으로하면 ValueError 발생
        act_cfg=dict(type='GELU'),
        patch_norm=False,
        final_norm=False,
        norm_eval=False,
        init_values=0.1,
        ),
    neck=dict(
        type='MultiLevelNeck',
        in_channels=[768, 768, 768, 768],
        out_channels=768,
        scales=[4, 2, 1, 0.5]),
    decode_head=dict(
        type='UPerHead',
        in_channels=[768, 768, 768, 768],
        in_index=[0, 1, 2, 3],
        pool_scales=(1, 2, 3, 6),
        channels=768,
        dropout_ratio=0.1,
        num_classes=11,
        norm_cfg=dict(type='BN', requires_grad=True),
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0,avg_non_ignore=True)),
    auxiliary_head=dict(
        type='FCNHead',
        in_channels=768,
        in_index=2,
        channels=256,
        num_convs=1,
        concat_input=False,
        dropout_ratio=0.1,
        num_classes=11,
        norm_cfg=dict(type='BN', requires_grad=True),
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4,avg_non_ignore=True)),
    train_cfg=dict(),
    # test_cfg = dict(mode='slide', crop_size=(512,512), stride=(341, 341)))
    test_cfg=dict(mode='whole'))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions