Skip to content

Problems encountered while loading the model #5

@pyygg

Description

@pyygg

When importing weights shared by authors, you may encounter a lot of layer initialization information (which can be ignored or used). However, as an OCD patient, it is recommended to use Roberta base first to initialize the model and then import the shared weights.like this:bert = AutoModel.from_pretrained("#The location of Alberta base")
bert.config.type_vocab_size = 4
bert.embeddings.token_type_embeddings = nn.Embedding(
bert.config.type_vocab_size, bert.config.hidden_size
)
bert._init_weights(bert.embeddings.token_type_embeddings)
model = AutoModelForSequenceClassification_SPV_MIP(Model=bert, config=bert.config, num_labels=2)
output_model_file = "#The location of Melbert/pytorch_model.bin"
state_dict = torch.load(output_model_file)
del state_dict["encoder.embeddings.position_ids"]
if hasattr(model, "module"):
model.module.load_state_dict(state_dict)
else:
model.load_state_dict(state_dict,strict=False)
model.to(device)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions