Skip to content

FERMat-ML/OMatG

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Open Materials Generation (OMatG)

Static Badge arXiv Hugging Face

Python PyTorch Lightning

License: MIT Static Badge

A state-of-the-art generative model for crystal structure prediction and de novo generation of inorganic crystals. This open-source framework accompanies the ICML 2025 paper (also available on arXiv) which should be cited when using it.

Crystal structure prediction of GaTe:

csp movie

De novo generation of Pd3Te2I3:

dng movie

Table of Contents

Overview

OMatG supports two crystal generation modes:

  1. Crystal structure prediction (CSP): Here, the atomic species are fixed during generation and only the fractional coordinates and lattice vectors are adapted to yield a stable crystal structure for the given composition.
  2. De novo generation (DNG): At the start of the generation, all atomic species are masked or randomly sampled. During the generation process, the species change together with the lattice vectors and fractional coordinates to yield a stable crystal structure.

OMatG leverages the command line interface of PyTorch Lightning for choosing the crystal generation mode, the interpolants, the dataset, and other hyperparameters. Typically, we recommend to use YAML files to store configurations (and sparsely use individual command line arguments to override some of the configuration parameters). This allows for easy reproducibility and sharing of configurations.

The omg/conf_examples directory contains some example configuration files. In addition, we provide pretrained checkpoints of the models presented in the paper together with their configuration files on Hugging Face.

A tutorial notebook for using OMatG including short exercises is available on Kaggle (solutions can be found here). Note that this notebook is part of a more general workshop on generative modeling and thus refers to generative models for fashion pieces. The relevant beginner-friendly notebook that introduces generative modeling with short coding exercises is also available on Kaggle (with solutions here).

Expand this section for a brief introduction to the theoretical background of OMatG.

Theoretical Background

OMatG implements the stochastic interpolants (SIs) framework for the modeling and generation of inorganic crystalline materials. SIs are a unifying framework for generative modeling that encompasses flow-matching and diffusion-based methods as specific instances, while offering a more general and flexible approach enabling the design of a broad class of novel generative models.

A stochastic interpolant $x_t = x(t, x_0, x_1, z)=\alpha(t),x_0 + \beta(t),x_1 + \gamma(t),z$ bridges samples $x_0$ from a (trivial) base distribution to samples $x_1$ from the target data distribution. Here, $t\in[0, 1]$ represents time and the random variable $z$ is drawn from a Gaussian distribution. The functional forms of $\alpha$, $\beta$, and $\gamma$ are flexible, only subject to a few constraints that, amongst other things, ensure that $x(t=0, x_0, x_1, z) = x_0$ and $x(t=1, x_0, x_1, z) = x_1$.

The time-dependent probability density of the stochastic process $x_t$ can be realized either via deterministic sampling through an ordinary differential equation (ODE) or stochastic sampling through a stochastic differential equation (SDE), only requiring a sample $x_0$ from the base distribution. This enables generative modeling by evolving samples from the base distribution to samples from the data distribution. Here, the required velocity term $b^\theta(t, x)$ for both ODE- and SDE-based sampling can be learned from data by "averaging over many paired samples ($x_0, x_1$)." For SDE-based sampling, an additional denoiser $z^\theta(t, x)$ can be learned likewise.

The flexibility of the SI framework stems from the ability to tailor the choice of interpolants and choosing between deterministic (ODE) and stochastic (SDE) sampling schemes (see Fig. 1 that visualizes the tunable components of the SI framework for bridging samples from a base distribution (gray particles) to samples from a target distribution (purple particles); figure taken from the OMatG paper.).

stochastic interpolants

OMatG defines a crystalline material of $N$ atoms by its unit cell that is described by three lattice vectors $\mathbf{L} \in \mathbb{R}^{3\times3}$, $N$ fractional coordinates $\mathbf{X}\in[0,1)^{3\times N}$ with periodic boundary conditions, and $N$ discrete atomic species $\mathbf{A}\in\mathbb{Z}^N_{>0}$. During training an generation, all three components ${\mathbf{A}, \mathbf{X}, \mathbf{L}}$ are considered simultaneously. The SI framework is applied to the continuous structural representations ${\mathbf{X}, \mathbf{L}}$ while the discrete atomic species $\mathbf{A}$ are treated with discrete flow matching.

Expand this section for tips on how to set up new configuration files.

Configuration Files

Machine-learning models implemented with PyTorch Lightning rely on three essential parts:

  1. Trainer: The training engine.
  2. LightningDataModule: Handles data loading and preprocessing.
  3. LightningModule: Defines the model and training logic.

Configuration files of OMatG thus generally contain specifications for these three parts.

Trainer

OMatG uses the standard PyTorch Lightning Trainer. Its parameters are specified in the trainer section of the configuration file, for example:

trainer:
  callbacks:  # List of callbacks to be used during training.
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        filename: "best_val_loss_total"
        save_top_k: 1
        monitor: "val_loss_total"
        save_weights_only: true
  accelerator: "gpu"
  gradient_clip_val: 0.5
  gradient_clip_algorithm: "value"
  num_sanity_val_steps: 0
  precision: "32-true"
  max_epochs: 2000
  enable_progress_bar: true

Note that it is possible to initialize specialized classes in the configuration file by specifying the class_path and init_args. The init_args dictionary contains the arguments that are passed to the constructor of the class.

In addition to the trainer, one should specify the optimizer and (optionally) the learning rate scheduler in their own sections:

optimizer:
  class_path: torch.optim.AdamW
  init_args:
    lr: 0.001
    weight_decay: 0.01
lr_scheduler:
  class_path: torch.optim.lr_scheduler.CosineAnnealingLR
  init_args:
    T_max: 2000
    eta_min: 1e-07

LightingDataModule

The data section of the configuration constructs the OMGDataModule (see omg/datamodule/dataloader.py). It mainly expects the train_dataset, val_dataset, and predict_dataset sections. Each of these sections should construct an OMGTorchDataset (see omg/datamodule/dataloader.py again). This can be done based on lmdb files:

data:
  train_dataset:
    class_path: omg.datamodule.dataloader.OMGTorchDataset
    init_args:
      dataset:
        class_path: omg.datamodule.datamodule.DataModule
        init_args:
          lmdb_paths:
           - "data/mp_20/train.lmdb"
      niggli: False
  val_dataset:
    class_path: omg.datamodule.dataloader.OMGTorchDataset
    init_args:
      dataset:
        class_path: omg.datamodule.datamodule.DataModule
        init_args:
          lmdb_paths:
           - "data/mp_20/val.lmdb"
      niggli: False
  predict_dataset:
    class_path: omg.datamodule.dataloader.OMGTorchDataset
    init_args:
      dataset:
        class_path: omg.datamodule.datamodule.DataModule
        init_args:
          lmdb_paths:
           - "data/mp_20/test.lmdb"
      niggli: False
  batch_size: 32
  num_workers: 4
  pin_memory: True
  persistent_workers: True

Every record in the lmdb files should contain a crystal structure. The key of each record is assumed to be an (arbitrary) encoded string, while the value is assumed to be a pickled dictionary with, at least, the following keys:

  • pos: A torch.Tensor of shape (N, 3) containing the fractional coordinates of the atoms in the crystal structure.
  • cell: A torch.Tensor of shape (3, 3) containing the lattice vectors of the crystal structure.
  • atomic_numbers: A torch.Tensor of shape (N,) containing the atomic numbers of the atoms in the crystal structure.

The data section can also contain additional parameters for the data loading (such as batch_size, num_workers, pin_memory, and persistent_workers in the above example). These parameters are passed to the underlying PyTorch DataLoader instances.

Within OMatG, the data is passed around as torch_geometric.data.Data instances. For a batch size of batch_size, these instances contain the following attributes:

  • n_atoms: torch.Tensor of shape (batch_size, ) containing the number of atoms in each configuration.
  • batch: torch.Tensor of shape (sum(n_atoms),) containing the index of the configuration to which each atom belongs.
  • species: torch.Tensor of shape (sum(n_atoms),) containing the atomic numbers of the atoms in the configurations.
  • pos: torch.Tensor of shape (sum(n_atoms), 3) containing the atomic positions of the atoms in the configurations.
  • cell: torch.Tensor of shape (batch_size, 3, 3) containing the cell vectors of the configurations.
  • ptr: torch.Tensor of shape (batch_size + 1,) containing the indices of the first atom of each configuration in the species and pos tensors.
  • property: dict containing the properties of the configurations.

LightningModule

The model section of the configuration file constructs the OMGLightningModule (see omg/omg_lightning.py). Its arguments are documented in the class docstring. An example model section looks as follows:

model:
  si:  # Collection of stochastic interpolants.
    class_path: omg.si.stochastic_interpolants.StochasticInterpolants
    init_args:
      stochastic_interpolants:
        # Chemical species.
        # The SingleStochasticInterpolantIdentity keeps the species unchanged during interpolation (CSP task).
        # For DNG, use, e.g., omg.si.discrete_flow_matching_mask.DiscreteFlowMatchingMask.
        - class_path: omg.si.single_stochastic_interpolant_identity.SingleStochasticInterpolantIdentity
        # Fractional coordinates.
        - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
          init_args:
            # Use a periodic interpolant for fractional coordinates.
            interpolant: omg.si.interpolants.PeriodicLinearInterpolant
            gamma: null
            epsilon: null
            differential_equation_type: "ODE"
            integrator_kwargs:
              method: "euler"
            velocity_annealing_factor: 10.182659004291072
            correct_center_of_mass_motion: true
        # Lattice vectors.
        - class_path: omg.si.single_stochastic_interpolant.SingleStochasticInterpolant
          init_args:
            # Use a non-periodic interpolant for lattice vectors.
            interpolant: omg.si.interpolants.LinearInterpolant
            gamma: null
            epsilon: null
            differential_equation_type: "ODE"
            integrator_kwargs:
              method: "euler"
            velocity_annealing_factor: 1.824475401606087
            correct_center_of_mass_motion: false
      data_fields:
        # If the order of the data_fields changes,
        # the order of the above StochasticInterpolant inputs must also change.
        - "species"
        - "pos"
        - "cell"
      integration_time_steps: 1000
  relative_si_costs:
    species_loss: 0.0
    pos_loss_b: 0.999
    cell_loss_b: 0.001
  sampler:
    class_path: omg.sampler.sample_from_rng.SampleFromRNG
    init_args:
      # Uniform distribution for fractional coordinates.
      pos_distribution: null
      cell_distribution:
        class_path: omg.sampler.distributions.InformedLatticeDistribution
        init_args:
          dataset_name: mp_20
      species_distribution:
        # For DNG, use omg.sampler.distributions.MaskDistribution.
        class_path: omg.sampler.distributions.MirrorData
  model:
    class_path: omg.model.model.Model
    init_args:
      encoder:
        class_path: omg.model.encoders.cspnet_full.CSPNetFull
      head:
        class_path: omg.model.heads.pass_through.PassThrough
      time_embedder:
        class_path: omg.model.model_utils.SinusoidalTimeEmbeddings
        init_args:
          dim: 256

The si section combines the stochastic interpolants for the species, pos, and cell data fields of the crystal structures in the StochasticInterpolants class. This class is documented in its docstring but, in a nutshell, it is a container for multiple StochasticInterpolant instances. The typically used implementations of this abstract class are:

  • SingleStochasticInterpolant: For continuous data fields such as fractional coordinates and lattice vectors with arbitrary base distributions. The specific interpolant and its parameters are specified on initialization of this class. Every interpolant has a periodic (for fractional coordinates) and a non-periodic (for lattice vectors) version.
  • SingleStochasticInterpolantOS: For continuous data fields such as fractional coordinates and lattice vectors, but explicitly assuming a Gaussian base distribution as it implements one-sided stochastic interpolants.
  • SingleStochasticInterpolantIdentity: For keeping the corresponding data field unchanged during interpolation and generation.
  • DiscreteFlowMatchingMask: For discrete data fields such as atomic species with a completely masked base distribution.
  • DiscreteFlowMatchingUniform: For discrete data fields such as atomic species with a uniform base distribution.

Every StochasticInterpolant in the StochasticInterpolants class computes losses and returns them in a dictionary (see the loss_keys method in the respective class). The StochasticInterpolants class prefixes these keys with the name of the corresponding data field so that the losses can be identified. The relative_si_costs section specifies the relative weights of these losses when they are added up during training.

The sampler section specifies the base distributions for the positions, lattice vectors, and atomic species. Depending on the choice of the stochastic interpolant, one should choose the matching base distribution:

The model section specifies the model architecture. In the above example, we just use DiffCSPNet.

Installation

Install the dependencies (see pyproject.toml) and the omg package itself by running pip install . within the base directory of this repository. For editable mode (recommended for developers), use pip install -e . instead. You can use any Python version between 3.10 and 3.13.

If the code in this repository changes, the command pip install . has to be executed again to also change the code of the installed package. If you installed omg in editable mode, any changes in code are directly available in the installed omg package.

NOTE: Installing PyTorch 2.7 based on the correct compute platform as described on the PyTorch webpage before installing omg can help minimize sources of installation errors. The same applies to PyG 2.6 and PyTorch Scatter 2.1.

Installing the omg package as described above provides the omg command for training, generation, and evaluation.

Included Datasets

For convenience, we include several material datasets that can be used for training. They can be found in the omg/data directory and are described briefly below:

  • MP-20: 45,229 structures from the Materials Project with a maximum of 20 atoms per structure.
  • MPTS-52: Chronological data split of the Materials Project with 40,476 structures and up to 52 atoms per structure.
  • Perov-5: A perovskite dataset containing 18,928 structures each with five atoms per structure.
  • Carbon-24: A dataset of 10,153 structures consisting only of carbon with up to 24 atoms per structure.
  • Alex-MP-20: New split of a consolidated dataset of 675,204 structures of Alexandria and MP-20 structures. In comparison to MatterGen's dataset, we removed 10% of the training data to create a test dataset. The Alex-MP-20 dataset is too large to be stored in this repository. We have made it available via the HuggingFace link associated with this project.

Training

Run the following command to train OMatG from scratch based on a configuration file:

omg fit --config=<configuration_file.yaml>

This command will create checkpoints, log files, and cache files in the current working directory. The configuration file contain paths to lmdb files that are used, e.g., for training. The path to these data files can either be relative to the working directory, or relative to the omg directory within this repository (that is, use "data/mp_20/val.lmdb" in order to use the included mp_20 data set).

If you want to include a Wandb logger with a name, add the --trainer.logger=WandbLogger --trainer.logger.name=<name> argument. Other loggers can be found here.

In order to restart training from a checkpoint, add the --ckpt_path=<checkpoint_file.ckpt> argument.

In order to seed the random number generators before training, use --seed_everything=<seed>.

Generation

For generating new structures in an xyz file based on a trained model, run the following command:

omg predict --config=<configuration_file.yaml> --ckpt_path=<checkpoint_file.ckpt> --model.generation_xyz_filename=<xyz_file>

This command will generate one epoch of structures, that is, the number of generated structures is equal to the number structures in the prediction dataset specified in the configuration file.

For an xyz filename filename.xyz, this command will also create a file filename_init.xyz that contains the initial structures that were integrated to yield the structures in filename.xyz. This file is required for the visualization below.

If you want to change the batch size of the generation, you can overwrite the batch size in the configuration file with the --data.batch_size=<new_batch_size> argument.

Visualization

Run the following command to compare distributions over the generated structures to distributions over the training dataset:

omg visualize --config=<configuration_file.yaml> --xyz_file=<xyz_file> --plot_name=<plot_name.pdf>

Crystal Structure Prediction Metrics

Run the following command to compute the metrics for the CSP task:

omg csp_metrics --config=<configuration_file.yaml> --xyz_file=<xyz_file>

This command attempts to match structures at the same index in the generated dataset and the prediction dataset. The metrics include the match rate between the generated structures and the structures in the prediction dataset, as well as the average (normalized) root-mean square displacement between the matched structures. By default, these metrics are stored in the csp_metrics.json file. This command also plots the histogram of the root-mean-square distances between the matched structures in the rmsds.pdf file.

By default, this method first validates the generated structures and the structures in the prediction dataset based on volume, structure, composition, and fingerprint checks (see ValidAtoms class), and calculates the match rate between the valid generated structures and the valid structures in the prediction dataset. The (slow) validation can be skipped by using skip_validation=True.

The validations and matchings are parallelized. The number of processes is determined by os.cpu_count(). This can be changed by setting the --number_cpus argument (which is probably most useful in cluster environments).

Further arguments are documented in the csp_metrics method in the OMGTrainer class.

De Novo Generation Metrics

Run the following command to compute the metrics for the de novo generation task:

omg dng_metrics --config=<configuration_file.yaml> --xyz_file=<xyz_file> --dataset_name=<dataset_name>

The metrics include validity (structural and compositional) and Wasserstein distances between distributions of density, volume fraction, number of atoms, number of unique elements, and average coordination number. In addition, if dataset_name is set to mp_20, carbon_24, or perov_5, the metrics include coverage recall and precision. By default, these metrics are stored in the dng_metrics.json file.

The validations are parallelized. The number of processes is determined by os.cpu_count(). This can be changed by setting the --number_cpus argument (which is probably most useful in cluster environments).

Stability related metrics can be computed, for example, with the MatterGen codebase.

Citing OMatG

Please cite the following paper when using OMatG in your work:

@inproceedings{
    hoellmer2025,
    title={Open Materials Generation with Stochastic Interpolants},
    author={Philipp H{\"o}llmer and Thomas Egg and Maya Martirossyan and Eric
    Fuemmeler and Zeren Shui and Amit Gupta and Pawan Prakash and Adrian
    Roitberg and Mingjie Liu and George Karypis and Mark Transtrum and Richard
    Hennig and Ellad B. Tadmor and Stefano Martiniani},
    booktitle={Forty-second International Conference on Machine Learning},
    year={2025},
    url={https://openreview.net/forum?id=gHGrzxFujU},
    archivePrefix={arXiv},
    eprint={2502.02582},
    primaryClass={cs.LG},
}