2
2
Base. sum (xs:: AbstractArray , weights:: AbstractArray ) = dot (xs, weights)
3
3
struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end
4
4
5
- @testset " Maps and Reductions" begin
5
+ const CFG = ChainRulesTestUtils. ADviaRuleConfig ()
6
+
7
+ @testset " Reductions" begin
8
+ @testset " sum(::Tuple)" begin
9
+ test_frule (sum, Tuple (rand (5 )))
10
+ end
6
11
@testset " sum(x; dims=$dims )" for dims in (:, 2 , (1 ,3 ))
7
12
# Forward
8
13
test_frule (sum, rand (5 ); fkwargs= (;dims= dims))
@@ -79,12 +84,11 @@ struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end
79
84
test_rrule (sum, inv, transpose (view (x, 1 , :)))
80
85
81
86
# Make sure we preserve type for StaticArrays
82
- ADviaRuleConfig = ChainRulesTestUtils. ADviaRuleConfig
83
- _, pb = rrule (ADviaRuleConfig (), sum, abs, @SVector [1.0 , - 3.0 ])
87
+ _, pb = rrule (CFG, sum, abs, @SVector [1.0 , - 3.0 ])
84
88
@test pb (1.0 ) isa Tuple{NoTangent, NoTangent, SVector{2 , Float64}}
85
89
86
90
# make sure we preserve type for Diagonal
87
- _, pb = rrule (ADviaRuleConfig () , sum, abs, Diagonal ([1.0 , - 3.0 ]))
91
+ _, pb = rrule (CFG , sum, abs, Diagonal ([1.0 , - 3.0 ]))
88
92
@test pb (1.0 )[3 ] isa Diagonal
89
93
90
94
# Boolean -- via @non_differentiable, test that this isn't ambiguous
@@ -173,7 +177,64 @@ struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end
173
177
@test unthunk (rrule (prod, v)[2 ](1f0 )[2 ]) == zeros (4 )
174
178
test_rrule (prod, v)
175
179
end
176
- end # prod
180
+ end # prod
181
+
182
+ @testset " foldl(f, ::Array)" begin
183
+ # Simple
184
+ y1, b1 = rrule (CFG, foldl, * , [1 , 2 , 3 ]; init= 1 )
185
+ @test y1 == 6
186
+ b1 (7 ) == (NoTangent (), NoTangent (), [42 , 21 , 14 ])
187
+
188
+ y2, b2 = rrule (CFG, foldl, * , [1 2 ; 0 4 ]) # without init, needs vcat
189
+ @test y2 == 0
190
+ b2 (8 ) == (NoTangent (), NoTangent (), [0 0 ; 64 0 ]) # matrix, needs reshape
191
+
192
+ # Test execution order
193
+ c5 = Counter ()
194
+ y5, b5 = rrule (CFG, foldl, c5, [5 , 7 , 11 ])
195
+ @test c5 == Counter (2 )
196
+ @test y5 == ((5 + 7 )* 1 + 11 )* 2 == foldl (Counter (), [5 , 7 , 11 ])
197
+ @test b5 (1 ) == (NoTangent (), NoTangent (), [12 * 32 , 12 * 42 , 22 ])
198
+ @test c5 == Counter (42 )
199
+
200
+ c6 = Counter ()
201
+ y6, b6 = rrule (CFG, foldl, c6, [5 , 7 , 11 ], init= 3 )
202
+ @test c6 == Counter (3 )
203
+ @test y6 == (((3 + 5 )* 1 + 7 )* 2 + 11 )* 3 == foldl (Counter (), [5 , 7 , 11 ], init= 3 )
204
+ @test b6 (1 ) == (NoTangent (), NoTangent (), [63 * 33 * 13 , 43 * 13 , 23 ])
205
+ @test c6 == Counter (63 )
206
+
207
+ # Test gradient of function
208
+ y7, b7 = rrule (CFG, foldl, Multiplier (3 ), [5 , 7 , 11 ])
209
+ @test y7 == foldl ((x,y)-> x* y* 3 , [5 , 7 , 11 ])
210
+ @test b7 (1 ) == (NoTangent (), Tangent {Multiplier{Int}} (x = 2310 ,), [693 , 495 , 315 ])
211
+
212
+ y8, b8 = rrule (CFG, foldl, Multiplier (13 ), [5 , 7 , 11 ], init= 3 )
213
+ @test y8 == 2_537_535 == foldl ((x,y)-> x* y* 13 , [5 , 7 , 11 ], init= 3 )
214
+ @test b8 (1 ) == (NoTangent (), Tangent {Multiplier{Int}} (x = 585585 ,), [507507 , 362505 , 230685 ])
215
+ # To find these numbers:
216
+ # ForwardDiff.derivative(z -> foldl((x,y)->x*y*z, [5,7,11], init=3), 13)
217
+ # ForwardDiff.gradient(z -> foldl((x,y)->x*y*13, z, init=3), [5,7,11]) |> string
218
+
219
+ # Finite differencing
220
+ test_rrule (foldl, / , 1 .+ rand (3 ,4 ))
221
+ test_rrule (foldl, * , rand (ComplexF64,3 ,4 ); fkwargs= (; init= rand (ComplexF64)))
222
+ test_rrule (foldl, + , rand (ComplexF64,7 ); fkwargs= (; init= rand (ComplexF64)))
223
+ test_rrule (foldl, max, rand (3 ); fkwargs= (; init= 999 ))
224
+ end
225
+ VERSION >= v " 1.5" && @testset " foldl(f, ::Tuple)" begin
226
+ y1, b1 = rrule (CFG, foldl, * , (1 ,2 ,3 ); init= 1 )
227
+ @test y1 == 6
228
+ b1 (7 ) == (NoTangent (), NoTangent (), Tangent {NTuple{3,Int}} (42 , 21 , 14 ))
229
+
230
+ y2, b2 = rrule (CFG, foldl, * , (1 , 2 , 0 , 4 ))
231
+ @test y2 == 0
232
+ b2 (8 ) == (NoTangent (), NoTangent (), Tangent {NTuple{4,Int}} (0 , 0 , 64 , 0 ))
233
+
234
+ # Finite differencing
235
+ test_rrule (foldl, / , Tuple (1 .+ rand (5 )))
236
+ test_rrule (foldl, * , Tuple (rand (ComplexF64, 5 )))
237
+ end
177
238
end
178
239
179
240
@testset " Accumulations" begin
@@ -188,14 +249,14 @@ end
188
249
@testset " higher dimensions, dims=$dims " for dims in (1 ,2 ,3 )
189
250
m = round .(10 .* randn (4 ,5 ), sigdigits= 3 )
190
251
test_rrule (cumprod, m; fkwargs= (;dims= dims), atol= 0.1 )
191
- m[2 ,2 ] = 0
192
- m[2 ,4 ] = 0
252
+ m[2 , 2 ] = 0
253
+ m[2 , 4 ] = 0
193
254
test_rrule (cumprod, m; fkwargs= (;dims= dims))
194
255
195
256
t = round .(10 .* randn (3 ,3 ,3 ), sigdigits= 3 )
196
257
test_rrule (cumprod, t; fkwargs= (;dims= dims))
197
- t[2 ,2 , 2 ] = 0
198
- t[2 ,3 , 3 ] = 0
258
+ t[2 , 2 , 2 ] = 0
259
+ t[2 , 3 , 3 ] = 0
199
260
test_rrule (cumprod, t; fkwargs= (;dims= dims))
200
261
end
201
262
211
272
back = rrule (cumprod, Diagonal ([1 , 2 ]); dims= 1 )[2 ]
212
273
@test unthunk (back (fill (0.5 , 2 , 2 ))[2 ]) ≈ [1 / 2 0 ; 0 0 ] # ProjectTo'd to Diagonal now
213
274
end
275
+ end # cumprod
276
+
277
+ @testset " accumulate(f, ::Array)" begin
278
+ # Simple
279
+ y1, b1 = rrule (CFG, accumulate, * , [1 , 2 , 3 , 4 ]; init= 1 )
280
+ @test y1 == [1 , 2 , 6 , 24 ]
281
+ @test b1 ([1 , 1 , 1 , 1 ]) == (NoTangent (), NoTangent (), [33 , 16 , 10 , 6 ])
282
+
283
+ if VERSION >= v " 1.5"
284
+ y2, b2 = rrule (CFG, accumulate, / , [1 2 ; 3 4 ])
285
+ @test y2 ≈ accumulate (/ , [1 2 ; 3 4 ])
286
+ @test b2 (ones (2 , 2 ))[3 ] ≈ [1.5416666 - 0.104166664 ; - 0.18055555 - 0.010416667 ] atol= 1e-6
287
+ end
288
+
289
+ # Test execution order
290
+ c3 = Counter ()
291
+ y3, b3 = rrule (CFG, accumulate, c3, [5 , 7 , 11 ]; init= 3 )
292
+ @test c3 == Counter (3 )
293
+ @test y3 == [8 , 30 , 123 ] == accumulate (Counter (), [5 , 7 , 11 ]; init= 3 )
294
+ @test b3 ([1 , 1 , 1 ]) == (NoTangent (), NoTangent (), [29169 , 602 , 23 ]) # the 23 is clear!
295
+
296
+ c4 = Counter ()
297
+ y4, b4 = rrule (CFG, accumulate, c4, [5 , 7 , 11 ])
298
+ @test c4 == Counter (2 )
299
+ @test y4 == [5 , (5 + 7 )* 1 , ((5 + 7 )* 1 + 11 )* 2 ] == accumulate (Counter (), [5 , 7 , 11 ])
300
+ @test b4 ([1 , 1 , 1 ]) == (NoTangent (), NoTangent (), [417 , 42 * (1 + 12 ), 22 ])
301
+
302
+ # Test gradient of function
303
+ y7, b7 = rrule (CFG, accumulate, Multiplier (3 ), [5 , 7 , 11 ])
304
+ @test y7 == accumulate ((x,y)-> x* y* 3 , [5 , 7 , 11 ])
305
+ @test b7 ([1 , 1 , 1 ]) == (NoTangent (), Tangent {Multiplier{Int}} (x = 2345 ,), [715 , 510 , 315 ])
306
+
307
+ y8, b8 = rrule (CFG, accumulate, Multiplier (13 ), [5 , 7 , 11 ], init= 3 )
308
+ @test y8 == [195 , 17745 , 2537535 ] == accumulate ((x,y)-> x* y* 13 , [5 , 7 , 11 ], init= 3 )
309
+ @test b8 ([1 , 1 , 1 ]) == (NoTangent (), Tangent {Multiplier{Int}} (x = 588330 ,), [511095 , 365040 , 230685 ])
310
+ # To find these numbers:
311
+ # ForwardDiff.derivative(z -> sum(accumulate((x,y)->x*y*z, [5,7,11], init=3)), 13)
312
+ # ForwardDiff.gradient(z -> sum(accumulate((x,y)->x*y*13, z, init=3)), [5,7,11]) |> string
313
+
314
+ # Finite differencing
315
+ test_rrule (accumulate, * , randn (5 ); fkwargs= (; init= rand ()))
316
+ if VERSION >= v " 1.5"
317
+ test_rrule (accumulate, / , 1 .+ rand (3 , 4 ))
318
+ test_rrule (accumulate, ^ , 1 .+ rand (2 , 3 ); fkwargs= (; init= rand ()))
319
+ end
320
+ end
321
+ VERSION >= v " 1.5" && @testset " accumulate(f, ::Tuple)" begin
322
+ # Simple
323
+ y1, b1 = rrule (CFG, accumulate, * , (1 , 2 , 3 , 4 ); init= 1 )
324
+ @test y1 == (1 , 2 , 6 , 24 )
325
+ @test b1 ((1 , 1 , 1 , 1 )) == (NoTangent (), NoTangent (), Tangent {NTuple{4,Int}} (33 , 16 , 10 , 6 ))
326
+
327
+ # Finite differencing
328
+ test_rrule (accumulate, * , Tuple (randn (5 )); fkwargs= (; init= rand ()))
329
+ test_rrule (accumulate, / , Tuple (1 .+ rand (5 )); check_inferred= false )
214
330
end
215
331
end
0 commit comments