@@ -1095,3 +1095,108 @@ def _prefix_tuning_forward(
1095
1095
1096
1096
output = (logits ,) + outputs [2 :]
1097
1097
return ((loss ,) + output ) if loss is not None else output
1098
+
1099
+
1100
+ class PeftModelForFeatureExtraction (PeftModel ):
1101
+ """
1102
+ Peft model for extracting features/embeddings from transformer models
1103
+
1104
+ Args:
1105
+ model ([`~transformers.PreTrainedModel`]): Base transformer model.
1106
+ peft_config ([`PeftConfig`]): Peft config.
1107
+ adapter_name (`str`, *optional*): The name of the adapter, defaults to `"default"`.
1108
+ autocast_adapter_dtype (`bool`, *optional*):
1109
+ Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights
1110
+ using float16 and bfloat16 to float32, as this is typically required for stable training, and only affect
1111
+ select PEFT tuners.
1112
+
1113
+ **Attributes**:
1114
+ - **config** ([`~transformers.PretrainedConfig`]) -- The configuration object of the base model.
1115
+
1116
+ Example:
1117
+
1118
+ ```py
1119
+ >>> from transformers import AutoModel
1120
+ >>> from peft import PeftModelForFeatureExtraction, get_peft_config
1121
+
1122
+ >>> config = {
1123
+ ... "peft_type": "LORA",
1124
+ ... "task_type": "FEATURE_EXTRACTION",
1125
+ ... "inference_mode": False,
1126
+ ... "r": 16,
1127
+ ... "target_modules": ["query", "value"],
1128
+ ... "lora_alpha": 32,
1129
+ ... "lora_dropout": 0.05,
1130
+ ... "fan_in_fan_out": False,
1131
+ ... "bias": "none",
1132
+ ... }
1133
+ >>> peft_config = get_peft_config(config)
1134
+ >>> model = AutoModel.from_pretrained("bert-base-cased")
1135
+ >>> peft_model = PeftModelForFeatureExtraction(model, peft_config)
1136
+ >>> peft_model.print_trainable_parameters()
1137
+ ```
1138
+ """
1139
+
1140
+ def __init__ (self , model : nn .Module , peft_config : PeftConfig , adapter_name : str = "default" , ** kwargs ):
1141
+ super ().__init__ (model , peft_config , adapter_name , ** kwargs )
1142
+
1143
+ def forward (
1144
+ self ,
1145
+ input_ids = None ,
1146
+ attention_mask = None ,
1147
+ inputs_embeds = None ,
1148
+ output_attentions = None ,
1149
+ output_hidden_states = None ,
1150
+ return_dict = None ,
1151
+ task_ids = None ,
1152
+ ** kwargs ,
1153
+ ):
1154
+ peft_config = self .active_peft_config
1155
+ if not peft_config .is_prompt_learning :
1156
+ if peft_config .peft_type == PeftType .POLY :
1157
+ kwargs ["task_ids" ] = task_ids
1158
+
1159
+ with self ._enable_peft_forward_hooks (** kwargs ):
1160
+ kwargs = {k : v for k , v in kwargs .items () if k not in self .special_peft_forward_args }
1161
+ return self .base_model (
1162
+ input_ids = input_ids ,
1163
+ attention_mask = attention_mask ,
1164
+ inputs_embeds = inputs_embeds ,
1165
+ output_attentions = output_attentions ,
1166
+ output_hidden_states = output_hidden_states ,
1167
+ return_dict = return_dict ,
1168
+ ** kwargs ,
1169
+ )
1170
+
1171
+ batch_size = _get_batch_size (input_ids , inputs_embeds )
1172
+ if attention_mask is not None :
1173
+ # concat prompt attention mask
1174
+ prefix_attention_mask = ops .ones (batch_size , peft_config .num_virtual_tokens )
1175
+ attention_mask = ops .cat ((prefix_attention_mask , attention_mask ), dim = 1 )
1176
+
1177
+ if kwargs .get ("position_ids" , None ) is not None :
1178
+ warnings .warn ("Position ids are not supported for parameter efficient tuning. Ignoring position ids." )
1179
+ kwargs ["position_ids" ] = None
1180
+ if kwargs .get ("token_type_ids" , None ) is not None :
1181
+ warnings .warn ("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids" )
1182
+ kwargs ["token_type_ids" ] = None
1183
+ kwargs .update (
1184
+ {
1185
+ "attention_mask" : attention_mask ,
1186
+ "output_attentions" : output_attentions ,
1187
+ "output_hidden_states" : output_hidden_states ,
1188
+ "return_dict" : return_dict ,
1189
+ }
1190
+ )
1191
+
1192
+ if peft_config .peft_type == PeftType .PREFIX_TUNING :
1193
+ # overwrite past_kv in kwargs
1194
+ kwargs ["past_key_values" ] = self .get_prompt (batch_size )
1195
+ return self .base_model (input_ids = input_ids , ** kwargs )
1196
+ else :
1197
+ if inputs_embeds is None :
1198
+ inputs_embeds = self .word_embeddings (input_ids )
1199
+ prompts = self .get_prompt (batch_size = batch_size )
1200
+ prompts = prompts .to (inputs_embeds .dtype )
1201
+ inputs_embeds = ops .cat ((prompts , inputs_embeds ), dim = 1 )
1202
+ return self .base_model (inputs_embeds = inputs_embeds , ** kwargs )
0 commit comments