Skip to content

Commit b0818b9

Browse files
authored
Add Aqua tests (#775)
* Add Aqua tests * Fix logpdf(::NamedDist) method ambiguity * Fix SimpleVarInfo method ambiguity * Fix VarInfo method ambiguity * Add InteractiveUtils compat entry See: https://discourse.julialang.org/t/psa-compat-requirements-in-the-general-registry-are-changing/104958 * Add Random.AbstractRNG type annotation * Remove unneeded getsym method * Fix (newly introduced 😅) ConditionContext method ambiguity * Fix unwrap_right_left_vns method ambiguity * KernelAbstractions is a weakdep not a dep * Fix StaticTransformation / ThreadSafeVarInfo link/invlink ambiguity * Fix more RNGs * Don't run Aqua tests on CI min versions * Fix ternary in GitHub Actions expression
1 parent 4bc43a4 commit b0818b9

File tree

14 files changed

+61
-13
lines changed

14 files changed

+61
-13
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ jobs:
7070
env:
7171
GROUP: ${{ matrix.test_group }}
7272
JULIA_NUM_THREADS: ${{ matrix.runner.num_threads }}
73+
# Only run Aqua tests on latest version
74+
AQUA: ${{ matrix.runner.version == '1' && 'true' || 'false' }}
7375

7476
- uses: julia-actions/julia-processcoverage@v1
7577

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1616
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1717
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1818
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
19-
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
2019
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2120
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
2221
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
@@ -30,6 +29,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3029
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
3130
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3231
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
32+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
3333
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
3434
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
3535

@@ -56,6 +56,7 @@ Distributions = "0.25"
5656
DocStringExtensions = "0.9"
5757
EnzymeCore = "0.6 - 0.8"
5858
ForwardDiff = "0.10.12"
59+
InteractiveUtils = "1"
5960
JET = "0.9"
6061
KernelAbstractions = "0.9.33"
6162
LinearAlgebra = "1.6"

benchmarks/src/DynamicPPLBenchmarks.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using LogDensityProblems: LogDensityProblems
99
using ForwardDiff: ForwardDiff
1010
using Mooncake: Mooncake
1111
using ReverseDiff: ReverseDiff
12+
using StableRNGs: StableRNG
1213

1314
include("./Models.jl")
1415
using .Models: Models
@@ -61,18 +62,20 @@ The AD backend should be specified as a Symbol (e.g. `:forwarddiff`, `:reversedi
6162
`islinked` determines whether to link the VarInfo for evaluation.
6263
"""
6364
function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::Bool)
65+
rng = StableRNG(23)
66+
6467
suite = BenchmarkGroup()
6568

6669
vi = if varinfo_choice == :untyped
6770
vi = VarInfo()
68-
model(vi)
71+
model(rng, vi)
6972
vi
7073
elseif varinfo_choice == :typed
71-
VarInfo(model)
74+
VarInfo(rng, model)
7275
elseif varinfo_choice == :simple_namedtuple
73-
SimpleVarInfo{Float64}(model())
76+
SimpleVarInfo{Float64}(model(rng))
7477
elseif varinfo_choice == :simple_dict
75-
retvals = model()
78+
retvals = model(rng)
7679
vns = [VarName{k}() for k in keys(retvals)]
7780
SimpleVarInfo{Float64}(Dict(zip(vns, values(retvals))))
7881
else

src/compiler.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,10 @@ x[1][3]
250250
```
251251
"""
252252
unwrap_right_left_vns(right, left, vns) = right, left, vns
253-
function unwrap_right_left_vns(right::NamedDist, left, vns)
253+
function unwrap_right_left_vns(right::NamedDist, left::AbstractArray, ::VarName)
254+
return unwrap_right_left_vns(right.dist, left, right.name)
255+
end
256+
function unwrap_right_left_vns(right::NamedDist, left::AbstractMatrix, ::VarName)
254257
return unwrap_right_left_vns(right.dist, left, right.name)
255258
end
256259
function unwrap_right_left_vns(

src/context_implementations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ function tilde_observe!!(context, right, left, vi)
195195
return left, acclogp_observe!!(context, vi, logp)
196196
end
197197

198-
function assume(rng, spl::Sampler, dist)
198+
function assume(rng::Random.AbstractRNG, spl::Sampler, dist)
199199
return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))")
200200
end
201201

src/contexts.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,8 @@ function ConditionContext(values::Union{NamedTuple,AbstractDict})
335335
end
336336
# Optimisation when there are no values to condition on
337337
ConditionContext(::NamedTuple{()}, context::AbstractContext) = context
338+
# Same as above, and avoids method ambiguity with below
339+
ConditionContext(::NamedTuple{()}, context::NamedConditionContext) = context
338340
# Collapse consecutive levels of `ConditionContext`. Note that this overrides
339341
# values inside the child context, thus giving precedence to the outermost
340342
# `ConditionContext`.

src/distribution_wrappers.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ Base.length(dist::NamedDist) = Base.length(dist.dist)
1717
Base.size(dist::NamedDist) = Base.size(dist.dist)
1818

1919
Distributions.logpdf(dist::NamedDist, x::Real) = Distributions.logpdf(dist.dist, x)
20+
function Distributions.logpdf(dist::NamedDist, x::AbstractArray{<:Real,0})
21+
# extract the singleton value from 0-dimensional array
22+
return Distributions.logpdf(dist.dist, first(x))
23+
end
2024
function Distributions.logpdf(dist::NamedDist, x::AbstractArray{<:Real})
2125
return Distributions.logpdf(dist.dist, x)
2226
end

src/simple_varinfo.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,14 @@ function SimpleVarInfo(; kwargs...)
232232
end
233233

234234
# Constructor from `Model`.
235-
SimpleVarInfo(model::Model, args...) = SimpleVarInfo{Float64}(model, args...)
236-
function SimpleVarInfo{T}(model::Model, args...) where {T<:Real}
235+
function SimpleVarInfo(
236+
model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}...
237+
)
238+
return SimpleVarInfo{Float64}(model, args...)
239+
end
240+
function SimpleVarInfo{T}(
241+
model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}...
242+
) where {T<:Real}
237243
return last(evaluate!!(model, SimpleVarInfo{T}(), args...))
238244
end
239245

src/threadsafe.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,19 @@ function invlink(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
115115
return invlink!!(t, deepcopy(vi), model)
116116
end
117117

118+
# These two StaticTransformation methods needed to resolve ambiguities
119+
function link!!(
120+
t::StaticTransformation{<:Bijectors.Transform}, vi::ThreadSafeVarInfo, model::Model
121+
)
122+
return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, model)
123+
end
124+
125+
function invlink!!(
126+
t::StaticTransformation{<:Bijectors.Transform}, vi::ThreadSafeVarInfo, model::Model
127+
)
128+
return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, model)
129+
end
130+
118131
function maybe_invlink_before_eval!!(vi::ThreadSafeVarInfo, model::Model)
119132
# Defer to the wrapped `AbstractVarInfo` object.
120133
# NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the

src/varinfo.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,11 @@ function VarInfo(
200200
)
201201
return typed_varinfo(model, SamplingContext(rng, sampler, context), metadata)
202202
end
203-
VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...)
203+
function VarInfo(
204+
model::Model, args::Union{AbstractSampler,AbstractContext,Metadata,VarNamedVector}...
205+
)
206+
return VarInfo(Random.default_rng(), model, args...)
207+
end
204208

205209
"""
206210
vector_length(varinfo::VarInfo)

0 commit comments

Comments
 (0)