Skip to content

Commit 1dcd37c

Browse files
committed
Fix imports
1 parent a8bedc1 commit 1dcd37c

File tree

5 files changed

+26
-25
lines changed

5 files changed

+26
-25
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.39.1"
55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
77
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
8+
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
89
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
910
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
1011
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
@@ -49,6 +50,7 @@ TuringOptimExt = "Optim"
4950
[compat]
5051
ADTypes = "1.9"
5152
AbstractMCMC = "5.5"
53+
AbstractPPL = "0.11.0"
5254
Accessors = "0.1"
5355
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6, 0.7, 0.8"
5456
AdvancedMH = "0.8"

src/mcmc/Inference.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
module Inference
22

33
using DynamicPPL:
4+
DynamicPPL,
45
@model,
56
Metadata,
67
VarInfo,
8+
LogDensityFunction,
9+
SimpleVarInfo,
10+
AbstractVarInfo,
711
# TODO(mhauru) all_varnames_grouped_by_symbol isn't exported by DPPL, because it is only
812
# implemented for NTVarInfo. It is used by mh.jl. Either refactor mh.jl to not use it
913
# or implement it for other VarInfo types and export it from DPPL.
@@ -24,6 +28,7 @@ using DynamicPPL:
2428
DefaultContext,
2529
PriorContext,
2630
LikelihoodContext,
31+
SamplingContext,
2732
set_flag!,
2833
unset_flag!
2934
using Distributions, Libtask, Bijectors
@@ -32,14 +37,14 @@ using LinearAlgebra
3237
using ..Turing: PROGRESS, Turing
3338
using StatsFuns: logsumexp
3439
using Random: AbstractRNG
35-
using DynamicPPL
3640
using AbstractMCMC: AbstractModel, AbstractSampler
3741
using DocStringExtensions: FIELDS, TYPEDEF, TYPEDFIELDS
38-
using DataStructures: OrderedSet
42+
using DataStructures: OrderedSet, OrderedDict
3943
using Accessors: Accessors
4044

4145
import ADTypes
4246
import AbstractMCMC
47+
import AbstractPPL
4348
import AdvancedHMC
4449
const AHMC = AdvancedHMC
4550
import AdvancedMH
@@ -74,8 +79,6 @@ export InferenceAlgorithm,
7479
PG,
7580
RepeatSampler,
7681
Prior,
77-
assume,
78-
observe,
7982
predict,
8083
externalsampler
8184

src/mcmc/gibbs.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarNa
124124
end
125125

126126
function is_target_varname(ctx::GibbsContext, vn::VarName)
127-
return any(Base.Fix2(subsumes, vn), ctx.target_varnames)
127+
return any(Base.Fix2(AbstractPPL.subsumes, vn), ctx.target_varnames)
128128
end
129129

130130
function is_target_varname(context::GibbsContext, vns::AbstractArray{<:VarName})
@@ -660,7 +660,7 @@ function gibbs_step_recursive(
660660

661661
# Construct the conditional model and the varinfo that this sampler should use.
662662
conditioned_model, context = make_conditional(model, varnames, global_vi)
663-
vi = subset(global_vi, varnames)
663+
vi = DynamicPPL.subset(global_vi, varnames)
664664
vi = match_linking!!(vi, state, model)
665665

666666
# TODO(mhauru) The below may be overkill. If the varnames for this sampler are not

src/mcmc/mh.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ end
289289

290290
function maybe_link!!(varinfo, sampler, proposal, model)
291291
return if should_link(varinfo, sampler, proposal)
292-
link!!(varinfo, model)
292+
DynamicPPL.link!!(varinfo, model)
293293
else
294294
varinfo
295295
end

src/mcmc/particle_mcmc.jl

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,10 @@ function DynamicPPL.initialstep(
193193
kwargs...,
194194
)
195195
# Reset the VarInfo.
196-
reset_num_produce!(vi)
197-
set_retained_vns_del!(vi)
198-
resetlogp!!(vi)
199-
empty!!(vi)
196+
DynamicPPL.reset_num_produce!(vi)
197+
DynamicPPL.set_retained_vns_del!(vi)
198+
DynamicPPL.resetlogp!!(vi)
199+
DynamicPPL.empty!!(vi)
200200

201201
# Create a new set of particles.
202202
particles = AdvancedPS.ParticleContainer(
@@ -327,9 +327,9 @@ function DynamicPPL.initialstep(
327327
kwargs...,
328328
)
329329
# Reset the VarInfo before new sweep
330-
reset_num_produce!(vi)
331-
set_retained_vns_del!(vi)
332-
resetlogp!!(vi)
330+
DynamicPPL.reset_num_produce!(vi)
331+
DynamicPPL.set_retained_vns_del!(vi)
332+
DynamicPPL.resetlogp!!(vi)
333333

334334
# Create a new set of particles
335335
num_particles = spl.alg.nparticles
@@ -359,14 +359,14 @@ function AbstractMCMC.step(
359359
)
360360
# Reset the VarInfo before new sweep.
361361
vi = state.vi
362-
reset_num_produce!(vi)
363-
resetlogp!!(vi)
362+
DynamicPPL.reset_num_produce!(vi)
363+
DynamicPPL.resetlogp!!(vi)
364364

365365
# Create reference particle for which the samples will be retained.
366366
reference = AdvancedPS.forkr(AdvancedPS.Trace(model, spl, vi, state.rng))
367367

368368
# For all other particles, do not retain the variables but resample them.
369-
set_retained_vns_del!(vi)
369+
DynamicPPL.set_retained_vns_del!(vi)
370370

371371
# Create a new set of particles.
372372
num_particles = spl.alg.nparticles
@@ -429,23 +429,19 @@ function trace_local_rng_maybe(rng::Random.AbstractRNG)
429429
end
430430

431431
function DynamicPPL.assume(
432-
rng,
433-
spl::Sampler{<:Union{PG,SMC}},
434-
dist::Distribution,
435-
vn::VarName,
436-
_vi::AbstractVarInfo,
432+
rng, ::Sampler{<:Union{PG,SMC}}, dist::Distribution, vn::VarName, _vi::AbstractVarInfo
437433
)
438434
vi = trace_local_varinfo_maybe(_vi)
439435
trng = trace_local_rng_maybe(rng)
440436

441437
if ~haskey(vi, vn)
442438
r = rand(trng, dist)
443439
push!!(vi, vn, r, dist)
444-
elseif is_flagged(vi, vn, "del")
445-
unset_flag!(vi, vn, "del") # Reference particle parent
440+
elseif DynamicPPL.is_flagged(vi, vn, "del")
441+
DynamicPPL.unset_flag!(vi, vn, "del") # Reference particle parent
446442
r = rand(trng, dist)
447443
vi[vn] = DynamicPPL.tovec(r)
448-
setorder!(vi, vn, get_num_produce(vi))
444+
DynamicPPL.setorder!(vi, vn, DynamicPPL.get_num_produce(vi))
449445
else
450446
r = vi[vn]
451447
end

0 commit comments

Comments
 (0)