@@ -12,11 +12,11 @@ function (g::∇getindex)(Δ)
12
12
(ChainRulesCore. NoTangent (), Δ′, map (_ -> nothing , g. i)... )
13
13
end
14
14
15
- function ChainRulesCore. rrule (g:: ∇getindex , Δ)
15
+ function ChainRulesCore. rrule (:: DiffractorRuleConfig , g:: ∇getindex , Δ)
16
16
g (Δ), Δ′′-> (nothing , Δ′′[1 ][g. i... ])
17
17
end
18
18
19
- function ChainRulesCore. rrule (:: typeof (getindex), xs:: Array , i... )
19
+ function ChainRulesCore. rrule (:: DiffractorRuleConfig , :: typeof (getindex), xs:: Array , i... )
20
20
xs[i... ], ∇getindex (xs, i)
21
21
end
22
22
@@ -37,14 +37,14 @@ function assert_gf(f)
37
37
@assert sizeof (sin) == 0
38
38
end
39
39
40
- function ChainRulesCore. rrule (:: typeof (assert_gf), f)
40
+ function ChainRulesCore. rrule (:: DiffractorRuleConfig , :: typeof (assert_gf), f)
41
41
assert_gf (f), Δ-> begin
42
42
(NoTangent (), NoTangent ())
43
43
end
44
44
end
45
45
46
46
#=
47
- function ChainRulesCore.rrule(::typeof(map), f, xs::Vector...)
47
+ function ChainRulesCore.rrule(::DiffractorRuleConfig, :: typeof(map), f, xs::Vector...)
48
48
assert_gf(f)
49
49
primal, dual = reversediff_array(f, xs...)
50
50
primal, Δ->begin
@@ -94,7 +94,7 @@ function ChainRulesCore.frule((_, ∂A, ∂B), ::typeof(*), A::AbstractMatrix{<:
94
94
end
95
95
96
96
#=
97
- function ChainRulesCore.rrule(::typeof(map), f, xs::Vector)
97
+ function ChainRulesCore.rrule(::DiffractorRuleConfig, :: typeof(map), f, xs::Vector)
98
98
assert_gf(f)
99
99
arrs = reversediff_array(f, xs)
100
100
primal = getfield(arrs, 1)
105
105
=#
106
106
107
107
#=
108
- function ChainRulesCore.rrule(::typeof(map), f, xs::Vector, ys::Vector)
108
+ function ChainRulesCore.rrule(::DiffractorRuleConfig, :: typeof(map), f, xs::Vector, ys::Vector)
109
109
assert_gf(f)
110
110
arrs = reversediff_array(f, xs, ys)
111
111
primal = getfield(arrs, 1)
@@ -116,14 +116,14 @@ end
116
116
=#
117
117
118
118
xsum (x:: Vector ) = sum (x)
119
- function ChainRulesCore. rrule (:: typeof (xsum), x:: Vector )
119
+ function ChainRulesCore. rrule (:: DiffractorRuleConfig , :: typeof (xsum), x:: Vector )
120
120
xsum (x), let xdims= size (x)
121
121
Δ-> (NoTangent (), xfill (Δ, xdims... ))
122
122
end
123
123
end
124
124
125
125
xfill (x, dims... ) = fill (x, dims... )
126
- function ChainRulesCore. rrule (:: typeof (xfill), x, dim)
126
+ function ChainRulesCore. rrule (:: DiffractorRuleConfig , :: typeof (xfill), x, dim)
127
127
xfill (x, dim), Δ-> (NoTangent (), xsum (Δ), NoTangent ())
128
128
end
129
129
@@ -137,11 +137,11 @@ struct NonDiffOdd{N, O, P}; end
137
137
# This should not happen
138
138
(:: NonDiffEven{N, O, O} )(Δ... ) where {N, O} = error ()
139
139
140
- @Base . pure function ChainRulesCore. rrule (:: typeof (Core. apply_type), head, args... )
140
+ @Base . pure function ChainRulesCore. rrule (:: DiffractorRuleConfig , :: typeof (Core. apply_type), head, args... )
141
141
Core. apply_type (head, args... ), NonDiffOdd {plus1(plus1(length(args))), 1, 1} ()
142
142
end
143
143
144
- function ChainRulesCore. rrule (:: typeof (Core. tuple), args... )
144
+ function ChainRulesCore. rrule (:: DiffractorRuleConfig , :: typeof (Core. tuple), args... )
145
145
Core. tuple (args... ), Δ-> Core. tuple (NoTangent (), Δ... )
146
146
end
147
147
151
151
ChainRulesCore. canonicalize (:: ChainRulesCore.ZeroTangent ) = ChainRulesCore. ZeroTangent ()
152
152
153
153
# Skip AD'ing through the axis computation
154
- function ChainRules. rrule (:: typeof (Base. Broadcast. instantiate), bc:: Base.Broadcast.Broadcasted )
154
+ function ChainRules. rrule (:: DiffractorRuleConfig , :: typeof (Base. Broadcast. instantiate), bc:: Base.Broadcast.Broadcasted )
155
155
return Base. Broadcast. instantiate (bc), Δ-> begin
156
156
Core. tuple (NoTangent (), Δ)
157
157
end
@@ -169,11 +169,11 @@ struct to_tuple{N}; end
169
169
end
170
170
(:: to_tuple )(Δ:: SArray ) = getfield (Δ, :data )
171
171
172
- function ChainRules. rrule (:: Type{SArray{S, T, N, L}} , x:: NTuple{L,T} ) where {S, T, N, L}
172
+ function ChainRules. rrule (:: DiffractorRuleConfig , :: Type{SArray{S, T, N, L}} , x:: NTuple{L,T} ) where {S, T, N, L}
173
173
SArray {S, T, N, L} (x), to_tuple {L} ()
174
174
end
175
175
176
- function ChainRules. rrule (:: Type{SArray{S, T, N, L}} , x:: NTuple{L,Any} ) where {S, T, N, L}
176
+ function ChainRules. rrule (:: DiffractorRuleConfig , :: Type{SArray{S, T, N, L}} , x:: NTuple{L,Any} ) where {S, T, N, L}
177
177
SArray {S, T, N, L} (x), to_tuple {L} ()
178
178
end
179
179
@@ -191,22 +191,22 @@ function ChainRules.frule((_, ∂A), ::typeof(getindex), A::AbstractArray, args.
191
191
getindex (A, args... ), getindex (∂A, args... )
192
192
end
193
193
194
- function ChainRules. rrule (:: typeof (map), :: typeof (+ ), A:: AbstractArray , B:: AbstractArray )
194
+ function ChainRules. rrule (:: DiffractorRuleConfig , :: typeof (map), :: typeof (+ ), A:: AbstractArray , B:: AbstractArray )
195
195
map (+ , A, B), Δ-> (NoTangent (), NoTangent (), Δ, Δ)
196
196
end
197
197
198
- function ChainRules. rrule (:: typeof (map), :: typeof (+ ), A:: AbstractVector , B:: AbstractVector )
198
+ function ChainRules. rrule (:: DiffractorRuleConfig , :: typeof (map), :: typeof (+ ), A:: AbstractVector , B:: AbstractVector )
199
199
map (+ , A, B), Δ-> (NoTangent (), NoTangent (), Δ, Δ)
200
200
end
201
201
202
- function ChainRules. rrule (AT:: Type{<:Array{T,N}} , x:: AbstractArray{S,N} ) where {T,S,N}
202
+ function ChainRules. rrule (:: DiffractorRuleConfig , AT:: Type{<:Array{T,N}} , x:: AbstractArray{S,N} ) where {T,S,N}
203
203
# We're leaving these in the eltype that the cotangent vector already has.
204
204
# There isn't really a good reason to believe we should convert to the
205
205
# original array type, so don't unless explicitly requested.
206
206
AT (x), Δ-> (NoTangent (), Δ)
207
207
end
208
208
209
- function ChainRules. rrule (AT:: Type{<:Array} , undef:: UndefInitializer , args... )
209
+ function ChainRules. rrule (:: DiffractorRuleConfig , AT:: Type{<:Array} , undef:: UndefInitializer , args... )
210
210
# We're leaving these in the eltype that the cotangent vector already has.
211
211
# There isn't really a good reason to believe we should convert to the
212
212
# original array type, so don't unless explicitly requested.
@@ -217,7 +217,7 @@ function unzip_tuple(t::Tuple)
217
217
map (x-> x[1 ], t), map (x-> x[2 ], t)
218
218
end
219
219
220
- function ChainRules. rrule (:: typeof (unzip_tuple), args:: Tuple )
220
+ function ChainRules. rrule (:: DiffractorRuleConfig , :: typeof (unzip_tuple), args:: Tuple )
221
221
unzip_tuple (args), Δ-> (NoTangent (), map ((x,y)-> (x,y), Δ... ))
222
222
end
223
223
228
228
back_apply (x, y) = x (y)
229
229
back_apply_zero (x) = x (Zero ())
230
230
231
- function ChainRules. rrule (:: typeof (map), f, args:: Tuple )
231
+ function ChainRules. rrule (:: DiffractorRuleConfig , :: typeof (map), f, args:: Tuple )
232
232
a, b = unzip_tuple (map (BackMap (f), args))
233
233
function back (Δ)
234
234
(fs, xs) = unzip_tuple (map (back_apply, b, Δ))
@@ -241,14 +241,14 @@ function ChainRules.rrule(::typeof(map), f, args::Tuple)
241
241
a, back
242
242
end
243
243
244
- function ChainRules. rrule (:: typeof (Base. ntuple), f, n)
244
+ function ChainRules. rrule (:: DiffractorRuleConfig , :: typeof (Base. ntuple), f, n)
245
245
a, b = unzip_tuple (ntuple (BackMap (f), n))
246
246
a, function (Δ)
247
247
(NoTangent (), sum (map (back_apply, b, Δ)), NoTangent ())
248
248
end
249
249
end
250
250
251
- function ChainRules. frule (_, :: Type{Vector{T}} , undef:: UndefInitializer , dims:: Int... ) where {T}
251
+ function ChainRules. frule (:: DiffractorRuleConfig , _, :: Type{Vector{T}} , undef:: UndefInitializer , dims:: Int... ) where {T}
252
252
Vector {T} (undef, dims... ), zeros (T, dims... )
253
253
end
254
254
@@ -258,11 +258,11 @@ end
258
258
ChainRulesCore. canonicalize (:: NoTangent ) = NoTangent ()
259
259
260
260
# Disable thunking at higher order (TODO : These should go into ChainRulesCore)
261
- function ChainRulesCore. rrule (:: Type{Thunk} , thnk)
261
+ function ChainRulesCore. rrule (:: DiffractorRuleConfig , :: Type{Thunk} , thnk)
262
262
z, ∂z = ∂⃖¹ (thnk)
263
263
z, Δ-> (NoTangent (), ∂z (Δ)... )
264
264
end
265
265
266
- function ChainRulesCore. rrule (:: Type{InplaceableThunk} , add!!, val)
266
+ function ChainRulesCore. rrule (:: DiffractorRuleConfig , :: Type{InplaceableThunk} , add!!, val)
267
267
val, Δ-> (NoTangent (), NoTangent (), Δ)
268
268
end
0 commit comments