Skip to content

Commit 1de9c70

Browse files
authored
Merge pull request #260 from Julia-Tempering/fix-259
Make variational refs elide AD buffering
2 parents 47d96dd + 65be78d commit 1de9c70

File tree

6 files changed

+50
-6
lines changed

6 files changed

+50
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Pigeons"
22
uuid = "0eb8d820-af6a-4919-95ae-11206f830c31"
33
authors = ["Alexandre Bouchard-Côté <bouchard@stat.ubc.ca>, Nikola Surjanovic <nikola.surjanovic@stat.ubc.ca>, Paul Tiede <ptiede91@gmail.com>, Trevor Campbell, Miguel Biron-Lattes, Saifuddin Syed"]
4-
version = "0.4.3"
4+
version = "0.4.4"
55

66
[deps]
77
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"

src/explorers/BufferedAD.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ future.
133133
exploration step.
134134
"""
135135
function get_buffer(a::Augmentation{<:Dict{Symbol, BufferedAD}}, key::Symbol, args...)
136-
dict = a.contents
136+
dict = a.contents
137137
if !haskey(dict, key)
138138
dict[key] = LogDensityProblemsAD.ADgradient(args...)
139139
end

src/includes.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ include("targets/DistributionLogPotential.jl")
8080
include("pt/checks.jl")
8181
include("explorers/BufferedAD.jl")
8282
include("variational/GaussianReference.jl")
83+
include("variational/VariationalReference.jl")
8384
include("paths/ScaledPrecisionNormalPath.jl")
8485
include("targets/toy_mvn_target.jl")
8586
include("explorers/AAPS.jl")

src/variational/GaussianReference.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ function LogDensityProblems.dimension(log_potential::GaussianReference)
6262
return length(log_potential.mean[:singleton_variable])
6363
end
6464

65-
LogDensityProblemsAD.ADgradient(kind::Val, log_potential::GaussianReference, buffers::Augmentation) =
66-
BufferedAD(log_potential, buffers)
65+
LogDensityProblemsAD.ADgradient(kind::Val, log_potential::GaussianReference, replica::Replica) =
66+
BufferedAD(log_potential, replica.recorders.buffers)
6767

6868
function LogDensityProblems.logdensity_and_gradient(log_potential::BufferedAD{GaussianReference}, x)
6969
variational = log_potential.enclosed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#=
2+
Methods common to all variational references
3+
=#
4+
5+
# Currently implemented variational references
6+
const VariationalReference = Union{GaussianReference}
7+
8+
# Elide the AD buffering system
9+
# Reasoning:
10+
# 1. Variational refs usually have analytic gradients anyway
11+
# 2. It can be challenging to distinguish between the proper reference and the
12+
# variational reference in the buffering system, especially since the var ref
13+
# is not activated immediately
14+
get_buffer(
15+
::Augmentation{<:Dict{Symbol, BufferedAD}},
16+
::Symbol,
17+
kind,
18+
log_potential::VariationalReference,
19+
replica::Replica) = LogDensityProblemsAD.ADgradient(kind, log_potential, replica)

test/test_BufferedAD.jl

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,26 @@ end
2222

2323
function test_BufferedAD_usage(pt)
2424
replica = last(pt.replicas)
25-
25+
backend = pt.shared.explorer.default_autodiff_backend
26+
2627
# BufferedAD were created and stored
2728
@test haskey(replica.recorders.ad_buffers.contents, :target)
2829
@test haskey(replica.recorders.ad_buffers.contents, :reference)
2930

3031
# ADgradient uses the stored BufferedAD
3132
int_lp = Pigeons.find_log_potential(replica, pt.shared.tempering, pt.shared)
32-
int_ad = ADgradient(pt.shared.explorer.default_autodiff_backend, int_lp, replica)
33+
int_ad = ADgradient(backend, int_lp, replica)
3334
@test int_ad isa Pigeons.InterpolatedAD
3435
@test int_ad.ref_ad === replica.recorders.ad_buffers.contents[:reference]
3536
@test int_ad.target_ad === replica.recorders.ad_buffers.contents[:target]
3637

38+
# check that the stored BufferedAD is the correct one
39+
# NB: can only check type equality because extra buffers can be different (e.g. for StanLogPotential)
40+
ref_ad = ADgradient(backend, int_lp.path.ref, replica)
41+
target_ad = ADgradient(backend, int_lp.path.target, replica)
42+
@test typeof(int_ad.ref_ad) === typeof(ref_ad)
43+
@test typeof(int_ad.target_ad) === typeof(target_ad)
44+
3745
# target and ref share the same gradient buffer
3846
if int_ad.ref_ad.buffer isa DiffResults.MutableDiffResult
3947
@test DiffResults.gradient(int_ad.ref_ad.buffer) === DiffResults.gradient(int_ad.target_ad.buffer)
@@ -100,3 +108,19 @@ end
100108
end
101109
Pigeons.set_tape_compilation_strategy!(true) # reverse setting
102110
end
111+
112+
@testset "Variational reference elides the AD augmentation" begin
113+
target = Pigeons.toy_stan_unid_target(100)
114+
pt = pigeons(
115+
target = target,
116+
variational = GaussianReference(),
117+
n_chains = 5,
118+
n_chains_variational = 5,
119+
n_rounds = 7
120+
)
121+
replica = pt.replicas[end];
122+
var_ref = pt.shared.tempering.variational_leg.path.ref
123+
var_ref_ad = Pigeons.get_buffer(replica.recorders.ad_buffers, :reference, Val(:ForwardDiff), var_ref, replica)
124+
@test var_ref_ad === ADgradient(Val(:ForwardDiff), var_ref, replica)
125+
@test var_ref_ad != Pigeons.get_buffer(replica.recorders.ad_buffers, :reference, Val(:ForwardDiff), target, replica)
126+
end

0 commit comments

Comments
 (0)