A state-of-the-art generative model for crystal structure prediction and de novo generation of inorganic crystals. This open-source framework accompanies the ICML 2025 paper (also available on arXiv) which should be cited when using it.
- Overview.
- Installation.
- Included Datasets.
- Training.
- Generation.
- Visualization.
- Crystal Structure Prediction Metrics.
- De Novo Generation Metrics.
- Citing OMatG.
OMatG supports two crystal generation modes:
- Crystal structure prediction (CSP): Here, the atomic species are fixed during generation and only the fractional coordinates and lattice vectors are adapted to yield a stable crystal structure for the given composition.
- De novo generation (DNG): At the start of the generation, all atomic species are masked or randomly sampled. During the generation process, the species change together with the lattice vectors and fractional coordinates to yield a stable crystal structure.
OMatG leverages the command line interface of PyTorch Lightning for choosing the crystal generation mode, the interpolants, the dataset, and other hyperparameters. Typically, we recommend to use YAML files to store configurations (and sparsely use individual command line arguments to override some of the configuration parameters). This allows for easy reproducibility and sharing of configurations.
The omg/conf_examples
directory contains some example configuration files. In
addition, we provide pretrained checkpoints of the models presented in the paper together with their configuration files
on Hugging Face.
A tutorial notebook for using OMatG including short exercises is available on Kaggle (solutions can be found here). Note that this notebook is part of a more general workshop on generative modeling and thus refers to generative models for fashion pieces. The relevant beginner-friendly notebook that introduces generative modeling with short coding exercises is also available on Kaggle (with solutions here).
Expand this section for a brief introduction to the theoretical background of OMatG.
Theoretical Background
OMatG implements the stochastic interpolants (SIs) framework for the modeling and generation of inorganic crystalline materials. SIs are a unifying framework for generative modeling that encompasses flow-matching and diffusion-based methods as specific instances, while offering a more general and flexible approach enabling the design of a broad class of novel generative models.
A stochastic interpolant
The time-dependent probability density of the stochastic process
The flexibility of the SI framework stems from the ability to tailor the choice of interpolants and choosing between deterministic (ODE) and stochastic (SDE) sampling schemes (see Fig. 1 that visualizes the tunable components of the SI framework for bridging samples from a base distribution (gray particles) to samples from a target distribution (purple particles); figure taken from the OMatG paper.).
OMatG defines a crystalline material of
Expand this section for tips on how to set up new configuration files.
Machine-learning models implemented with PyTorch Lightning rely on three essential parts:
Trainer
: The training engine.LightningDataModule
: Handles data loading and preprocessing.LightningModule
: Defines the model and training logic.
Configuration files of OMatG thus generally contain specifications for these three parts.
OMatG uses the standard PyTorch Lightning Trainer. Its
parameters are specified in the trainer
section of the configuration file, for example:
trainer:
callbacks: # List of callbacks to be used during training.
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
filename: "best_val_loss_total"
save_top_k: 1
monitor: "val_loss_total"
save_weights_only: true
accelerator: "gpu"
gradient_clip_val: 0.5
gradient_clip_algorithm: "value"
num_sanity_val_steps: 0
precision: "32-true"
max_epochs: 2000
enable_progress_bar: true
Note that it is possible to initialize specialized classes in the configuration file by specifying the class_path
and
init_args
. The init_args
dictionary contains the arguments that are passed to the constructor of the class.
In addition to the trainer, one should specify the optimizer and (optionally) the learning rate scheduler in their own sections:
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.001
weight_decay: 0.01
lr_scheduler:
class_path: torch.optim.lr_scheduler.CosineAnnealingLR
init_args:
T_max: 2000
eta_min: 1e-07
The data
section of the configuration constructs the OMGDataModule
(see
omg/datamodule/dataloader.py
). It mainly expects the
train_dataset
, val_dataset
, and predict_dataset
sections. Each of these sections should construct an
OMGTorchDataset
(see omg/datamodule/dataloader.py
again). This can be done based
on lmdb files:
data:
train_dataset:
class_path: omg.datamodule.dataloader.OMGTorchDataset
init_args:
dataset:
class_path: omg.datamodule.datamodule.DataModule
init_args:
lmdb_paths:
- "data/mp_20/train.lmdb"
niggli: False
val_dataset:
class_path: omg.datamodule.dataloader.OMGTorchDataset
init_args:
dataset:
class_path: omg.datamodule.datamodule.DataModule
init_args:
lmdb_paths:
- "data/mp_20/val.lmdb"
niggli: False
predict_dataset:
class_path: omg.datamodule.dataloader.OMGTorchDataset
init_args:
dataset:
class_path: omg.datamodule.datamodule.DataModule
init_args:
lmdb_paths:
- "data/mp_20/test.lmdb"
niggli: False
batch_size: 32
num_workers: 4
pin_memory: True
persistent_workers: True
Every record in the lmdb files should contain a crystal structure. The key of each record is assumed to be an (arbitrary) encoded string, while the value is assumed to be a pickled dictionary with, at least, the following keys:
pos
: Atorch.Tensor
of shape(N, 3)
containing the fractional coordinates of the atoms in the crystal structure.cell
: Atorch.Tensor
of shape(3, 3)
containing the lattice vectors of the crystal structure.atomic_numbers
: Atorch.Tensor
of shape(N,)
containing the atomic numbers of the atoms in the crystal structure.
The data
section can also contain additional parameters for the data loading (such as batch_size
, num_workers
,
pin_memory
, and persistent_workers
in the above example). These parameters are passed to the underlying
PyTorch DataLoader
instances.
Within OMatG, the data is passed around as torch_geometric.data.Data
instances. For a batch size of batch_size
,
these instances contain the following attributes:
n_atoms
:torch.Tensor
of shape(batch_size, )
containing the number of atoms in each configuration.batch
:torch.Tensor
of shape(sum(n_atoms),)
containing the index of the configuration to which each atom belongs.species
:torch.Tensor
of shape(sum(n_atoms),)
containing the atomic numbers of the atoms in the configurations.pos
:torch.Tensor
of shape(sum(n_atoms), 3)
containing the atomic positions of the atoms in the configurations.cell
:torch.Tensor
of shape(batch_size, 3, 3)
containing the cell vectors of the configurations.ptr
:torch.Tensor
of shape(batch_size + 1,)
containing the indices of the first atom of each configuration in thespecies
andpos
tensors.property
: dict containing the properties of the configurations.
The model
section of the configuration file constructs the OMGLightningModule
(see
omg/omg_lightning.py
). Its arguments are documented in the class docstring.
An example model
section looks as follows:
model:
si: # Collection of stochastic interpolants.
class_path: omg.si.stochastic_interpolants.StochasticInterpolants
init_args:
stochastic_interpolants:
# Chemical species.
# The SingleStochasticInterpolantIdentity keeps the species unchanged during interpolation (CSP task).
# For DNG, use, e.g., omg.si.discrete_flow_matching_mask.DiscreteFlowMatchingMask.
- class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
# Fractional coordinates.
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
init_args:
# Use a periodic interpolant for fractional coordinates.
interpolant: omg.si.interpolants.PeriodicLinearInterpolant
gamma: null
epsilon: null
differential_equation_type: "ODE"
integrator_kwargs:
method: "euler"
velocity_annealing_factor: 10.182659004291072
correct_center_of_mass_motion: true
# Lattice vectors.
- class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
init_args:
# Use a non-periodic interpolant for lattice vectors.
interpolant: omg.si.interpolants.LinearInterpolant
gamma: null
epsilon: null
differential_equation_type: "ODE"
integrator_kwargs:
method: "euler"
velocity_annealing_factor: 1.824475401606087
correct_center_of_mass_motion: false
data_fields:
# If the order of the data_fields changes,
# the order of the above StochasticInterpolant inputs must also change.
- "species"
- "pos"
- "cell"
integration_time_steps: 1000
relative_si_costs:
species_loss: 0.0
pos_loss_b: 0.999
cell_loss_b: 0.001
sampler:
class_path: omg.sampler.sample_from_rng.SampleFromRNG
init_args:
# Uniform distribution for fractional coordinates.
pos_distribution: null
cell_distribution:
class_path: omg.sampler.distributions.InformedLatticeDistribution
init_args:
dataset_name: mp_20
species_distribution:
# For DNG, use omg.sampler.distributions.MaskDistribution.
class_path: omg.sampler.distributions.MirrorData
model:
class_path: omg.model.model.Model
init_args:
encoder:
class_path: omg.model.encoders.cspnet_full.CSPNetFull
head:
class_path: omg.model.heads.pass_through.PassThrough
time_embedder:
class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
init_args:
dim: 256
The si
section combines the stochastic interpolants for the species
, pos
, and cell
data fields of the crystal
structures in the StochasticInterpolants
class.
This class is documented in its docstring but, in a nutshell, it is a container for multiple
StochasticInterpolant
instances. The typically used implementations of this abstract class are:
SingleStochasticInterpolant
: For continuous data fields such as fractional coordinates and lattice vectors with arbitrary base distributions. The specific interpolant and its parameters are specified on initialization of this class. Every interpolant has a periodic (for fractional coordinates) and a non-periodic (for lattice vectors) version.SingleStochasticInterpolantOS
: For continuous data fields such as fractional coordinates and lattice vectors, but explicitly assuming a Gaussian base distribution as it implements one-sided stochastic interpolants.SingleStochasticInterpolantIdentity
: For keeping the corresponding data field unchanged during interpolation and generation.DiscreteFlowMatchingMask
: For discrete data fields such as atomic species with a completely masked base distribution.DiscreteFlowMatchingUniform
: For discrete data fields such as atomic species with a uniform base distribution.
Every StochasticInterpolant
in the StochasticInterpolants
class computes losses and returns them in a
dictionary (see the loss_keys
method in the respective class). The StochasticInterpolants
class prefixes these keys
with the name of the corresponding data field so that the losses can be identified. The relative_si_costs
section
specifies the relative weights of these losses when they are added up during training.
The sampler
section specifies the base distributions for the positions, lattice vectors, and atomic species. Depending
on the choice of the stochastic interpolant, one should choose the matching base distribution:
SingleStochasticInterpolant
: The choice of the base distribution is arbitrary. As in the example above, we typically use a uniform distribution for the fractional coordinates and an informed base distribution for the lattice vectors.SingleStochasticInterpolantOS
: Explicitly assumes aNormalDistribution
.SingleStochasticInterpolantIdentity
: Explicitly assumes that the training data is just taken over in the "random" sample as implemented by theMirrorData
distribution.DiscreteFlowMatchingMask
: Explicitly assumes fully masked samples as the base distribution as implemented in theMaskDistribution
.DiscreteFlowMatchingUniform
: Explicitly assumes uniformly distributed atomic species as the base distribution which can achieved by usingspecies_distribution: null
.
The model
section specifies the model architecture. In the above example, we just use DiffCSPNet.
Install the dependencies (see pyproject.toml) and the omg
package itself by running
pip install .
within the base directory of this repository. For editable mode (recommended for developers), use
pip install -e .
instead. You can use any Python version between 3.10 and 3.13.
If the code in this repository changes, the command pip install .
has to be executed again to also
change the code of the installed package. If you installed omg
in editable mode, any changes in code are directly
available in the installed omg
package.
NOTE: Installing PyTorch 2.7 based on the correct compute platform as described on the PyTorch webpage before installing
omg
can help minimize sources of installation errors. The same applies to PyG 2.6 and PyTorch Scatter 2.1.
Installing the omg
package as described above provides the omg
command for training, generation, and evaluation.
For convenience, we include several material datasets that can be used for training. They can be found in the
omg/data
directory and are described briefly below:
- MP-20: 45,229 structures from the Materials Project with a maximum of 20 atoms per structure.
- MPTS-52: Chronological data split of the Materials Project with 40,476 structures and up to 52 atoms per structure.
- Perov-5: A perovskite dataset containing 18,928 structures each with five atoms per structure.
- Carbon-24: A dataset of 10,153 structures consisting only of carbon with up to 24 atoms per structure.
- Alex-MP-20: New split of a consolidated dataset of 675,204 structures of Alexandria and MP-20 structures. In comparison to MatterGen's dataset, we removed 10% of the training data to create a test dataset. The Alex-MP-20 dataset is too large to be stored in this repository. We have made it available via the HuggingFace link associated with this project.
Run the following command to train OMatG from scratch based on a configuration file:
omg fit --config=<configuration_file.yaml>
This command will create checkpoints, log files, and cache files in the current working directory. The configuration
file contain paths to lmdb files that are used, e.g., for training. The path
to these data files can either be relative to the working directory, or relative to the omg
directory within this
repository (that is, use "data/mp_20/val.lmdb"
in order to use the included mp_20
data set).
If you want to include a Wandb logger with a name, add the --trainer.logger=WandbLogger --trainer.logger.name=<name>
argument. Other loggers can be found here.
In order to restart training from a checkpoint, add the --ckpt_path=<checkpoint_file.ckpt>
argument.
In order to seed the random number generators before training, use --seed_everything=<seed>
.
For generating new structures in an xyz file based on a trained model, run the following command:
omg predict --config=<configuration_file.yaml> --ckpt_path=<checkpoint_file.ckpt> --model.generation_xyz_filename=<xyz_file>
This command will generate one epoch of structures, that is, the number of generated structures is equal to the number structures in the prediction dataset specified in the configuration file.
For an xyz filename filename.xyz
, this command will also create a file filename_init.xyz
that contains the initial
structures that were integrated to yield the structures in filename.xyz
. This file is required for the visualization
below.
If you want to change the batch size of the generation, you can overwrite the batch size in the configuration file with
the --data.batch_size=<new_batch_size>
argument.
Run the following command to compare distributions over the generated structures to distributions over the training dataset:
omg visualize --config=<configuration_file.yaml> --xyz_file=<xyz_file> --plot_name=<plot_name.pdf>
Run the following command to compute the metrics for the CSP task:
omg csp_metrics --config=<configuration_file.yaml> --xyz_file=<xyz_file>
This command attempts to match structures at the same index in the generated dataset and the prediction dataset.
The metrics include the match rate between the generated structures and the structures in the prediction dataset, as
well as the average (normalized) root-mean square displacement between the matched structures. By default, these metrics
are stored in the csp_metrics.json
file. This command also plots the histogram of the root-mean-square distances
between the matched structures in the rmsds.pdf
file.
By default, this method first validates the generated structures and the structures in the prediction dataset
based on volume, structure, composition, and fingerprint checks (see ValidAtoms
class),
and calculates the match rate between the valid generated structures and the valid structures in the prediction dataset.
The (slow) validation can be skipped by using skip_validation=True
.
The validations and matchings are parallelized. The number of processes is determined by os.cpu_count()
. This can
be changed by setting the --number_cpus
argument (which is probably most useful in cluster environments).
Further arguments are documented in the csp_metrics
method in the OMGTrainer
class.
Run the following command to compute the metrics for the de novo generation task:
omg dng_metrics --config=<configuration_file.yaml> --xyz_file=<xyz_file> --dataset_name=<dataset_name>
The metrics include validity (structural and compositional) and Wasserstein distances between distributions of density,
volume fraction, number of atoms, number of unique elements, and average coordination number.
In addition, if dataset_name
is set to mp_20
, carbon_24
, or perov_5
, the metrics include coverage recall and
precision. By default, these metrics are stored in the dng_metrics.json
file.
The validations are parallelized. The number of processes is determined by os.cpu_count()
. This can
be changed by setting the --number_cpus
argument (which is probably most useful in cluster environments).
Stability related metrics can be computed, for example, with the MatterGen codebase.
Please cite the following paper when using OMatG in your work:
@inproceedings{
hoellmer2025,
title={Open Materials Generation with Stochastic Interpolants},
author={Philipp H{\"o}llmer and Thomas Egg and Maya Martirossyan and Eric
Fuemmeler and Zeren Shui and Amit Gupta and Pawan Prakash and Adrian
Roitberg and Mingjie Liu and George Karypis and Mark Transtrum and Richard
Hennig and Ellad B. Tadmor and Stefano Martiniani},
booktitle={Forty-second International Conference on Machine Learning},
year={2025},
url={https://openreview.net/forum?id=gHGrzxFujU},
archivePrefix={arXiv},
eprint={2502.02582},
primaryClass={cs.LG},
}