Skip to content

Commit 8215253

Browse files
committed
adapt many tests from Zygote
1 parent 29a14f9 commit 8215253

File tree

5 files changed

+1508
-6
lines changed

5 files changed

+1508
-6
lines changed

test/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
33
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
44
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
5+
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
6+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
57
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
68
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
79
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
810
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
911
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
12+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1013
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1114
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
1215
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

test/chainrules.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
2+
# This file has integration tests for some rules defined in ChainRules.jl,
3+
# especially those which aim to support higher derivatives, as properly
4+
# testing those is difficult.
5+
6+
using Diffractor, ChainRulesCore, ForwardDiff
7+
8+
#####
9+
##### Base/array.jl
10+
#####
11+
12+
13+
14+
15+
16+
17+
#####
18+
##### Base/arraymath.jl
19+
#####
20+
21+
22+
23+
24+
#####
25+
##### Base/base.jl
26+
#####
27+
28+
29+
30+
31+
32+
#####
33+
##### Base/indexing.jl
34+
#####
35+
36+
37+
38+
39+
#####
40+
##### Base/mapreduce.jl
41+
#####
42+
43+
44+
45+
46+
#####
47+
##### LinearAlgebra/dense.jl
48+
#####
49+
50+
51+
52+
53+
#####
54+
##### LinearAlgebra/norm.jl
55+
#####
56+
57+

test/runtests.jl

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,32 @@
11
using Diffractor
2+
using Test
3+
4+
@testset verbose=true "ChainRules integration.jl" begin
5+
include("chainrules.jl")
6+
end
7+
@testset verbose=true "from Zygote" begin
8+
include("zygote_features.jl")
9+
end
10+
@testset verbose=true "from Zygote's gradcheck.jl" begin
11+
include("zygote_gradcheck.jl")
12+
end
13+
@testset verbose=true "Unit tests" begin
14+
15+
# The rest of this file is unchanged, except the very end,
16+
# but IMO we should move these tests to a new file.
17+
18+
# Loading Diffractor: var"'" globally will break many tests above, which use it for adjoint.
19+
220
using Diffractor: var"'", ∂⃖, DiffractorRuleConfig
321
using ChainRules
422
using ChainRulesCore
523
using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad
624
using Symbolics
725
using LinearAlgebra
826

9-
using Test
10-
1127
const fwd = Diffractor.PrimeDerivativeFwd
1228
const bwd = Diffractor.PrimeDerivativeBack
1329

14-
@testset verbose=true "Diffractor.jl" begin # overall testset, ensures all tests run
15-
1630
# Unit tests
1731
function tup2(f)
1832
a, b = ∂⃖{2}()(f, 1)
@@ -276,7 +290,11 @@ end
276290
@test_broken gradient(z -> gradient(x -> sum((y -> (x^2*y)).([1,2,3])), z)[1], 5.0) == (12.0,)
277291
end
278292

293+
294+
end
295+
296+
@testset verbose=true "pseudo-Flux" begin
279297
# Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24)
280-
#include("pinn.jl")
298+
# include("pinn.jl")
299+
end
281300

282-
end # overall testset

0 commit comments

Comments
 (0)