-
Notifications
You must be signed in to change notification settings - Fork 457
[DRAFT]: Refactor of diffusion samplers #1106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[DRAFT]: Refactor of diffusion samplers #1106
Conversation
Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
) | ||
err1 = torch.abs((y - y_x0)) ** self.norm_ord # (*_y,) | ||
|
||
# Compute log-likelihood p(y|x_0_hat) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is relatively specific to DPS, I believe. Other model-based guidance approaches may use a different parameterization of the time-dependent variance (rather than with gamma), or a different loss altogether (cBottle TC uses BinaryCrossEntropy
for guidance, guidance_args, guidance_kwargs in zip( | ||
guidances, guidances_args, guidances_kwargs | ||
): | ||
if isinstance(guidance, ModelBasedGuidance): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a fundamental difference between kinds of guidance that use the clean and the noisy state (e.g. cBottle TC) so maybe worth making a separate class for guidance that does not require the clean state?
|
||
# Guidance (e.g. posterior sampling, etc...) | ||
guidance_sum = 0.0 | ||
if guidances: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this modularity/allowing for multiple guidance terms!
# Activate gradient computation if needed for guidance | ||
with torch.set_grad_enabled(req_grad): | ||
if req_grad: | ||
x_hat_in = x_hat.clone().detach().requires_grad_(True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: first detaching and then cloning is more efficient.
return torch.cat(x_generated) | ||
|
||
|
||
class EDMStochasticSampler: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@CharlelieLrt do you plan to decouple the sampler from the choice of time discretization? E.g., we could use this Heun sampler with other time steps than the ones proposed in the EDM paper.
import torch | ||
|
||
|
||
class ModelBasedGuidance: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would maybe call this DPSGuidance (to be more precise, this implementation also assumes Gaussian noise model), since there are several other guidance methods; see, e.g., https://arxiv.org/pdf/2503.11043 for an overview.
x : torch.Tensor | ||
The latent state of the diffusion model, typically of shape | ||
:math:`(B, *)`. | ||
sigma : torch.Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think more generally the input to the model is t
(which just coincides with sigma
for the VE schedule in the EDM formulation).
PhysicsNeMo Pull Request
Description
Note: this PR is currently a draft and is open for feedback and suggestions. It will not be merged in its current form and might be broken down into smaller PRs.
Objectives
Overarching objective: refactor and improve all utilities and and models related to diffusion in order to consolidate these components into a framework. More precisely, these diffusion components should:
PR objective: focuses on the EDM samplers (
stochastic_sampler
anddeterministic_sampler
). For these samplers, the objective is to be:x
Solutions
Refactored the
stochastic_sampler
functional interface into an object-oriented interface to facilitate future extensibility.Model agnostic: relies on a callback whose signature is assumed invariant. To be able to satisfy this invariant constraint we also provide a surface adapter (i.e. a thin wrapper) that modifies the signature of any given
Module
to ensure compliance.Latent state agnostic: simple refactors to avoid unnecessary assumptions on the shape of the latent state.
Conditioning agnostic: all conditioning variables are packed into a dict of tensors. The sampler never accesses the conditioning as the model is responsible for handling the conditioning ops.
Plug-and-play guidance: relies on callbacks passed to the sampler. Introduces a guidance API to facilitate creating these callbacks, and ensure compliance with the sampler requirements. For now two types of guidance are provided (model-based guidance for DPS, and data consistency guidance for inpainting/outpainting/channel infilling, etc.)
Multi-diffusion: TBD. For now the plan is to defer multi-diffusion ops to a dedicated model wrapper.
Remaining items
- Multi-diffusion: in their current implementation, the samplers use a
patching
object to extract patches from the latent statex
; it also calls methods from the model to extract the global positional embedding of these patches. These strong assumptions on the model APIs are not compatible with the objective (1) above. A better solution might be to defer all multi-diffusion ops to a model wrapper that extract patches, get the positional embeddings ,etc...- Guidance: only support pre-defined guidance, no mechanism to allow user-defined guidance. Extend the range of available off-the-shelf guidances.
- Model-based guidance: only support a model that processes batches, and that is implemented in pytorch and compatible with
torch.autograd
.Checklist
Dependencies