Skip to content

Error creating InferenceData from Turing Chains with extra info #147

Open
@sethaxen

Description

@sethaxen

Currently it seems that if some objects are in the Turing info, we can't map these to the InferenceData info.

using ArviZ, Turing

julia> @model function foo()
    x ~ Normal()
end
foo (generic function with 1 method)

julia> chn = sample(foo(),NUTS(),200); # this is fine

julia> chn.info
NamedTuple()

julia> from_mcmcchains(chn)
InferenceData with groups:
	> posterior
	> sample_stats

julia> chn = sample(foo(),NUTS(),200,;save_state=true) # this will error

julia> chn.info
(model = DynamicPPL.Model{var"#3#4", (), (), (), Tuple{}, Tuple{}}(:foo, var"#3#4"(), NamedTuple(), NamedTuple()), sampler = DynamicPPL.Sampler{NUTS{Turing.Core.ForwardDiffAD{40}, (), AdvancedHMC.DiagEuclideanMetric}}(NUTS{Turing.Core.ForwardDiffAD{40}, (), AdvancedHMC.DiagEuclideanMetric}(-1, 0.65, 10, 1000.0, 0.0), DynamicPPL.Selector(0x00016a8da36513f2, :default, false)), samplerstate = Turing.Inference.HMCState{DynamicPPL.TypedVarInfo{NamedTuple{(:x,), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:x, Tuple{}}, Int64}, Vector{Normal{Float64}}, Vector{AbstractPPL.VarName{:x, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, AdvancedHMC.NUTS{AdvancedHMC.MultinomialTS, AdvancedHMC.GeneralisedNoUTurn, AdvancedHMC.Leapfrog{Float64}, Float64}, AdvancedHMC.Hamiltonian{AdvancedHMC.DiagEuclideanMetric{Float64, Vector{Float64}}, Turing.Inference.var"#logπ#54"{DynamicPPL.TypedVarInfo{NamedTuple{(:x,), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:x, Tuple{}}, Int64}, Vector{Normal{Float64}}, Vector{AbstractPPL.VarName{:x, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, DynamicPPL.Sampler{NUTS{Turing.Core.ForwardDiffAD{40}, (), AdvancedHMC.DiagEuclideanMetric}}, DynamicPPL.Model{var"#3#4", (), (), (), Tuple{}, Tuple{}}}, Turing.Inference.var"#∂logπ∂θ#53"{DynamicPPL.TypedVarInfo{NamedTuple{(:x,), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:x, Tuple{}}, Int64}, Vector{Normal{Float64}}, Vector{AbstractPPL.VarName{:x, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}, DynamicPPL.Sampler{NUTS{Turing.Core.ForwardDiffAD{40}, (), AdvancedHMC.DiagEuclideanMetric}}, DynamicPPL.Model{var"#3#4", (), (), (), Tuple{}, Tuple{}}}}, AdvancedHMC.PhasePoint{Vector{Float64}, AdvancedHMC.DualValue{Float64, Vector{Float64}}}, AdvancedHMC.Adaptation.StanHMCAdaptor{AdvancedHMC.Adaptation.WelfordVar{Float64, Vector{Float64}}, AdvancedHMC.Adaptation.NesterovDualAveraging{Float64}}}(DynamicPPL.TypedVarInfo{NamedTuple{(:x,), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:x, Tuple{}}, Int64}, Vector{Normal{Float64}}, Vector{AbstractPPL.VarName{:x, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}}, Float64}((x = DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:x, Tuple{}}, Int64}, Vector{Normal{Float64}}, Vector{AbstractPPL.VarName{:x, Tuple{}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}(Dict(x => 1), [x], UnitRange{Int64}[1:1], [0.007224315165188178], Normal{Float64}[Normal{Float64}(μ=0.0, σ=1.0)], Set{DynamicPPL.Selector}[Set([DynamicPPL.Selector(0x00016a8da36513f2, :default, false)])], [0], Dict{String, BitVector}("del" => [0], "trans" => [1])),), Base.RefValue{Float64}(-0.9189646285694758), Base.RefValue{Int64}(0)), 299, NUTS{MultinomialTS,Generalised}(integrator=Leapfrog(ϵ=1.43), max_depth=10), Δ_max=1000.0), Hamiltonian(metric=DiagEuclideanMetric([1.0])), AdvancedHMC.PhasePoint{Vector{Float64}, AdvancedHMC.DualValue{Float64, Vector{Float64}}}([0.007224315165188178], [-0.5308342150394731], AdvancedHMC.DualValue{Float64, Vector{Float64}}(-0.9189646285694758, [0.007224315165188178]), AdvancedHMC.DualValue{Float64, Vector{Float64}}(-0.14089248192828682, [-0.5308342150394731])), StanHMCAdaptor(
    pc=WelfordVar,
    ssa=NesterovDualAveraging=0.05, t_0=10.0, κ=0.75, δ=0.65, state.ϵ=1.425166901462951),
    init_buffer=75, term_buffer=50, window_size=25,
    state=window(76, 50), window_splits()
)))

julia> from_mcmcchains(chn)
ERROR: PyError ($(Expr(:escape, :(ccall(#= /Users/sethaxen/.julia/packages/PyCall/L0fLP/src/pyfncall.jl:43 =# @pysym(:PyObject_Call), PyPtr, (PyPtr, PyPtr, PyPtr), o, pyargsptr, kw))))) <class 'TypeError'>
TypeError("cannot pickle 'PyCall.jlwrap' object")
  File "/Users/sethaxen/.julia/conda/3/lib/python3.8/site-packages/arviz/data/inference_data.py", line 1837, in concat
    args_groups[group] = deepcopy(group_data) if copy else group_data
  File "/Users/sethaxen/.julia/conda/3/lib/python3.8/copy.py", line 153, in deepcopy
    y = copier(memo)
  File "/Users/sethaxen/.julia/conda/3/lib/python3.8/site-packages/xarray/core/dataset.py", line 1425, in __deepcopy__
    return self.copy(deep=True)
  File "/Users/sethaxen/.julia/conda/3/lib/python3.8/site-packages/xarray/core/dataset.py", line 1322, in copy
    attrs = copy.deepcopy(self._attrs) if deep else copy.copy(self._attrs)
  File "/Users/sethaxen/.julia/conda/3/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/Users/sethaxen/.julia/conda/3/lib/python3.8/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/Users/sethaxen/.julia/conda/3/lib/python3.8/copy.py", line 161, in deepcopy
    rv = reductor(4)

We should probably filter the info on our end before InferenceData creation so that these errors can't happen.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions