Skip to content

How is PLoRA used? #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
bang123-box opened this issue May 13, 2025 · 7 comments
Open

How is PLoRA used? #12

bang123-box opened this issue May 13, 2025 · 7 comments

Comments

@bang123-box
Copy link

Image
作者您好您的工作非常吸引我,但是在阅读了您的源码之后我有一些事情想向您请教一下,Coobiw/InternLM-XComposer2_Enhanced的L169行中的Plora是如何使用的呢,我看您的train.py中已经有了lora的实现,感觉这里的PLora好像没有什么用。非常期待您的回复

@Coobiw
Copy link
Collaborator

Coobiw commented May 13, 2025

PLoRA是internlm-xc2这个baseline里实现好、已经训练好了的部分

@bang123-box
Copy link
Author

好的,确实是这样的IX_2的源码中也有这个PLoRA,感谢作者您的回复

@Coobiw
Copy link
Collaborator

Coobiw commented May 13, 2025

嗯嗯 这部分是xc2里转门给visual tokens用的lora,和后面微调llm用的lora不太一样

@bang123-box
Copy link
Author

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): InternLMXComposer2ForCausalLM(
      (model): InternLM2Model(
        (tok_embeddings): Embedding(92544, 4096, padding_idx=2)
        (layers): ModuleList(
          (0-31): 32 x InternLM2DecoderLayer(
            (attention): InternLM2FlashAttention2(
              (wqkv): Linear(
                in_features=4096, out_features=6144, bias=False
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=6144, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (wo): Linear(
                in_features=4096, out_features=4096, bias=False
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (rotary_emb): InternLM2DynamicNTKScalingRotaryEmbedding()
            )
            (feed_forward): InternLM2MLP(
              (w1): Linear(
                in_features=4096, out_features=14336, bias=False
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=14336, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (w3): Linear(
                in_features=4096, out_features=14336, bias=False
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=14336, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (w2): Linear(
                in_features=14336, out_features=4096, bias=False
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=14336, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (act_fn): SiLUActivation()
            )
            (attention_norm): InternLM2RMSNorm()
            (ffn_norm): InternLM2RMSNorm()
          )
        )
        (norm): InternLM2RMSNorm()
      )
      (output): Linear(in_features=4096, out_features=92544, bias=False)
      (vit): CLIPVisionTower(
        (vision_tower): CLIPVisionModel(
          (vision_model): CLIPVisionTransformer(
            (embeddings): CLIPVisionEmbeddings(
              (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
              (position_embedding): Embedding(1226, 1024)
            )
            (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (encoder): CLIPEncoder(
              (layers): ModuleList(
                (0-23): 24 x CLIPEncoderLayer(
                  (self_attn): CLIPAttention(
                    (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
                    (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
                    (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
                    (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
                  )
                  (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                  (mlp): CLIPMLP(
                    (activation_fn): QuickGELUActivation()
                    (fc1): Linear(in_features=1024, out_features=4096, bias=True)
                    (fc2): Linear(in_features=4096, out_features=1024, bias=True)
                  )
                  (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                )
              )
            )
            (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          )
        )
      )
      (vision_proj): Sequential(
        (0): Linear(in_features=1024, out_features=4096, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=4096, out_features=4096, bias=True)
      )
    )
  )
)

还有一点我想问一下您,我发现SFT阶段对XC2使用lora之后PLora类里面的self.Plora_A和self.Plora_B不存在了,这样的话那这个类的forward函数里面的im_mask应该就失效了,也就是说SFT阶段不存在对visual tokens用lora了,但是再您开源的这个https://huggingface.co/IDEA-FinAI/chartmoe模型权重里面,为什么还能看到类似于
model.layers.0.attention.wqkv.Plora_A.weight: pytorch_model-00001-of-00002.bin,model.layers.0.attention.wqkv.Plora_B.weight": "pytorch_model-00001-of-00002.bin这样的权重。按照SFT阶段的lora微调方式,貌似都不会使用Plora_A和Plora_B,但是最后在generate的时候确传入了im_mask,让其走PLora类里面的的这一步:

if torch.sum(im_mask) > 0:
    part_x = x[im_mask]
    res[im_mask] += self.Plora_B(self.Plora_A(self.lora_dropout(part_x))) * self.lora_scaling

@Coobiw
Copy link
Collaborator

Coobiw commented May 15, 2025

sft阶段的lora和PLoRA并没有什么关联,PLoRA是原始base model的一部分,全阶段保留即可

@Coobiw
Copy link
Collaborator

Coobiw commented May 15, 2025

至于你说的lora的时候PLoRA“不见了”,我想你可以具体打印出来每个InternLM2MLP,他的w1、w2、w3应该都有PLoRA相应的AB矩阵,我猜测没显示出来是因为PLoRA继承自nn.Linear(而peft中的lora只会加在Linear上),所以可能会这么显示(感觉可以观察验证一下)

可以参考这两段代码:

1️⃣https://huggingface.co/Coobiw/InternLM-XComposer2_Enhanced/blob/main/modeling_internlm2.py#L278

class InternLM2MLP(nn.Module):

def __init__(self, config):
  | super().__init__()
  | self.config = config
  | self.hidden_size = config.hidden_size
  | self.intermediate_size = config.intermediate_size
  |  
  | self.w1 = PLoRA(
  | self.hidden_size,
  | self.intermediate_size,
  | bias=False,
  | lora_r=256,
  | lora_alpha=256,
  | lora_len=1225)
  | self.w3 = PLoRA(
  | self.hidden_size,
  | self.intermediate_size,
  | bias=False,
  | lora_r=256,
  | lora_alpha=256,
  | lora_len=1225)
  | self.w2 = PLoRA(
  | self.intermediate_size,
  | self.hidden_size,
  | bias=False,
  | lora_r=256,
  | lora_alpha=256,
  | lora_len=1225)
  |  
  | self.act_fn = ACT2FN[config.hidden_act]
  |  
def forward(self, x, im_mask):
  | down_proj = self.w2(
  | self.act_fn(self.w1(x, im_mask)) * self.w3(x, im_mask), im_mask)
  |  
  | return down_proj

2️⃣:https://huggingface.co/Coobiw/InternLM-XComposer2_Enhanced/blob/main/build_mlp.py#L169

class PLoRA(nn.Linear):
--
  |  
def __init__(self,
  | in_features: int,
  | out_features: int,
  | bias: bool = True,
  | device=None,
  | dtype=None,
  | lora_r=8,
  | lora_alpha=16,
  | lora_dropout=0.05,
  | lora_len=0,
  | **kwargs) -> None:
  | super().__init__(in_features, out_features, bias, device, dtype)
  | self.lora_r = lora_r
  | self.lora_alpha = lora_alpha
  | self.lora_len = lora_len
  | if lora_dropout > 0.:
  | self.lora_dropout = nn.Dropout(p=lora_dropout)
  | else:
  | self.lora_dropout = lambda x: x
  | self.lora_scaling = self.lora_alpha / self.lora_r
  |  
  | self.Plora_A = nn.Linear(
  | in_features, self.lora_r, bias=False, device=device, dtype=dtype)
  | self.Plora_B = nn.Linear(
  | self.lora_r, out_features, bias=False, device=device, dtype=dtype)
  |  
  | self.reset_parameters()
  |  
def reset_parameters(self):
  | if hasattr(self, 'lora_A'):
  | # initialize A the same way as the default for nn.Linear and B to zero
  | nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
  | nn.init.zeros_(self.lora_B.weight)
  |  
def forward(self, x, im_mask=None):
  | res = super().forward(x)
  | if im_mask is not None:
  | if torch.sum(im_mask) > 0:
  | part_x = x[im_mask]
  | res[im_mask] += self.Plora_B(
  | self.Plora_A(
  | self.lora_dropout(part_x))) * self.lora_scaling
  | else:
  | part_x = x[:, :1]
  | res[:, :1] += self.Plora_B(
  | self.Plora_A(self.lora_dropout(part_x))) * 0
  | return res

@bang123-box
Copy link
Author

再次感谢您的积极回复,我后面再去尝试一下。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants