@@ -129,14 +129,53 @@ end
129
129
@test entropy (Dirichlet (ones (N))) ≈ - loggamma (N)
130
130
end
131
131
132
- @testset " Dirichlet differentiation" begin
133
- for n in (2 , 10 )
134
- alpha = rand (n)
135
- Δalpha = randn (n)
136
- d2, ∂d = ChainRulesCore. frule ((nothing , Δalpha), Dirichlet, alpha)
137
- ChainRulesTestUtils. test_frule (Dirichlet ⊢ ChainRulesCore. NoTangent (), alpha ⊢ Δalpha, check_inferred= true )
138
-
139
- _, dp = ChainRulesCore. rrule (Dirichlet, alpha)
140
- ChainRulesTestUtils. test_rrule (Dirichlet{Float64} ⊢ ChainRulesCore. NoTangent (), alpha)
132
+ @testset " Dirichlet differentiation $n " for n in (2 , 10 )
133
+ alpha = rand (n)
134
+ Δalpha = randn (n)
135
+ d, ∂d = ChainRulesCore. frule ((nothing , Δalpha), Dirichlet, alpha)
136
+ ChainRulesTestUtils. test_frule (Dirichlet ⊢ ChainRulesCore. NoTangent (), alpha ⊢ Δalpha)
137
+ _, dp = ChainRulesCore. rrule (Dirichlet, alpha)
138
+ ChainRulesTestUtils. test_rrule (Dirichlet{Float64} ⊢ ChainRulesCore. NoTangent (), alpha)
139
+ x = rand (n)
140
+ x ./= sum (x)
141
+ Δx = 0.05 * rand (n)
142
+ Δx .- = mean (Δx)
143
+ # such that x ∈ Δ, x + Δx ∈ Δ
144
+ ChainRulesTestUtils. test_frule (Distributions. _logpdf ⊢ ChainRulesCore. NoTangent (), d, x ⊢ Δx)
145
+ @testset " finite diff f/r-rule logpdf" begin
146
+ for _ in 1 : 10
147
+ x = rand (n)
148
+ x ./= sum (x)
149
+ Δx = 0.005 * rand (n)
150
+ Δx .- = mean (Δx)
151
+ if insupport (d, x + Δx) && insupport (d, x - Δx)
152
+ y, pullback = ChainRulesCore. rrule (Distributions. _logpdf, d, x)
153
+ yf, Δy = ChainRulesCore. frule (
154
+ (
155
+ ChainRulesCore. NoTangent (),
156
+ map (zero, ChainRulesTestUtils. rand_tangent (d)),
157
+ Δx,
158
+ ),
159
+ Distributions. _logpdf,
160
+ d, x,
161
+ )
162
+ y2 = Distributions. _logpdf (d, x + Δx)
163
+ y1 = Distributions. _logpdf (d, x - Δx)
164
+ @test isfinite (y)
165
+ @test y == yf
166
+ @test Δy ≈ y2 - y atol= 5e-3
167
+ _, ∂d, ∂x = pullback (1.0 )
168
+ @test y2 - y1 ≈ dot (2 Δx, ∂x) atol= 5e-3 rtol= 1e-6
169
+ # mutating alpha only to compute a new y, changing only this term and not the others in Dirichlet
170
+ Δalpha = 0.03 * rand (n)
171
+ Δalpha .- = mean (Δalpha)
172
+ @assert all (>= (0 ), alpha + Δalpha)
173
+ d. alpha .+ = Δalpha
174
+ ya = Distributions. _logpdf (d, x)
175
+ # resetting alpha
176
+ d. alpha .- = Δalpha
177
+ @test ya - y ≈ dot (Δalpha, ∂d. alpha) atol= 5e-5 rtol= 1e-6
178
+ end
179
+ end
141
180
end
142
181
end
0 commit comments