Skip to content

Commit 711a298

Browse files
authored
Better support for AbstractString (#397)
* Support `AbstractString` * Add tests * Bump version
1 parent 8c3f5ed commit 711a298

File tree

4 files changed

+67
-33
lines changed

4 files changed

+67
-33
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
33
keywords = ["markov chain monte carlo", "probablistic programming"]
44
license = "MIT"
55
desc = "Chain types and utility functions for MCMC simulations."
6-
version = "5.6.1"
6+
version = "5.7.0"
77

88
[deps]
99
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/chains.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ function Chains(
8080
end
8181

8282
"""
83-
Chains(c::Chains, section::Union{Symbol,String})
83+
Chains(c::Chains, section::Union{Symbol,AbstractString})
8484
Chains(c::Chains, sections)
8585
8686
Return a new chain with only a specific `section` or multiple `sections` pulled out.
@@ -101,7 +101,7 @@ julia> names(chn2)
101101
:a
102102
```
103103
"""
104-
Chains(c::Chains, section::Union{Symbol,String}) = Chains(c, (section,))
104+
Chains(c::Chains, section::Union{Symbol,AbstractString}) = Chains(c, (section,))
105105
function Chains(chn::Chains, sections)
106106
# Make sure the sections exist first.
107107
all(haskey(chn.name_map, Symbol(x)) for x in sections) ||
@@ -121,7 +121,7 @@ Chains(chain::Chains, ::Nothing) = chain
121121
# Groups of parameters
122122

123123
"""
124-
namesingroup(chains::Chains, sym::Symbol; index_type::Symbol=:bracket)
124+
namesingroup(chains::Chains, sym::Union{AbstractString,Symbol}; index_type::Symbol=:bracket)
125125
126126
Return the parameters with the same name `sym`, but have a different index. Bracket indexing format
127127
in the form of `:sym[index]` is assumed by default. Use `index_type=:dot` for parameters with dot
@@ -147,7 +147,7 @@ julia> namesingroup(chn, :A; index_type=:dot)
147147
Symbol("A.2")
148148
```
149149
"""
150-
namesingroup(chains::Chains, sym::String; kwargs...) = namesingroup(chains, Symbol(sym); kwargs...)
150+
namesingroup(chains::Chains, sym::AbstractString; kwargs...) = namesingroup(chains, Symbol(sym); kwargs...)
151151
function namesingroup(chains::Chains, sym::Symbol; index_type::Symbol=:bracket)
152152
if index_type !== :bracket && index_type !== :dot
153153
error("index_type must be :bracket or :dot")
@@ -161,14 +161,14 @@ function namesingroup(chains::Chains, sym::Symbol; index_type::Symbol=:bracket)
161161
end
162162

163163
"""
164-
group(chains::Chains, name::Union{String,Symbol}; index_type::Symbol=:bracket)
164+
group(chains::Chains, name::Union{AbstractString,Symbol}; index_type::Symbol=:bracket)
165165
166166
Return a subset of the chain containing parameters with the same `name`, but a different index.
167167
168168
Bracket indexing format in the form of `:name[index]` is assumed by default. Use `index_type=:dot` for parameters with dot
169169
indexing, i.e. `:sym.index`.
170170
"""
171-
function group(chains::Chains, name::Union{String,Symbol}; kwargs...)
171+
function group(chains::Chains, name::Union{AbstractString,Symbol}; kwargs...)
172172
return chains[:, namesingroup(chains, name; kwargs...), :]
173173
end
174174

@@ -177,8 +177,8 @@ end
177177
Base.getindex(c::Chains, i::Integer) = c[i, :, :]
178178
Base.getindex(c::Chains, i::AbstractVector{<:Integer}) = c[i, :, :]
179179

180-
Base.getindex(c::Chains, v::String) = c[:, Symbol(v), :]
181-
Base.getindex(c::Chains, v::AbstractVector{String}) = c[:, Symbol.(v), :]
180+
Base.getindex(c::Chains, v::AbstractString) = c[:, Symbol(v), :]
181+
Base.getindex(c::Chains, v::AbstractVector{<:AbstractString}) = c[:, Symbol.(v), :]
182182

183183
Base.getindex(c::Chains, v::Symbol) = c[:, v, :]
184184
Base.getindex(c::Chains, v::AbstractVector{Symbol}) = c[:, v, :]
@@ -199,7 +199,7 @@ _toindex(i, j, k::Integer) = (i, string2symbol(j), k:k)
199199
_toindex(i::Integer, j, k::Integer) = (i:i, string2symbol(j), k:k)
200200

201201
# return an array or a number if a single parameter is specified
202-
const SingleIndex = Union{Symbol,String,Integer}
202+
const SingleIndex = Union{Symbol,AbstractString,Integer}
203203
_toindex(i, j::SingleIndex, k) = (i, string2symbol(j), k)
204204
_toindex(i::Integer, j::SingleIndex, k) = (i, string2symbol(j), k)
205205
_toindex(i, j::SingleIndex, k::Integer) = (i, string2symbol(j), k)
@@ -542,7 +542,7 @@ Return multiple `Chains` objects, each containing only a single section.
542542
function get_sections(chains::Chains, sections = keys(chains.name_map))
543543
return [Chains(chains, section) for section in sections]
544544
end
545-
get_sections(chains::Chains, section::Union{Symbol, String}) = Chains(chains, section)
545+
get_sections(chains::Chains, section::Union{Symbol, AbstractString}) = Chains(chains, section)
546546

547547
"""
548548
sections(c::Chains)
@@ -727,7 +727,7 @@ function _clean_sections(chains::Chains, sections)
727727
haskey(chains.name_map, Symbol(section))
728728
end
729729
end
730-
function _clean_sections(chains::Chains, section::Union{String,Symbol})
730+
function _clean_sections(chains::Chains, section::Union{AbstractString,Symbol})
731731
return haskey(chains.name_map, Symbol(section)) ? section : ()
732732
end
733733
_clean_sections(::Chains, ::Nothing) = nothing

src/utils.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@ end
2727
2828
Convert strings to symbols.
2929
30-
If `x isa String`, the corresponding `Symbol` is returned. Likewise, if
31-
`x isa AbstractVector{String}`, the corresponding vector of `Symbol`s is returned. In all
32-
other cases, input `x` is returned.
30+
If `x isa AbstractString`, the corresponding `Symbol` is returned.
31+
Likewise, if `x isa AbstractVector{<:AbstractString}`, the corresponding vector of `Symbol`s is returned.
32+
In all other cases, input `x` is returned.
3333
"""
3434
string2symbol(x) = x
35-
string2symbol(x::String) = Symbol(x)
36-
string2symbol(x::AbstractVector{String}) = Symbol.(x)
35+
string2symbol(x::AbstractString) = Symbol(x)
36+
string2symbol(x::AbstractVector{<:AbstractString}) = Symbol.(x)
3737

3838
#################### Mathematical Operators ####################
3939
function cummean(x::AbstractArray)

test/diagnostic_tests.jl

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,42 @@ end
101101
end
102102

103103
@testset "indexing tests" begin
104-
@test chn[:,1,:] isa AbstractMatrix
105-
@test chn[200:300, "param_1", :] isa AbstractMatrix
106-
@test chn[200:300, ["param_1", "param_3"], :] isa Chains
107-
@test chn[200:300, "param_1", 1] isa AbstractVector
108-
@test size(chn[:,1,:]) == (niter, nchains)
109-
@test chn[:,1,1] == val[:,1,1]
110-
@test chn[:,1,2] == val[:,1,2]
104+
c = chn[:, 1, :]
105+
@test c isa AbstractMatrix
106+
@test size(c) == (niter, nchains)
107+
@test c == val[:, 1, :]
108+
109+
for i in 1:2
110+
c = chn[:, 1, i]
111+
@test c isa AbstractVector
112+
@test length(c) == niter
113+
@test c == val[:, 1, i]
114+
end
115+
116+
for p in (:param_1, "param_1", SubString("param_1", 1))
117+
c = chn[200:300, p, :]
118+
@test c isa AbstractMatrix
119+
@test size(c) == (101, size(chn, 3))
120+
@test c == val[200:300, 1, :]
121+
122+
c = chn[200:300, p, 1]
123+
@test c isa AbstractVector
124+
@test length(c) == 101
125+
@test c == val[200:300, 1, 1]
126+
end
127+
128+
for ps in (
129+
[:param_1, :param_3],
130+
["param_1", "param_3"],
131+
[SubString("param_1", 1), "param_3"],
132+
["param_1", SubString("param_3", 1)],
133+
[SubString("param_1", 1), SubString("param_3", 1)],
134+
)
135+
c = chn[200:300, ps, :]
136+
@test c isa Chains
137+
@test size(c) == (101, 2, nchains)
138+
@test c.value.data == val[200:300, [1, 3], :]
139+
end
111140
end
112141

113142
@testset "names and groups tests" begin
@@ -116,18 +145,23 @@ end
116145
(@inferred replacenames(chn, Dict("param_2" => "param[2]",
117146
"param_3" => "param[3]"))).value
118147
@test names(chn2) == [:param_1, Symbol("param[2]"), Symbol("param[3]"), :param_4]
119-
@test namesingroup(chn2, "param") == Symbol.(["param[2]", "param[3]"])
148+
for p in (:param, "param", SubString("param", 1))
149+
@test namesingroup(chn2, p) == Symbol.(["param[2]", "param[3]"])
150+
end
120151

121-
chn3 = group(chn2, "param")
122-
@test names(chn3) == Symbol.(["param[2]", "param[3]"])
123-
@test chn3.value == chn[:, [:param_2, :param_3], :].value
152+
for p in (:param, "param", SubString("param", 1))
153+
chn3 = group(chn2, p)
154+
@test names(chn3) == Symbol.(["param[2]", "param[3]"])
155+
@test chn3.value == chn[:, [:param_2, :param_3], :].value
156+
end
124157

125158
stan_chn = Chains(rand(100, 3, 1), ["a.1", "a[2]", "b"])
126-
@test namesingroup(stan_chn, "a"; index_type=:dot) == [Symbol("a.1")]
127-
@test namesingroup(stan_chn, :a; index_type=:dot) == [Symbol("a.1")]
128-
@test names(group(stan_chn, :a; index_type=:dot)) == [Symbol("a.1")]
129-
@test_throws Exception namesingroup(stan_chn, :a; index_type=:x)
130-
@test_throws Exception group(stan_chn, :a; index_type=:x)
159+
for p in (:a, "a", SubString("a", 1))
160+
@test namesingroup(stan_chn, p; index_type=:dot) == [Symbol("a.1")]
161+
@test names(group(stan_chn, p; index_type=:dot)) == [Symbol("a.1")]
162+
@test_throws Exception namesingroup(stan_chn, p; index_type=:x)
163+
@test_throws Exception group(stan_chn, p; index_type=:x)
164+
end
131165
end
132166

133167
@testset "function tests" begin

0 commit comments

Comments
 (0)