Skip to content

NaViT is slower than expected #347

@AmitMY

Description

@AmitMY

Full reproduction: https://github.com/sign/image-latent-transformer/blob/main/examples/benchmark_image_encoder.py

I am trying to find a model to encode lots of small images per batch.

I started by using Huggingface Transformers, and creating a ViTModel:

ViTModel.from_pretrained("WinKawaks/vit-tiny-patch16-224")

However, this model does not know how to work on variable size images, and if padded, returns different results.

So I created a patch, that allows the model to accept padded images - https://github.com/sign/image-latent-transformer/blob/main/image_latent_transformer/vision/batch_image_encoder.py - and that works, but is ugly and can only support eager attention.

I moved to vit-pytorch and wrapped a NaViT model from this library with the huggingface setup:

NaViT Model
import torch
from transformers import AutoModel, AutoConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel
from vit_pytorch.na_vit_nested_tensor import NaViT


class NaViTConfig(PretrainedConfig):
    model_type = "navit"

    def __init__(
            self,
            image_size: int = 256,  # only used to set a default; NaViT handles var-size
            patch_size: int = 16,
            hidden_size: int = 512,
            dim: int = 512,
            depth: int = 6,
            heads: int = 8,
            mlp_dim: int = 1024,
            dropout: float = 0.0,
            emb_dropout: float = 0.0,
            token_dropout_prob: float = 0.1,
            **kwargs,
    ):
        super().__init__(**kwargs)
        self.image_size = image_size
        self.patch_size = patch_size
        self.hidden_size = hidden_size
        self.dim = dim
        self.depth = depth
        self.heads = heads
        self.mlp_dim = mlp_dim
        self.dropout = dropout
        self.emb_dropout = emb_dropout
        self.token_dropout_prob = token_dropout_prob


class NaViTModel(PreTrainedModel):
    config_class = NaViTConfig

    def __init__(self, config: NaViTConfig):
        super().__init__(config)

        self.navit = NaViT(
            image_size=config.image_size,
            patch_size=config.patch_size,
            num_classes=config.hidden_size,
            dim=config.dim,
            depth=config.depth,
            heads=config.heads,
            mlp_dim=config.mlp_dim,
            dropout=config.dropout,
            emb_dropout=config.emb_dropout,
            token_dropout_prob=config.token_dropout_prob,
        )

        # Initialize weights via HF utility (won't disturb vit-pytorch-initialized weights unless uninitialized)
        self.post_init()

    def forward(self, images: list[torch.Tensor], **kwargs):
        # vit-pytorch NaViT returns shape (B, num_classes)
        logits = self.navit(images)

        return BaseModelOutputWithPooling(
            last_hidden_state=None,
            pooler_output=logits
        )

AutoConfig.register(NaViTConfig.model_type, NaViTConfig)
AutoModel.register(NaViTConfig, NaViTModel)

I created three models of similar size, and then benchmarked the various setups:

    vit_model = ViTModel.from_pretrained("WinKawaks/vit-tiny-patch16-224")

    patched_vit_model = ViTModel.from_pretrained("WinKawaks/vit-tiny-patch16-224")
    maybe_patch_vit_model(vit_model)

    # Similar model~
    navit_model = NaViTModel(NaViTConfig(
        image_size=512,  # Max size for pos embeddings
        patch_size=vit_model.config.patch_size,
        hidden_size=vit_model.config.hidden_size,
        dim=vit_model.config.intermediate_size,
        depth=vit_model.config.num_hidden_layers // 2,
        heads=vit_model.config.num_attention_heads,
        mlp_dim=128,
    ))

Batch size: 15k images (size on average 16x64~)
GPU: A100

  1. ViT Model with padding (inconsistent results) - 1 second
  2. Vit Model grouped into batches of same size images - 13 seconds
  3. Vit Model patched (consistent results) - 34 second
  4. NaVit Model - 39 seconds

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions