Skip to content

When train the model with Lora, the ref_model in npo is None. #7

@sev777

Description

@sev777

forget_outputs_oracle = self.ref_model(inputs['input_ids'], labels=inputs['labels'], attention_mask=inputs['attention_mask'])

forget_outputs_oracle = self.ref_model(inputs['input_ids'], labels=inputs['labels'], attention_mask=inputs['attention_mask'])
How to init the ref_model in NPO with lora?

`
def create_ref_model(
model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: bool = False
) -> Optional[Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]]:
r"""
Creates reference model for PPO/DPO training. Evaluation mode is not supported.

The valuehead parameter is randomly initialized since it is useless for PPO training.
"""
if finetuning_args.ref_model is not None:
    ref_model_args_dict = model_args.to_dict()
    ref_model_args_dict.update(
        dict(
            model_name_or_path=finetuning_args.ref_model,
            adapter_name_or_path=finetuning_args.ref_model_adapters,
            quantization_bit=finetuning_args.ref_model_quantization_bit,
        )
    )
    ref_model_args = ModelArguments(**ref_model_args_dict)
    ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
    tokenizer = load_tokenizer(ref_model_args)["tokenizer"]
    ref_model = load_model(
        tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
    )
    logger.info("Created reference model from {}".format(finetuning_args.ref_model))
else:
    if finetuning_args.finetuning_type == "lora":
        ref_model = None
    else:
        tokenizer = load_tokenizer(model_args)["tokenizer"]
        ref_model = load_model(
            tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead
        )
        logger.info("Created reference model from the model itself.")

return ref_model

`

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions