-
Notifications
You must be signed in to change notification settings - Fork 199
Simformer #1621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Simformer #1621
Conversation
Removing code duplication on embedding net handing
After re-basing this PR to merge into main, a bunch of collaborators' commits from the original parent branch appeared here. Tried to clean a little by squashing commits, but it would have required 285 different conflicts solutions 😅 so I aborted the operation and had to keep everything as it is |
…t use of simformer (condition is an empty tensor)
… up time of not slow tests
…t is default True, in linear gaussian vf test
…ing a warning in case it is detected
…xture to gpu Pass device information to IID method in VectorFieldBasedPotential
Alright, as requested by Google:
I mark the below as the last commit for my GSoC. Nonetheless, I am still able to work more on this to implement advices and fixes after review👍 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I started the review and made some initial comments, but I noticed that there is a conflict with main
which should be resolved first.
Overall, it's amazing to have all this implemented - well done!
Here are a couple of general comments:
sde_type
handling: score-based uses it; FM variant drops it. This is correct, but should be explicit in docs (and maybe warn if provided to FM). Clarify in FlowMatchingSimformer that sde_type is ignored; consider explicitly documenting that a provided sde_type is dropped. (see above)- Trainer docstrings list a prior argument that does not exist in init. Remove “prior” from both Simformer and FlowMatchingSimformer init docstrings.
- In both trainer docstrings, the line “kwargs: ... passed to the default builder if score_estimator is a string” names the wrong parameter; it should reference mvf_estimator (and also be consistent with “model” naming in the factory).
- Normalize terminology for
time_emb_type
across the file: consistently use "sinusoidal" | "random_fourier" (some places say "fourier"). - In VectorFieldSimformer docstring, it would be helpful to add expected shapes and dtypes:
- inputs: [B, T, in_features]
- condition_mask: [B, T], bool
- edge_mask: Optional[[B, T, T]], bool
- t: [B] or [B, 1], float
- Minor: I find it a bit confusing to have
Simformer
andFMSimformer
as trainer classes and thenVectorFieldSimformer
as the actual NN class. Maybe, in both trainer class docstrings, add the sentence “This trainer uses the Simformer network (VectorFieldSimformer) under the hood.” Or, a potential renaming could be "SimformerTrainer", "FlowMatchingSimformerTrainer" and "SimformerNet". I prefer these long class names if they add clarity. VectorFieldSimformer
defaultnum_layers = 4
, but build_simformer_network default num_layers = 5.
I also made some small typo fixes and docs clarifications already locally and will push them now.
Looking forward to doing the full review.
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"execution_count": 16, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems unrelated to this PR and a left-over from testing. please remove if possible.
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"execution_count": 6, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems unrelated to this PR and a left-over from testing. please remove if possible.
sbi.inference.Simformer | ||
sbi.inference.FlowMatchingSimformer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sorted alphabetically, but maybe that's fine, it seems to be sorted semantically.
|
||
__all__ = ["FMPE", "MarginalTrainer", "NLE", "NPE", "NPSE", "NRE", "simulate_for_sbi"] | ||
__all__ = [ | ||
"Simformer", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please sort alphabetically.
assert num_nans + num_infs == 0, ( | ||
"Some invalid entries (NaN/Infs) were " | ||
"found in x. You probably passed these as the ground observed x's `x_obs`. " | ||
"Please, remove these values and provide reasonable observed x's to avoid " | ||
"the sampling process to run indefinitely." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, let's try to avoid using assert
statements and rather replace them with informative errors, e.g., here, an if
-statement and ValueError
would be appropriate.
assert num_nans + num_infs == 0, ( | ||
"Some invalid entries (NaN/Infs) were " | ||
"found in x. You probably passed these as the ground observed x's `x_obs`. " | ||
"Please, remove these values and provide reasonable observed x's to avoid " | ||
"the sampling process to run indefinitely." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, let's try to avoid using assert
statements and rather replace them with informative errors, e.g., here, an if
-statement and ValueError
would be appropriate.
|
||
class NeuralInference(ABC): | ||
"""Abstract base class for neural inference methods.""" | ||
def check_if_proposal_has_default_x(proposal: Any): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest to move this helper function somewhere else, in a utils file, or just further down below the main classes in this file. e.g., it seems to be used only in npe_base.py
, so we could move it there?
Also, can we make the type more precise, e.g., Union[Distribution, NeuralPosterior]
?
num_layers: int = 4, | ||
num_heads: int = 4, | ||
mlp_ratio: int = 2, | ||
ada_time: int = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ada_time: int = False
should be a bool typing: ada_time: bool = True/False.
Simformer
Important
This PR is part of Google Summer of Code 2025
Note
Before opening this PR, I initially experimented with the Simformer and auxiliary components in a separate branch of my sbi fork. You can find it here. I used such branch (
simformer-dev
) as a first environment where I could experiment solutions with more freedom, then I opened this PR once I got a minimum viable product. Such branch basically served as my working enviroment for the first month and a half of the GSOC. Nevertheless, all the code I finalized there has been fully incorporated into this PR you are reading.More specifically, in such branch I mainly worked on a first version of the Simformer neural network architecture and the "masked" interface, I also attempted to introduce a Joint distribution interface, i.e., a parallel interface to the current "Posterior" approach in sbi that could generalize better to the Simformer case—as the Simformer do not work by means of "posterior", "likelihood" or such, but more generally by means of arbitrary conditionals. Neverthless, the later has been dropped to rather implement the use of a Wrapper class that could adapt the more general Simformer approach to the existing sbi posterior interface (see below for more information)
Implemented the Simformer, Gloeckler et al. 2024 ICML. The Simformer aims to unify the various simulation-based inference paradigms (posterior, likelihood, or arbitrary conditional sampling) within a single framework, allowing users to sample from any conditional distribution of interest—potentially acting also by a novel data generator if one samples the unconditioned joint distribution of all variables.
The Simformer diverges from the standard sbi paradigm of data provided by means of
theta
andx
, it rather exploits a full tensorinputs
of data and two masks:condition_mask
to identify which variables are latent (to be inferred by the Simformer) and which are observed (ground data)edge_mask
to identify relationships between variables, equivalent to and adjacency matrix for a DAG. This mask will be directly used by the transformer attention block to mask-out certain attention scores.Design of the Masked Classes
To accomplish this, it has been necessary to create some "parallel" classes of the current
ScoreEstimator
,VectorFieldEstimator
, etc. to work by means of this "masked" paradigm.Generally, each "Masked" version of other objects are provided exactly below their counterpart in the same python file, e.g.
MaskedConditionalVectorFieldEstimator
is exactly below the code block ofConditionalVectorFieldEstimator
; and they simply consist in an overall re-factor of the original counterpart, where each use of a "theta
andx
" or "inputs
andcondition
" has been replaced with a general "inputs
,condtion_mask
, andedge_mask
".It has been also introduced a Wrapper class able to adapt the original API of the Posterior to the Simformer one, thanks to this class one is able to simply call$(\theta, x)$ setting to the full input tensor, and back.
build_conditional()
method directly on the Simformer inference object and obtain a standard Posterior object that works as always—given some fixed condition and edge masks. The Wrapper handles all the shapes automatically and perform auxiliary operations to pass the data to a Simformer network and the underlying masked estimator; this is done mainly through two helper functions:assemble_full_inputs()
anddisassemble_full_inputs()
, which are able to convert between theAt inference time, an
edge_mask
can be specified, otherwise it will beNone
(equivalent to a full ones tensor, but memory safer),condition_mask
instead must be specifically passed atbuild_conditional
time; another option is to directly use thebuild_posterior()
andbuild_likelihood()
method which will automatically generate an appropriate condition_mask based onposterior_latent_idx
andposterior_observed_idx
parameters specified at init() of the Simformer.Also at training time an$\text{Bernoulli}(p=0.5)$ .
edge_mask
can be specified, if not the default value will still beNone
, more generally the user can pass a Callable to generate condition or edge masks, so that one can simply choose the mask distributions they prefer. Sets of tensors/lists or even just one tensor can be passed as well. Masks are also generated just-in-time (JIT) for the training, that is, they are not provided atappend_simulation()
, but during the train() in order to save up memory. Differently from inference time, here if a condition mask is not specified, a default generator will be used, producing masks sampled by aNote that the Simformer potentially allows the user to set any mask of their choice both at training and inference time, it is rather duty of the user to provide coherent definitions (callables, sets, or fixed tensors) that make sense, e.g. if the user passes a specific edge mask at training time, the Simformer will learn that specific DAG structure, it is then duty of the user to pass a coherent edge_mask also when calling build_conditional, build_posterior or build_likelihood.
Furthermore, the Simformer is also able to manage invalid inputs (
nan
's andinf
's) natively, ifhandle_invalid_x=True
then the Simformer will automatically spot invalid inputs at training time (still JIT) and switch their state on the condition mask as latent (to be inferred), other than also replace such values with small Gaussian noise for numerical stability.Also, a Flow-matching equivalent of the Simformer (we assumed the above to be score-based) has been provided.
This PR then includes integration with the
mini-sbibm
benchmakr suite, and a notebook tutorial for the Simformer (underadvanced_tutorials/docs
), where I showcase its use. I also tried to make the API Reference as clear as possible for documentation.Refactor of existing code
Parts of the existing code have been refactored, mainly to avoid repetition of code and keep everything DRY. The most important pieces of code that have been modified are:
mean_t
,std_t
etc. into some standard Mixins (e.g., instead ofVEScoreEstimator(ConditionalScoreEstimator)
one now haveVarianceExplodingSDE
which definedmean_t
,std_t
etc., andVEScoreEstimator
becomesVEScoreEstimator(ConditionalScoreEstimator, VarianceExplodingSDE)
; so that I can also define easilyMaskedVEScoreEstimator(MaskedConditionalScoreEstimator, VarianceExplodingSDE)
without repeating the VE SDE pieces.)NeuralInference
interface, which has been split using a Mixin too (BaseNeuralInference
) which defines shared properties of bothNeuralInference
andMaskedNeuralInference
, this also requested some minor adjustments mainly for methods such as_resolve_prior()
and_resolve_estimator()
, most importantly a newNoPrior
object has been created as a temporary solution for Keep prior optional and remove unnecessary copies of theas from ImproperPrior. #1635ConditionalVectorFieldEstimator
and theMaskedConditionalVectorFieldEstimator
where simplified by moving shared code into a Mixin calledBaseConditionalVectorFieldEstimator
, mainly regardingmean_base
,std_base
properties, or methods such asdiffusion_fn()
Summary of modified files
Files I modified should count to be the following:
sbi/inference
sbi/inference/trainers/base.py
: AddedMaskedNeuralInference
.sbi/inference/trainers/vfpe/base_vf_inference.py
: AddedMaskedVectorFieldEstimatorBuilder
andMaskedVectorFieldInference
(subclass ofMaskedNeuralInference
).sbi/inference/trainers/vfpe/simformer.py
: New file introducing the Simformer inference class.sbi/neural_nets
sbi/neural_nets/factory.py
: Added support for building Simformer networks (simformer_nn
).sbi/neural_nets/estimators/base.py
: AddedMaskedConditionalEstimator
andMaskedConditionalVectorFieldEstimator
(subclass ofMaskedConditionalEstimator
).sbi/neural_nets/estimators/score_estimator.py
:MaskedConditionalScoreEstimator
(subclass ofMaskedConditionalVectorFieldEstimator
), placed directly aboveConditionalScoreEstimator
.MaskedVEScoreEstimator
(subclass ofMaskedConditionalScoreEstimator
).sbi/neural_nets/net_builders/vector_field_nets.py
:build_vector_field_estimator
updated to supportsimformer
andmasked-score
.MaskedSimformerBlock
,MaskedDiTBlock
,SimformerNet
(subclass ofMaskedVectorFieldNet
), andbuild_simformer_network
(defines default architecture parameters).sbi/utils
sbi/utils/vector_field_utils.py
: AddedMaskedVectorFieldNet
.sbi/analysis
sbi/analysis/plots.py
: Minor fix to ensure CPU conversion inensure_numpy()
(added.cpu()
before.numpy()
).Unit Test
Introduced benchmarks (
mini_sbibm
) and test for the simformer and related masked objects intests/linearGaussian_vector_field_test.py
tests/posterior_nn_test.py
tests/vector_field_nets_test.py
tests/vf_estimator_test.py
(which also includes shape tests on the Wrapper)tests/bm_test.py
Regarding linear gaussian tests, I tried to implement the simformer tests in existing methods as much as possible, nonetheless iid test and sde/ode sampling equivalence are still provided as separate dedicated tests and fixtures
New files
docs/advanced_tutorials/22_simformer.ipynb
sbi/inference/trainers/vfpe/simformer.py
: including both Score-based and Flow-matching Simformer interfacesThank you
Thank you sbi and Google for this opportunity. It has been so rewarding implementing the Simformer: not only I learned something completely new itself, but most importantly I understood how to do it: having to familiarize with new concepts, writing code within code made by others, and following indications of mentors are the real value of this experience. Special thanks to my mentors Manuel (@manuelgloeckler ) and Jan (@janfb ) for accepting my proposal, and @manuelgloeckler in particular for having helped me throughout the whole journey!