Skip to content

Conversation

CharlelieLrt
Copy link
Collaborator

@CharlelieLrt CharlelieLrt commented Sep 5, 2025

PhysicsNeMo Pull Request

Description

Note: this PR is currently a draft and is open for feedback and suggestions. It will not be merged in its current form and might be broken down into smaller PRs.

Objectives

Overarching objective: refactor and improve all utilities and and models related to diffusion in order to consolidate these components into a framework. More precisely, these diffusion components should:

  • Support out of the box all current diffusion use-cases (CorrDiff, cBottle, StormCast, etc.)
  • Be composable
  • Be extendable
  • Be well-documented

PR objective: focuses on the EDM samplers (stochastic_sampler and deterministic_sampler). For these samplers, the objective is to be:

  1. Agnostic to the diffusion model
  2. Agnostic to the modality of the latent state x
  3. Agnostic to the modality of the conditioning (for conditional diffusion models)
  4. Support a large range of guidance for plug-and-play generation (e.g. DPS)
  5. Support multiple implementation of multi-diffusion

Solutions

  1. Refactored the stochastic_sampler functional interface into an object-oriented interface to facilitate future extensibility.

  2. Model agnostic: relies on a callback whose signature is assumed invariant. To be able to satisfy this invariant constraint we also provide a surface adapter (i.e. a thin wrapper) that modifies the signature of any given Module to ensure compliance.

  3. Latent state agnostic: simple refactors to avoid unnecessary assumptions on the shape of the latent state.

  4. Conditioning agnostic: all conditioning variables are packed into a dict of tensors. The sampler never accesses the conditioning as the model is responsible for handling the conditioning ops.

  5. Plug-and-play guidance: relies on callbacks passed to the sampler. Introduces a guidance API to facilitate creating these callbacks, and ensure compliance with the sampler requirements. For now two types of guidance are provided (model-based guidance for DPS, and data consistency guidance for inpainting/outpainting/channel infilling, etc.)

  6. Multi-diffusion: TBD. For now the plan is to defer multi-diffusion ops to a dedicated model wrapper.

Remaining items

- Multi-diffusion: in their current implementation, the samplers use a patching object to extract patches from the latent state x; it also calls methods from the model to extract the global positional embedding of these patches. These strong assumptions on the model APIs are not compatible with the objective (1) above. A better solution might be to defer all multi-diffusion ops to a model wrapper that extract patches, get the positional embeddings ,etc...

- Guidance: only support pre-defined guidance, no mechanism to allow user-defined guidance. Extend the range of available off-the-shelf guidances.

- Model-based guidance: only support a model that processes batches, and that is implemented in pytorch and compatible with torch.autograd.

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

Dependencies

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
@CharlelieLrt CharlelieLrt self-assigned this Sep 19, 2025
)
err1 = torch.abs((y - y_x0)) ** self.norm_ord # (*_y,)

# Compute log-likelihood p(y|x_0_hat)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is relatively specific to DPS, I believe. Other model-based guidance approaches may use a different parameterization of the time-dependent variance (rather than with gamma), or a different loss altogether (cBottle TC uses BinaryCrossEntropy

for guidance, guidance_args, guidance_kwargs in zip(
guidances, guidances_args, guidances_kwargs
):
if isinstance(guidance, ModelBasedGuidance):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a fundamental difference between kinds of guidance that use the clean and the noisy state (e.g. cBottle TC) so maybe worth making a separate class for guidance that does not require the clean state?


# Guidance (e.g. posterior sampling, etc...)
guidance_sum = 0.0
if guidances:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this modularity/allowing for multiple guidance terms!

# Activate gradient computation if needed for guidance
with torch.set_grad_enabled(req_grad):
if req_grad:
x_hat_in = x_hat.clone().detach().requires_grad_(True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: first detaching and then cloning is more efficient.

return torch.cat(x_generated)


class EDMStochasticSampler:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CharlelieLrt do you plan to decouple the sampler from the choice of time discretization? E.g., we could use this Heun sampler with other time steps than the ones proposed in the EDM paper.

import torch


class ModelBasedGuidance:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would maybe call this DPSGuidance (to be more precise, this implementation also assumes Gaussian noise model), since there are several other guidance methods; see, e.g., https://arxiv.org/pdf/2503.11043 for an overview.

x : torch.Tensor
The latent state of the diffusion model, typically of shape
:math:`(B, *)`.
sigma : torch.Tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think more generally the input to the model is t (which just coincides with sigma for the VE schedule in the EDM formulation).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants