Skip to content

How to load weights correctly? #6

@davimedio01

Description

@davimedio01

I'm using the fine_tunning model in my code with this approach to load weights in torch2.6, in "init" for ViT class:

[...]
        self.head.weight.data.mul_(init_scale)  # type: ignore
        self.head.bias.data.mul_(init_scale)  # type: ignore

        if pretrained_path is not None:
            checkpoint = torch.load(
                pretrained_path, weights_only=False, map_location="cpu",
            )
            checkpoint_model = None
            model_key = "model|module"
            for model_key in model_key.split("|"):
                if model_key in checkpoint:
                    checkpoint_model = checkpoint[model_key]
                    print("Load state_dict by model_key = %s" % model_key)
                    break
            if checkpoint_model is None:  # type: ignore
                checkpoint_model = checkpoint  # type: ignore
            state_dict = self.state_dict()
            for k in ["head.weight", "head.bias"]:
                if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:  # type: ignore
                    print(f"Removing key {k} from pretrained checkpoint")
                    del checkpoint_model[k]  # type: ignore

            all_keys = list(checkpoint_model.keys())  # type: ignore
            new_dict = OrderedDict()
            for key in all_keys:
                if key.startswith("backbone."):
                    new_dict[key[9:]] = checkpoint_model[key]  # type: ignore
                elif key.startswith("encoder."):
                    new_dict[key[8:]] = checkpoint_model[key]  # type: ignore
                elif key == "encoder.norm.weight":
                    new_dict["fc_norm.weight"] = checkpoint_model[key]  # type: ignore
                elif key == "encoder.norm.bias":
                    new_dict["fc_norm.bias"] = checkpoint_model[key]  # type: ignore
                elif key == "decoder.head.weight":
                    new_dict["head.weight"] = checkpoint_model[key]  # type: ignore
                elif key == "decoder.head.bias":
                    new_dict["head.bias"] = checkpoint_model[key]  # type: ignore
                else:
                    new_dict[key] = checkpoint_model[key]  # type: ignore
            checkpoint_model = new_dict

            # interpolate position embedding
            if "pos_embed" in checkpoint_model:
                pos_embed_checkpoint = checkpoint_model["pos_embed"]
                embedding_size = pos_embed_checkpoint.shape[-1]  # channel dim
                num_patches = self.patch_embed.num_patches  #
                num_extra_tokens = self.pos_embed.shape[-2] - num_patches  # 0/1

                # height (== width) for the checkpoint position embedding
                orig_size = int(
                    (
                        (pos_embed_checkpoint.shape[-2] - num_extra_tokens)
                        // (all_frames // tubelet_size)
                    )
                    ** 0.5
                )
                # height (== width) for the new position embedding
                new_size = int((num_patches // (all_frames // tubelet_size)) ** 0.5)
                # class_token and dist_token are kept unchanged
                if orig_size != new_size:
                    print(
                        "Position interpolate from %dx%d to %dx%d"
                        % (orig_size, orig_size, new_size, new_size)
                    )
                    extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
                    # only the position tokens are interpolated
                    pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
                    # B, L, C -> BT, H, W, C -> BT, C, H, W
                    pos_tokens = pos_tokens.reshape(
                        -1,
                        all_frames // tubelet_size,
                        orig_size,
                        orig_size,
                        embedding_size,
                    )
                    pos_tokens = pos_tokens.reshape(
                        -1, orig_size, orig_size, embedding_size
                    ).permute(0, 3, 1, 2)
                    pos_tokens = torch.nn.functional.interpolate(  # type: ignore
                        pos_tokens,
                        size=(new_size, new_size),
                        mode="bicubic",
                        align_corners=False,
                    )
                    # BT, C, H, W -> BT, H, W, C ->  B, T, H, W, C
                    pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(
                        -1,
                        all_frames // tubelet_size,
                        new_size,
                        new_size,
                        embedding_size,
                    )
                    pos_tokens = pos_tokens.flatten(1, 3)  # B, L, C
                    new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
                    checkpoint_model["pos_embed"] = new_pos_embed

            # self.load_state_dict(checkpoint['model'], strict=False)
            load_state_dict(self, checkpoint_model, prefix="")
[...]

My output is like:

Load state_dict by model_key = model
Weights from pretrained model not used in VisionTransformer: ['mask_token', 'pos_embed_probs', 'get_token_probs.0.norm1.weight', 'get_token_probs.0.norm1.bias', 'get_token_probs.0.attn.qkv.weight', 'get_token_probs.0.attn.proj.weight', 'get_token_probs.0.attn.proj.bias', 'get_token_probs.0.norm2.weight', 'get_token_probs.0.norm2.bias', 'get_token_probs.0.mlp.fc1.weight', 'get_token_probs.0.mlp.fc1.bias', 'get_token_probs.0.mlp.fc2.weight', 'get_token_probs.0.mlp.fc2.bias', 'get_token_probs.1.weight', 'get_token_probs.1.bias', 'decoder.blocks.0.norm1.weight', 'decoder.blocks.0.norm1.bias', 'decoder.blocks.0.attn.q_bias', 'decoder.blocks.0.attn.v_bias', 'decoder.blocks.0.attn.qkv.weight', 'decoder.blocks.0.attn.proj.weight', 'decoder.blocks.0.attn.proj.bias', 'decoder.blocks.0.norm2.weight', 'decoder.blocks.0.norm2.bias', 'decoder.blocks.0.mlp.fc1.weight', 'decoder.blocks.0.mlp.fc1.bias', 'decoder.blocks.0.mlp.fc2.weight', 'decoder.blocks.0.mlp.fc2.bias', 'decoder.blocks.1.norm1.weight', 'decoder.blocks.1.norm1.bias', 'decoder.blocks.1.attn.q_bias', 'decoder.blocks.1.attn.v_bias', 'decoder.blocks.1.attn.qkv.weight', 'decoder.blocks.1.attn.proj.weight', 'decoder.blocks.1.attn.proj.bias', 'decoder.blocks.1.norm2.weight', 'decoder.blocks.1.norm2.bias', 'decoder.blocks.1.mlp.fc1.weight', 'decoder.blocks.1.mlp.fc1.bias', 'decoder.blocks.1.mlp.fc2.weight', 'decoder.blocks.1.mlp.fc2.bias', 'decoder.blocks.2.norm1.weight', 'decoder.blocks.2.norm1.bias', 'decoder.blocks.2.attn.q_bias', 'decoder.blocks.2.attn.v_bias', 'decoder.blocks.2.attn.qkv.weight', 'decoder.blocks.2.attn.proj.weight', 'decoder.blocks.2.attn.proj.bias', 'decoder.blocks.2.norm2.weight', 'decoder.blocks.2.norm2.bias', 'decoder.blocks.2.mlp.fc1.weight', 'decoder.blocks.2.mlp.fc1.bias', 'decoder.blocks.2.mlp.fc2.weight', 'decoder.blocks.2.mlp.fc2.bias', 'decoder.blocks.3.norm1.weight', 'decoder.blocks.3.norm1.bias', 'decoder.blocks.3.attn.q_bias', 'decoder.blocks.3.attn.v_bias', 'decoder.blocks.3.attn.qkv.weight', 'decoder.blocks.3.attn.proj.weight', 'decoder.blocks.3.attn.proj.bias', 'decoder.blocks.3.norm2.weight', 'decoder.blocks.3.norm2.bias', 'decoder.blocks.3.mlp.fc1.weight', 'decoder.blocks.3.mlp.fc1.bias', 'decoder.blocks.3.mlp.fc2.weight', 'decoder.blocks.3.mlp.fc2.bias', 'encoder_to_decoder.weight', 'norm.weight', 'norm.bias']

size mismatch for fc_norm.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
size mismatch for fc_norm.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
size mismatch for head.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([400, 768]).
size mismatch for head.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([400]).

So, what i'm doing wrong? because i'm using the release weights model aka "adamae_k400.pth" and use the almost same logic in fine_tunning_main, but the fc_norm and head weight and bias are incorrectly...

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions