-
Notifications
You must be signed in to change notification settings - Fork 14
Description
When importing weights shared by authors, you may encounter a lot of layer initialization information (which can be ignored or used). However, as an OCD patient, it is recommended to use Roberta base first to initialize the model and then import the shared weights.like this:bert = AutoModel.from_pretrained("#The location of Alberta base")
bert.config.type_vocab_size = 4
bert.embeddings.token_type_embeddings = nn.Embedding(
bert.config.type_vocab_size, bert.config.hidden_size
)
bert._init_weights(bert.embeddings.token_type_embeddings)
model = AutoModelForSequenceClassification_SPV_MIP(Model=bert, config=bert.config, num_labels=2)
output_model_file = "#The location of Melbert/pytorch_model.bin"
state_dict = torch.load(output_model_file)
del state_dict["encoder.embeddings.position_ids"]
if hasattr(model, "module"):
model.module.load_state_dict(state_dict)
else:
model.load_state_dict(state_dict,strict=False)
model.to(device)