@@ -100,12 +100,12 @@ struct ∂⃖weaveInnerOdd{N, O}; b̄; end
100
100
end
101
101
@Base . constprop :aggressive function (w: :∂⃖weaveInnerOdd {N, O})(Δ) where {N, O}
102
102
@destruct c, c̄ = w. b̄ (Δ... )
103
- return (c̄, c), ∂⃖weaveInnerEven {plus1(N) , O} ()
103
+ return (c̄, c), ∂⃖weaveInnerEven {N+1 , O} ()
104
104
end
105
105
struct ∂⃖weaveInnerEven{N, O}; end
106
106
@Base . constprop :aggressive function (w: :∂⃖weaveInnerEven {N, O})(Δ′, x... ) where {N, O}
107
107
@destruct y, ȳ = Δ′ (x... )
108
- return y, ∂⃖weaveInnerOdd {plus1(N) , O} (ȳ)
108
+ return y, ∂⃖weaveInnerOdd {N+1 , O} (ȳ)
109
109
end
110
110
111
111
struct ∂⃖weaveOuterOdd{N, O}; end
@@ -114,15 +114,15 @@ struct ∂⃖weaveOuterOdd{N, O}; end
114
114
end
115
115
@Base . constprop :aggressive function (w: :∂⃖weaveOuterOdd {N, O})((Δ′′, Δ′′′)) where {N, O}
116
116
@destruct α, ᾱ = Δ′′′ (Δ′′)
117
- return (NoTangent (), α... ), ∂⃖weaveOuterEven {plus1(N) , O} (ᾱ)
117
+ return (NoTangent (), α... ), ∂⃖weaveOuterEven {N+1 , O} (ᾱ)
118
118
end
119
119
struct ∂⃖weaveOuterEven{N, O}; ᾱ end
120
120
@Base . constprop :aggressive function (w: :∂⃖weaveOuterEven {N, O})(Δ⁴... ) where {N, O}
121
- return w. ᾱ (Base. tail (Δ⁴)... ), ∂⃖weaveOuterOdd {plus1(N) , O} ()
121
+ return w. ᾱ (Base. tail (Δ⁴)... ), ∂⃖weaveOuterOdd {N+1 , O} ()
122
122
end
123
123
124
124
function (:: ∂⃖{N})(:: ∂⃖{1 }, args... ) where {N}
125
- @destruct (a, ā) = ∂⃖ {plus1(N) } ()(args... )
125
+ @destruct (a, ā) = ∂⃖ {N+1 } ()(args... )
126
126
let O = c_order (N)
127
127
(a, Protected {N} (@opaque Δ-> begin
128
128
(b, b̄) = ā (Δ)
@@ -187,10 +187,10 @@ end
187
187
(:: ∂⃖rruleD{N, N})(Δ... ) where {N} = error (" Should not be reached" )
188
188
189
189
# ∂⃖rrule
190
- @Base . pure term_depth (N) = 2 ^ (N- 2 )
190
+ term_depth (N) = 1 << (N- 2 )
191
191
function (:: ∂⃖rrule{N})(z, z̄) where {N}
192
192
@destruct (y, ȳ) = z
193
- y, ∂⃖rruleA {term_depth(N), 1} (∂⃖ {minus1(N) } (), ȳ, z̄)
193
+ y, ∂⃖rruleA {term_depth(N), 1} (∂⃖ {N-1 } (), ȳ, z̄)
194
194
end
195
195
196
196
function (:: ∂⃖{N})(f:: Core.IntrinsicFunction , args... ) where {N}
@@ -216,7 +216,7 @@ function (::∂⃖{N})(f::T, args...) where {T, N}
216
216
end
217
217
return z
218
218
else
219
- ∂⃖p = ∂⃖ {minus1(N) } ()
219
+ ∂⃖p = ∂⃖ {N-1 } ()
220
220
@destruct z, z̄ = ∂⃖p (rrule, f, args... )
221
221
if z === nothing
222
222
return ∂⃖recurse {N} ()(f, args... )
@@ -230,7 +230,7 @@ function ChainRulesCore.rrule_via_ad(::DiffractorRuleConfig, f::T, args...) wher
230
230
Tuple {Any, Any} (∂⃖ {1} ()(f, args... ))
231
231
end
232
232
233
- @Base . pure function (:: ∂⃖{1 })(:: typeof (Core. apply_type), head, args... )
233
+ @Base . assume_effects :total function (:: ∂⃖{1 })(:: typeof (Core. apply_type), head, args... )
234
234
return rrule (Core. apply_type, head, args... )
235
235
end
236
236
@@ -283,8 +283,8 @@ struct EvenOddEven{O, P, F, G}; f::F; g::G; end
283
283
EvenOddEven {O, P} (f:: F , g:: G ) where {O, P, F, G} = EvenOddEven {O, P, F, G} (f, g)
284
284
struct EvenOddOdd{O, P, F, G}; f:: F ; g:: G ; end
285
285
EvenOddOdd {O, P} (f:: F , g:: G ) where {O, P, F, G} = EvenOddOdd {O, P, F, G} (f, g)
286
- @Base . constprop :aggressive (o:: EvenOddOdd{O, P, F, G} )(Δ) where {O, P, F, G} = (o. f (Δ), EvenOddEven {plus1(O) , P, F, G} (o. f, o. g))
287
- @Base . constprop :aggressive (e:: EvenOddEven{O, P, F, G} )(Δ... ) where {O, P, F, G} = (e. g (Δ... ), EvenOddOdd {plus1(O) , P, F, G} (e. f, e. g))
286
+ @Base . constprop :aggressive (o:: EvenOddOdd{O, P, F, G} )(Δ) where {O, P, F, G} = (o. f (Δ), EvenOddEven {O+1 , P, F, G} (o. f, o. g))
287
+ @Base . constprop :aggressive (e:: EvenOddEven{O, P, F, G} )(Δ... ) where {O, P, F, G} = (e. g (Δ... ), EvenOddOdd {O+1 , P, F, G} (e. f, e. g))
288
288
@Base . constprop :aggressive (o:: EvenOddOdd{O, O} )(Δ) where {O} = o. f (Δ)
289
289
290
290
@@ -362,11 +362,11 @@ struct ApplyOdd{O, P}; u; ∂⃖f; end
362
362
struct ApplyEven{O, P}; u; ∂⃖∂⃖f; end
363
363
@Base . constprop :aggressive function (a:: ApplyOdd{O, P} )(Δ) where {O, P}
364
364
r, ∂⃖∂⃖f = a.∂⃖f (Δ)
365
- (a. u (r), ApplyEven {plus1(O) , P} (a. u, ∂⃖∂⃖f))
365
+ (a. u (r), ApplyEven {O+1 , P} (a. u, ∂⃖∂⃖f))
366
366
end
367
367
@Base . constprop :aggressive function (a:: ApplyEven{O, P} )(_, _, ff, args... ) where {O, P}
368
368
r, ∂⃖∂⃖∂⃖f = Core. _apply_iterate (iterate, a.∂⃖∂⃖f, (ff,), args... )
369
- (r, ApplyOdd {plus1(O) , P} (a. u, ∂⃖∂⃖∂⃖f))
369
+ (r, ApplyOdd {O+1 , P} (a. u, ∂⃖∂⃖∂⃖f))
370
370
end
371
371
@Base . constprop :aggressive function (a:: ApplyOdd{O, O} )(Δ) where {O}
372
372
r = a.∂⃖f (Δ)
@@ -380,10 +380,10 @@ function (this::∂⃖{N})(::typeof(Core._apply_iterate), iterate, f, args::Unio
380
380
end
381
381
382
382
383
- @Base . pure c_order (N:: Int ) = 2 ^ N - 1
383
+ c_order (N:: Int ) = 1 << N - 1
384
384
385
- @Base . pure function (:: ∂⃖{N})(:: typeof (Core. apply_type), head, args... ) where {N}
386
- Core. apply_type (head, args... ), NonDiffOdd {plus1(plus1( length(args))) , 1, c_order(N)} ()
385
+ @Base . assume_effects :total function (:: ∂⃖{N})(:: typeof (Core. apply_type), head, args... ) where {N}
386
+ Core. apply_type (head, args... ), NonDiffOdd {length(args)+2 , 1, c_order(N)} ()
387
387
end
388
388
389
389
@Base . constprop :aggressive lifted_getfield (x, s) = getfield (x, s)
0 commit comments