@@ -5,9 +5,13 @@ using ChainRulesCore
5
5
using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad
6
6
using Symbolics
7
7
using LinearAlgebra
8
-
9
8
using Test
10
9
10
+ const fwd = Diffractor. PrimeDerivativeFwd
11
+ const bwd = Diffractor. PrimeDerivativeBack
12
+
13
+ @testset verbose= true " Diffractor.jl" begin # overall testset, ensures all tests run
14
+
11
15
# Unit tests
12
16
function tup2 (f)
13
17
a, b = ∂⃖ {2} ()(f, 1 )
@@ -88,9 +92,9 @@ let var"'" = Diffractor.PrimeDerivativeBack
88
92
@test @inferred (sin' (1.0 )) == cos (1.0 )
89
93
@test @inferred (sin'' (1.0 )) == - sin (1.0 )
90
94
@test sin''' (1.0 ) == - cos (1.0 )
91
- @test sin'''' (1.0 ) == sin (1.0 )
92
- @test sin''''' (1.0 ) == cos (1.0 )
93
- @test sin'''''' (1.0 ) == - sin (1.0 )
95
+ @test sin'''' (1.0 ) == sin (1.0 ) broken = VERSION >= v " 1.8 "
96
+ @test sin''''' (1.0 ) == cos (1.0 ) broken = VERSION >= v " 1.8 "
97
+ @test sin'''''' (1.0 ) == - sin (1.0 ) broken = VERSION >= v " 1.8 "
94
98
95
99
f_getfield (x) = getfield ((x,), 1 )
96
100
@test f_getfield' (1 ) == 1
@@ -101,9 +105,9 @@ let var"'" = Diffractor.PrimeDerivativeBack
101
105
102
106
complicated_2sin (x) = (x = map (sin, Diffractor. xfill (x, 2 )); x[1 ] + x[2 ])
103
107
@test @inferred (complicated_2sin' (1.0 )) == 2 sin' (1.0 )
104
- @test @inferred (complicated_2sin'' (1.0 )) == 2 sin'' (1.0 )
105
- @test @inferred (complicated_2sin''' (1.0 )) == 2 sin''' (1.0 )
106
- @test @inferred (complicated_2sin'''' (1.0 )) == 2 sin'''' (1.0 )
108
+ @test @inferred (complicated_2sin'' (1.0 )) == 2 sin'' (1.0 ) broken = true
109
+ @test @inferred (complicated_2sin''' (1.0 )) == 2 sin''' (1.0 ) broken = true
110
+ @test @inferred (complicated_2sin'''' (1.0 )) == 2 sin'''' (1.0 ) broken = true
107
111
108
112
# Control flow cases
109
113
@test @inferred ((x-> simple_control_flow (true , x))' (1.0 )) == sin' (1.0 )
149
153
# Regression tests
150
154
@test gradient (x -> sum (abs2, x .+ 1.0 ), zeros (3 ))[1 ] == [2.0 , 2.0 , 2.0 ]
151
155
152
- const fwd = Diffractor. PrimeDerivativeFwd
153
- const bwd = Diffractor. PrimeDerivativeBack
154
-
155
156
function f_broadcast (a)
156
157
l = a / 2.0 * [[0. 1. 1. ]; [1. 0. 1. ]; [1. 1. 0. ]]
157
158
return sum (l)
161
162
# Make sure that there's no infinite recursion in kwarg calls
162
163
g_kw (;x= 1.0 ) = sin (x)
163
164
f_kw (x) = g_kw (;x)
164
- @test bwd (f_kw)(1.0 ) == bwd (sin)(1.0 )
165
+ @test bwd (f_kw)(1.0 ) == bwd (sin)(1.0 ) broken= true
166
+ #=
167
+ MethodError: no method matching +(::Tangent{var"#g_kw#47"{var"#g_kw#11#48"}, NamedTuple{(Symbol("#g_kw#11"),), Tuple{ZeroTangent}}}, ::Tangent{Diffractor.KwFunc{var"#g_kw#47"{var"#g_kw#11#48"}, var"#g_kw#47##kw"}, NamedTuple{(:kwf,), Tuple{ZeroTangent}}})
168
+ ...
169
+ [2] elementwise_add(a::NamedTuple{(:contents,), Tuple{Tangent{var"#g_kw#47"{var"#g_kw#11#48"}, NamedTuple{(Symbol("#g_kw#11"),), Tuple{ZeroTangent}}}}}, b::NamedTuple{(:contents,), Tuple{Tangent{Diffractor.KwFunc{var"#g_kw#47"{var"#g_kw#11#48"}, var"#g_kw#47##kw"}, NamedTuple{(:kwf,), Tuple{ZeroTangent}}}}})
170
+ @ ChainRulesCore ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_types/tangent.jl:287
171
+ [3] +(a::Tangent{Core.Box, NamedTuple{(:contents,), Tuple{Tangent{var"#g_kw#47"{var"#g_kw#11#48"}, NamedTuple{(Symbol("#g_kw#11"),), Tuple{ZeroTangent}}}}}}, b::Tangent{Core.Box, NamedTuple{(:contents,), Tuple{Tangent{Diffractor.KwFunc{var"#g_kw#47"{var"#g_kw#11#48"}, var"#g_kw#47##kw"}, NamedTuple{(:kwf,), Tuple{ZeroTangent}}}}}})
172
+ @ ChainRulesCore ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_arithmetic.jl:130
173
+ =#
165
174
166
175
function f_crit_edge (a, b, c, x)
167
176
# A function with two critical edges. This used to trigger an issue where
@@ -220,3 +229,5 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2)
220
229
221
230
# Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24)
222
231
# include("pinn.jl")
232
+
233
+ end # overall testset
0 commit comments