Skip to content

Commit 54691bf

Browse files
Allow empty subsets of VarInfos (#692)
* Allow empty subsets of VarInfos * Run JuliaFormatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 27ba772 commit 54691bf

File tree

3 files changed

+23
-17
lines changed

3 files changed

+23
-17
lines changed

src/simple_varinfo.jl

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -439,22 +439,17 @@ function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName})
439439
return Accessors.@set varinfo.values = _subset(varinfo.values, vns)
440440
end
441441

442-
function _subset(x::AbstractDict, vns)
442+
function _subset(x::AbstractDict, vns::AbstractVector{VN}) where {VN<:VarName}
443443
vns_present = collect(keys(x))
444-
vns_found = mapreduce(vcat, vns) do vn
444+
vns_found = mapreduce(vcat, vns; init=VN[]) do vn
445445
return filter(Base.Fix1(subsumes, vn), vns_present)
446446
end
447-
448-
# NOTE: This `vns` to be subsume varnames explicitly present in `x`.
447+
C = ConstructionBase.constructorof(typeof(x))
449448
if isempty(vns_found)
450-
throw(
451-
ArgumentError(
452-
"Cannot subset `AbstractDict` with `VarName` which does not subsume any keys.",
453-
),
454-
)
449+
return C()
450+
else
451+
return C(vn => x[vn] for vn in vns_found)
455452
end
456-
C = ConstructionBase.constructorof(typeof(x))
457-
return C(vn => x[vn] for vn in vns_found)
458453
end
459454

460455
function _subset(x::NamedTuple, vns)

src/varinfo.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -368,20 +368,24 @@ function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName})
368368
)
369369
end
370370

371-
function subset(metadata::Metadata, vns_given::AbstractVector{<:VarName})
371+
function subset(metadata::Metadata, vns_given::AbstractVector{VN}) where {VN<:VarName}
372372
# TODO: Should we error if `vns` contains a variable that is not in `metadata`?
373373
# For each `vn` in `vns`, get the variables subsumed by `vn`.
374-
vns = mapreduce(vcat, vns_given) do vn
374+
vns = mapreduce(vcat, vns_given; init=VN[]) do vn
375375
filter(Base.Fix1(subsumes, vn), metadata.vns)
376376
end
377377
indices_for_vns = map(Base.Fix1(getindex, metadata.idcs), vns)
378-
indices = Dict(vn => i for (i, vn) in enumerate(vns))
378+
indices = if isempty(vns)
379+
Dict{VarName,Int}()
380+
else
381+
Dict(vn => i for (i, vn) in enumerate(vns))
382+
end
379383
# Construct new `vals` and `ranges`.
380384
vals_original = metadata.vals
381385
ranges_original = metadata.ranges
382386
# Allocate the new `vals`. and `ranges`.
383-
vals = similar(metadata.vals, sum(length, ranges_original[indices_for_vns]))
384-
ranges = similar(ranges_original)
387+
vals = similar(metadata.vals, sum(length, ranges_original[indices_for_vns]; init=0))
388+
ranges = similar(ranges_original, length(vns))
385389
# The new range `r` for `vns[i]` is offset by `offset` and
386390
# has the same length as the original range `r_original`.
387391
# The new `indices` (from above) ensures ordering according to `vns`.
@@ -415,7 +419,7 @@ function subset(metadata::Metadata, vns_given::AbstractVector{<:VarName})
415419
ranges,
416420
vals,
417421
metadata.dists[indices_for_vns],
418-
metadata.gids,
422+
metadata.gids[indices_for_vns],
419423
metadata.orders[indices_for_vns],
420424
flags,
421425
)

test/varinfo.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,13 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
578578
else
579579
vns_supported_standard
580580
end
581+
582+
@testset ("$(convert(Vector{VarName}, vns_subset)) empty") for vns_subset in
583+
vns_supported
584+
varinfo_subset = subset(varinfo, VarName[])
585+
@test isempty(varinfo_subset)
586+
end
587+
581588
@testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in
582589
vns_supported
583590
varinfo_subset = subset(varinfo, vns_subset)

0 commit comments

Comments
 (0)