-
Notifications
You must be signed in to change notification settings - Fork 199
Open
Labels
enhancementNew feature or requestNew feature or requesthackathonscore-matching-performanceImproving the performance of score- and flow-matching methodsImproving the performance of score- and flow-matching methods
Description
🚀 Feature Request
One cool property of diffusion models is that their sampling process can be "guided" toward a certain goal. This goal can be to just truncate the prior, slightly shift the scale/shrink the prior or likelihood, and many other things.
To do so, we need to be able to "change" the output score to satisfy a certain goal (see here for some examples in the sbi context).
Describe the solution you'd like
- What is used in the sampling is the
potential
, i.e., see here or more specifically in the context of score matching we estimate the gradient of the potential. Here we would need an "entry" point/API to change the output on user request. - A simple way of guidance is "classifier-free" guidance, which essentially scales up or down the likelihood contribution to the posterior score, which would be a good start for this problem.
- This requires subtracting the time marginal prior, for which a suite of analytic solutions are provided in this branch Score-based iid sampling #1381 (not yet merged) i.e. it provides
$\nabla_{\theta_t}log p(\theta_t)$ . - The score returned by the potential when guidance is applied should change i.e. you need to change it to
$\alpha \cdot (s_\phi(\theta_t|x) - \nabla_{\theta_t}\log p(\theta_t)) + \nabla_{\theta_t} \log p(\theta_t).$ - While these kinds of manipulation work empirically well, they usually only ensure that the end result at
$t=0$ is "formally correct" and ignore it for$t>0$ . Performance should be evaluated, and if "correctors" are required, a User-Warning should be raised (i.e., usually, a Gibbs-Corrector is used in these cases, often referred to as "self-recurrence").
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or requesthackathonscore-matching-performanceImproving the performance of score- and flow-matching methodsImproving the performance of score- and flow-matching methods