-
Notifications
You must be signed in to change notification settings - Fork 8
Open
Description
I'm using the fine_tunning model in my code with this approach to load weights in torch2.6, in "init" for ViT class:
[...]
self.head.weight.data.mul_(init_scale) # type: ignore
self.head.bias.data.mul_(init_scale) # type: ignore
if pretrained_path is not None:
checkpoint = torch.load(
pretrained_path, weights_only=False, map_location="cpu",
)
checkpoint_model = None
model_key = "model|module"
for model_key in model_key.split("|"):
if model_key in checkpoint:
checkpoint_model = checkpoint[model_key]
print("Load state_dict by model_key = %s" % model_key)
break
if checkpoint_model is None: # type: ignore
checkpoint_model = checkpoint # type: ignore
state_dict = self.state_dict()
for k in ["head.weight", "head.bias"]:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: # type: ignore
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k] # type: ignore
all_keys = list(checkpoint_model.keys()) # type: ignore
new_dict = OrderedDict()
for key in all_keys:
if key.startswith("backbone."):
new_dict[key[9:]] = checkpoint_model[key] # type: ignore
elif key.startswith("encoder."):
new_dict[key[8:]] = checkpoint_model[key] # type: ignore
elif key == "encoder.norm.weight":
new_dict["fc_norm.weight"] = checkpoint_model[key] # type: ignore
elif key == "encoder.norm.bias":
new_dict["fc_norm.bias"] = checkpoint_model[key] # type: ignore
elif key == "decoder.head.weight":
new_dict["head.weight"] = checkpoint_model[key] # type: ignore
elif key == "decoder.head.bias":
new_dict["head.bias"] = checkpoint_model[key] # type: ignore
else:
new_dict[key] = checkpoint_model[key] # type: ignore
checkpoint_model = new_dict
# interpolate position embedding
if "pos_embed" in checkpoint_model:
pos_embed_checkpoint = checkpoint_model["pos_embed"]
embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
num_patches = self.patch_embed.num_patches #
num_extra_tokens = self.pos_embed.shape[-2] - num_patches # 0/1
# height (== width) for the checkpoint position embedding
orig_size = int(
(
(pos_embed_checkpoint.shape[-2] - num_extra_tokens)
// (all_frames // tubelet_size)
)
** 0.5
)
# height (== width) for the new position embedding
new_size = int((num_patches // (all_frames // tubelet_size)) ** 0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print(
"Position interpolate from %dx%d to %dx%d"
% (orig_size, orig_size, new_size, new_size)
)
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
# B, L, C -> BT, H, W, C -> BT, C, H, W
pos_tokens = pos_tokens.reshape(
-1,
all_frames // tubelet_size,
orig_size,
orig_size,
embedding_size,
)
pos_tokens = pos_tokens.reshape(
-1, orig_size, orig_size, embedding_size
).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate( # type: ignore
pos_tokens,
size=(new_size, new_size),
mode="bicubic",
align_corners=False,
)
# BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(
-1,
all_frames // tubelet_size,
new_size,
new_size,
embedding_size,
)
pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model["pos_embed"] = new_pos_embed
# self.load_state_dict(checkpoint['model'], strict=False)
load_state_dict(self, checkpoint_model, prefix="")
[...]
My output is like:
Load state_dict by model_key = model
Weights from pretrained model not used in VisionTransformer: ['mask_token', 'pos_embed_probs', 'get_token_probs.0.norm1.weight', 'get_token_probs.0.norm1.bias', 'get_token_probs.0.attn.qkv.weight', 'get_token_probs.0.attn.proj.weight', 'get_token_probs.0.attn.proj.bias', 'get_token_probs.0.norm2.weight', 'get_token_probs.0.norm2.bias', 'get_token_probs.0.mlp.fc1.weight', 'get_token_probs.0.mlp.fc1.bias', 'get_token_probs.0.mlp.fc2.weight', 'get_token_probs.0.mlp.fc2.bias', 'get_token_probs.1.weight', 'get_token_probs.1.bias', 'decoder.blocks.0.norm1.weight', 'decoder.blocks.0.norm1.bias', 'decoder.blocks.0.attn.q_bias', 'decoder.blocks.0.attn.v_bias', 'decoder.blocks.0.attn.qkv.weight', 'decoder.blocks.0.attn.proj.weight', 'decoder.blocks.0.attn.proj.bias', 'decoder.blocks.0.norm2.weight', 'decoder.blocks.0.norm2.bias', 'decoder.blocks.0.mlp.fc1.weight', 'decoder.blocks.0.mlp.fc1.bias', 'decoder.blocks.0.mlp.fc2.weight', 'decoder.blocks.0.mlp.fc2.bias', 'decoder.blocks.1.norm1.weight', 'decoder.blocks.1.norm1.bias', 'decoder.blocks.1.attn.q_bias', 'decoder.blocks.1.attn.v_bias', 'decoder.blocks.1.attn.qkv.weight', 'decoder.blocks.1.attn.proj.weight', 'decoder.blocks.1.attn.proj.bias', 'decoder.blocks.1.norm2.weight', 'decoder.blocks.1.norm2.bias', 'decoder.blocks.1.mlp.fc1.weight', 'decoder.blocks.1.mlp.fc1.bias', 'decoder.blocks.1.mlp.fc2.weight', 'decoder.blocks.1.mlp.fc2.bias', 'decoder.blocks.2.norm1.weight', 'decoder.blocks.2.norm1.bias', 'decoder.blocks.2.attn.q_bias', 'decoder.blocks.2.attn.v_bias', 'decoder.blocks.2.attn.qkv.weight', 'decoder.blocks.2.attn.proj.weight', 'decoder.blocks.2.attn.proj.bias', 'decoder.blocks.2.norm2.weight', 'decoder.blocks.2.norm2.bias', 'decoder.blocks.2.mlp.fc1.weight', 'decoder.blocks.2.mlp.fc1.bias', 'decoder.blocks.2.mlp.fc2.weight', 'decoder.blocks.2.mlp.fc2.bias', 'decoder.blocks.3.norm1.weight', 'decoder.blocks.3.norm1.bias', 'decoder.blocks.3.attn.q_bias', 'decoder.blocks.3.attn.v_bias', 'decoder.blocks.3.attn.qkv.weight', 'decoder.blocks.3.attn.proj.weight', 'decoder.blocks.3.attn.proj.bias', 'decoder.blocks.3.norm2.weight', 'decoder.blocks.3.norm2.bias', 'decoder.blocks.3.mlp.fc1.weight', 'decoder.blocks.3.mlp.fc1.bias', 'decoder.blocks.3.mlp.fc2.weight', 'decoder.blocks.3.mlp.fc2.bias', 'encoder_to_decoder.weight', 'norm.weight', 'norm.bias']
size mismatch for fc_norm.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
size mismatch for fc_norm.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
size mismatch for head.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([400, 768]).
size mismatch for head.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([400]).
So, what i'm doing wrong? because i'm using the release weights model aka "adamae_k400.pth" and use the almost same logic in fine_tunning_main, but the fc_norm and head weight and bias are incorrectly...
Metadata
Metadata
Assignees
Labels
No labels