-
Notifications
You must be signed in to change notification settings - Fork 227
Update to the AdvancedVI@0.4 interface #2506
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 33 commits
ed6946c
a94269d
a4711a9
3f8068b
222a638
57097f5
a42eea8
798f319
69a4972
cbcb8b5
081d6ff
a32a673
1bcec3e
b142832
061ec35
736bd3e
fd434d8
57108ee
8dc8067
297c32a
3010b5e
0c04434
17a8290
626c5b5
cb2c618
2b08a4b
0e496c4
c1533a8
231d6e2
c2ae04a
69639ec
ef9aeb1
43c19aa
cc18528
162899a
0b79495
3818152
91a9afe
12539aa
0653bf1
f74ec38
406824f
f62e7b8
e3b7618
f0374b6
187a65c
a5021d1
218eb23
4714c3c
f712755
8086398
37f6b06
c717220
e9f7f1e
6a8c6ed
c4d73fb
feb1a57
ea417fc
4c9a538
dfa8d20
a18f581
f9528e0
fb150c7
8174725
dec108b
b0d791e
29373ee
d21e652
4c72501
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,50 +1,149 @@ | ||
|
||
module Variational | ||
|
||
using DistributionsAD: DistributionsAD | ||
using DynamicPPL: DynamicPPL | ||
using StatsBase: StatsBase | ||
using StatsFuns: StatsFuns | ||
using LogDensityProblems: LogDensityProblems | ||
using DynamicPPL | ||
using ADTypes | ||
using Distributions | ||
using LinearAlgebra | ||
using LogDensityProblems | ||
using Random | ||
|
||
using Random: Random | ||
import ..Turing: DEFAULT_ADTYPE, PROGRESS | ||
|
||
import AdvancedVI | ||
import Bijectors | ||
|
||
# Reexports | ||
using AdvancedVI: vi, ADVI, ELBO, elbo, TruncatedADAGrad, DecayedADAGrad | ||
export vi, ADVI, ELBO, elbo, TruncatedADAGrad, DecayedADAGrad | ||
|
||
""" | ||
make_logjoint(model::Model; weight = 1.0) | ||
Constructs the logjoint as a function of latent variables, i.e. the map z → p(x ∣ z) p(z). | ||
The weight used to scale the likelihood, e.g. when doing stochastic gradient descent one needs to | ||
use `DynamicPPL.MiniBatch` context to run the `Model` with a weight `num_total_obs / batch_size`. | ||
## Notes | ||
- For sake of efficiency, the returned function is closes over an instance of `VarInfo`. This means that you *might* run into some weird behaviour if you call this method sequentially using different types; if that's the case, just generate a new one for each type using `make_logjoint`. | ||
""" | ||
function make_logjoint(model::DynamicPPL.Model; weight=1.0) | ||
# setup | ||
using AdvancedVI: RepGradELBO, ScoreGradELBO, DoG, DoWG | ||
export RepGradELBO, ScoreGradELBO, DoG, DoWG | ||
|
||
export vi, q_init, q_meanfield_gaussian, q_fullrank_gaussian | ||
|
||
include("bijectors.jl") | ||
|
||
function make_logdensity(model::DynamicPPL.Model) | ||
weight = 1.0 | ||
ctx = DynamicPPL.MiniBatchContext(DynamicPPL.DefaultContext(), weight) | ||
f = DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx) | ||
return Base.Fix1(LogDensityProblems.logdensity, f) | ||
return DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx) | ||
end | ||
|
||
# objectives | ||
function (elbo::ELBO)( | ||
function initialize_gaussian_scale( | ||
rng::Random.AbstractRNG, | ||
alg::AdvancedVI.VariationalInference, | ||
q, | ||
model::DynamicPPL.Model, | ||
num_samples; | ||
weight=1.0, | ||
location::AbstractVector, | ||
scale::AbstractMatrix; | ||
num_samples::Int=10, | ||
num_max_trials::Int=10, | ||
reduce_factor=one(eltype(scale)) / 2, | ||
) | ||
prob = make_logdensity(model) | ||
ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) | ||
varinfo = DynamicPPL.VarInfo(model) | ||
|
||
n_trial = 0 | ||
while true | ||
q = AdvancedVI.MvLocationScale(location, scale, Normal()) | ||
b = Bijectors.bijector(model; varinfo=varinfo) | ||
q_trans = Bijectors.transformed(q, Bijectors.inverse(b)) | ||
energy = mean(ℓπ, eachcol(rand(rng, q_trans, num_samples))) | ||
|
||
if isfinite(energy) | ||
return scale | ||
elseif n_trial == num_max_trials | ||
error("Could not find an initial") | ||
end | ||
|
||
scale = reduce_factor * scale | ||
n_trial += 1 | ||
end | ||
end | ||
|
||
function q_init( | ||
rng::Random.AbstractRNG, | ||
model::DynamicPPL.Model; | ||
location::Union{Nothing,<:AbstractVector}=nothing, | ||
scale::Union{Nothing,<:Diagonal,<:LowerTriangular}=nothing, | ||
meanfield::Bool=true, | ||
basedist::Distributions.UnivariateDistribution=Normal(), | ||
kwargs..., | ||
) | ||
return elbo(rng, alg, q, make_logjoint(model; weight=weight), num_samples; kwargs...) | ||
varinfo = DynamicPPL.VarInfo(model) | ||
# Use linked `varinfo` to determine the correct number of parameters. | ||
# TODO: Replace with `length` once this is implemented for `VarInfo`. | ||
varinfo_linked = DynamicPPL.link(varinfo, model) | ||
num_params = length(varinfo_linked[:]) | ||
|
||
μ = if isnothing(location) | ||
zeros(num_params) | ||
else | ||
@assert length(location) == num_params "Length of the provided location vector, $(length(location)), does not match dimension of the target distribution, $(num_params)." | ||
location | ||
end | ||
|
||
L = if isnothing(scale) | ||
if meanfield | ||
initialize_gaussian_scale(rng, model, μ, Diagonal(ones(num_params)); kwargs...) | ||
else | ||
L0 = LowerTriangular(Matrix{Float64}(I, num_params, num_params)) | ||
initialize_gaussian_scale(rng, model, μ, L0; kwargs...) | ||
end | ||
else | ||
@assert size(scale) == (num_params, num_params) "Dimensions of the provided scale matrix, $(size(scale)), does not match the dimension of the target distribution, $(num_params)." | ||
if meanfield | ||
Diagonal(diag(scale)) | ||
else | ||
scale | ||
end | ||
end | ||
q = AdvancedVI.MvLocationScale(μ, L, basedist) | ||
b = Bijectors.bijector(model; varinfo=varinfo) | ||
return Bijectors.transformed(q, Bijectors.inverse(b)) | ||
end | ||
|
||
# VI algorithms | ||
include("advi.jl") | ||
function q_meanfield_gaussian( | ||
rng::Random.AbstractRNG, | ||
model::DynamicPPL.Model; | ||
location::Union{Nothing,<:AbstractVector}=nothing, | ||
scale::Union{Nothing,<:Diagonal}=nothing, | ||
kwargs..., | ||
) | ||
return q_init(rng, model; location, scale, meanfield=true, basedist=Normal(), kwargs...) | ||
end | ||
|
||
function q_fullrank_gaussian( | ||
rng::Random.AbstractRNG, | ||
model::DynamicPPL.Model; | ||
location::Union{Nothing,<:AbstractVector}=nothing, | ||
scale::Union{Nothing,<:LowerTriangular}=nothing, | ||
kwargs..., | ||
) | ||
return q_init(rng, model; location, scale, meanfield=false, basedist=Normal(), kwargs...) | ||
yebai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
|
||
function vi( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function is a thin wrapper around
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given that the function |
||
model::DynamicPPL.Model, | ||
q::Bijectors.TransformedDistribution, | ||
n_iterations::Int; | ||
objective=RepGradELBO(10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()), | ||
show_progress::Bool=PROGRESS[], | ||
optimizer=AdvancedVI.DoWG(), | ||
averager=AdvancedVI.PolynomialAveraging(), | ||
operator=AdvancedVI.ProximalLocationScaleEntropy(), | ||
adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE, | ||
kwargs..., | ||
) | ||
return AdvancedVI.optimize( | ||
make_logdensity(model), | ||
objective, | ||
q, | ||
n_iterations; | ||
show_progress=show_progress, | ||
adtype, | ||
optimizer, | ||
averager, | ||
operator, | ||
kwargs..., | ||
) | ||
end | ||
|
||
end |
This file was deleted.
Uh oh!
There was an error while loading. Please reload this page.