Skip to content

Commit 519493b

Browse files
committed
DistributionsChainRulesCoreExt: make rrule/frule take ChainRulesCore.RuleConfig
This is what was happening implicitly already: https://github.com/JuliaDiff/ChainRulesCore.jl/blob/9627bd6a949ca88ba867bf4fa177a45bf780a248/src/rules.jl#L137-L138 ``` unning tests: 1 ambiguities found. To get a list, set `broken = false`. Ambiguity #1 frule(::ChainRulesCore.RuleConfig, args...) in ChainRulesCore at /home/runner/.julia/packages/ChainRulesCore/I1EbV/src/rules.jl:64 frule(::Any, ::typeof(Distributions.logpdf), d::Distributions.Uniform, x::Real) in Distributions.DistributionsChainRulesCoreExt at /home/runner/work/Distributions.jl/Distributions.jl/ext/DistributionsChainRulesCoreExt/univariate/continuous/uniform.jl:1 Possible fix, define frule(::ChainRulesCore.RuleConfig, ::typeof(Distributions.logpdf), ::Distributions.Uniform, ::Real) Aqua: Test Failed at /home/runner/.julia/packages/Aqua/tHrmY/src/ambiguities.jl:78 Expression: iszero(num_ambiguities) Stacktrace: [1] _test_ambiguities(packages::Vector{Base.PkgId}; broken::Bool, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}) @ Aqua ~/.julia/packages/Aqua/tHrmY/src/ambiguities.jl:78 [2] _test_ambiguities @ ~/.julia/packages/Aqua/tHrmY/src/ambiguities.jl:69 [inlined] [3] test_ambiguities(packages::Module; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}) @ Aqua ~/.julia/packages/Aqua/tHrmY/src/ambiguities.jl:28 [4] test_ambiguities(packages::Module) @ Aqua ~/.julia/packages/Aqua/tHrmY/src/ambiguities.jl:28 [5] macro expansion @ ~/work/Distributions.jl/Distributions.jl/test/aqua.jl:19 [inlined] [6] macro expansion @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined] [7] top-level scope @ ~/work/Distributions.jl/Distributions.jl/test/aqua.jl:9 ```
1 parent 133a93e commit 519493b

File tree

5 files changed

+10
-10
lines changed

5 files changed

+10
-10
lines changed

ext/DistributionsChainRulesCoreExt/eachvariate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function ChainRulesCore.rrule(::Type{Distributions.EachVariate{V}}, x::AbstractArray{<:Real}) where {V}
1+
function ChainRulesCore.rrule(::ChainRulesCore.RuleConfig, ::Type{Distributions.EachVariate{V}}, x::AbstractArray{<:Real}) where {V}
22
y = Distributions.EachVariate{V}(x)
33
size_x = size(x)
44
function EachVariate_pullback(Δ)

ext/DistributionsChainRulesCoreExt/multivariate/dirichlet.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function ChainRulesCore.frule((_, Δalpha)::Tuple{Any,Any}, ::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}}
1+
function ChainRulesCore.frule(::ChainRulesCore.RuleConfig, (_, Δalpha)::Tuple{Any,Any}, ::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}}
22
d = DT(alpha; check_args=check_args)
33
∂alpha0 = sum(Δalpha)
44
digamma_alpha0 = SpecialFunctions.digamma(d.alpha0)
@@ -9,7 +9,7 @@ function ChainRulesCore.frule((_, Δalpha)::Tuple{Any,Any}, ::Type{DT}, alpha::A
99
return d, Δd
1010
end
1111

12-
function ChainRulesCore.rrule(::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}}
12+
function ChainRulesCore.rrule(::ChainRulesCore.RuleConfig, ::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}}
1313
d = DT(alpha; check_args=check_args)
1414
digamma_alpha0 = SpecialFunctions.digamma(d.alpha0)
1515
function Dirichlet_pullback(_Δd)
@@ -20,7 +20,7 @@ function ChainRulesCore.rrule(::Type{DT}, alpha::AbstractVector{T}; check_args::
2020
return d, Dirichlet_pullback
2121
end
2222

23-
function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(Distributions._logpdf), d::Dirichlet, x::AbstractVector{<:Real})
23+
function ChainRulesCore.frule(::ChainRulesCore.RuleConfig, (_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(Distributions._logpdf), d::Dirichlet, x::AbstractVector{<:Real})
2424
Ω = Distributions._logpdf(d, x)
2525
∂alpha = sum(Broadcast.instantiate(Broadcast.broadcasted(Δd.alpha, Δx, d.alpha, x) do Δalphai, Δxi, alphai, xi
2626
StatsFuns.xlogy(Δalphai, xi) + (alphai - 1) * Δxi / xi
@@ -33,7 +33,7 @@ function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(Distri
3333
return Ω, ΔΩ
3434
end
3535

36-
function ChainRulesCore.rrule(::typeof(Distributions._logpdf), d::T, x::AbstractVector{<:Real}) where {T<:Dirichlet}
36+
function ChainRulesCore.rrule(::ChainRulesCore.RuleConfig, ::typeof(Distributions._logpdf), d::T, x::AbstractVector{<:Real}) where {T<:Dirichlet}
3737
Ω = Distributions._logpdf(d, x)
3838
isfinite_Ω = isfinite(Ω)
3939
alpha = d.alpha

ext/DistributionsChainRulesCoreExt/univariate/continuous/uniform.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function ChainRulesCore.frule((_, Δd, _), ::typeof(logpdf), d::Uniform, x::Real)
1+
function ChainRulesCore.frule(::ChainRulesCore.RuleConfig, (_, Δd, _), ::typeof(logpdf), d::Uniform, x::Real)
22
# Compute log probability
33
a, b = params(d)
44
insupport = a <= x <= b
@@ -12,7 +12,7 @@ function ChainRulesCore.frule((_, Δd, _), ::typeof(logpdf), d::Uniform, x::Real
1212
return Ω, ΔΩ
1313
end
1414

15-
function ChainRulesCore.rrule(::typeof(logpdf), d::Uniform, x::Real)
15+
function ChainRulesCore.rrule(::ChainRulesCore.RuleConfig, ::typeof(logpdf), d::Uniform, x::Real)
1616
# Compute log probability
1717
a, b = params(d)
1818
insupport = a <= x <= b

ext/DistributionsChainRulesCoreExt/univariate/discrete/negativebinomial.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ function (f::LogPDFNegativeBinomialPullback{D})(Δ) where {D}
1111
return ChainRulesCore.NoTangent(), Δd, ChainRulesCore.NoTangent()
1212
end
1313

14-
function ChainRulesCore.rrule(::typeof(logpdf), d::NegativeBinomial, k::Real)
14+
function ChainRulesCore.rrule(::ChainRulesCore.RuleConfig, ::typeof(logpdf), d::NegativeBinomial, k::Real)
1515
# Compute log probability (as in the definition of `logpdf(d, k)` above)
1616
r, p = params(d)
1717
z = StatsFuns.xlogy(r, p) + StatsFuns.xlog1py(k, -p)

ext/DistributionsChainRulesCoreExt/univariate/discrete/poissonbinomial.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@ for f in (:poissonbinomial_pdf, :poissonbinomial_pdf_fft)
22
pullback = Symbol(f, :_pullback)
33
@eval begin
44
function ChainRulesCore.frule(
5-
(_, Δp)::Tuple{<:Any,<:AbstractVector{<:Real}}, ::typeof(Distributions.$f), p::AbstractVector{<:Real}
5+
::ChainRulesCore.RuleConfig, (_, Δp)::Tuple{<:Any,<:AbstractVector{<:Real}}, ::typeof(Distributions.$f), p::AbstractVector{<:Real}
66
)
77
y = Distributions.$f(p)
88
A = Distributions.poissonbinomial_pdf_partialderivatives(p)
99
return y, A' * Δp
1010
end
11-
function ChainRulesCore.rrule(::typeof(Distributions.$f), p::AbstractVector{<:Real})
11+
function ChainRulesCore.rrule(::ChainRulesCore.RuleConfig, ::typeof(Distributions.$f), p::AbstractVector{<:Real})
1212
y = Distributions.$f(p)
1313
A = Distributions.poissonbinomial_pdf_partialderivatives(p)
1414
function $pullback(Δy)

0 commit comments

Comments
 (0)