Skip to content

Commit ce86af0

Browse files
committed
adapt many tests from Zygote
1 parent 55d2871 commit ce86af0

File tree

5 files changed

+1510
-5
lines changed

5 files changed

+1510
-5
lines changed

test/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
33
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
44
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
5+
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
6+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
57
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
68
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
79
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
810
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
911
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
12+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1013
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1114
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
1215
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

test/chainrules.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
2+
# This file has integration tests for some rules defined in ChainRules.jl,
3+
# especially those which aim to support higher derivatives, as properly
4+
# testing those is difficult.
5+
6+
using Diffractor, ChainRulesCore, ForwardDiff
7+
8+
#####
9+
##### Base/array.jl
10+
#####
11+
12+
13+
14+
15+
16+
17+
#####
18+
##### Base/arraymath.jl
19+
#####
20+
21+
22+
23+
24+
#####
25+
##### Base/base.jl
26+
#####
27+
28+
29+
30+
31+
32+
#####
33+
##### Base/indexing.jl
34+
#####
35+
36+
37+
38+
39+
#####
40+
##### Base/mapreduce.jl
41+
#####
42+
43+
44+
45+
46+
#####
47+
##### LinearAlgebra/dense.jl
48+
#####
49+
50+
51+
52+
53+
#####
54+
##### LinearAlgebra/norm.jl
55+
#####
56+
57+

test/runtests.jl

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,33 @@
11
using Diffractor
2+
using Test
3+
4+
@testset verbose=true "ChainRules integration.jl" begin
5+
include("chainrules.jl")
6+
end
7+
@testset verbose=true "from Zygote" begin
8+
include("zygote_features.jl")
9+
end
10+
@testset verbose=true "from Zygote's gradcheck.jl" begin
11+
include("zygote_gradcheck.jl")
12+
end
13+
@testset verbose=true "Unit tests" begin
14+
15+
# The rest of this file is unchanged, except the very end,
16+
# but IMO we should move these tests to a new file.
17+
18+
# Loading Diffractor: var"'" globally will break many tests above, which use it for adjoint.
19+
220
using Diffractor: var"'", ∂⃖, DiffractorRuleConfig
321
using ChainRules
422
using ChainRulesCore
523
using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad
624
using Symbolics
725
using LinearAlgebra
8-
using Test
26+
927

1028
const fwd = Diffractor.PrimeDerivativeFwd
1129
const bwd = Diffractor.PrimeDerivativeBack
1230

13-
@testset verbose=true "Diffractor.jl" begin # overall testset, ensures all tests run
14-
1531
# Unit tests
1632
function tup2(f)
1733
a, b = ∂⃖{2}()(f, 1)
@@ -219,6 +235,7 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2)
219235
@test z45 2.0
220236
@test delta45 1.0
221237

238+
<<<<<<< HEAD
222239
# PR #82 - getindex on non-numeric arrays
223240
@test gradient(ls -> ls[1](1.), [Base.Fix1(*, 1.)])[1][1] isa Tangent{<:Base.Fix1}
224241

@@ -281,7 +298,11 @@ end
281298
@test_broken gradient(z -> gradient(x -> sum((y -> (x^2*y)).([1,2,3])), z)[1], 5.0) == (12.0,)
282299
end
283300

301+
302+
end
303+
304+
@testset verbose=true "pseudo-Flux" begin
284305
# Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24)
285-
#include("pinn.jl")
306+
# include("pinn.jl")
307+
end
286308

287-
end # overall testset

0 commit comments

Comments
 (0)