From f687db008da9a5222c17b518d4cdb1151ff2fa61 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 6 Jun 2025 14:42:44 +0100 Subject: [PATCH 1/7] AdvancedPS v0.7 support, work in progress --- Project.toml | 4 ++-- src/mcmc/particle_mcmc.jl | 24 ++++++++++++------------ test/Project.toml | 2 +- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index 6fcb779e3..0cde5c600 100644 --- a/Project.toml +++ b/Project.toml @@ -52,7 +52,7 @@ AbstractMCMC = "5.5" Accessors = "0.1" AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6, 0.7, 0.8" AdvancedMH = "0.8" -AdvancedPS = "0.6.0" +AdvancedPS = "0.7" AdvancedVI = "0.4" BangBang = "0.4.2" Bijectors = "0.14, 0.15" @@ -65,7 +65,7 @@ DynamicHMC = "3.4" DynamicPPL = "0.36.3" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3, 1" -Libtask = "0.8.8" +Libtask = "0.9.1" LinearAlgebra = "1" LogDensityProblems = "2" MCMCChains = "5, 6, 7" diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index ac5cd7648..2af01e08f 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -25,9 +25,8 @@ function TracedModel( "Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.", ) end - return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}( - model, sampler, varinfo, (model.f, args...) - ) + evaluator = (model.f, args...) + return TracedModel(model, sampler, varinfo, evaluator) end function AdvancedPS.advance!( @@ -71,8 +70,10 @@ function AdvancedPS.update_rng!( return trace end -function Libtask.TapedTask(model::TracedModel, ::Random.AbstractRNG, args...; kwargs...) # RNG ? - return Libtask.TapedTask(model.evaluator[1], model.evaluator[2:end]...; kwargs...) +function Libtask.TapedTask(taped_globals, model::TracedModel, args...; kwargs...) # RNG ? + return Libtask.TapedTask( + taped_globals, model.evaluator[1], model.evaluator[2:end]...; kwargs... + ) end abstract type ParticleInference <: InferenceAlgorithm end @@ -402,11 +403,11 @@ end function trace_local_varinfo_maybe(varinfo) try - trace = AdvancedPS.current_trace() - return trace.model.f.varinfo + trace = Libtask.get_taped_globals(Any).other + return (trace === nothing ? varinfo : trace.model.f.varinfo)::AbstractVarInfo catch e # NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`. - if e == KeyError(:__trace) || current_task().storage isa Nothing + if e == KeyError(:task_variable) return varinfo else rethrow(e) @@ -416,11 +417,10 @@ end function trace_local_rng_maybe(rng::Random.AbstractRNG) try - trace = AdvancedPS.current_trace() - return trace.rng + return Libtask.get_taped_globals(Any).rng catch e # NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`. - if e == KeyError(:__trace) || current_task().storage isa Nothing + if e == KeyError(:task_variable) return rng else rethrow(e) @@ -485,6 +485,6 @@ function AdvancedPS.Trace( tmodel = TracedModel(model, sampler, newvarinfo, rng) newtrace = AdvancedPS.Trace(tmodel, rng) - AdvancedPS.addreference!(newtrace.model.ctask.task, newtrace) + AdvancedPS.addreference!(newtrace.model.ctask, newtrace) return newtrace end diff --git a/test/Project.toml b/test/Project.toml index 7cab77a01..303a5453e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -43,7 +43,7 @@ ADTypes = "1" AbstractMCMC = "5" AbstractPPL = "0.9, 0.10, 0.11" AdvancedMH = "0.6, 0.7, 0.8" -AdvancedPS = "=0.6.0" +AdvancedPS = "0.7" AdvancedVI = "0.4" Aqua = "0.8" BangBang = "0.4" From 2366bfa5d11017745f639eb4aefa71c586ff1337 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 19 Jun 2025 15:34:34 +0100 Subject: [PATCH 2/7] Fixing particle_mcmc.jl --- Project.toml | 2 +- src/mcmc/particle_mcmc.jl | 36 ++++++++++++++++++++++-------------- test/mcmc/particle_mcmc.jl | 1 + 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index 0cde5c600..511797cf0 100644 --- a/Project.toml +++ b/Project.toml @@ -65,7 +65,7 @@ DynamicHMC = "3.4" DynamicPPL = "0.36.3" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3, 1" -Libtask = "0.9.1" +Libtask = "0.9.2" LinearAlgebra = "1" LogDensityProblems = "2" MCMCChains = "5, 6, 7" diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 2af01e08f..782595909 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -58,19 +58,7 @@ function AdvancedPS.reset_logprob!(trace::TracedModel) return trace end -function AdvancedPS.update_rng!( - trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}} -) - # Extract the `args`. - args = trace.model.ctask.args - # From `args`, extract the `SamplingContext`, which contains the RNG. - sampling_context = args[3] - rng = sampling_context.rng - trace.rng = rng - return trace -end - -function Libtask.TapedTask(taped_globals, model::TracedModel, args...; kwargs...) # RNG ? +function Libtask.TapedTask(taped_globals::Any, model::TracedModel, args...; kwargs...) # RNG ? return Libtask.TapedTask( taped_globals, model.evaluator[1], model.evaluator[2:end]...; kwargs... ) @@ -485,6 +473,26 @@ function AdvancedPS.Trace( tmodel = TracedModel(model, sampler, newvarinfo, rng) newtrace = AdvancedPS.Trace(tmodel, rng) - AdvancedPS.addreference!(newtrace.model.ctask, newtrace) + AdvancedPS.addreference!(newtrace) return newtrace end + +# We need to tell Libtask which calls may have `produce` calls within them. In practice most +# of these won't be needed, because of inline and the fact that `might_produce` is only +# called on `:invoke` expressions rather than `:call`s, but since those are implementation +# details of the compiler we define a bunch of these here, starting with +# `acclogp_observe!!` which is what calls `produce`, and going up the call stack. +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.acclogp_observe!!),Vararg}}) = true +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true +Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true +function Libtask.might_produce( + ::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadsafe!!),Vararg}} +) + return true +end +function Libtask.might_produce( + ::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadunsafe!!),Vararg}} +) + return true +end +Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true diff --git a/test/mcmc/particle_mcmc.jl b/test/mcmc/particle_mcmc.jl index 699ee6854..7a2f5fe1c 100644 --- a/test/mcmc/particle_mcmc.jl +++ b/test/mcmc/particle_mcmc.jl @@ -34,6 +34,7 @@ using Turing tested = sample(normal(), SMC(), 100) + # TODO(mhauru) This needs an explanation for why it fails. # failing test @model function fail_smc() a ~ Normal(4, 5) From b4823d968babe29060562743a9fc6be834cd5cce Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 20 Jun 2025 11:03:50 +0100 Subject: [PATCH 3/7] Remove use of AdvancedPS.addreference! --- src/mcmc/particle_mcmc.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 782595909..c1f3ab443 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -473,7 +473,6 @@ function AdvancedPS.Trace( tmodel = TracedModel(model, sampler, newvarinfo, rng) newtrace = AdvancedPS.Trace(tmodel, rng) - AdvancedPS.addreference!(newtrace) return newtrace end From d34dd3db9733d3b9ac509c4c7b8eaf8845bb8b4d Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 20 Jun 2025 11:05:22 +0100 Subject: [PATCH 4/7] Improve a comment --- src/mcmc/particle_mcmc.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index c1f3ab443..c6a5fe7ca 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -477,10 +477,10 @@ function AdvancedPS.Trace( end # We need to tell Libtask which calls may have `produce` calls within them. In practice most -# of these won't be needed, because of inline and the fact that `might_produce` is only +# of these won't be needed, because of inlining and the fact that `might_produce` is only # called on `:invoke` expressions rather than `:call`s, but since those are implementation -# details of the compiler we define a bunch of these here, starting with -# `acclogp_observe!!` which is what calls `produce`, and going up the call stack. +# details of the compiler, we set a bunch of methods as might_produce = true. We start with +# `acclogp_observe!!` which is what calls `produce` and go up the call stack. Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.acclogp_observe!!),Vararg}}) = true Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true From 7cf8ee08d6fac64d8b8b8b71edf8f02a868805a1 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Mon, 23 Jun 2025 22:20:24 +0100 Subject: [PATCH 5/7] Update Project.toml (#2598) --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index bfaa0d61a..b0a0652c9 100644 --- a/Project.toml +++ b/Project.toml @@ -85,7 +85,7 @@ Statistics = "1.6" StatsAPI = "1.6" StatsBase = "0.32, 0.33, 0.34" StatsFuns = "0.8, 0.9, 1" -julia = "1.10.2" +julia = "1.10.8" [extras] DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" From 1c6fad9b5e50adc6ad452e157ee04c29582b6bda Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 3 Jul 2025 15:11:13 +0100 Subject: [PATCH 6/7] Fix a bug and a test --- src/mcmc/particle_mcmc.jl | 2 +- test/essential/container.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index c73944b23..a81f436c8 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -58,7 +58,7 @@ function AdvancedPS.reset_logprob!(trace::TracedModel) return trace end -function Libtask.TapedTask(taped_globals::Any, model::TracedModel, args...; kwargs...) # RNG ? +function Libtask.TapedTask(taped_globals, model::TracedModel; kwargs...) return Libtask.TapedTask( taped_globals, model.evaluator[1], model.evaluator[2:end]...; kwargs... ) diff --git a/test/essential/container.jl b/test/essential/container.jl index 1cb790d5a..cbd7a6fe2 100644 --- a/test/essential/container.jl +++ b/test/essential/container.jl @@ -23,8 +23,8 @@ using Turing model = test() trace = AdvancedPS.Trace(model, sampler, vi, AdvancedPS.TracedRNG()) - # Make sure we link the traces - @test haskey(trace.model.ctask.task.storage, :__trace) + # Make sure the backreference from taped_globals to the trace is in place. + @test trace.model.ctask.taped_globals.other === trace res = AdvancedPS.advance!(trace, false) @test DynamicPPL.get_num_produce(trace.model.f.varinfo) == 1 From df31909982b588c504db9f08a7b0b7cf1b4611b1 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 10 Jul 2025 17:26:14 +0100 Subject: [PATCH 7/7] Bump Libtask to 0.9.3 Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index da6060f66..0f232dc90 100644 --- a/Project.toml +++ b/Project.toml @@ -67,7 +67,7 @@ DynamicHMC = "3.4" DynamicPPL = "0.36.3" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3, 1" -Libtask = "0.9.2" +Libtask = "0.9.3" LinearAlgebra = "1" LogDensityProblems = "2" MCMCChains = "5, 6, 7"