From ed6946c4c8ffcec0494b74ce6978bb4f13883c53 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:09:15 -0400 Subject: [PATCH 01/52] update to match the AdvancedVI@0.3 interface --- Project.toml | 4 +- src/variational/VariationalInference.jl | 185 ++++++++++++++++++++---- src/variational/advi.jl | 140 ------------------ src/variational/bijectors.jl | 70 +++++++++ 4 files changed, 227 insertions(+), 172 deletions(-) delete mode 100644 src/variational/advi.jl create mode 100644 src/variational/bijectors.jl diff --git a/Project.toml b/Project.toml index 459f11dcbd..6fa6daa2a1 100644 --- a/Project.toml +++ b/Project.toml @@ -38,6 +38,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" [weakdeps] DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" @@ -54,7 +55,7 @@ Accessors = "0.1" AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6" AdvancedMH = "0.8" AdvancedPS = "0.6.0" -AdvancedVI = "0.2" +AdvancedVI = "0.3.1" BangBang = "0.4.2" Bijectors = "0.14, 0.15" Compat = "4.15.0" @@ -85,6 +86,7 @@ Statistics = "1.6" StatsAPI = "1.6" StatsBase = "0.32, 0.33, 0.34" StatsFuns = "0.8, 0.9, 1" +UnicodePlots = "3" julia = "1.10" [extras] diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 189d3f7001..db95093508 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -1,50 +1,173 @@ + module Variational -using DistributionsAD: DistributionsAD -using DynamicPPL: DynamicPPL -using StatsBase: StatsBase -using StatsFuns: StatsFuns -using LogDensityProblems: LogDensityProblems +using DynamicPPL +using ADTypes using Distributions +using LinearAlgebra +using LogDensityProblems +using Random +using UnicodePlots -using Random: Random +import ..Turing: DEFAULT_ADTYPE, PROGRESS import AdvancedVI import Bijectors # Reexports -using AdvancedVI: vi, ADVI, ELBO, elbo, TruncatedADAGrad, DecayedADAGrad -export vi, ADVI, ELBO, elbo, TruncatedADAGrad, DecayedADAGrad - -""" - make_logjoint(model::Model; weight = 1.0) -Constructs the logjoint as a function of latent variables, i.e. the map z → p(x ∣ z) p(z). -The weight used to scale the likelihood, e.g. when doing stochastic gradient descent one needs to -use `DynamicPPL.MiniBatch` context to run the `Model` with a weight `num_total_obs / batch_size`. -## Notes -- For sake of efficiency, the returned function is closes over an instance of `VarInfo`. This means that you *might* run into some weird behaviour if you call this method sequentially using different types; if that's the case, just generate a new one for each type using `make_logjoint`. -""" -function make_logjoint(model::DynamicPPL.Model; weight=1.0) - # setup +using AdvancedVI: RepGradELBO, ScoreGradELBO, DoG, DoWG +export vi, RepGradELBO, ScoreGradELBO, DoG, DoWG + +export meanfield_gaussian, fullrank_gaussian + +include("bijectors.jl") + +function make_logdensity(model::DynamicPPL.Model) + weight = 1.0 ctx = DynamicPPL.MiniBatchContext(DynamicPPL.DefaultContext(), weight) - f = DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx) - return Base.Fix1(LogDensityProblems.logdensity, f) + return DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx) +end + +function initialize_gaussian_scale( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + location::AbstractVector, + scale::AbstractMatrix; + num_samples::Int = 10, + num_max_trials::Int = 10, + reduce_factor = one(eltype(scale))/2 +) + prob = make_logdensity(model) + ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) + varinfo = DynamicPPL.VarInfo(model) + + n_trial = 0 + while true + q = AdvancedVI.MvLocationScale(location, scale, Normal()) + b = Bijectors.bijector(model; varinfo=varinfo) + q_trans = Bijectors.transformed(q, Bijectors.inverse(b)) + energy = mean(ℓπ, eachcol(rand(rng, q_trans, num_samples))) + + if isfinite(energy) + return scale + elseif n_trial == num_max_trials + error("Could not find an initial") + end + + scale = reduce_factor*scale + n_trial += 1 + end +end + +function meanfield_gaussian( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + location::Union{Nothing, <:AbstractVector} = nothing, + scale::Union{Nothing, <:Diagonal} = nothing; + kwargs... +) + varinfo = DynamicPPL.VarInfo(model) + # Use linked `varinfo` to determine the correct number of parameters. + # TODO: Replace with `length` once this is implemented for `VarInfo`. + varinfo_linked = DynamicPPL.link(varinfo, model) + num_params = length(varinfo_linked[:]) + + μ = if isnothing(location) + zeros(num_params) + else + @assert length(location) == num_params "Length of the provided location vector, $(length(location)), does not match dimension of the target distribution, $(num_params)." + location + end + + L = if isnothing(scale) + initialize_gaussian_scale(rng, model, μ, Diagonal(ones(num_params)); kwargs...) + else + @assert size(scale) == (num_params, num_params) "Dimensions of the provided scale matrix, $(size(scale)), does not match the dimension of the target distribution, $(num_params)." + L = scale + end + + q = AdvancedVI.MeanFieldGaussian(μ, L) + b = Bijectors.bijector(model; varinfo=varinfo) + return Bijectors.transformed(q, Bijectors.inverse(b)) end -# objectives -function (elbo::ELBO)( +function meanfield_gaussian( + model::DynamicPPL.Model, + location::Union{Nothing, <:AbstractVector} = nothing, + scale::Union{Nothing, <:Diagonal} = nothing; + kwargs... +) + meanfield_gaussian(Random.default_rng(), model, location, scale; kwargs...) +end + +function fullrank_gaussian( rng::Random.AbstractRNG, - alg::AdvancedVI.VariationalInference, - q, model::DynamicPPL.Model, - num_samples; - weight=1.0, - kwargs..., + location::Union{Nothing, <:AbstractVector} = nothing, + scale::Union{Nothing, <:LowerTriangular} = nothing; + kwargs... ) - return elbo(rng, alg, q, make_logjoint(model; weight=weight), num_samples; kwargs...) + varinfo = DynamicPPL.VarInfo(model) + # Use linked `varinfo` to determine the correct number of parameters. + # TODO: Replace with `length` once this is implemented for `VarInfo`. + varinfo_linked = DynamicPPL.link(varinfo, model) + num_params = length(varinfo_linked[:]) + + μ = if isnothing(location) + zeros(num_params) + else + @assert length(location) == num_params "Length of the provided location vector, $(length(location)), does not match dimension of the target distribution, $(num_params)." + location + end + + L = if isnothing(scale) + L0 = LowerTriangular(Matrix{Float64}(I, num_params, num_params)) + initialize_gaussian_scale(rng, model, μ, L0; kwargs...) + else + @assert size(scale) == (num_params, num_params) "Dimensions of the provided scale matrix, $(size(scale)), does not match the dimension of the target distribution, $(num_params)." + scale + end + + q = AdvancedVI.FullRankGaussian(μ, L) + b = Bijectors.bijector(model; varinfo=varinfo) + return Bijectors.transformed(q, Bijectors.inverse(b)) +end + +function fullrank_gaussian( + model::DynamicPPL.Model, + location::Union{Nothing, <:AbstractVector} = nothing, + scale::Union{Nothing, <:Diagonal} = nothing; + kwargs... +) + fullrank_gaussian(Random.default_rng(), model, location, scale; kwargs...) end -# VI algorithms -include("advi.jl") +function vi( + model::DynamicPPL.Model, + q::Bijectors.TransformedDistribution, + n_iterations::Int; + objective=RepGradELBO(10, entropy=AdvancedVI.ClosedFormEntropyZeroGradient()), + show_progress::Bool=PROGRESS[], + optimizer=AdvancedVI.DoWG(), + averager=AdvancedVI.PolynomialAveraging(), + operator=AdvancedVI.ProximalLocationScaleEntropy(), + adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE, +) + q_avg_trans, _, stats, _ = AdvancedVI.optimize( + make_logdensity(model), + objective, + q, + n_iterations; + show_progress=show_progress, + adtype, + optimizer, + averager, + operator, + ) + if show_progress + lineplot([stat.elbo for stat in stats], ylabel="Objective", xlabel="Iteration") |> display + end + return q_avg_trans +end end diff --git a/src/variational/advi.jl b/src/variational/advi.jl deleted file mode 100644 index ec3e6552e3..0000000000 --- a/src/variational/advi.jl +++ /dev/null @@ -1,140 +0,0 @@ -# TODO: Move to Bijectors.jl if we find further use for this. -""" - wrap_in_vec_reshape(f, in_size) - -Wraps a bijector `f` such that it operates on vectors of length `prod(in_size)` and produces -a vector of length `prod(Bijectors.output(f, in_size))`. -""" -function wrap_in_vec_reshape(f, in_size) - vec_in_length = prod(in_size) - reshape_inner = Bijectors.Reshape((vec_in_length,), in_size) - out_size = Bijectors.output_size(f, in_size) - vec_out_length = prod(out_size) - reshape_outer = Bijectors.Reshape(out_size, (vec_out_length,)) - return reshape_outer ∘ f ∘ reshape_inner -end - -""" - bijector(model::Model[, sym2ranges = Val(false)]) - -Returns a `Stacked <: Bijector` which maps from the support of the posterior to ℝᵈ with `d` -denoting the dimensionality of the latent variables. -""" -function Bijectors.bijector( - model::DynamicPPL.Model, ::Val{sym2ranges}=Val(false); varinfo=DynamicPPL.VarInfo(model) -) where {sym2ranges} - num_params = sum([ - size(varinfo.metadata[sym].vals, 1) for sym in keys(varinfo.metadata) - ]) - - dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...) - - num_ranges = sum([ - length(varinfo.metadata[sym].ranges) for sym in keys(varinfo.metadata) - ]) - ranges = Vector{UnitRange{Int}}(undef, num_ranges) - idx = 0 - range_idx = 1 - - # ranges might be discontinuous => values are vectors of ranges rather than just ranges - sym_lookup = Dict{Symbol,Vector{UnitRange{Int}}}() - for sym in keys(varinfo.metadata) - sym_lookup[sym] = Vector{UnitRange{Int}}() - for r in varinfo.metadata[sym].ranges - ranges[range_idx] = idx .+ r - push!(sym_lookup[sym], ranges[range_idx]) - range_idx += 1 - end - - idx += varinfo.metadata[sym].ranges[end][end] - end - - bs = map(tuple(dists...)) do d - b = Bijectors.bijector(d) - if d isa Distributions.UnivariateDistribution - b - else - wrap_in_vec_reshape(b, size(d)) - end - end - - if sym2ranges - return ( - Bijectors.Stacked(bs, ranges), - (; collect(zip(keys(sym_lookup), values(sym_lookup)))...), - ) - else - return Bijectors.Stacked(bs, ranges) - end -end - -""" - meanfield([rng, ]model::Model) - -Creates a mean-field approximation with multivariate normal as underlying distribution. -""" -meanfield(model::DynamicPPL.Model) = meanfield(Random.default_rng(), model) -function meanfield(rng::Random.AbstractRNG, model::DynamicPPL.Model) - # Setup. - varinfo = DynamicPPL.VarInfo(model) - # Use linked `varinfo` to determine the correct number of parameters. - # TODO: Replace with `length` once this is implemented for `VarInfo`. - varinfo_linked = DynamicPPL.link(varinfo, model) - num_params = length(varinfo_linked[:]) - - # initial params - μ = randn(rng, num_params) - σ = StatsFuns.softplus.(randn(rng, num_params)) - - # Construct the base family. - d = DistributionsAD.TuringDiagMvNormal(μ, σ) - - # Construct the bijector constrained → unconstrained. - b = Bijectors.bijector(model; varinfo=varinfo) - - # We want to transform from unconstrained space to constrained, - # hence we need the inverse of `b`. - return Bijectors.transformed(d, Bijectors.inverse(b)) -end - -# Overloading stuff from `AdvancedVI` to specialize for Turing -function AdvancedVI.update(d::DistributionsAD.TuringDiagMvNormal, μ, σ) - return DistributionsAD.TuringDiagMvNormal(μ, σ) -end -function AdvancedVI.update(td::Bijectors.TransformedDistribution, θ...) - return Bijectors.transformed(AdvancedVI.update(td.dist, θ...), td.transform) -end -function AdvancedVI.update( - td::Bijectors.TransformedDistribution{<:DistributionsAD.TuringDiagMvNormal}, - θ::AbstractArray, -) - # `length(td.dist) != length(td)` if `td.transform` changes the dimensionality, - # so we need to use the length of the underlying distribution `td.dist` here. - # TODO: Check if we can get away with `view` instead of `getindex` for all AD backends. - μ, ω = θ[begin:(begin + length(td.dist) - 1)], θ[(begin + length(td.dist)):end] - return AdvancedVI.update(td, μ, StatsFuns.softplus.(ω)) -end - -function AdvancedVI.vi( - model::DynamicPPL.Model, alg::AdvancedVI.ADVI; optimizer=AdvancedVI.TruncatedADAGrad() -) - q = meanfield(model) - return AdvancedVI.vi(model, alg, q; optimizer=optimizer) -end - -function AdvancedVI.vi( - model::DynamicPPL.Model, - alg::AdvancedVI.ADVI, - q::Bijectors.TransformedDistribution{<:DistributionsAD.TuringDiagMvNormal}; - optimizer=AdvancedVI.TruncatedADAGrad(), -) - # Initial parameters for mean-field approx - μ, σs = StatsBase.params(q) - θ = vcat(μ, StatsFuns.invsoftplus.(σs)) - - # Optimize - AdvancedVI.optimize!(elbo, alg, q, make_logjoint(model), θ; optimizer=optimizer) - - # Return updated `Distribution` - return AdvancedVI.update(q, θ) -end diff --git a/src/variational/bijectors.jl b/src/variational/bijectors.jl new file mode 100644 index 0000000000..e0633493f6 --- /dev/null +++ b/src/variational/bijectors.jl @@ -0,0 +1,70 @@ + +# TODO: Move to Bijectors.jl if we find further use for this. +""" + wrap_in_vec_reshape(f, in_size) + +Wraps a bijector `f` such that it operates on vectors of length `prod(in_size)` and produces +a vector of length `prod(Bijectors.output(f, in_size))`. +""" +function wrap_in_vec_reshape(f, in_size) + vec_in_length = prod(in_size) + reshape_inner = Bijectors.Reshape((vec_in_length,), in_size) + out_size = Bijectors.output_size(f, in_size) + vec_out_length = prod(out_size) + reshape_outer = Bijectors.Reshape(out_size, (vec_out_length,)) + return reshape_outer ∘ f ∘ reshape_inner +end + +""" + bijector(model::Model[, sym2ranges = Val(false)]) + +Returns a `Stacked <: Bijector` which maps from the support of the posterior to ℝᵈ with `d` +denoting the dimensionality of the latent variables. +""" +function Bijectors.bijector( + model::DynamicPPL.Model, ::Val{sym2ranges}=Val(false); varinfo=DynamicPPL.VarInfo(model) +) where {sym2ranges} + num_params = sum([ + size(varinfo.metadata[sym].vals, 1) for sym in keys(varinfo.metadata) + ]) + + dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...) + + num_ranges = sum([ + length(varinfo.metadata[sym].ranges) for sym in keys(varinfo.metadata) + ]) + ranges = Vector{UnitRange{Int}}(undef, num_ranges) + idx = 0 + range_idx = 1 + + # ranges might be discontinuous => values are vectors of ranges rather than just ranges + sym_lookup = Dict{Symbol,Vector{UnitRange{Int}}}() + for sym in keys(varinfo.metadata) + sym_lookup[sym] = Vector{UnitRange{Int}}() + for r in varinfo.metadata[sym].ranges + ranges[range_idx] = idx .+ r + push!(sym_lookup[sym], ranges[range_idx]) + range_idx += 1 + end + + idx += varinfo.metadata[sym].ranges[end][end] + end + + bs = map(tuple(dists...)) do d + b = Bijectors.bijector(d) + if d isa Distributions.UnivariateDistribution + b + else + wrap_in_vec_reshape(b, size(d)) + end + end + + if sym2ranges + return ( + Bijectors.Stacked(bs, ranges), + (; collect(zip(keys(sym_lookup), values(sym_lookup)))...), + ) + else + return Bijectors.Stacked(bs, ranges) + end +end From a94269d5ec189826ebc89bb025520e35b1329198 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:21:33 -0400 Subject: [PATCH 02/52] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/VariationalInference.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index db95093508..0d773a8ffc 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -33,9 +33,9 @@ function initialize_gaussian_scale( model::DynamicPPL.Model, location::AbstractVector, scale::AbstractMatrix; - num_samples::Int = 10, - num_max_trials::Int = 10, - reduce_factor = one(eltype(scale))/2 + num_samples::Int=10, + num_max_trials::Int=10, + reduce_factor=one(eltype(scale)) / 2, ) prob = make_logdensity(model) ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) From a4711a9a1e4493ee10be811a3ca85f48bcbcb58e Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:21:44 -0400 Subject: [PATCH 03/52] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/VariationalInference.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 0d773a8ffc..8d5084626a 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -93,9 +93,9 @@ end function meanfield_gaussian( model::DynamicPPL.Model, - location::Union{Nothing, <:AbstractVector} = nothing, - scale::Union{Nothing, <:Diagonal} = nothing; - kwargs... + location::Union{Nothing,<:AbstractVector}=nothing, + scale::Union{Nothing,<:Diagonal}=nothing; + kwargs..., ) meanfield_gaussian(Random.default_rng(), model, location, scale; kwargs...) end From 3f8068be6dfd1c0c782dba363c8c2898d152a083 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:21:50 -0400 Subject: [PATCH 04/52] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/VariationalInference.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 8d5084626a..b9c112b3b5 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -54,7 +54,7 @@ function initialize_gaussian_scale( error("Could not find an initial") end - scale = reduce_factor*scale + scale = reduce_factor * scale n_trial += 1 end end From 222a638d44ff639895d348817c15cecc56fcc176 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:21:55 -0400 Subject: [PATCH 05/52] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/VariationalInference.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index b9c112b3b5..992ca09607 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -97,7 +97,7 @@ function meanfield_gaussian( scale::Union{Nothing,<:Diagonal}=nothing; kwargs..., ) - meanfield_gaussian(Random.default_rng(), model, location, scale; kwargs...) + return meanfield_gaussian(Random.default_rng(), model, location, scale; kwargs...) end function fullrank_gaussian( From 57097f5c6c372b6dc7ada063f2a4ad0d6e27750d Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:22:00 -0400 Subject: [PATCH 06/52] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/VariationalInference.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 992ca09607..9aad42e438 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -62,9 +62,9 @@ end function meanfield_gaussian( rng::Random.AbstractRNG, model::DynamicPPL.Model, - location::Union{Nothing, <:AbstractVector} = nothing, - scale::Union{Nothing, <:Diagonal} = nothing; - kwargs... + location::Union{Nothing,<:AbstractVector}=nothing, + scale::Union{Nothing,<:Diagonal}=nothing; + kwargs..., ) varinfo = DynamicPPL.VarInfo(model) # Use linked `varinfo` to determine the correct number of parameters. From a42eea8e4f6f197265061430799c60ef8c6f04ec Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:22:06 -0400 Subject: [PATCH 07/52] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/VariationalInference.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 9aad42e438..cea01500cd 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -135,9 +135,9 @@ end function fullrank_gaussian( model::DynamicPPL.Model, - location::Union{Nothing, <:AbstractVector} = nothing, - scale::Union{Nothing, <:Diagonal} = nothing; - kwargs... + location::Union{Nothing,<:AbstractVector}=nothing, + scale::Union{Nothing,<:Diagonal}=nothing; + kwargs..., ) fullrank_gaussian(Random.default_rng(), model, location, scale; kwargs...) end From 798f3198f681f23e3bd7605233ddf64508571e75 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:22:12 -0400 Subject: [PATCH 08/52] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/VariationalInference.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index cea01500cd..0a676cf3ff 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -146,7 +146,7 @@ function vi( model::DynamicPPL.Model, q::Bijectors.TransformedDistribution, n_iterations::Int; - objective=RepGradELBO(10, entropy=AdvancedVI.ClosedFormEntropyZeroGradient()), + objective=RepGradELBO(10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()), show_progress::Bool=PROGRESS[], optimizer=AdvancedVI.DoWG(), averager=AdvancedVI.PolynomialAveraging(), From 69a49720d2c9b4f37504cd458e298bd6cd09be14 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:22:20 -0400 Subject: [PATCH 09/52] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/VariationalInference.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 0a676cf3ff..63ab1540bc 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -139,7 +139,7 @@ function fullrank_gaussian( scale::Union{Nothing,<:Diagonal}=nothing; kwargs..., ) - fullrank_gaussian(Random.default_rng(), model, location, scale; kwargs...) + return fullrank_gaussian(Random.default_rng(), model, location, scale; kwargs...) end function vi( From cbcb8b5744604aad363c6d6ad8cf2d86f3a3bb2e Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:22:26 -0400 Subject: [PATCH 10/52] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/VariationalInference.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 63ab1540bc..cbf4aab779 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -151,7 +151,7 @@ function vi( optimizer=AdvancedVI.DoWG(), averager=AdvancedVI.PolynomialAveraging(), operator=AdvancedVI.ProximalLocationScaleEntropy(), - adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE, + adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE, ) q_avg_trans, _, stats, _ = AdvancedVI.optimize( make_logdensity(model), From 081d6ff497daf6044fb93818a8c369d8dea519cb Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:22:48 -0400 Subject: [PATCH 11/52] remove plotting --- Project.toml | 2 -- src/variational/VariationalInference.jl | 8 +++----- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 6fa6daa2a1..ae940ca541 100644 --- a/Project.toml +++ b/Project.toml @@ -38,7 +38,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" [weakdeps] DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" @@ -86,7 +85,6 @@ Statistics = "1.6" StatsAPI = "1.6" StatsBase = "0.32, 0.33, 0.34" StatsFuns = "0.8, 0.9, 1" -UnicodePlots = "3" julia = "1.10" [extras] diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index db95093508..d532444256 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -152,8 +152,9 @@ function vi( averager=AdvancedVI.PolynomialAveraging(), operator=AdvancedVI.ProximalLocationScaleEntropy(), adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE, + kwargs... ) - q_avg_trans, _, stats, _ = AdvancedVI.optimize( + return AdvancedVI.optimize( make_logdensity(model), objective, q, @@ -163,11 +164,8 @@ function vi( optimizer, averager, operator, + kwargs... ) - if show_progress - lineplot([stat.elbo for stat in stats], ylabel="Objective", xlabel="Iteration") |> display - end - return q_avg_trans end end From 1bcec3e18e15d0f8ff45fc72008b32060b1e7435 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:29:52 -0400 Subject: [PATCH 12/52] fix formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/VariationalInference.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index f0b98bb314..4c1a0c289c 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -103,9 +103,9 @@ end function fullrank_gaussian( rng::Random.AbstractRNG, model::DynamicPPL.Model, - location::Union{Nothing, <:AbstractVector} = nothing, - scale::Union{Nothing, <:LowerTriangular} = nothing; - kwargs... + location::Union{Nothing,<:AbstractVector}=nothing, + scale::Union{Nothing,<:LowerTriangular}=nothing; + kwargs..., ) varinfo = DynamicPPL.VarInfo(model) # Use linked `varinfo` to determine the correct number of parameters. From b142832c4d3cc0762bf2277d1d1905c0ac2a1c1a Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:30:04 -0400 Subject: [PATCH 13/52] fix formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/VariationalInference.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 4c1a0c289c..c4d7fe4b04 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -152,7 +152,7 @@ function vi( averager=AdvancedVI.PolynomialAveraging(), operator=AdvancedVI.ProximalLocationScaleEntropy(), adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE, - kwargs... + kwargs..., ) return AdvancedVI.optimize( make_logdensity(model), From 061ec35b66b8fed680bd67e17a3960b7f71166f9 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:30:10 -0400 Subject: [PATCH 14/52] fix formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/VariationalInference.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index c4d7fe4b04..5810239378 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -164,7 +164,7 @@ function vi( optimizer, averager, operator, - kwargs... + kwargs..., ) end From 736bd3e4bef5e2b08d09727293896fedff690b8e Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:32:14 -0400 Subject: [PATCH 15/52] remove unused dependency --- src/variational/VariationalInference.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index f0b98bb314..022622e6f8 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -7,7 +7,6 @@ using Distributions using LinearAlgebra using LogDensityProblems using Random -using UnicodePlots import ..Turing: DEFAULT_ADTYPE, PROGRESS From 297c32a97bfbd118a5148d31769a365a9289046f Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Thu, 20 Mar 2025 21:24:52 +0000 Subject: [PATCH 16/52] Update Project.toml --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 36b7ebdec9..f1d829ae1f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -43,7 +43,7 @@ AbstractMCMC = "5" AbstractPPL = "0.9, 0.10" AdvancedMH = "0.6, 0.7, 0.8" AdvancedPS = "=0.6.0" -AdvancedVI = "0.2" +AdvancedVI = "0.3" Aqua = "0.8" BangBang = "0.4" Bijectors = "0.14, 0.15" From 0c0443402853163e7ae1512c88111af09029fb61 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 25 Mar 2025 16:08:59 -0400 Subject: [PATCH 17/52] fix make some arugments of vi initializer to be optional kwargs --- src/variational/VariationalInference.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 5847f74afd..e449237212 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -60,9 +60,9 @@ end function meanfield_gaussian( rng::Random.AbstractRNG, - model::DynamicPPL.Model, + model::DynamicPPL.Model; location::Union{Nothing,<:AbstractVector}=nothing, - scale::Union{Nothing,<:Diagonal}=nothing; + scale::Union{Nothing,<:Diagonal}=nothing, kwargs..., ) varinfo = DynamicPPL.VarInfo(model) @@ -91,19 +91,19 @@ function meanfield_gaussian( end function meanfield_gaussian( - model::DynamicPPL.Model, + model::DynamicPPL.Model; location::Union{Nothing,<:AbstractVector}=nothing, - scale::Union{Nothing,<:Diagonal}=nothing; + scale::Union{Nothing,<:Diagonal}=nothing, kwargs..., ) - return meanfield_gaussian(Random.default_rng(), model, location, scale; kwargs...) + return meanfield_gaussian(Random.default_rng(), model; location, scale, kwargs...) end function fullrank_gaussian( rng::Random.AbstractRNG, - model::DynamicPPL.Model, + model::DynamicPPL.Model; location::Union{Nothing,<:AbstractVector}=nothing, - scale::Union{Nothing,<:LowerTriangular}=nothing; + scale::Union{Nothing,<:LowerTriangular}=nothing, kwargs..., ) varinfo = DynamicPPL.VarInfo(model) @@ -133,12 +133,12 @@ function fullrank_gaussian( end function fullrank_gaussian( - model::DynamicPPL.Model, + model::DynamicPPL.Model; location::Union{Nothing,<:AbstractVector}=nothing, - scale::Union{Nothing,<:Diagonal}=nothing; + scale::Union{Nothing,<:LowerTriangular}=nothing, kwargs..., ) - return fullrank_gaussian(Random.default_rng(), model, location, scale; kwargs...) + return fullrank_gaussian(Random.default_rng(), model; location, scale, kwargs...) end function vi( From 626c5b5f0ae2927b1dc29f64af50b8e967e8cf9c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 25 Mar 2025 16:26:41 -0400 Subject: [PATCH 18/52] remove tests for custom optimizers --- test/runtests.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 47b714188e..75ad71d90b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -71,10 +71,6 @@ end end end - @testset "variational optimisers" begin - @timeit_include("variational/optimisers.jl") - end - @testset "stdlib" verbose = true begin @timeit_include("stdlib/distributions.jl") @timeit_include("stdlib/RandomMeasures.jl") From cb2c6181ada758fbd3bc364ab21f6262d3765212 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 25 Mar 2025 16:27:36 -0400 Subject: [PATCH 19/52] remove unused file --- test/variational/optimisers.jl | 29 ----------------------------- 1 file changed, 29 deletions(-) delete mode 100644 test/variational/optimisers.jl diff --git a/test/variational/optimisers.jl b/test/variational/optimisers.jl deleted file mode 100644 index 6f64d5fb1f..0000000000 --- a/test/variational/optimisers.jl +++ /dev/null @@ -1,29 +0,0 @@ -module VariationalOptimisersTests - -using AdvancedVI: DecayedADAGrad, TruncatedADAGrad, apply! -import ForwardDiff -import ReverseDiff -using Test: @test, @testset -using Turing - -function test_opt(ADPack, opt) - θ = randn(10, 10) - θ_fit = randn(10, 10) - loss(x, θ_) = mean(sum(abs2, θ * x - θ_ * x; dims=1)) - for t in 1:(10^4) - x = rand(10) - Δ = ADPack.gradient(θ_ -> loss(x, θ_), θ_fit) - Δ = apply!(opt, θ_fit, Δ) - @. θ_fit = θ_fit - Δ - end - @test loss(rand(10, 100), θ_fit) < 0.01 - @test length(opt.acc) == 1 -end -for opt in [TruncatedADAGrad(), DecayedADAGrad(1e-2)] - test_opt(ForwardDiff, opt) -end -for opt in [TruncatedADAGrad(), DecayedADAGrad(1e-2)] - test_opt(ReverseDiff, opt) -end - -end From c1533a863db45c251abc281a8fe53dc922973839 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Fri, 18 Apr 2025 21:10:02 +0100 Subject: [PATCH 20/52] Update src/variational/bijectors.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/bijectors.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/variational/bijectors.jl b/src/variational/bijectors.jl index e0633493f6..86078efaa4 100644 --- a/src/variational/bijectors.jl +++ b/src/variational/bijectors.jl @@ -22,7 +22,9 @@ Returns a `Stacked <: Bijector` which maps from the support of the posterior to denoting the dimensionality of the latent variables. """ function Bijectors.bijector( - model::DynamicPPL.Model, ::Val{sym2ranges}=Val(false); varinfo=DynamicPPL.VarInfo(model) + model::DynamicPPL.Model, + (::Val{sym2ranges})=Val(false); + varinfo=DynamicPPL.VarInfo(model), ) where {sym2ranges} num_params = sum([ size(varinfo.metadata[sym].vals, 1) for sym in keys(varinfo.metadata) From 231d6e2d72aef59cf51da7810a89ddc2e3588b37 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Mon, 21 Apr 2025 11:57:44 +0100 Subject: [PATCH 21/52] Update Turing.jl --- src/Turing.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/Turing.jl b/src/Turing.jl index aa5fbe8500..4bd3058906 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -39,8 +39,6 @@ function setprogress!(progress::Bool) @info "[Turing]: progress logging is $(progress ? "enabled" : "disabled") globally" PROGRESS[] = progress AbstractMCMC.setprogress!(progress; silent=true) - # TODO: `AdvancedVI.turnprogress` is removed in AdvancedVI v0.3 - AdvancedVI.turnprogress(progress) return progress end From 69639ec4e3f12333cfecffc28ceea3369174df1b Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 29 Apr 2025 11:55:22 -0400 Subject: [PATCH 22/52] fix remove call to `AdvancedVI.turnprogress`, which has been removed --- src/Turing.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/Turing.jl b/src/Turing.jl index aa5fbe8500..4bd3058906 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -39,8 +39,6 @@ function setprogress!(progress::Bool) @info "[Turing]: progress logging is $(progress ? "enabled" : "disabled") globally" PROGRESS[] = progress AbstractMCMC.setprogress!(progress; silent=true) - # TODO: `AdvancedVI.turnprogress` is removed in AdvancedVI v0.3 - AdvancedVI.turnprogress(progress) return progress end From ef9aeb1cc59396092a68e8e8082a8b5c60203f8f Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 29 Apr 2025 12:14:30 -0400 Subject: [PATCH 23/52] apply comments from @yebai --- src/variational/VariationalInference.jl | 67 +++++++++---------------- 1 file changed, 23 insertions(+), 44 deletions(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index e449237212..010b34a456 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -15,9 +15,9 @@ import Bijectors # Reexports using AdvancedVI: RepGradELBO, ScoreGradELBO, DoG, DoWG -export vi, RepGradELBO, ScoreGradELBO, DoG, DoWG +export RepGradELBO, ScoreGradELBO, DoG, DoWG -export meanfield_gaussian, fullrank_gaussian +export vi, q_init, q_meanfield_gaussian, q_fullrank_gaussian include("bijectors.jl") @@ -58,11 +58,13 @@ function initialize_gaussian_scale( end end -function meanfield_gaussian( +function q_init( rng::Random.AbstractRNG, model::DynamicPPL.Model; location::Union{Nothing,<:AbstractVector}=nothing, - scale::Union{Nothing,<:Diagonal}=nothing, + scale::Union{Nothing,<:Diagonal,<:LowerTriangular}=nothing, + meanfield::Bool=true, + basedist::Distributions.UnivariateDistribution=Normal(), kwargs..., ) varinfo = DynamicPPL.VarInfo(model) @@ -79,66 +81,43 @@ function meanfield_gaussian( end L = if isnothing(scale) - initialize_gaussian_scale(rng, model, μ, Diagonal(ones(num_params)); kwargs...) + if meanfield + initialize_gaussian_scale(rng, model, μ, Diagonal(ones(num_params)); kwargs...) + else + L0 = LowerTriangular(Matrix{Float64}(I, num_params, num_params)) + initialize_gaussian_scale(rng, model, μ, L0; kwargs...) + end else @assert size(scale) == (num_params, num_params) "Dimensions of the provided scale matrix, $(size(scale)), does not match the dimension of the target distribution, $(num_params)." - L = scale + if meanfield + Diagonal(diag(scale)) + else + scale + end end - - q = AdvancedVI.MeanFieldGaussian(μ, L) + q = AdvancedVI.MvLocationScale(μ, L, basedist) b = Bijectors.bijector(model; varinfo=varinfo) return Bijectors.transformed(q, Bijectors.inverse(b)) end -function meanfield_gaussian( +function q_meanfield_gaussian( + rng::Random.AbstractRNG, model::DynamicPPL.Model; location::Union{Nothing,<:AbstractVector}=nothing, scale::Union{Nothing,<:Diagonal}=nothing, kwargs..., ) - return meanfield_gaussian(Random.default_rng(), model; location, scale, kwargs...) + return q_init(rng, model; location, scale, meanfield=true, basedist=Normal(), kwargs...) end -function fullrank_gaussian( +function q_fullrank_gaussian( rng::Random.AbstractRNG, model::DynamicPPL.Model; location::Union{Nothing,<:AbstractVector}=nothing, scale::Union{Nothing,<:LowerTriangular}=nothing, kwargs..., ) - varinfo = DynamicPPL.VarInfo(model) - # Use linked `varinfo` to determine the correct number of parameters. - # TODO: Replace with `length` once this is implemented for `VarInfo`. - varinfo_linked = DynamicPPL.link(varinfo, model) - num_params = length(varinfo_linked[:]) - - μ = if isnothing(location) - zeros(num_params) - else - @assert length(location) == num_params "Length of the provided location vector, $(length(location)), does not match dimension of the target distribution, $(num_params)." - location - end - - L = if isnothing(scale) - L0 = LowerTriangular(Matrix{Float64}(I, num_params, num_params)) - initialize_gaussian_scale(rng, model, μ, L0; kwargs...) - else - @assert size(scale) == (num_params, num_params) "Dimensions of the provided scale matrix, $(size(scale)), does not match the dimension of the target distribution, $(num_params)." - scale - end - - q = AdvancedVI.FullRankGaussian(μ, L) - b = Bijectors.bijector(model; varinfo=varinfo) - return Bijectors.transformed(q, Bijectors.inverse(b)) -end - -function fullrank_gaussian( - model::DynamicPPL.Model; - location::Union{Nothing,<:AbstractVector}=nothing, - scale::Union{Nothing,<:LowerTriangular}=nothing, - kwargs..., -) - return fullrank_gaussian(Random.default_rng(), model; location, scale, kwargs...) + return q_init(rng, model; location, scale, meanfield=false, basedist=Normal(), kwargs...) end function vi( From cc18528d39e66d1e2d29ae05c4938feb705ef511 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Thu, 8 May 2025 14:08:34 +0100 Subject: [PATCH 24/52] Update src/variational/VariationalInference.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/VariationalInference.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 010b34a456..5df90177d5 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -117,7 +117,9 @@ function q_fullrank_gaussian( scale::Union{Nothing,<:LowerTriangular}=nothing, kwargs..., ) - return q_init(rng, model; location, scale, meanfield=false, basedist=Normal(), kwargs...) + return q_init( + rng, model; location, scale, meanfield=false, basedist=Normal(), kwargs... + ) end function vi( From 0b79495f5bacfcff6c259a60cdb14b16b8ada879 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 14 May 2025 16:55:08 -0400 Subject: [PATCH 25/52] add old interface as deprecated --- src/variational/VariationalInference.jl | 1 + src/variational/deprecated.jl | 59 +++++++++++++++++++++++++ 2 files changed, 60 insertions(+) create mode 100644 src/variational/deprecated.jl diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 5df90177d5..d884a9c374 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -20,6 +20,7 @@ export RepGradELBO, ScoreGradELBO, DoG, DoWG export vi, q_init, q_meanfield_gaussian, q_fullrank_gaussian include("bijectors.jl") +include("deprecated.jl") function make_logdensity(model::DynamicPPL.Model) weight = 1.0 diff --git a/src/variational/deprecated.jl b/src/variational/deprecated.jl new file mode 100644 index 0000000000..21a93f0016 --- /dev/null +++ b/src/variational/deprecated.jl @@ -0,0 +1,59 @@ + +import DistributionsAD +export ADVI + +struct ADVI{AD} + "Number of samples used to estimate the ELBO in each optimization step." + samples_per_step::Int + "Maximum number of gradient steps." + max_iters::Int + "AD backend used for automatic differentiation." + adtype::AD +end + +function ADVI( + samples_per_step::Int=1, + max_iters::Int=1000; + adtype::ADTypes.AbstractADType=ADTypes.AutoForwardDiff(), +) + Base.depwarn( + "The type ADVI will be removed in future releases. Please refer to the new interface for `vi`", + :ADVI; + force=true, + ) + return ADVI{typeof(adtype)}(samples_per_step, max_iters, adtype) +end + +function vi(model::DynamicPPL.Model, alg::ADVI; kwargs...) + Base.depwarn( + "This specialization along with the type `ADVI` will be deprecated in future releases. Please refer to the new interface for `vi`.", + :vi; + force=true, + ) + q = q_meanfield_gaussian(Random.default_rng(), model) + objective = AdvancedVI.RepGradELBO( + alg.samples_per_step; entropy=AdvancedVI.ClosedFormEntropy() + ) + operator = AdvancedVI.IdentityOperator() + _, q_avg, _, _ = vi(model, q, alg.max_iters; objective, operator, kwargs...) + return q_avg +end + +function vi( + model::DynamicPPL.Model, + alg::ADVI, + q::Bijectors.TransformedDistribution{<:DistributionsAD.TuringDiagMvNormal}; + kwargs..., +) + Base.depwarn( + "This specialization along with the type `ADVI` will be deprecated in future releases. Please refer to the new interface for `vi`.", + :vi; + force=true, + ) + objective = AdvancedVI.RepGradELBO( + alg.samples_per_step; entropy=AdvancedVI.ClosedFormEntropy() + ) + operator = AdvancedVI.IdentityOperator() + _, q_avg, _, _ = vi(model, q, alg.max_iters; objective, operator, kwargs...) + return q_avg +end From 3818152bb9f999c14f689ba9a075ca2435040228 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 14 May 2025 16:55:26 -0400 Subject: [PATCH 26/52] bump AdvancedVI version --- Project.toml | 2 +- test/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 671a2914c4..bc2826cdd9 100644 --- a/Project.toml +++ b/Project.toml @@ -53,7 +53,7 @@ Accessors = "0.1" AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6, 0.7" AdvancedMH = "0.8" AdvancedPS = "0.6.0" -AdvancedVI = "0.3.1" +AdvancedVI = "0.4" BangBang = "0.4.2" Bijectors = "0.14, 0.15" Compat = "4.15.0" diff --git a/test/Project.toml b/test/Project.toml index f1d829ae1f..a8f47d80fd 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -43,7 +43,7 @@ AbstractMCMC = "5" AbstractPPL = "0.9, 0.10" AdvancedMH = "0.6, 0.7, 0.8" AdvancedPS = "=0.6.0" -AdvancedVI = "0.3" +AdvancedVI = "0.4" Aqua = "0.8" BangBang = "0.4" Bijectors = "0.14, 0.15" From 91a9afe7b9ed312769712b277651f31583c93445 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 14 May 2025 17:21:16 -0400 Subject: [PATCH 27/52] add deprecation for `meanfield` --- src/variational/deprecated.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/variational/deprecated.jl b/src/variational/deprecated.jl index 21a93f0016..de1919cb1d 100644 --- a/src/variational/deprecated.jl +++ b/src/variational/deprecated.jl @@ -2,6 +2,8 @@ import DistributionsAD export ADVI +Base.@deprecate meanfield(model) q_meanfield_gaussian(model) + struct ADVI{AD} "Number of samples used to estimate the ELBO in each optimization step." samples_per_step::Int From 12539aa331b4c2b299281e4fb363e33b8dee75df Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 14 May 2025 17:21:25 -0400 Subject: [PATCH 28/52] add `default_rng` interfaces --- src/variational/VariationalInference.jl | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index d884a9c374..f5a4987e75 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -111,6 +111,10 @@ function q_meanfield_gaussian( return q_init(rng, model; location, scale, meanfield=true, basedist=Normal(), kwargs...) end +function q_meanfield_gaussian(model::DynamicPPL.Model; kwargs...) + return q_meanfield_gaussian(Random.default_rng(), model; kwargs...) +end + function q_fullrank_gaussian( rng::Random.AbstractRNG, model::DynamicPPL.Model; @@ -123,9 +127,14 @@ function q_fullrank_gaussian( ) end +function q_fullrank_gaussian(model::DynamicPPL.Model; kwargs...) + return q_fullrank_gaussian(Random.default_rng(), model; kwargs...) +end + function vi( + rng::Random.AbstractRNG, model::DynamicPPL.Model, - q::Bijectors.TransformedDistribution, + q, n_iterations::Int; objective=RepGradELBO(10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()), show_progress::Bool=PROGRESS[], @@ -136,6 +145,7 @@ function vi( kwargs..., ) return AdvancedVI.optimize( + rng, make_logdensity(model), objective, q, @@ -149,4 +159,8 @@ function vi( ) end +function vi(model::DynamicPPL.Model, q, n_iterations::Int; kwargs...) + return vi(Random.default_rng(), model, q, n_iterations; kwargs...) +end + end From 0653bf1765735392b0c6f60e211a8c4ea20e724e Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 14 May 2025 17:21:43 -0400 Subject: [PATCH 29/52] add tests for variational inference --- test/variational/advi.jl | 126 +++++++++++++++++++++++++-------------- 1 file changed, 82 insertions(+), 44 deletions(-) diff --git a/test/variational/advi.jl b/test/variational/advi.jl index c2abacb675..6c0d54993f 100644 --- a/test/variational/advi.jl +++ b/test/variational/advi.jl @@ -1,9 +1,9 @@ + module AdvancedVITests using ..Models: gdemo_default using ..NumericalTests: check_gdemo -import AdvancedVI -using AdvancedVI: TruncatedADAGrad, DecayedADAGrad +using AdvancedVI using Bijectors: Bijectors using Distributions: Dirichlet, Normal using LinearAlgebra: I @@ -11,56 +11,93 @@ using MCMCChains: Chains import Random using Test: @test, @testset using Turing -using Turing.Essential: TuringDiagMvNormal +using Turing.Variational -@testset "advi.jl" begin - @testset "advi constructor" begin +@testset "ADVI" begin + @testset "default interface" begin Random.seed!(0) N = 500 - s1 = ADVI() - q = vi(gdemo_default, s1) - c1 = rand(q, N) - end - @testset "advi inference" begin - @testset for opt in [TruncatedADAGrad(), DecayedADAGrad()] - Random.seed!(1) - N = 500 - - alg = ADVI(10, 5000) - q = vi(gdemo_default, alg; optimizer=opt) - samples = transpose(rand(q, N)) - chn = Chains(reshape(samples, size(samples)..., 1), ["s", "m"]) - - # TODO: uhmm, seems like a large `eps` here... - check_gdemo(chn; atol=0.5) + for q0 in [q_meanfield_gaussian(gdemo_default), q_fullrank_gaussian(gdemo_default)] + _, q, _, _ = vi(gdemo_default, q0, N; show_progress=Turing.PROGRESS[]) + c1 = rand(q, N) end end - @testset "advi different interfaces" begin - Random.seed!(1234) - - target = MvNormal(zeros(2), I) - logπ(z) = logpdf(target, z) - advi = ADVI(10, 1000) - - # Using a function z ↦ q(⋅∣z) - getq(θ) = TuringDiagMvNormal(θ[1:2], exp.(θ[3:4])) - q = vi(logπ, advi, getq, randn(4)) + @testset "custom interface $name" for (name, objective, operator, optimizer) in [ + ( + "ADVI with closed-form entropy", + RepGradELBO(10), + AdvancedVI.ProximalLocationScaleEntropy(), + DoG(), + ), + ( + "ADVI with proximal entropy", + RepGradELBO(10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()), + AdvancedVI.ClipScale(), + DoG(), + ), + ( + "ADVI with STL entropy", + RepGradELBO(10; entropy=AdvancedVI.StickingTheLandingEntropy()), + AdvancedVI.ClipScale(), + DoG(), + ), + ] + Random.seed!(0) + T = 1000 + q, q_avg, _, _ = vi( + gdemo_default, + q_meanfield_gaussian(gdemo_default), + T; + objective, + optimizer, + operator, + show_progress=Turing.PROGRESS[], + ) + + N = 1000 + c1 = rand(q_avg, N) + c2 = rand(q, N) + end - xs = rand(target, 10) - @test mean(abs2, logpdf(q, xs) - logpdf(target, xs)) ≤ 0.07 + @testset "inference $name" for (name, objective, operator, optimizer) in [ + ( + "ADVI with closed-form entropy", + RepGradELBO(10), + AdvancedVI.ProximalLocationScaleEntropy(), + DoG(), + ), + ( + "ADVI with proximal entropy", + RepGradELBO(10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()), + AdvancedVI.ClipScale(), + DoG(), + ), + ( + "ADVI with STL entropy", + RepGradELBO(10; entropy=AdvancedVI.StickingTheLandingEntropy()), + AdvancedVI.ClipScale(), + DoG(), + ), + ] + Random.seed!(0) + T = 1000 + q, q_avg, _, _ = vi( + gdemo_default, + q_meanfield_gaussian(gdemo_default), + T; + optimizer, + show_progress=Turing.PROGRESS[], + ) + + N = 1000 + for q_out in [q_avg, q] + samples = transpose(rand(q_out, N)) + chn = Chains(reshape(samples, size(samples)..., 1), ["s", "m"]) - # OR: implement `update` and pass a `Distribution` - function AdvancedVI.update(d::TuringDiagMvNormal, θ::AbstractArray{<:Real}) - return TuringDiagMvNormal(θ[1:length(q)], exp.(θ[(length(q) + 1):end])) + check_gdemo(chn; atol=0.5) end - - q0 = TuringDiagMvNormal(zeros(2), ones(2)) - q = vi(logπ, advi, q0, randn(4)) - - xs = rand(target, 10) - @test mean(abs2, logpdf(q, xs) - logpdf(target, xs)) ≤ 0.05 end # regression test for: @@ -70,6 +107,7 @@ using Turing.Essential: TuringDiagMvNormal x ~ Dirichlet([1.0, 1.0]) return x end + Random.seed!(0) m = dirichlet() b = Bijectors.bijector(m) @@ -81,7 +119,7 @@ using Turing.Essential: TuringDiagMvNormal @test all(x0 .≈ x0_inv) # And regression for https://github.com/TuringLang/Turing.jl/issues/2160. - q = vi(m, ADVI(10, 1000)) + _, q, _, _ = vi(m, q_meanfield_gaussian(m), 1000) x = rand(q, 1000) @test mean(eachcol(x)) ≈ [0.5, 0.5] atol = 0.1 end @@ -94,7 +132,7 @@ using Turing.Essential: TuringDiagMvNormal end model = demo_issue2205() | (y=1.0,) - q = vi(model, ADVI(10, 1000)) + _, q, _, _ = vi(model, q_meanfield_gaussian(model), 1000) # True mean. mean_true = 1 / 2 var_true = 1 / 2 From f74ec3824bda90fa45f6b683f635852d495a9eb4 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 14 May 2025 17:28:31 -0400 Subject: [PATCH 30/52] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/deprecated.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/variational/deprecated.jl b/src/variational/deprecated.jl index de1919cb1d..9a9f4777b5 100644 --- a/src/variational/deprecated.jl +++ b/src/variational/deprecated.jl @@ -2,7 +2,7 @@ import DistributionsAD export ADVI -Base.@deprecate meanfield(model) q_meanfield_gaussian(model) +Base.@deprecate meanfield(model) q_meanfield_gaussian(model) struct ADVI{AD} "Number of samples used to estimate the ELBO in each optimization step." From f62e7b803669268724f712d4748749b3cd70c988 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 18 May 2025 16:34:09 -0400 Subject: [PATCH 31/52] remove "src/variational/bijectors.jl" (moved to `DynamicPPL.jl`) --- Project.toml | 2 +- src/variational/VariationalInference.jl | 1 - src/variational/bijectors.jl | 72 ------------------------- test/Project.toml | 2 +- 4 files changed, 2 insertions(+), 75 deletions(-) delete mode 100644 src/variational/bijectors.jl diff --git a/Project.toml b/Project.toml index 320ba51054..265683f272 100644 --- a/Project.toml +++ b/Project.toml @@ -62,7 +62,7 @@ Distributions = "0.25.77" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.36" +DynamicPPL = "0.36.3" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3" Libtask = "0.8.8" diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index f5a4987e75..48441522bc 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -19,7 +19,6 @@ export RepGradELBO, ScoreGradELBO, DoG, DoWG export vi, q_init, q_meanfield_gaussian, q_fullrank_gaussian -include("bijectors.jl") include("deprecated.jl") function make_logdensity(model::DynamicPPL.Model) diff --git a/src/variational/bijectors.jl b/src/variational/bijectors.jl deleted file mode 100644 index 86078efaa4..0000000000 --- a/src/variational/bijectors.jl +++ /dev/null @@ -1,72 +0,0 @@ - -# TODO: Move to Bijectors.jl if we find further use for this. -""" - wrap_in_vec_reshape(f, in_size) - -Wraps a bijector `f` such that it operates on vectors of length `prod(in_size)` and produces -a vector of length `prod(Bijectors.output(f, in_size))`. -""" -function wrap_in_vec_reshape(f, in_size) - vec_in_length = prod(in_size) - reshape_inner = Bijectors.Reshape((vec_in_length,), in_size) - out_size = Bijectors.output_size(f, in_size) - vec_out_length = prod(out_size) - reshape_outer = Bijectors.Reshape(out_size, (vec_out_length,)) - return reshape_outer ∘ f ∘ reshape_inner -end - -""" - bijector(model::Model[, sym2ranges = Val(false)]) - -Returns a `Stacked <: Bijector` which maps from the support of the posterior to ℝᵈ with `d` -denoting the dimensionality of the latent variables. -""" -function Bijectors.bijector( - model::DynamicPPL.Model, - (::Val{sym2ranges})=Val(false); - varinfo=DynamicPPL.VarInfo(model), -) where {sym2ranges} - num_params = sum([ - size(varinfo.metadata[sym].vals, 1) for sym in keys(varinfo.metadata) - ]) - - dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...) - - num_ranges = sum([ - length(varinfo.metadata[sym].ranges) for sym in keys(varinfo.metadata) - ]) - ranges = Vector{UnitRange{Int}}(undef, num_ranges) - idx = 0 - range_idx = 1 - - # ranges might be discontinuous => values are vectors of ranges rather than just ranges - sym_lookup = Dict{Symbol,Vector{UnitRange{Int}}}() - for sym in keys(varinfo.metadata) - sym_lookup[sym] = Vector{UnitRange{Int}}() - for r in varinfo.metadata[sym].ranges - ranges[range_idx] = idx .+ r - push!(sym_lookup[sym], ranges[range_idx]) - range_idx += 1 - end - - idx += varinfo.metadata[sym].ranges[end][end] - end - - bs = map(tuple(dists...)) do d - b = Bijectors.bijector(d) - if d isa Distributions.UnivariateDistribution - b - else - wrap_in_vec_reshape(b, size(d)) - end - end - - if sym2ranges - return ( - Bijectors.Stacked(bs, ranges), - (; collect(zip(keys(sym_lookup), values(sym_lookup)))...), - ) - else - return Bijectors.Stacked(bs, ranges) - end -end diff --git a/test/Project.toml b/test/Project.toml index e5226bc364..c6a2ce19ec 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -52,7 +52,7 @@ Combinatorics = "1" Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -DynamicPPL = "0.36" +DynamicPPL = "0.36.3" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" HypothesisTests = "0.11" From f0374b6c815a00bc8f9263df38b5e59b685bb00b Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 23 May 2025 17:34:31 -0400 Subject: [PATCH 32/52] add more tests for variational inference initializer --- test/variational/advi.jl | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/test/variational/advi.jl b/test/variational/advi.jl index 6c0d54993f..d6832036e6 100644 --- a/test/variational/advi.jl +++ b/test/variational/advi.jl @@ -14,6 +14,29 @@ using Turing using Turing.Variational @testset "ADVI" begin + @testset "q initialization" begin + m = gdemo_default + d = length(Turing.DynamicPPL.VarInfo(m)[:]) + for q0 in [q_meanfield_gaussian(m), q_fullrank_gaussian(m)] + rand(q) + end + + μ = ones(d) + q = q_meanfield_gaussian(m; location=μ) + @assert mean(q) == μ + + q = q_fullrank_gaussian(m; location=μ) + @assert mean(q) == μ + + L = Diagonal(fill(0.1, d)) + q = q_meanfield_gaussian(m; scale=L) + @assert cov(q) ≈ L*L + + L = LowerTriangular(tril(0.001*I + I)) + q = q_fullrank_gaussian(m; location=μ) + @assert cov(q) ≈ L*L' + end + @testset "default interface" begin Random.seed!(0) N = 500 From 187a65c1912889c509f296f9152f47367c7a8f5e Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 23 May 2025 18:01:04 -0400 Subject: [PATCH 33/52] remove non-essential reexports, fix tests --- src/variational/VariationalInference.jl | 6 +--- test/variational/advi.jl | 42 ++++++++++++++----------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 48441522bc..bb05725d70 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -13,10 +13,6 @@ import ..Turing: DEFAULT_ADTYPE, PROGRESS import AdvancedVI import Bijectors -# Reexports -using AdvancedVI: RepGradELBO, ScoreGradELBO, DoG, DoWG -export RepGradELBO, ScoreGradELBO, DoG, DoWG - export vi, q_init, q_meanfield_gaussian, q_fullrank_gaussian include("deprecated.jl") @@ -135,7 +131,7 @@ function vi( model::DynamicPPL.Model, q, n_iterations::Int; - objective=RepGradELBO(10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()), + objective=AdvancedVI.RepGradELBO(10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()), show_progress::Bool=PROGRESS[], optimizer=AdvancedVI.DoWG(), averager=AdvancedVI.PolynomialAveraging(), diff --git a/test/variational/advi.jl b/test/variational/advi.jl index d6832036e6..d93e5b16b3 100644 --- a/test/variational/advi.jl +++ b/test/variational/advi.jl @@ -6,7 +6,7 @@ using ..NumericalTests: check_gdemo using AdvancedVI using Bijectors: Bijectors using Distributions: Dirichlet, Normal -using LinearAlgebra: I +using LinearAlgebra using MCMCChains: Chains import Random using Test: @test, @testset @@ -17,24 +17,28 @@ using Turing.Variational @testset "q initialization" begin m = gdemo_default d = length(Turing.DynamicPPL.VarInfo(m)[:]) - for q0 in [q_meanfield_gaussian(m), q_fullrank_gaussian(m)] + for q in [q_meanfield_gaussian(m), q_fullrank_gaussian(m)] rand(q) end μ = ones(d) q = q_meanfield_gaussian(m; location=μ) - @assert mean(q) == μ + println(q.dist.location) + @assert mean(q.dist) ≈ μ q = q_fullrank_gaussian(m; location=μ) - @assert mean(q) == μ + println(q.dist.location) + @assert mean(q.dist) ≈ μ L = Diagonal(fill(0.1, d)) q = q_meanfield_gaussian(m; scale=L) - @assert cov(q) ≈ L*L + @assert cov(q.dist) ≈ L*L - L = LowerTriangular(tril(0.001*I + I)) - q = q_fullrank_gaussian(m; location=μ) - @assert cov(q) ≈ L*L' + L = LowerTriangular(tril(0.01*ones(d,d) + I)) + q = q_fullrank_gaussian(m; scale=L) + println(cov(q.dist)) + println(L*L') + @assert cov(q.dist) ≈ L*L' end @testset "default interface" begin @@ -50,21 +54,21 @@ using Turing.Variational @testset "custom interface $name" for (name, objective, operator, optimizer) in [ ( "ADVI with closed-form entropy", - RepGradELBO(10), + AdvancedVI.RepGradELBO(10), AdvancedVI.ProximalLocationScaleEntropy(), - DoG(), + AdvancedVI.DoG(), ), ( "ADVI with proximal entropy", - RepGradELBO(10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()), + AdvancedVI.RepGradELBO(10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()), AdvancedVI.ClipScale(), - DoG(), + AdvancedVI.DoG(), ), ( "ADVI with STL entropy", - RepGradELBO(10; entropy=AdvancedVI.StickingTheLandingEntropy()), + AdvancedVI.RepGradELBO(10; entropy=AdvancedVI.StickingTheLandingEntropy()), AdvancedVI.ClipScale(), - DoG(), + AdvancedVI.DoG(), ), ] Random.seed!(0) @@ -87,21 +91,21 @@ using Turing.Variational @testset "inference $name" for (name, objective, operator, optimizer) in [ ( "ADVI with closed-form entropy", - RepGradELBO(10), + AdvancedVI.RepGradELBO(10), AdvancedVI.ProximalLocationScaleEntropy(), - DoG(), + AdvancedVI.DoG(), ), ( "ADVI with proximal entropy", RepGradELBO(10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()), AdvancedVI.ClipScale(), - DoG(), + AdvancedVI.DoG(), ), ( "ADVI with STL entropy", - RepGradELBO(10; entropy=AdvancedVI.StickingTheLandingEntropy()), + AdvancedVI.RepGradELBO(10; entropy=AdvancedVI.StickingTheLandingEntropy()), AdvancedVI.ClipScale(), - DoG(), + AdvancedVI.DoG(), ), ] Random.seed!(0) From a5021d1386f433e5a5af1458f59f29354378250d Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 23 May 2025 18:11:24 -0400 Subject: [PATCH 34/52] run formatter, rename functions --- src/variational/VariationalInference.jl | 19 ++++++++++--------- test/variational/advi.jl | 2 +- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index bb05725d70..32a5b40d75 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -13,7 +13,7 @@ import ..Turing: DEFAULT_ADTYPE, PROGRESS import AdvancedVI import Bijectors -export vi, q_init, q_meanfield_gaussian, q_fullrank_gaussian +export vi, q_locationscale, q_meanfield_gaussian, q_fullrank_gaussian include("deprecated.jl") @@ -23,11 +23,12 @@ function make_logdensity(model::DynamicPPL.Model) return DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx) end -function initialize_gaussian_scale( +function q_initialize_scale( rng::Random.AbstractRNG, model::DynamicPPL.Model, location::AbstractVector, - scale::AbstractMatrix; + scale::AbstractMatrix, + basedist::Distributions.UnivariateDistribution; num_samples::Int=10, num_max_trials::Int=10, reduce_factor=one(eltype(scale)) / 2, @@ -38,7 +39,7 @@ function initialize_gaussian_scale( n_trial = 0 while true - q = AdvancedVI.MvLocationScale(location, scale, Normal()) + q = AdvancedVI.MvLocationScale(location, scale, basedist) b = Bijectors.bijector(model; varinfo=varinfo) q_trans = Bijectors.transformed(q, Bijectors.inverse(b)) energy = mean(ℓπ, eachcol(rand(rng, q_trans, num_samples))) @@ -54,7 +55,7 @@ function initialize_gaussian_scale( end end -function q_init( +function q_locationscale( rng::Random.AbstractRNG, model::DynamicPPL.Model; location::Union{Nothing,<:AbstractVector}=nothing, @@ -78,10 +79,10 @@ function q_init( L = if isnothing(scale) if meanfield - initialize_gaussian_scale(rng, model, μ, Diagonal(ones(num_params)); kwargs...) + q_initialize_scale(rng, model, μ, Diagonal(ones(num_params)), basedist; kwargs...) else L0 = LowerTriangular(Matrix{Float64}(I, num_params, num_params)) - initialize_gaussian_scale(rng, model, μ, L0; kwargs...) + q_initialize_scale(rng, model, μ, L0, basedist; kwargs...) end else @assert size(scale) == (num_params, num_params) "Dimensions of the provided scale matrix, $(size(scale)), does not match the dimension of the target distribution, $(num_params)." @@ -103,7 +104,7 @@ function q_meanfield_gaussian( scale::Union{Nothing,<:Diagonal}=nothing, kwargs..., ) - return q_init(rng, model; location, scale, meanfield=true, basedist=Normal(), kwargs...) + return q_locationscale_init(rng, model; location, scale, meanfield=true, basedist=Normal(), kwargs...) end function q_meanfield_gaussian(model::DynamicPPL.Model; kwargs...) @@ -117,7 +118,7 @@ function q_fullrank_gaussian( scale::Union{Nothing,<:LowerTriangular}=nothing, kwargs..., ) - return q_init( + return q_locationscale_init( rng, model; location, scale, meanfield=false, basedist=Normal(), kwargs... ) end diff --git a/test/variational/advi.jl b/test/variational/advi.jl index d93e5b16b3..20e7ed6841 100644 --- a/test/variational/advi.jl +++ b/test/variational/advi.jl @@ -34,7 +34,7 @@ using Turing.Variational q = q_meanfield_gaussian(m; scale=L) @assert cov(q.dist) ≈ L*L - L = LowerTriangular(tril(0.01*ones(d,d) + I)) + L = LowerTriangular(tril(0.01*ones(d, d) + I)) q = q_fullrank_gaussian(m; scale=L) println(cov(q.dist)) println(L*L') From 218eb23dc8f124b6734d14c83ce5740930dfcdad Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 23 May 2025 18:58:29 -0400 Subject: [PATCH 35/52] add documentation --- docs/src/api.md | 14 +-- src/variational/VariationalInference.jl | 122 +++++++++++++++++++++++- 2 files changed, 127 insertions(+), 9 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 01f022e7e5..fe6fa4abdf 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -76,12 +76,14 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu ### Variational inference -See the [variational inference tutorial](https://turinglang.org/docs/tutorials/09-variational-inference/) for a walkthrough on how to use these. - -| Exported symbol | Documentation | Description | -|:--------------- |:---------------------------- |:--------------------------------------- | -| `vi` | [`AdvancedVI.vi`](@extref) | Perform variational inference | -| `ADVI` | [`AdvancedVI.ADVI`](@extref) | Construct an instance of a VI algorithm | +See the [docs of AdvancedVI.jl](https://turinglang.org/AdvancedVI.jl/stable/) for detailed usage and the [variational inference tutorial](https://turinglang.org/docs/tutorials/09-variational-inference/) for a basic walkthrough. + +| Exported symbol | Documentation | Description | +|:---------------------- |:------------------------------------------------- |:---------------------------------------------------------------------------------------- | +| `vi` | [`Turing.vi`](@ref) | Perform variational inference | +| `q_locationscale` | [`Turing.Variational.q_locationscale`](@ref) | Find a numerically non-degenerate initialization for a location-scale variational family | +| `q_meanfield_gaussian` | [`Turing.Variational.q_meanfield_gaussian`](@ref) | Find a numerically non-degenerate initialization for a mean-field Gaussian family | +| `q_fullrank_gaussian` | [`Turing.Variational.q_fullrank_gaussian`](@ref) | Find a numerically non-degenerate initialization for a full-rank Gaussian family | ### Automatic differentiation types diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 32a5b40d75..7190425e4d 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -23,6 +23,32 @@ function make_logdensity(model::DynamicPPL.Model) return DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx) end +""" + q_initialize_scale([rng, ]model, location, scale, basedist; num_samples, num_max_trials, reduce_factor) + +Given an initial location-scale distribution `q` formed by `location`, `scale`, and `basedist`, shrink `scale` until the expectation of log-densities of `model` taken over `q` are finite. +If the log-densities are not finite even after `num_max_trials`, throw an error. + +For reference, a location-scale distribution \$q\$ formed by `location`, `scale`, and `basedist` is a distribution where its sampling process \$z \\sim q\$ can be represented as +```julia +u = rand(basedist, d) +z = scale * u + location +``` + +# Arguments +- `model::DynamicPPL.Model`: The target `DynamicPPL.Model`. +- `location::AbstractVector`: The location parameter of the initialization. +- `scale::AbstractMatrix`: The scale parameter of the initialization. +- `basedist::Distributions.UnivariateDistribution`: The base distribution of the location-scale family. + +# Keyword Arguments +- `num_samples::Int`: Number of samples used to compute the average log-density at each trial. (Default: `10`.) +- `num_max_trials::Int`: Number of trials until throwing an error. (Default: `10`.) +- `reduce_factor::Real`: Factor for shrinking the scale. After `n` trials, the scale is then `scale*reduce_factor^n`. (Default: `0.5`.) + +# Returns +- `scale_adj`: The adjusted scale matrix matching the type of `scale`. +""" function q_initialize_scale( rng::Random.AbstractRNG, model::DynamicPPL.Model, @@ -31,7 +57,7 @@ function q_initialize_scale( basedist::Distributions.UnivariateDistribution; num_samples::Int=10, num_max_trials::Int=10, - reduce_factor=one(eltype(scale)) / 2, + reduce_factor::Real=one(eltype(scale)) / 2, ) prob = make_logdensity(model) ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) @@ -55,6 +81,35 @@ function q_initialize_scale( end end +""" + q_locationscale([rng, ]model; location, scale, meanfield, basedist) + +Find a numerically non-degenerate variational distribution `q` for approximating the target `model` within the location-scale variational family formed by the type of `scale` and `basedist`. + +The distribution can be manually specified by setting `location`, `scale`, and `basedist`. +Otherwise, it chooses a standard Gaussian by default. +Whether the default choice is used or not, the `scale` may be adjusted via `q_initialize_scale` so that the log-densities of `model` are finite over the samples from `q`. +If `meanfield` is set as `true`, the scale of `q` is restricted to be a diagonal matrix and only the diagonal of `scale` is used. + +For reference, a location-scale distribution \$q\$ formed by `location`, `scale`, and `basedist` is a distribution where its sampling process \$z \\sim q\$ can be represented as +```julia +u = rand(basedist, d) +z = scale * u + location +``` + +# Arguments +- `model::DynamicPPL.Model`: The target `DynamicPPL.Model`. + +# Keyword Arguments +- `location::Union{Nothing,<:AbstractVector}`: The location parameter of the initialization. If `nothing`, a vector of zeros is used. +- `scale::Union{Nothing,<:Diagonal,<:LowerTriangular}`: The scale parameter of the initialization. If `nothing`, an identity matrix is used. +- `basedist::Distributions.UnivariateDistribution`: The distribution + +The remaining keywords are passed to `q_initialize_scale`. + +# Returns +- `q::Bijectors.TransformedDistribution`: A `AdvancedVI.LocationScale` distribution matching the support of `model`. +""" function q_locationscale( rng::Random.AbstractRNG, model::DynamicPPL.Model; @@ -97,6 +152,23 @@ function q_locationscale( return Bijectors.transformed(q, Bijectors.inverse(b)) end +""" + q_meanfield_gaussian([rng, ]model; location, scale, kwargs...) + +Find a numerically non-degenerate mean-field Gaussian `q` for approximating the target `model`. + +# Arguments +- `model::DynamicPPL.Model`: The target `DynamicPPL.Model`. + +# Keyword Arguments +- `location::Union{Nothing,<:AbstractVector}`: The location parameter of the initialization. If `nothing`, a vector of zeros is used. +- `scale::Union{Nothing,<:Diagonal}`: The scale parameter of the initialization. If `nothing`, an identity matrix is used. + +The remaining keyword arguments are passed to `q_locationscale`. + +# Returns +- `q::Bijectors.TransformedDistribution`: A `AdvancedVI.LocationScale` distribution matching the support of `model`. +""" function q_meanfield_gaussian( rng::Random.AbstractRNG, model::DynamicPPL.Model; @@ -104,13 +176,30 @@ function q_meanfield_gaussian( scale::Union{Nothing,<:Diagonal}=nothing, kwargs..., ) - return q_locationscale_init(rng, model; location, scale, meanfield=true, basedist=Normal(), kwargs...) + return q_locationscale(rng, model; location, scale, meanfield=true, basedist=Normal(), kwargs...) end function q_meanfield_gaussian(model::DynamicPPL.Model; kwargs...) return q_meanfield_gaussian(Random.default_rng(), model; kwargs...) end +""" + q_fullrank_gaussian([rng, ]model; location, scale, kwargs...) + +Find a numerically non-degenerate Gaussian `q` with a dense scale (traditionally referred to as "full-rank") for approximating the target `model`. + +# Arguments +- `model::DynamicPPL.Model`: The target `DynamicPPL.Model`. + +# Keyword Arguments +- `location::Union{Nothing,<:AbstractVector}`: The location parameter of the initialization. If `nothing`, a vector of zeros is used. +- `scale::Union{Nothing,<:LowerTriangular}`: The scale parameter of the initialization. If `nothing`, an identity matrix is used. + +The remaining keyword arguments are passed to `q_locationscale`. + +# Returns +- `q::Bijectors.TransformedDistribution`: A `AdvancedVI.LocationScale` distribution matching the support of `model`. +""" function q_fullrank_gaussian( rng::Random.AbstractRNG, model::DynamicPPL.Model; @@ -118,7 +207,7 @@ function q_fullrank_gaussian( scale::Union{Nothing,<:LowerTriangular}=nothing, kwargs..., ) - return q_locationscale_init( + return q_locationscale( rng, model; location, scale, meanfield=false, basedist=Normal(), kwargs... ) end @@ -127,6 +216,33 @@ function q_fullrank_gaussian(model::DynamicPPL.Model; kwargs...) return q_fullrank_gaussian(Random.default_rng(), model; kwargs...) end +""" + vi([rng, ]model, q, n_iterations; objective, show_progress, optimizer, averager, operator, adtype, kwargs...) + +Approximating the target `model` via variational inference by optimizing `objective` with the initialization `q`. +This is a thin wrapper around `AdvancedVI.optimize`. + +# Arguments +- `model::DynamicPPL.Model`: The target `DynamicPPL.Model`. +- `q`: The initial variational approximation. +- `n_iterations::Int`: Number of optimization steps. + +# Keyword Arguments +- `objective::AdvancedVI.AbstractVariationalObjective`: Variational objective to be optimized. +- `show_progress::Bool`: Whether to show the progress bar. (Default: `Turing.PROGRESS[]`.) +- `optimizer::Optimisers.AbstractRule`: Optimization algorithm. (Default: `AdvancedVI.DoWG`.) +- `averager::AdvancedVI.AbstractAverager`: Parameter averaging strategy. (Default: `AdvancedVI.PolynomialAveraging()`) +- `operator::AdvancedVI.AbstractOperator`: Operator applied after each optimization step. (Default: `AdvancedVI.ProximalLocationScaleEntropy()`.) +- `adtype::ADTypes.AbstractADType`: Automatic differentiation backend. (Default: `Turing.DEFAULT_ADTYPE`) + +See the docs of `AvancedVI.optimize` for additional keyword arguments. + +# Returns +- `q`: Variational distribution formed by the last iterate of the optimization run. +- `q_avg`: Variational distribution formed by the averaged iterates according to `averager`. +- `state`: Collection of states used for optimization. This can be used to resume from a past call to `vi`. +- `info`: Information generated during the optimization run. +""" function vi( rng::Random.AbstractRNG, model::DynamicPPL.Model, From 4714c3cabdadcb06db67169ff41b4d8b9467a023 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 23 May 2025 19:03:05 -0400 Subject: [PATCH 36/52] fix run formatter --- src/variational/VariationalInference.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 7190425e4d..e34160f0b8 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -176,7 +176,9 @@ function q_meanfield_gaussian( scale::Union{Nothing,<:Diagonal}=nothing, kwargs..., ) - return q_locationscale(rng, model; location, scale, meanfield=true, basedist=Normal(), kwargs...) + return q_locationscale( + rng, model; location, scale, meanfield=true, basedist=Normal(), kwargs... + ) end function q_meanfield_gaussian(model::DynamicPPL.Model; kwargs...) @@ -248,7 +250,9 @@ function vi( model::DynamicPPL.Model, q, n_iterations::Int; - objective=AdvancedVI.RepGradELBO(10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()), + objective=AdvancedVI.RepGradELBO( + 10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient() + ), show_progress::Bool=PROGRESS[], optimizer=AdvancedVI.DoWG(), averager=AdvancedVI.PolynomialAveraging(), From f7127555e1ee3dd12fb9697775b7cf9965bfac18 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 23 May 2025 19:05:52 -0400 Subject: [PATCH 37/52] fix remove debug commits --- test/variational/advi.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/variational/advi.jl b/test/variational/advi.jl index 20e7ed6841..76c1add83f 100644 --- a/test/variational/advi.jl +++ b/test/variational/advi.jl @@ -23,7 +23,6 @@ using Turing.Variational μ = ones(d) q = q_meanfield_gaussian(m; location=μ) - println(q.dist.location) @assert mean(q.dist) ≈ μ q = q_fullrank_gaussian(m; location=μ) @@ -36,8 +35,6 @@ using Turing.Variational L = LowerTriangular(tril(0.01*ones(d, d) + I)) q = q_fullrank_gaussian(m; scale=L) - println(cov(q.dist)) - println(L*L') @assert cov(q.dist) ≈ L*L' end From 808639812111debd091cfb4f592002234d2c8b12 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 24 May 2025 15:05:07 -0400 Subject: [PATCH 38/52] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/variational/advi.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/variational/advi.jl b/test/variational/advi.jl index 76c1add83f..b1a918d799 100644 --- a/test/variational/advi.jl +++ b/test/variational/advi.jl @@ -31,7 +31,7 @@ using Turing.Variational L = Diagonal(fill(0.1, d)) q = q_meanfield_gaussian(m; scale=L) - @assert cov(q.dist) ≈ L*L + @assert cov(q.dist) ≈ L * L L = LowerTriangular(tril(0.01*ones(d, d) + I)) q = q_fullrank_gaussian(m; scale=L) From 37f6b06b3e9dd4c4feea61ccf99db422c8fd1eb3 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 24 May 2025 15:05:14 -0400 Subject: [PATCH 39/52] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/variational/advi.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/variational/advi.jl b/test/variational/advi.jl index b1a918d799..c666cd05d8 100644 --- a/test/variational/advi.jl +++ b/test/variational/advi.jl @@ -33,7 +33,7 @@ using Turing.Variational q = q_meanfield_gaussian(m; scale=L) @assert cov(q.dist) ≈ L * L - L = LowerTriangular(tril(0.01*ones(d, d) + I)) + L = LowerTriangular(tril(0.01 * ones(d, d) + I)) q = q_fullrank_gaussian(m; scale=L) @assert cov(q.dist) ≈ L*L' end From c71722057e0b21ba355fbafbcc101e806d3e1d1d Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 24 May 2025 15:05:22 -0400 Subject: [PATCH 40/52] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/variational/advi.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/variational/advi.jl b/test/variational/advi.jl index c666cd05d8..d08ec7ee20 100644 --- a/test/variational/advi.jl +++ b/test/variational/advi.jl @@ -35,7 +35,7 @@ using Turing.Variational L = LowerTriangular(tril(0.01 * ones(d, d) + I)) q = q_fullrank_gaussian(m; scale=L) - @assert cov(q.dist) ≈ L*L' + @assert cov(q.dist) ≈ L * L' end @testset "default interface" begin From e9f7f1e43cb5106f40d72c17491a7cc7e0da44f6 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 24 May 2025 15:50:25 -0400 Subject: [PATCH 41/52] add Variational submodule --- docs/make.jl | 7 +++++-- docs/src/api/Variational.md | 6 ++++++ 2 files changed, 11 insertions(+), 2 deletions(-) create mode 100644 docs/src/api/Variational.md diff --git a/docs/make.jl b/docs/make.jl index 978e5881b3..af24e7b1ec 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -23,8 +23,11 @@ makedocs(; pages=[ "Home" => "index.md", "API" => "api.md", - "Submodule APIs" => - ["Inference" => "api/Inference.md", "Optimisation" => "api/Optimisation.md"], + "Submodule APIs" => [ + "Inference" => "api/Inference.md", + "Optimisation" => "api/Optimisation.md", + "Variational " => "api/Variational.md", + ], ], checkdocs=:exports, doctest=false, diff --git a/docs/src/api/Variational.md b/docs/src/api/Variational.md new file mode 100644 index 0000000000..382efe7e18 --- /dev/null +++ b/docs/src/api/Variational.md @@ -0,0 +1,6 @@ +# API: `Turing.Variational` + +```@autodocs +Modules = [Turing.Variational] +Order = [:type, :function] +``` From 6a8c6edde0c0f3a0bb5531938a46a93800505414 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 24 May 2025 16:16:47 -0400 Subject: [PATCH 42/52] fix docstring style --- src/variational/VariationalInference.jl | 49 ++++++++++++++++++------- test/variational/advi.jl | 1 - 2 files changed, 35 insertions(+), 15 deletions(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index e34160f0b8..1aa8086167 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -24,7 +24,16 @@ function make_logdensity(model::DynamicPPL.Model) end """ - q_initialize_scale([rng, ]model, location, scale, basedist; num_samples, num_max_trials, reduce_factor) + q_initialize_scale( + [rng::Random.AbstractRNG,] + model::DynamicPPL.Model, + location::AbstractVector, + scale::AbstractMatrix, + basedist::Distributions.UnivariateDistribution; + num_samples::Int=10, + num_max_trials::Int=10, + reduce_factor::Real=one(eltype(scale)) / 2 + ) Given an initial location-scale distribution `q` formed by `location`, `scale`, and `basedist`, shrink `scale` until the expectation of log-densities of `model` taken over `q` are finite. If the log-densities are not finite even after `num_max_trials`, throw an error. @@ -36,15 +45,15 @@ z = scale * u + location ``` # Arguments -- `model::DynamicPPL.Model`: The target `DynamicPPL.Model`. -- `location::AbstractVector`: The location parameter of the initialization. -- `scale::AbstractMatrix`: The scale parameter of the initialization. -- `basedist::Distributions.UnivariateDistribution`: The base distribution of the location-scale family. +- `model`: The target `DynamicPPL.Model`. +- `location`: The location parameter of the initialization. +- `scale`: The scale parameter of the initialization. +- `basedist`: The base distribution of the location-scale family. # Keyword Arguments -- `num_samples::Int`: Number of samples used to compute the average log-density at each trial. (Default: `10`.) -- `num_max_trials::Int`: Number of trials until throwing an error. (Default: `10`.) -- `reduce_factor::Real`: Factor for shrinking the scale. After `n` trials, the scale is then `scale*reduce_factor^n`. (Default: `0.5`.) +- `num_samples`: Number of samples used to compute the average log-density at each trial. +- `num_max_trials`: Number of trials until throwing an error. +- `reduce_factor`: Factor for shrinking the scale. After `n` trials, the scale is then `scale*reduce_factor^n`. # Returns - `scale_adj`: The adjusted scale matrix matching the type of `scale`. @@ -82,7 +91,14 @@ function q_initialize_scale( end """ - q_locationscale([rng, ]model; location, scale, meanfield, basedist) + q_locationscale( + [rng::Random.AbstractRNG,] + model::DynamicPPL.Model; + location::Union{Nothing,<:AbstractVector}, + scale::Union{Nothing,<:Diagonal,<:LowerTriangular}, + meanfield::Bool=true, + basedist::Distributions.UnivariateDistribution + ) Find a numerically non-degenerate variational distribution `q` for approximating the target `model` within the location-scale variational family formed by the type of `scale` and `basedist`. @@ -98,12 +114,13 @@ z = scale * u + location ``` # Arguments -- `model::DynamicPPL.Model`: The target `DynamicPPL.Model`. +- `model`: The target `DynamicPPL.Model`. # Keyword Arguments -- `location::Union{Nothing,<:AbstractVector}`: The location parameter of the initialization. If `nothing`, a vector of zeros is used. -- `scale::Union{Nothing,<:Diagonal,<:LowerTriangular}`: The scale parameter of the initialization. If `nothing`, an identity matrix is used. -- `basedist::Distributions.UnivariateDistribution`: The distribution +- `location`: The location parameter of the initialization. If `nothing`, a vector of zeros is used. +- `scale`: The scale parameter of the initialization. If `nothing`, an identity matrix is used. +- `meanfield`: Whether to use the mean-field approximation. If `true`, `scale` is converted into a `Diagonal` matrix. Otherwise, it is converted into a `LowerTriangular` matrix. +- `basedist`: The base distribution of the location-scale family. The remaining keywords are passed to `q_initialize_scale`. @@ -144,7 +161,7 @@ function q_locationscale( if meanfield Diagonal(diag(scale)) else - scale + LowerTriangular(Matrix(scale)) end end q = AdvancedVI.MvLocationScale(μ, L, basedist) @@ -152,6 +169,10 @@ function q_locationscale( return Bijectors.transformed(q, Bijectors.inverse(b)) end +function q_locationscale(model::DynamicPPL.Model; kwargs...) + return q_locationscale(Random.default_rng(), model; kwargs...) +end + """ q_meanfield_gaussian([rng, ]model; location, scale, kwargs...) diff --git a/test/variational/advi.jl b/test/variational/advi.jl index d08ec7ee20..44bbe83945 100644 --- a/test/variational/advi.jl +++ b/test/variational/advi.jl @@ -26,7 +26,6 @@ using Turing.Variational @assert mean(q.dist) ≈ μ q = q_fullrank_gaussian(m; location=μ) - println(q.dist.location) @assert mean(q.dist) ≈ μ L = Diagonal(fill(0.1, d)) From c4d73fbc2d5c84f533d97fda379fbd29bba049fa Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 26 May 2025 15:08:16 -0400 Subject: [PATCH 43/52] update docstring style --- src/variational/VariationalInference.jl | 62 ++++++++++++++++++------- 1 file changed, 44 insertions(+), 18 deletions(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 1aa8086167..063a0ba27a 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -174,16 +174,22 @@ function q_locationscale(model::DynamicPPL.Model; kwargs...) end """ - q_meanfield_gaussian([rng, ]model; location, scale, kwargs...) + q_meanfield_gaussian( + [rng::Random.AbstractRNG,] + model::DynamicPPL.Model; + location::Union{Nothing,<:AbstractVector}, + scale::Union{Nothing,<:Diagonal}, + kwargs... + ) Find a numerically non-degenerate mean-field Gaussian `q` for approximating the target `model`. # Arguments -- `model::DynamicPPL.Model`: The target `DynamicPPL.Model`. +- `model`: The target `DynamicPPL.Model`. # Keyword Arguments -- `location::Union{Nothing,<:AbstractVector}`: The location parameter of the initialization. If `nothing`, a vector of zeros is used. -- `scale::Union{Nothing,<:Diagonal}`: The scale parameter of the initialization. If `nothing`, an identity matrix is used. +- `location`: The location parameter of the initialization. If `nothing`, a vector of zeros is used. +- `scale`: The scale parameter of the initialization. If `nothing`, an identity matrix is used. The remaining keyword arguments are passed to `q_locationscale`. @@ -207,16 +213,22 @@ function q_meanfield_gaussian(model::DynamicPPL.Model; kwargs...) end """ - q_fullrank_gaussian([rng, ]model; location, scale, kwargs...) + q_fullrank_gaussian( + [rng::Random.AbstractRNG,] + model::DynamicPPL.Model; + location::Union{Nothing,<:AbstractVector}, + scale::Union{Nothing,<:LowerTriangular}, + kwargs... + ) -Find a numerically non-degenerate Gaussian `q` with a dense scale (traditionally referred to as "full-rank") for approximating the target `model`. +Find a numerically non-degenerate Gaussian `q` with a scale with full-rank factors (traditionally referred to as a "full-rank family") for approximating the target `model`. # Arguments -- `model::DynamicPPL.Model`: The target `DynamicPPL.Model`. +- `model`: The target `DynamicPPL.Model`. # Keyword Arguments -- `location::Union{Nothing,<:AbstractVector}`: The location parameter of the initialization. If `nothing`, a vector of zeros is used. -- `scale::Union{Nothing,<:LowerTriangular}`: The scale parameter of the initialization. If `nothing`, an identity matrix is used. +- `location`: The location parameter of the initialization. If `nothing`, a vector of zeros is used. +- `scale`: The scale parameter of the initialization. If `nothing`, an identity matrix is used. The remaining keyword arguments are passed to `q_locationscale`. @@ -240,23 +252,37 @@ function q_fullrank_gaussian(model::DynamicPPL.Model; kwargs...) end """ - vi([rng, ]model, q, n_iterations; objective, show_progress, optimizer, averager, operator, adtype, kwargs...) + vi( + [rng::Random.AbstractRNG,] + model::DynamicPPL.Model; + q, + n_iterations::Int; + objective::AdvancedVI.AbstractVariationalObjective=AdvancedVI.RepGradELBO( + 10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient() + ), + show_progress::Bool=Turing.PROGRESS[], + optimizer::Optimisers.AbstractRule=AdvancedVI.DoWG(), + averager::AdvancedVI.AbstractAverager=AdvancedVI.PolynomialAveraging(), + operator::AdvancedVI.AbstractOperator=AdvancedVI.ProximalLocationScaleEntropy(), + adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE, + kwargs... + ) Approximating the target `model` via variational inference by optimizing `objective` with the initialization `q`. This is a thin wrapper around `AdvancedVI.optimize`. # Arguments -- `model::DynamicPPL.Model`: The target `DynamicPPL.Model`. +- `model`: The target `DynamicPPL.Model`. - `q`: The initial variational approximation. -- `n_iterations::Int`: Number of optimization steps. +- `n_iterations`: Number of optimization steps. # Keyword Arguments -- `objective::AdvancedVI.AbstractVariationalObjective`: Variational objective to be optimized. -- `show_progress::Bool`: Whether to show the progress bar. (Default: `Turing.PROGRESS[]`.) -- `optimizer::Optimisers.AbstractRule`: Optimization algorithm. (Default: `AdvancedVI.DoWG`.) -- `averager::AdvancedVI.AbstractAverager`: Parameter averaging strategy. (Default: `AdvancedVI.PolynomialAveraging()`) -- `operator::AdvancedVI.AbstractOperator`: Operator applied after each optimization step. (Default: `AdvancedVI.ProximalLocationScaleEntropy()`.) -- `adtype::ADTypes.AbstractADType`: Automatic differentiation backend. (Default: `Turing.DEFAULT_ADTYPE`) +- `objective`: Variational objective to be optimized. +- `show_progress`: Whether to show the progress bar. +- `optimizer`: Optimization algorithm. +- `averager`: Parameter averaging strategy. +- `operator`: Operator applied after each optimization step. +- `adtype`: Automatic differentiation backend. See the docs of `AvancedVI.optimize` for additional keyword arguments. From feb1a57c301d26595e7a79d4451acebfc3cc83c9 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 26 May 2025 15:23:20 -0400 Subject: [PATCH 44/52] format docstring style --- src/variational/VariationalInference.jl | 36 ++++++++++++------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 063a0ba27a..0102631cf0 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -30,9 +30,9 @@ end location::AbstractVector, scale::AbstractMatrix, basedist::Distributions.UnivariateDistribution; - num_samples::Int=10, - num_max_trials::Int=10, - reduce_factor::Real=one(eltype(scale)) / 2 + num_samples::Int = 10, + num_max_trials::Int = 10, + reduce_factor::Real = one(eltype(scale)) / 2 ) Given an initial location-scale distribution `q` formed by `location`, `scale`, and `basedist`, shrink `scale` until the expectation of log-densities of `model` taken over `q` are finite. @@ -94,10 +94,10 @@ end q_locationscale( [rng::Random.AbstractRNG,] model::DynamicPPL.Model; - location::Union{Nothing,<:AbstractVector}, - scale::Union{Nothing,<:Diagonal,<:LowerTriangular}, - meanfield::Bool=true, - basedist::Distributions.UnivariateDistribution + location::Union{Nothing,<:AbstractVector} = nothing, + scale::Union{Nothing,<:Diagonal,<:LowerTriangular} = nothing, + meanfield::Bool = true, + basedist::Distributions.UnivariateDistribution = Normal() ) Find a numerically non-degenerate variational distribution `q` for approximating the target `model` within the location-scale variational family formed by the type of `scale` and `basedist`. @@ -177,8 +177,8 @@ end q_meanfield_gaussian( [rng::Random.AbstractRNG,] model::DynamicPPL.Model; - location::Union{Nothing,<:AbstractVector}, - scale::Union{Nothing,<:Diagonal}, + location::Union{Nothing,<:AbstractVector} = nothing, + scale::Union{Nothing,<:Diagonal} = nothing, kwargs... ) @@ -216,8 +216,8 @@ end q_fullrank_gaussian( [rng::Random.AbstractRNG,] model::DynamicPPL.Model; - location::Union{Nothing,<:AbstractVector}, - scale::Union{Nothing,<:LowerTriangular}, + location::Union{Nothing,<:AbstractVector} = nothing, + scale::Union{Nothing,<:LowerTriangular} = nothing, kwargs... ) @@ -257,14 +257,14 @@ end model::DynamicPPL.Model; q, n_iterations::Int; - objective::AdvancedVI.AbstractVariationalObjective=AdvancedVI.RepGradELBO( - 10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient() + objective::AdvancedVI.AbstractVariationalObjective = AdvancedVI.RepGradELBO( + 10; entropy = AdvancedVI.ClosedFormEntropyZeroGradient() ), - show_progress::Bool=Turing.PROGRESS[], - optimizer::Optimisers.AbstractRule=AdvancedVI.DoWG(), - averager::AdvancedVI.AbstractAverager=AdvancedVI.PolynomialAveraging(), - operator::AdvancedVI.AbstractOperator=AdvancedVI.ProximalLocationScaleEntropy(), - adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE, + show_progress::Bool = Turing.PROGRESS[], + optimizer::Optimisers.AbstractRule = AdvancedVI.DoWG(), + averager::AdvancedVI.AbstractAverager = AdvancedVI.PolynomialAveraging(), + operator::AdvancedVI.AbstractOperator = AdvancedVI.ProximalLocationScaleEntropy(), + adtype::ADTypes.AbstractADType = Turing.DEFAULT_ADTYPE, kwargs... ) From 4c9a538d7ac11e2e6d2eae837c1f5a5965f65c7c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 30 May 2025 17:10:02 -0400 Subject: [PATCH 45/52] fix typo Co-authored-by: Penelope Yong --- src/variational/VariationalInference.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 0102631cf0..b9428af112 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -284,7 +284,7 @@ This is a thin wrapper around `AdvancedVI.optimize`. - `operator`: Operator applied after each optimization step. - `adtype`: Automatic differentiation backend. -See the docs of `AvancedVI.optimize` for additional keyword arguments. +See the docs of `AdvancedVI.optimize` for additional keyword arguments. # Returns - `q`: Variational distribution formed by the last iterate of the optimization run. From dfa8d201cf18d858f64936d966fb9c52152ef262 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 30 May 2025 18:39:20 -0400 Subject: [PATCH 46/52] fix use fixed seed with StableRNGs --- test/variational/advi.jl | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/test/variational/advi.jl b/test/variational/advi.jl index 44bbe83945..ed8f745df2 100644 --- a/test/variational/advi.jl +++ b/test/variational/advi.jl @@ -3,12 +3,14 @@ module AdvancedVITests using ..Models: gdemo_default using ..NumericalTests: check_gdemo + using AdvancedVI using Bijectors: Bijectors using Distributions: Dirichlet, Normal using LinearAlgebra using MCMCChains: Chains -import Random +using Random +using StableRNGs: StableRNG using Test: @test, @testset using Turing using Turing.Variational @@ -38,12 +40,9 @@ using Turing.Variational end @testset "default interface" begin - Random.seed!(0) - N = 500 - for q0 in [q_meanfield_gaussian(gdemo_default), q_fullrank_gaussian(gdemo_default)] - _, q, _, _ = vi(gdemo_default, q0, N; show_progress=Turing.PROGRESS[]) - c1 = rand(q, N) + _, q, _, _ = vi(gdemo_default, q0, 100; show_progress=Turing.PROGRESS[]) + c1 = rand(q, 10) end end @@ -67,7 +66,6 @@ using Turing.Variational AdvancedVI.DoG(), ), ] - Random.seed!(0) T = 1000 q, q_avg, _, _ = vi( gdemo_default, @@ -104,9 +102,11 @@ using Turing.Variational AdvancedVI.DoG(), ), ] - Random.seed!(0) + rng = StableRNG(0x517e1d9bf89bf94f) + T = 1000 q, q_avg, _, _ = vi( + rng, gdemo_default, q_meanfield_gaussian(gdemo_default), T; @@ -116,7 +116,7 @@ using Turing.Variational N = 1000 for q_out in [q_avg, q] - samples = transpose(rand(q_out, N)) + samples = transpose(rand(rng, q_out, N)) chn = Chains(reshape(samples, size(samples)..., 1), ["s", "m"]) check_gdemo(chn; atol=0.5) @@ -126,11 +126,12 @@ using Turing.Variational # regression test for: # https://github.com/TuringLang/Turing.jl/issues/2065 @testset "simplex bijector" begin + rng = StableRNG(0x517e1d9bf89bf94f) + @model function dirichlet() x ~ Dirichlet([1.0, 1.0]) return x end - Random.seed!(0) m = dirichlet() b = Bijectors.bijector(m) @@ -142,25 +143,27 @@ using Turing.Variational @test all(x0 .≈ x0_inv) # And regression for https://github.com/TuringLang/Turing.jl/issues/2160. - _, q, _, _ = vi(m, q_meanfield_gaussian(m), 1000) - x = rand(q, 1000) + _, q, _, _ = vi(rng, m, q_meanfield_gaussian(m), 1000) + x = rand(rng, q, 1000) @test mean(eachcol(x)) ≈ [0.5, 0.5] atol = 0.1 end # Ref: https://github.com/TuringLang/Turing.jl/issues/2205 @testset "with `condition` (issue #2205)" begin + rng = StableRNG(0x517e1d9bf89bf94f) + @model function demo_issue2205() x ~ Normal() return y ~ Normal(x, 1) end model = demo_issue2205() | (y=1.0,) - _, q, _, _ = vi(model, q_meanfield_gaussian(model), 1000) + _, q, _, _ = vi(rng, model, q_meanfield_gaussian(model), 1000) # True mean. mean_true = 1 / 2 var_true = 1 / 2 # Check the mean and variance of the posterior. - samples = rand(q, 1000) + samples = rand(rng, q, 1000) mean_est = mean(samples) var_est = var(samples) @test mean_est ≈ mean_true atol = 0.2 From a18f5818430ffc2541a715718be76017d4591572 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 30 May 2025 18:47:11 -0400 Subject: [PATCH 47/52] fix export variational families --- src/Turing.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/Turing.jl b/src/Turing.jl index 4bd3058906..a4cee0451f 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -114,6 +114,9 @@ export # Variational inference - AdvancedVI vi, ADVI, + q_locationscale, + q_meanfield_gaussian, + q_fullrank_gaussian, # ADTypes AutoForwardDiff, AutoReverseDiff, From f9528e0f51fa7a86d9168f58760ecdfe12912e1f Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 30 May 2025 18:52:17 -0400 Subject: [PATCH 48/52] fix forma Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/Turing.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Turing.jl b/src/Turing.jl index a4cee0451f..8004e78962 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -114,9 +114,9 @@ export # Variational inference - AdvancedVI vi, ADVI, - q_locationscale, - q_meanfield_gaussian, - q_fullrank_gaussian, + q_locationscale, + q_meanfield_gaussian, + q_fullrank_gaussian, # ADTypes AutoForwardDiff, AutoReverseDiff, From dec108bf5443cab6668ca6cd1a391b2516a7aeca Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 2 Jun 2025 18:42:35 -0400 Subject: [PATCH 49/52] update changelog for advancedvi 0.4 --- HISTORY.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index f8ef0e2e54..37823cca14 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,13 @@ + +# Release 0.39 +The interface for `AdvancedVI` was updated to match v0.4 version v0.2 of `AdvancedVI`. +The v0.4 version of `AdvancedVI` introduces various new features: +- location-scale families with dense scale matrices, +- parameter-free stochastic optimization algorithms like `DoG` and `DoWG`, +- proximal operators for stable optimization, +- the sticking-the-landing control variate for faster convergence, and +- the score gradient estimator for non-differentiable targets. + # Release 0.38.4 The minimum Julia version was increased to 1.10.2 (from 1.10.0). From b0d791e5bc1c50eab3ea06c40c2d147581e17f76 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 2 Jun 2025 18:44:00 -0400 Subject: [PATCH 50/52] fix version number --- HISTORY.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index 37823cca14..f6130f2453 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,5 @@ -# Release 0.39 +# Release 0.39.0 The interface for `AdvancedVI` was updated to match v0.4 version v0.2 of `AdvancedVI`. The v0.4 version of `AdvancedVI` introduces various new features: - location-scale families with dense scale matrices, From 29373eef6819e99aeef1a736e8478e90bbfd4540 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 3 Jun 2025 00:22:52 +0100 Subject: [PATCH 51/52] Format & add some links --- HISTORY.md | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index f6130f2453..34264cc959 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,12 +1,16 @@ - # Release 0.39.0 -The interface for `AdvancedVI` was updated to match v0.4 version v0.2 of `AdvancedVI`. -The v0.4 version of `AdvancedVI` introduces various new features: -- location-scale families with dense scale matrices, -- parameter-free stochastic optimization algorithms like `DoG` and `DoWG`, -- proximal operators for stable optimization, -- the sticking-the-landing control variate for faster convergence, and -- the score gradient estimator for non-differentiable targets. + +Turing's variational inference interface was updated to match version 0.4 version of AdvancedVI.jl. + +AdvancedVI v0.4 introduces various new features: + + - location-scale families with dense scale matrices, + - parameter-free stochastic optimization algorithms like `DoG` and `DoWG`, + - proximal operators for stable optimization, + - the sticking-the-landing control variate for faster convergence, and + - the score gradient estimator for non-differentiable targets. + +Please see the [Turing API documentation](https://turinglang.org/Turing.jl/stable/api/#Variational-inference), and [AdvancedVI's documentation](https://turinglang.org/AdvancedVI.jl/stable/), for more details. # Release 0.38.4 From 4c72501e650f3ec990435c8a09733d8ee84b2299 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 3 Jun 2025 08:24:26 -0400 Subject: [PATCH 52/52] fix formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- HISTORY.md | 1 + 1 file changed, 1 insertion(+) diff --git a/HISTORY.md b/HISTORY.md index 180a17a4bb..1265fafabc 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,6 +1,7 @@ # Release 0.39.0 ## Update to the AdvancedVI interface + Turing's variational inference interface was updated to match version 0.4 version of AdvancedVI.jl. AdvancedVI v0.4 introduces various new features: