A repository of building blocks in PyTorch 2.0 for E(3)/SE(3)-equivariant neural networks, built on top of e3nn:
- Equivariant Linear Layers:
e3tools.nn.Linear
- Equivariant Convolution:
e3tools.nn.Conv
ande3tools.nn.SeparableConv
- Equivariant Multi-Layer Perceptrons (MLPs):
e3tools.nn.EquivariantMLP
- Equivariant Layer Norm:
e3tools.nn.LayerNorm
- Equivariant Activations:
e3tools.nn.Gate
,e3tools.nn.GateWrapper
ande3tools.nn.Gated
- Separable Equivariant Tensor Products:
e3tools.nn.SeparableTensorProduct
- Extracting Irreps:
e3tools.nn.ExtractIrreps
- Self-Interactions:
e3tools.nn.LinearSelfInteraction
All modules are compatible with torch.compile
for JIT compilation.
Note that you may need to turn off the old torch JIT compiler for some e3nn
modules, at the top of your script (example):
import e3nn
e3nn.set_optimization_defaults(jit_script_fx=False)
Install from PyPI:
pip install e3tools
or get the latest development version from GitHub:
pip install git+https://github.com/prescient-design/e3tools.git
We provide examples of a convolution-based and attention-based E(3)-equivariant message passing networks built with e3tools
. We also provide an example training script on QM9:
python examples/train_qm9.py --model conv
We see an approximate 2.5x improvement in training speed with torch.compile
.