Skip to content

Commit 8ccf877

Browse files
authored
Fixes sorting bug (#183)
* Fixes sorting bug. * Increment patch version. * Minor corrections. * Minor test corrections.
1 parent 4274a98 commit 8ccf877

File tree

4 files changed

+48
-13
lines changed

4 files changed

+48
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
33
keywords = ["markov chain monte carlo", "probablistic programming"]
44
license = "MIT"
55
desc = " Chain types and utility functions for MCMC simulations."
6-
version = "3.0.0"
6+
version = "3.0.1"
77

88
[deps]
99
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/stats.jl

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,19 +159,22 @@ function describe(io::IO,
159159
showall::Bool=false,
160160
sections::Union{Symbol, Vector{Symbol}}=Symbol[:parameters],
161161
digits::Int=4,
162+
sorted=false,
162163
args...
163164
)
164165
dfs = vcat(summarystats(c,
165166
showall=showall,
166167
sections=sections,
167168
etype=etype,
168169
digits=digits,
170+
sorted=sorted,
169171
args...),
170172
quantile(c,
171173
showall=showall,
172174
sections=sections,
173175
q=q,
174-
digits=digits))
176+
digits=digits,
177+
sorted=sorted))
175178
return dfs
176179
end
177180

@@ -223,20 +226,24 @@ function quantile(chn::Chains;
223226
append_chains=true,
224227
showall=false,
225228
sections::Union{Symbol, Vector{Symbol}}=Symbol[:parameters],
226-
digits::Int=4)
229+
digits::Int=4,
230+
sorted=false)
227231
# compute quantiles
228232
funs = Function[]
229233
func_names = String[]
230234
for i in q
231235
push!(funs, x -> quantile(cskip(x), i))
232236
push!(func_names, "$(string(100*i))%")
233237
end
238+
234239
return summarize(chn, funs...;
240+
sections=sections,
235241
func_names=func_names,
236242
showall=showall,
237-
sections=sections,
238-
name = "Quantiles",
239-
digits=digits)
243+
name="Quantiles",
244+
digits=digits,
245+
append_chains=append_chains,
246+
sorted=sorted)
240247
end
241248

242249
"""
@@ -257,7 +264,8 @@ function ess(chn::Chains;
257264
showall=false,
258265
sections::Union{Symbol, Vector{Symbol}}=Symbol[:parameters],
259266
maxlag = 250,
260-
digits::Int=4
267+
digits::Int=4,
268+
sorted=false
261269
)
262270
param = showall ? names(chn) : names(chn, sections)
263271
n_chain_orig = size(chn, 3)
@@ -401,7 +409,7 @@ function summarystats(chn::Chains;
401409
func_names = [:mean, :std, :naive_se, :mcse]
402410

403411
# Caluclate ESS separately.
404-
ess_df = ess(chn, sections=sections, showall=showall)
412+
ess_df = ess(chn, sections=sections, showall=showall, sorted=sorted)
405413

406414
# Summarize.
407415
summary_df = summarize(chn, funs...;

src/summarize.jl

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ function Base.getindex(
128128
)
129129
s = s isa AbstractArray ? s : [s]
130130
ks = ks isa AbstractArray ? ks : [ks]
131-
ind = indexin(s, c.nt[:parameters])
131+
ind = indexin(Symbol.(s), Symbol.(c.nt[:parameters]))
132132

133133
not_found = map(x -> x === nothing, ind)
134134

@@ -269,10 +269,9 @@ function summarize(chn::Chains, funs...;
269269

270270
if additional_df != nothing
271271
if append_chains
272-
ret_df = merge(ret_df, additional_df.nt)
273-
# ret_df = join(ret_df, additional_df, on=:parameters)
272+
ret_df = merge_cdf(ret_df, additional_df.nt)
274273
else
275-
ret_df = [merge(r, additional_df.nt) for r in ret_df]
274+
ret_df = [merge_cdf(r, additional_df.nt) for r in ret_df]
276275
end
277276
end
278277

@@ -288,3 +287,31 @@ function handle_funs(fns)
288287
tmp = [string(f) for f in fns]
289288
Symbol.([split(tmp[i], ".")[end] for i in 1:length(tmp)])
290289
end
290+
291+
"""
292+
Collects the keys of a named tuple and maintains parameter name ordering. Used
293+
when an additional namedtuple is passed to `summarize` to be joined.
294+
"""
295+
function merge_cdf(n1::NamedTuple, n2::NamedTuple)
296+
ks1 = collect(keys(n1))
297+
ks2 = collect(keys(n2))
298+
ks = tuple(unique(vcat(ks1, ks2))...)
299+
if :parameters in ks1 && :parameters in ks2
300+
p1 = Symbol.(n1.parameters)
301+
p2 = Symbol.(n2.parameters)
302+
inds = indexin(p1, p2)
303+
304+
vals = []
305+
for k in ks
306+
if k in ks1
307+
push!(vals, n1[k])
308+
else
309+
push!(vals, n2[k][inds])
310+
end
311+
end
312+
313+
return NamedTuple{ks}(tuple(vals...))
314+
else
315+
return merge(n1, n2)
316+
end
317+
end

test/summarize_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ using Statistics: std
1515
@test names(parm_df) == [:parameters, :mean, :std, :naive_se, :mcse, :ess, :r_hat]
1616

1717
all_sections_df = summarize(chns, sections=[:parameters, :internals])
18-
@test all_sections_df[:,:parameters] == [:a, :b, :c, :d, :e, :f, :g, :h]
18+
@test all_sections_df[:,:parameters] == ["a", "b", "c", "d", "e", "f", "g", "h"]
1919
@test size(all_sections_df) == (8, 7)
2020

2121
two_parms_two_funs_df = summarize(chns[[:a, :b]], mean, std)

0 commit comments

Comments
 (0)