-
Notifications
You must be signed in to change notification settings - Fork 199
Open
Labels
enhancementNew feature or requestNew feature or requesthackathonscore-matching-performanceImproving the performance of score- and flow-matching methodsImproving the performance of score- and flow-matching methods
Description
🚀 Feature Request
One cool property of diffusion models is that their sampling process can be "guided" toward a certain goal. This goal can be to just truncate the prior, slightly shift the scale/shrink the prior or likelihood, and many other things.
To do so, we need to be able to "change" the output score to satisfy a certain goal (see here for some examples in the sbi context).
Describe the solution you'd like
- What is used in the sampling is the
potential, i.e., see here or more specifically in the context of score matching we estimate the gradient of the potential. Here we would need an "entry" point/API to change the output on user request. - A simple way of guidance is "classifier-free" guidance, which essentially scales up or down the likelihood contribution to the posterior score, which would be a good start for this problem.
- This requires subtracting the time marginal prior, for which a suite of analytic solutions are provided in this branch Score-based iid sampling #1381 (not yet merged) i.e. it provides
$\nabla_{\theta_t}log p(\theta_t)$ . - The score returned by the potential when guidance is applied should change i.e. you need to change it to
$\alpha \cdot (s_\phi(\theta_t|x) - \nabla_{\theta_t}\log p(\theta_t)) + \nabla_{\theta_t} \log p(\theta_t).$ - While these kinds of manipulation work empirically well, they usually only ensure that the end result at
$t=0$ is "formally correct" and ignore it for$t>0$ . Performance should be evaluated, and if "correctors" are required, a User-Warning should be raised (i.e., usually, a Gibbs-Corrector is used in these cases, often referred to as "self-recurrence").
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or requesthackathonscore-matching-performanceImproving the performance of score- and flow-matching methodsImproving the performance of score- and flow-matching methods