Skip to content

Westlake-AI/OpenToMe

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

53 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Toolbox and Benchmark for Token Merging Modules

News

  • 2025-06-23: Setup the basic framework of OpenToMe.

Installation

Install experimental dependencies

conda create -n opentome python=3.10.0
conda activate opentome
pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124
pip install -r requirements.txt

Install OpenToMe from source

git git@github.com:Westlake-AI/OpenToMe.git
cd OpenToMe
pip install -e .

Getting Started

Model Examples

Here is an example of using ToMe with timm Attention blocks.

import torch
import timm
from torch import nn
from opentome.timm import Block, tome_apply_patch
from opentome.tome import check_parse_r


class TransformerBlock(nn.Module):
    def __init__(self, *, embed_dim=768, num_layers=12, num_heads=12, drop_path=0.0,
                 with_cls_token=True, init_values=1e-5, use_flash_attn=False, **kwargs):
        super(TransformerBlock, self).__init__()

        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.with_cls_token = with_cls_token

        if self.with_cls_token:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        self.pos_drop = nn.Dropout(p=0.0)

        dp_rates=[x.item() for x in torch.linspace(drop_path, 0.0, num_layers)]
        self.blocks = nn.Sequential(
            *[Block(
                dim=self.embed_dim,
                num_heads=self.num_heads,
                mlp_ratio=4.0,
                qkv_bias=True,
                init_values=init_values,
                drop_path=dp_rates[j],
                norm_layer=nn.LayerNorm,
                act_layer=nn.GELU,
                mlp_layer=timm.layers.Mlp,
                use_flash_attn=use_flash_attn,
            ) for j in range(num_layers)]
        )
        self.norm = nn.LayerNorm(self.embed_dim)

    def forward(self, x):
        B, N, C = x.shape
        if self.with_cls_token:
            cls_tokens = self.cls_token.expand(B, -1, -1)
            x = torch.cat((cls_tokens, x), dim=1)
        x = self.pos_drop(x)
        x = self.blocks(x)
        x = self.norm(x)
        return x

# test
embed_dim, token_num, merge_num, inflect = 384, 196, 100, 0.5
x = torch.randn(1, token_num, embed_dim)
# model = timm.create_model('vit_small_patch16_224')
model = TransformerBlock(embed_dim=384, num_layers=12, num_heads=8)
z = model.forward(x)
print(x.shape, z.shape)

# update tome
merge_ratio = check_parse_r(len(model.blocks), merge_num, token_num, inflect)
tome_apply_patch(model)
model.r = (merge_ratio, inflect)
model._tome_info["r"] = model.r
model._tome_info["total_merge"] = merge_num

z = model.forward(x)
print(x.shape, z.shape)

Image Classification on ImageNet

Here is an example of evaluate ImageNet validation set with various Token Compression methods.

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 \
./evaluations/image_classification/in1k_example.py \
--model_name vit_base_patch16_224 \
--tome tome \
--merge_num 100 \
--dataset /PATH/TO/ImageNet/val \
--inflect -0.5 \

You can also run the evaluation with the bash example on GPU0:

bash evaluations/image_classification/in1k_eval.sh 0 tome 100 /PATH/TO/ImageNet/val 1 deit_small_patch16_224

Image ToMe Visualization

Here is an example of visualization with various Token Compression methods

CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=0 \
evaluations/visualizations/vis_classification.py \
--image_path ./demo \
--model_name deit_small_patch16_224 \
--tome tome \
--merge_num 100 \
--inflect -0.5 \
--save_vis True \

You can run the visualization with the bash example:

bash evaluations/visualizations/vis_eval.sh ./demo tome 100 1 deit_small_patch16_224

Token Compression Baselines

  • ToMe [ICLR'23] Token Merging: Your ViT but Faster paper code
  • DiffRate [ICCV'23] Diffrate: Differentiable Compression Rate for Efficient Vision Transformers paper code
  • DCT [ACL'23] Fourier Transformer: Fast Long Range Modeling by Removing Sequence Redundancy with FFT Operator paper code
  • ToFu [WACV'24] Token Fusion: Bridging the Gap between Token Pruning and Token Merging paper
  • CrossGET [ICML'24] CrossGET: Cross-Guided Ensemble of Tokens for Accelerating Vision-Language Transformers paper code
  • MCTF [CVPR'24] Multi-criteria Token Fusion with One-step-ahead Attention for Efficient Vision Transformers paper code
  • ATC [ECCV'24] paper code
  • DTEM [NeurIPS'24] Learning to Merge Tokens via Decoupled Embedding for Efficient Vision Transformers paper code
  • PiToMe [NeurIPS'24] Accelerating Transformers with Spectrum-Preserving Token Merging paper code
  • FPET [CVPR'25] paper code

Support Tasks (TODO List)

  • Image Classification
  • Image Generation
  • M/LLM Inference
  • Long Sequence
  • Throughput
  • AI for Science
  • ToMe Visualization

Citation

If you find this repository useful, please consider giving a star ⭐ and citation:

@article{2025opentome,
  title = {OpenToMe},
  author = {Siyuan Li and Xin Jin and Kai Yu},
  year = {2025},
  url={https://github.com/Westlake-AI/OpenToMe},
  urldate = {2025-08-15},
}

(back to top)

About

[Developing] Toolbox and Benchmark for Token Merging Modules

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 5