@@ -101,12 +101,12 @@ struct ∂⃖weaveInnerOdd{N, O}; b̄; end
101
101
end
102
102
@Base . constprop :aggressive function (w: :∂⃖weaveInnerOdd {N, O})(Δ) where {N, O}
103
103
@destruct c, c̄ = w. b̄ (Δ... )
104
- return (c̄, c), ∂⃖weaveInnerEven {plus1(N) , O} ()
104
+ return (c̄, c), ∂⃖weaveInnerEven {N+1 , O} ()
105
105
end
106
106
struct ∂⃖weaveInnerEven{N, O}; end
107
107
@Base . constprop :aggressive function (w: :∂⃖weaveInnerEven {N, O})(Δ′, x... ) where {N, O}
108
108
@destruct y, ȳ = Δ′ (x... )
109
- return y, ∂⃖weaveInnerOdd {plus1(N) , O} (ȳ)
109
+ return y, ∂⃖weaveInnerOdd {N+1 , O} (ȳ)
110
110
end
111
111
112
112
struct ∂⃖weaveOuterOdd{N, O}; end
@@ -115,15 +115,15 @@ struct ∂⃖weaveOuterOdd{N, O}; end
115
115
end
116
116
@Base . constprop :aggressive function (w: :∂⃖weaveOuterOdd {N, O})((Δ′′, Δ′′′)) where {N, O}
117
117
@destruct α, ᾱ = Δ′′′ (Δ′′)
118
- return (NoTangent (), α... ), ∂⃖weaveOuterEven {plus1(N) , O} (ᾱ)
118
+ return (NoTangent (), α... ), ∂⃖weaveOuterEven {N+1 , O} (ᾱ)
119
119
end
120
120
struct ∂⃖weaveOuterEven{N, O}; ᾱ end
121
121
@Base . constprop :aggressive function (w: :∂⃖weaveOuterEven {N, O})(Δ⁴... ) where {N, O}
122
- return w. ᾱ (Base. tail (Δ⁴)... ), ∂⃖weaveOuterOdd {plus1(N) , O} ()
122
+ return w. ᾱ (Base. tail (Δ⁴)... ), ∂⃖weaveOuterOdd {N+1 , O} ()
123
123
end
124
124
125
125
function (:: ∂⃖{N})(:: ∂⃖{1 }, args... ) where {N}
126
- @destruct (a, ā) = ∂⃖ {plus1(N) } ()(args... )
126
+ @destruct (a, ā) = ∂⃖ {N+1 } ()(args... )
127
127
let O = c_order (N)
128
128
(a, Protected {N} (@opaque Δ-> begin
129
129
(b, b̄) = ā (Δ)
@@ -188,10 +188,10 @@ end
188
188
(:: ∂⃖rruleD{N, N})(Δ... ) where {N} = error (" Should not be reached" )
189
189
190
190
# ∂⃖rrule
191
- @Base . pure term_depth (N) = 2 ^ (N- 2 )
191
+ term_depth (N) = 1 << (N- 2 )
192
192
function (:: ∂⃖rrule{N})(z, z̄) where {N}
193
193
@destruct (y, ȳ) = z
194
- y, ∂⃖rruleA {term_depth(N), 1} (∂⃖ {minus1(N) } (), ȳ, z̄)
194
+ y, ∂⃖rruleA {term_depth(N), 1} (∂⃖ {N-1 } (), ȳ, z̄)
195
195
end
196
196
197
197
function (:: ∂⃖{N})(f:: Core.IntrinsicFunction , args... ) where {N}
@@ -217,7 +217,7 @@ function (::∂⃖{N})(f::T, args...) where {T, N}
217
217
end
218
218
return z
219
219
else
220
- ∂⃖p = ∂⃖ {minus1(N) } ()
220
+ ∂⃖p = ∂⃖ {N-1 } ()
221
221
@destruct z, z̄ = ∂⃖p (rrule, f, args... )
222
222
if z === nothing
223
223
return ∂⃖recurse {N} ()(f, args... )
@@ -231,7 +231,7 @@ function ChainRulesCore.rrule_via_ad(::DiffractorRuleConfig, f::T, args...) wher
231
231
Tuple {Any, Any} (∂⃖ {1} ()(f, args... ))
232
232
end
233
233
234
- @Base . pure function (:: ∂⃖{1 })(:: typeof (Core. apply_type), head, args... )
234
+ @Base . assume_effects :total function (:: ∂⃖{1 })(:: typeof (Core. apply_type), head, args... )
235
235
return rrule (Core. apply_type, head, args... )
236
236
end
237
237
@@ -284,8 +284,8 @@ struct EvenOddEven{O, P, F, G}; f::F; g::G; end
284
284
EvenOddEven {O, P} (f:: F , g:: G ) where {O, P, F, G} = EvenOddEven {O, P, F, G} (f, g)
285
285
struct EvenOddOdd{O, P, F, G}; f:: F ; g:: G ; end
286
286
EvenOddOdd {O, P} (f:: F , g:: G ) where {O, P, F, G} = EvenOddOdd {O, P, F, G} (f, g)
287
- @Base . constprop :aggressive (o:: EvenOddOdd{O, P, F, G} )(Δ) where {O, P, F, G} = (o. f (Δ), EvenOddEven {plus1(O) , P, F, G} (o. f, o. g))
288
- @Base . constprop :aggressive (e:: EvenOddEven{O, P, F, G} )(Δ... ) where {O, P, F, G} = (e. g (Δ... ), EvenOddOdd {plus1(O) , P, F, G} (e. f, e. g))
287
+ @Base . constprop :aggressive (o:: EvenOddOdd{O, P, F, G} )(Δ) where {O, P, F, G} = (o. f (Δ), EvenOddEven {O+1 , P, F, G} (o. f, o. g))
288
+ @Base . constprop :aggressive (e:: EvenOddEven{O, P, F, G} )(Δ... ) where {O, P, F, G} = (e. g (Δ... ), EvenOddOdd {O+1 , P, F, G} (e. f, e. g))
289
289
@Base . constprop :aggressive (o:: EvenOddOdd{O, O} )(Δ) where {O} = o. f (Δ)
290
290
291
291
@@ -363,11 +363,11 @@ struct ApplyOdd{O, P}; u; ∂⃖f; end
363
363
struct ApplyEven{O, P}; u; ∂⃖∂⃖f; end
364
364
@Base . constprop :aggressive function (a:: ApplyOdd{O, P} )(Δ) where {O, P}
365
365
r, ∂⃖∂⃖f = a.∂⃖f (Δ)
366
- (a. u (r), ApplyEven {plus1(O) , P} (a. u, ∂⃖∂⃖f))
366
+ (a. u (r), ApplyEven {O+1 , P} (a. u, ∂⃖∂⃖f))
367
367
end
368
368
@Base . constprop :aggressive function (a:: ApplyEven{O, P} )(_, _, ff, args... ) where {O, P}
369
369
r, ∂⃖∂⃖∂⃖f = Core. _apply_iterate (iterate, a.∂⃖∂⃖f, (ff,), args... )
370
- (r, ApplyOdd {plus1(O) , P} (a. u, ∂⃖∂⃖∂⃖f))
370
+ (r, ApplyOdd {O+1 , P} (a. u, ∂⃖∂⃖∂⃖f))
371
371
end
372
372
@Base . constprop :aggressive function (a:: ApplyOdd{O, O} )(Δ) where {O}
373
373
r = a.∂⃖f (Δ)
@@ -381,10 +381,10 @@ function (this::∂⃖{N})(::typeof(Core._apply_iterate), iterate, f, args::Unio
381
381
end
382
382
383
383
384
- @Base . pure c_order (N:: Int ) = 2 ^ N - 1
384
+ c_order (N:: Int ) = 1 << N - 1
385
385
386
- @Base . pure function (:: ∂⃖{N})(:: typeof (Core. apply_type), head, args... ) where {N}
387
- Core. apply_type (head, args... ), NonDiffOdd {plus1(plus1( length(args))) , 1, c_order(N)} ()
386
+ @Base . assume_effects :total function (:: ∂⃖{N})(:: typeof (Core. apply_type), head, args... ) where {N}
387
+ Core. apply_type (head, args... ), NonDiffOdd {length(args)+2 , 1, c_order(N)} ()
388
388
end
389
389
390
390
@Base . constprop :aggressive lifted_getfield (x, s) = getfield (x, s)
0 commit comments