Skip to content

Commit 5dedc30

Browse files
authored
Patch un-imported Accessors in MCMCChainsExt (#239)
1 parent 26ef662 commit 5dedc30

File tree

3 files changed

+34
-1
lines changed

3 files changed

+34
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "JuliaBUGS"
22
uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
3-
version = "0.7.0"
3+
version = "0.7.1"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

ext/JuliaBUGSMCMCChainsExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using JuliaBUGS.AbstractPPL
66
using JuliaBUGS.BUGSPrimitives
77
using JuliaBUGS.LogDensityProblems
88
using JuliaBUGS.LogDensityProblemsAD
9+
using JuliaBUGS: Accessors
910
using AbstractMCMC
1011
using MCMCChains: Chains
1112

test/ext/mcmchains.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,4 +91,36 @@
9191
@test means[:beta].nt.mean[1] 2.1 atol = 0.2
9292
@test means[:sigma].nt.mean[1] 0.9 atol = 0.2
9393
@test means[:gen_quant].nt.mean[1] 4.2 atol = 0.2
94+
95+
# test for more complicated varnames
96+
model_def = @bugs begin
97+
A[1, 1:3] ~ Dirichlet(ones(3))
98+
A[2, 1:3] ~ Dirichlet(ones(3))
99+
A[3, 1:3] ~ Dirichlet(ones(3))
100+
101+
mu[1:3] ~ MvNormal(zeros(3), 10 * Diagonal(ones(3)))
102+
sigma[1] ~ InverseGamma(2, 3)
103+
sigma[2] ~ InverseGamma(2, 3)
104+
sigma[3] ~ InverseGamma(2, 3)
105+
end
106+
model = compile(model_def, (;))
107+
ad_model = ADgradient(:ReverseDiff, model; compile=Val(true))
108+
hmc_chain = AbstractMCMC.sample(ad_model, NUTS(0.8), 10; chain_type=Chains)
109+
@test hmc_chain.name_map[:parameters] == [
110+
Symbol("sigma[3]"),
111+
Symbol("sigma[2]"),
112+
Symbol("sigma[1]"),
113+
Symbol("mu[1:3][1]"),
114+
Symbol("mu[1:3][2]"),
115+
Symbol("mu[1:3][3]"),
116+
Symbol("A[3, 1:3][1]"),
117+
Symbol("A[3, 1:3][2]"),
118+
Symbol("A[3, 1:3][3]"),
119+
Symbol("A[2, 1:3][1]"),
120+
Symbol("A[2, 1:3][2]"),
121+
Symbol("A[2, 1:3][3]"),
122+
Symbol("A[1, 1:3][1]"),
123+
Symbol("A[1, 1:3][2]"),
124+
Symbol("A[1, 1:3][3]"),
125+
]
94126
end

0 commit comments

Comments
 (0)