Skip to content

Commit 9928324

Browse files
committed
adapt many tests from Zygote
1 parent bc22ad6 commit 9928324

File tree

5 files changed

+1505
-2
lines changed

5 files changed

+1505
-2
lines changed

test/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
33
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
44
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
5+
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
6+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
57
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
68
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
79
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
810
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
911
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
12+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1013
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1114
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
1215
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

test/chainrules.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
2+
# This file has integration tests for some rules defined in ChainRules.jl,
3+
# especially those which aim to support higher derivatives, as properly
4+
# testing those is difficult.
5+
6+
using Diffractor, ChainRulesCore, ForwardDiff
7+
8+
#####
9+
##### Base/array.jl
10+
#####
11+
12+
13+
14+
15+
16+
17+
#####
18+
##### Base/arraymath.jl
19+
#####
20+
21+
22+
23+
24+
#####
25+
##### Base/base.jl
26+
#####
27+
28+
29+
30+
31+
32+
#####
33+
##### Base/indexing.jl
34+
#####
35+
36+
37+
38+
39+
#####
40+
##### Base/mapreduce.jl
41+
#####
42+
43+
44+
45+
46+
#####
47+
##### LinearAlgebra/dense.jl
48+
#####
49+
50+
51+
52+
53+
#####
54+
##### LinearAlgebra/norm.jl
55+
#####
56+
57+

test/runtests.jl

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,29 @@
11
using Diffractor
2+
using Test
3+
4+
@testset verbose=true "ChainRules integration.jl" begin
5+
include("chainrules.jl")
6+
end
7+
@testset verbose=true "from Zygote" begin
8+
include("zygote_features.jl")
9+
end
10+
@testset verbose=true "from Zygote's gradcheck.jl" begin
11+
include("zygote_gradcheck.jl")
12+
end
13+
@testset verbose=true "Unit tests" begin
14+
15+
# The rest of this file is unchanged, except the very end,
16+
# but IMO we should move these tests to a new file.
17+
18+
# Loading Diffractor: var"'" globally will break many tests above, which use it for adjoint.
19+
220
using Diffractor: var"'", ∂⃖, DiffractorRuleConfig
321
using ChainRules
422
using ChainRulesCore
523
using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad
624
using Symbolics
725
using LinearAlgebra
826

9-
using Test
10-
1127
# Unit tests
1228
function tup2(f)
1329
a, b = ∂⃖{2}()(f, 1)
@@ -214,4 +230,7 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2)
214230
@test z45 2.0
215231
@test delta45 1.0
216232

233+
end
234+
@testset verbose=true "pseudo-Flux" begin
217235
include("pinn.jl")
236+
end

0 commit comments

Comments
 (0)