π One-line code to implement PyTorch feature caching, accelerate training by 30x+!
Torch Module Cache is a simple yet powerful PyTorch tool that enables model feature caching with just one line of code, significantly boosting training and inference speed. Whether it's dataset preprocessing or pretrained model feature caching, it's all made easy.
- π Minimal Code: Enable caching with just one decorator
- π Significant Speedup: Real-world tests show 30x+ acceleration per epoch
- π» VRAM Friendly: Model will not be loaded until not hit cache, save your VRAM
- π Flexible Caching: Support for both dataset and model feature caching
- π― Smart Inference: Support for inference mode with global cache disabling
- πΎ Memory Optimized: Automatic cache memory management to prevent leaks
pip install torch-module-cache
Simply add the @cache_module()
decorator to enable feature caching, this will be extremely effective when extracting features within the model using pre-trained models:
from torch_module_cache import cache_module
# Only need to add one line of code to enable caching
@cache_module()
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 3)
def forward(self, x):
return self.linear(x)
# Using cache
model = MyModel()
# First run will compute and cache the result
output1 = model(x, cache_key="key1")
# Second run will use the cached result
output2 = model(x, cache_key="key1")
# For batch processing, you can use a list of cache keys:
cache_keys = [f"key_{i}" for i in range(10)]
outputs = model(torch.randn(10, 10), cache_key=cache_keys)
Accelerate your model by caching features from pretrained models like ViT, ResNet, etc.:
# Only need to add one line of code to enable caching
@cache_module()
class FeatureExtractor(nn.Module):
def __init__(self):
super().__init__()
# Load pretrained ViT
self.vit = timm.create_model("vit_base_patch16_224", pretrained=True)
self.vit.eval() # Set to eval mode
def forward(self, x):
# Extract features from ViT
with torch.no_grad():
features = self.vit.forward_features(x)
return features
class MyModel(nn.Module):
def __init__(self):
super().__init__()
# `feature_extractor` is frozen, so we can use cache to speed up
self.feature_extractor = FeatureExtractor()
self.classifier = nn.Linear(768, 10) # ViT-Base features are 768-dim
def forward(self, x, cache_key=None):
# Features will be cached automatically
features = self.feature_extractor(x, cache_key=cache_key)
return self.classifier(features)
Still manually extracting features and saving them to .pt
files? Use caching in your dataset to accelerate data loading with only one-line code:
@cache_module(cache_name="feature_processor")
class FeatureProcessor(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 256)
def forward(self, x):
return self.linear(x)
class CachedDataset(Dataset):
def __init__(self):
self.processor = FeatureProcessor()
def __getitem__(self, idx):
raw_data = self.data[idx]
# Use sample index as cache key, the second epoch will start using cache to speed up.
processed_data = self.processor(raw_data, cache_key=f"sample_{idx}")
return processed_data, self.labels[idx]
Disable caching during inference:
from torch_module_cache import enable_inference_mode
# Enable inference mode (disable caching and model will be init when instance is created)
enable_inference_mode()
# Model will compute directly without using cache
model = MyModel()
output = model(x)
Scenario | Without Cache | With Cache | Speedup |
---|---|---|---|
Dataset Preprocessing | 100s | 3.2s | 31.25x |
ViT Feature Extraction | 2.10s | 0.024s | 86.82x |
Check out the examples directory for more usage examples:
The @cache_module()
decorator accepts several configuration parameters:
from torch_module_cache import cache_module, CacheLevel
@cache_module(
# Path to store cache files (default: ~/.cache/torch-module-cache)
cache_path="/path/to/cache",
# Subfolder name for this specific model (default: class name)
cache_name="my_model_cache",
# Cache level: CacheLevel.DISK or CacheLevel.MEMORY
cache_level=CacheLevel.MEMORY,
# Whether to use safer loading options (recommended for untrusted data)
safe_load=False,
# Maximum memory usage in MB (default: None)
max_memory_cache_size_mb=None,
)
class MyModel(nn.Module):
# ... your model implementation
from torch_module_cache import clear_memory_caches, clear_disk_caches
# Clear all in-memory caches
clear_memory_caches()
# Clear all disk caches
clear_disk_caches()
# Clear caches for a specific model
clear_memory_caches(cache_name="my_model_cache")
clear_disk_caches(cache_name="my_model_cache")
Issues and Pull Requests are welcome!
MIT License