Skip to content

[FEATURE] Unified embed Interfaces for Vision Transformer Models for MIM Pretrain Research #2540

@ryan-minato

Description

@ryan-minato

This feature request is related to challenges in Masked Image Modeling (MIM) pre-training using vision transformer models in timm. Currently, embedding and feature extraction are tightly coupled within forward_features, making it difficult to inject mask operations after initial embedding and positional encoding but before the transformer stages, which is a common MIM requirement. Researchers need to access embedded tokens for masking before passing them through subsequent transformer layers.

Describe the solution you'd like

I propose a refactoring of all vision transformer models(e.g. vit, swin_transformer, etc.) to uniformly expose two distinct interfaces:

  • embed(self, x): This method should take the input tensor x (e.g., image patches) and return the embedded vectors, including positional encodings if applicable. The output should be ready for the transformer encoder stages.
  • forward_stages(self, x): This method should take the output from embed(x) (i.e., the embedded and position-encoded tokens) and pass them through the transformer encoder layers.

This separation would make the existing forward_features(x) effectively equivalent to forward_stages(embed(x)). This allows researchers to easily perform mask operations on the embedded tokens returned by embed(x) before passing them to forward_stages(x), enabling flexible MIM pre-training experiments.

Describe alternatives you've considered

I have considered alternative solutions, such as adding a mask parameter directly to the forward_features and forward methods, similar to vision_transformer's implementation.

def forward_features(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Forward pass through feature layers (embeddings, transformer blocks, post-transformer norm)."""

While seemingly straightforward, this approach presents drawbacks:

  • Function Signature: It still alters the method's interface, even if default parameter values mitigate direct breakage.
  • Mask Value Flexibility: More importantly, it would limit flexibility in how masked-out positions are handled, restricting whether masked token values can be learned (e.g., a learnable mask token) or simply set to zero. Separating embed and forward_stages provides full control.

Additional context

If this refactoring aligns with the library's design, I would gladly contribute a Pull Request to implement these changes across relevant vision transformer models.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions