1
1
module reverse_tests
2
2
using Diffractor
3
- using Diffractor: var"'" , ∂⃖, DiffractorRuleConfig
3
+ using Diffractor: ∂⃖, DiffractorRuleConfig
4
4
using ChainRules
5
5
using ChainRulesCore
6
6
using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad
@@ -44,14 +44,6 @@ function simple_control_flow(b, x)
44
44
end
45
45
end
46
46
47
- function myprod (xs)
48
- s = 1
49
- for x in xs
50
- s *= x
51
- end
52
- return s
53
- end
54
-
55
47
function mypow (x, n)
56
48
r = one (x)
57
49
while n > 0
@@ -79,10 +71,11 @@ let var"'" = Diffractor.PrimeDerivativeBack
79
71
@test @inferred (sin' (1.0 )) == cos (1.0 )
80
72
@test @inferred (sin'' (1.0 )) == - sin (1.0 )
81
73
@test @inferred (sin''' (1.0 )) == - cos (1.0 )
82
- # TODO These currently cause segfaults c.f. https://github.com/JuliaLang/julia/pull/48742
83
- # @test sin''''(1.0) == sin(1.0)
84
- # @test sin'''''(1.0) == cos(1.0)
85
- # @test sin''''''(1.0) == -sin(1.0)
74
+ # FIXME : These error with:
75
+ # Control flow support not fully implemented yet for higher-order reverse mode (TODO )
76
+ @test_broken @inferred (sin'''' (1.0 )) == sin (1.0 )
77
+ @test_broken @inferred (sin''''' (1.0 )) == cos (1.0 )
78
+ @test_broken @inferred (sin'''''' (1.0 )) == - sin (1.0 )
86
79
87
80
f_getfield (x) = getfield ((x,), 1 )
88
81
@test f_getfield' (1 ) == 1
@@ -93,10 +86,10 @@ let var"'" = Diffractor.PrimeDerivativeBack
93
86
94
87
complicated_2sin (x) = (x = map (sin, Diffractor. xfill (x, 2 )); x[1 ] + x[2 ])
95
88
@test @inferred (complicated_2sin' (1.0 )) == 2 sin' (1.0 )
96
- @test @inferred (complicated_2sin '' ( 1.0 )) == 2 sin '' ( 1.0 ) broken = true
97
- @test @inferred (complicated_2sin''' (1.0 )) == 2 sin''' (1.0 ) broken = true
98
- # TODO This currently causes a segfault, c.f. https://github.com/JuliaLang/julia/pull/48742
99
- # @test @inferred(complicated_2sin''''(1.0)) == 2sin''''(1.0) broken=true
89
+ # FIXME : These error with: Control flow support not fully implemented yet for higher-order reverse mode ( TODO )
90
+ @test_broken @inferred (complicated_2sin'' (1.0 )) == 2 sin'' (1.0 )
91
+ @test_broken @inferred (complicated_2sin ''' ( 1.0 )) == 2 sin ''' ( 1.0 )
92
+ @test_broken @inferred (complicated_2sin'''' (1.0 )) == 2 sin'''' (1.0 )
100
93
101
94
# Control flow cases
102
95
@test @inferred ((x-> simple_control_flow (true , x))' (1.0 )) == sin' (1.0 )
0 commit comments