Skip to content

Commit 018510d

Browse files
authored
add DiffractorRuleConfig to extra rrules (#80)
1 parent a2ea087 commit 018510d

File tree

2 files changed

+26
-26
lines changed

2 files changed

+26
-26
lines changed

src/extra_rules.jl

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ function (g::∇getindex)(Δ)
1212
(ChainRulesCore.NoTangent(), Δ′, map(_ -> nothing, g.i)...)
1313
end
1414

15-
function ChainRulesCore.rrule(g::∇getindex, Δ)
15+
function ChainRulesCore.rrule(::DiffractorRuleConfig, g::∇getindex, Δ)
1616
g(Δ), Δ′′->(nothing, Δ′′[1][g.i...])
1717
end
1818

19-
function ChainRulesCore.rrule(::typeof(getindex), xs::Array, i...)
19+
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(getindex), xs::Array, i...)
2020
xs[i...], ∇getindex(xs, i)
2121
end
2222

@@ -37,14 +37,14 @@ function assert_gf(f)
3737
@assert sizeof(sin) == 0
3838
end
3939

40-
function ChainRulesCore.rrule(::typeof(assert_gf), f)
40+
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(assert_gf), f)
4141
assert_gf(f), Δ->begin
4242
(NoTangent(), NoTangent())
4343
end
4444
end
4545

4646
#=
47-
function ChainRulesCore.rrule(::typeof(map), f, xs::Vector...)
47+
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(map), f, xs::Vector...)
4848
assert_gf(f)
4949
primal, dual = reversediff_array(f, xs...)
5050
primal, Δ->begin
@@ -94,7 +94,7 @@ function ChainRulesCore.frule((_, ∂A, ∂B), ::typeof(*), A::AbstractMatrix{<:
9494
end
9595

9696
#=
97-
function ChainRulesCore.rrule(::typeof(map), f, xs::Vector)
97+
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(map), f, xs::Vector)
9898
assert_gf(f)
9999
arrs = reversediff_array(f, xs)
100100
primal = getfield(arrs, 1)
@@ -105,7 +105,7 @@ end
105105
=#
106106

107107
#=
108-
function ChainRulesCore.rrule(::typeof(map), f, xs::Vector, ys::Vector)
108+
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(map), f, xs::Vector, ys::Vector)
109109
assert_gf(f)
110110
arrs = reversediff_array(f, xs, ys)
111111
primal = getfield(arrs, 1)
@@ -116,14 +116,14 @@ end
116116
=#
117117

118118
xsum(x::Vector) = sum(x)
119-
function ChainRulesCore.rrule(::typeof(xsum), x::Vector)
119+
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(xsum), x::Vector)
120120
xsum(x), let xdims=size(x)
121121
Δ->(NoTangent(), xfill(Δ, xdims...))
122122
end
123123
end
124124

125125
xfill(x, dims...) = fill(x, dims...)
126-
function ChainRulesCore.rrule(::typeof(xfill), x, dim)
126+
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(xfill), x, dim)
127127
xfill(x, dim), Δ->(NoTangent(), xsum(Δ), NoTangent())
128128
end
129129

@@ -137,11 +137,11 @@ struct NonDiffOdd{N, O, P}; end
137137
# This should not happen
138138
(::NonDiffEven{N, O, O})(Δ...) where {N, O} = error()
139139

140-
@Base.pure function ChainRulesCore.rrule(::typeof(Core.apply_type), head, args...)
140+
@Base.pure function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(Core.apply_type), head, args...)
141141
Core.apply_type(head, args...), NonDiffOdd{plus1(plus1(length(args))), 1, 1}()
142142
end
143143

144-
function ChainRulesCore.rrule(::typeof(Core.tuple), args...)
144+
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(Core.tuple), args...)
145145
Core.tuple(args...), Δ->Core.tuple(NoTangent(), Δ...)
146146
end
147147

@@ -151,7 +151,7 @@ end
151151
ChainRulesCore.canonicalize(::ChainRulesCore.ZeroTangent) = ChainRulesCore.ZeroTangent()
152152

153153
# Skip AD'ing through the axis computation
154-
function ChainRules.rrule(::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast.Broadcasted)
154+
function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast.Broadcasted)
155155
return Base.Broadcast.instantiate(bc), Δ->begin
156156
Core.tuple(NoTangent(), Δ)
157157
end
@@ -169,11 +169,11 @@ struct to_tuple{N}; end
169169
end
170170
(::to_tuple)(Δ::SArray) = getfield(Δ, :data)
171171

172-
function ChainRules.rrule(::Type{SArray{S, T, N, L}}, x::NTuple{L,T}) where {S, T, N, L}
172+
function ChainRules.rrule(::DiffractorRuleConfig, ::Type{SArray{S, T, N, L}}, x::NTuple{L,T}) where {S, T, N, L}
173173
SArray{S, T, N, L}(x), to_tuple{L}()
174174
end
175175

176-
function ChainRules.rrule(::Type{SArray{S, T, N, L}}, x::NTuple{L,Any}) where {S, T, N, L}
176+
function ChainRules.rrule(::DiffractorRuleConfig, ::Type{SArray{S, T, N, L}}, x::NTuple{L,Any}) where {S, T, N, L}
177177
SArray{S, T, N, L}(x), to_tuple{L}()
178178
end
179179

@@ -191,22 +191,22 @@ function ChainRules.frule((_, ∂A), ::typeof(getindex), A::AbstractArray, args.
191191
getindex(A, args...), getindex(∂A, args...)
192192
end
193193

194-
function ChainRules.rrule(::typeof(map), ::typeof(+), A::AbstractArray, B::AbstractArray)
194+
function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(map), ::typeof(+), A::AbstractArray, B::AbstractArray)
195195
map(+, A, B), Δ->(NoTangent(), NoTangent(), Δ, Δ)
196196
end
197197

198-
function ChainRules.rrule(::typeof(map), ::typeof(+), A::AbstractVector, B::AbstractVector)
198+
function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(map), ::typeof(+), A::AbstractVector, B::AbstractVector)
199199
map(+, A, B), Δ->(NoTangent(), NoTangent(), Δ, Δ)
200200
end
201201

202-
function ChainRules.rrule(AT::Type{<:Array{T,N}}, x::AbstractArray{S,N}) where {T,S,N}
202+
function ChainRules.rrule(::DiffractorRuleConfig, AT::Type{<:Array{T,N}}, x::AbstractArray{S,N}) where {T,S,N}
203203
# We're leaving these in the eltype that the cotangent vector already has.
204204
# There isn't really a good reason to believe we should convert to the
205205
# original array type, so don't unless explicitly requested.
206206
AT(x), Δ->(NoTangent(), Δ)
207207
end
208208

209-
function ChainRules.rrule(AT::Type{<:Array}, undef::UndefInitializer, args...)
209+
function ChainRules.rrule(::DiffractorRuleConfig, AT::Type{<:Array}, undef::UndefInitializer, args...)
210210
# We're leaving these in the eltype that the cotangent vector already has.
211211
# There isn't really a good reason to believe we should convert to the
212212
# original array type, so don't unless explicitly requested.
@@ -217,7 +217,7 @@ function unzip_tuple(t::Tuple)
217217
map(x->x[1], t), map(x->x[2], t)
218218
end
219219

220-
function ChainRules.rrule(::typeof(unzip_tuple), args::Tuple)
220+
function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(unzip_tuple), args::Tuple)
221221
unzip_tuple(args), Δ->(NoTangent(), map((x,y)->(x,y), Δ...))
222222
end
223223

@@ -228,7 +228,7 @@ end
228228
back_apply(x, y) = x(y)
229229
back_apply_zero(x) = x(Zero())
230230

231-
function ChainRules.rrule(::typeof(map), f, args::Tuple)
231+
function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(map), f, args::Tuple)
232232
a, b = unzip_tuple(map(BackMap(f), args))
233233
function back(Δ)
234234
(fs, xs) = unzip_tuple(map(back_apply, b, Δ))
@@ -241,14 +241,14 @@ function ChainRules.rrule(::typeof(map), f, args::Tuple)
241241
a, back
242242
end
243243

244-
function ChainRules.rrule(::typeof(Base.ntuple), f, n)
244+
function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(Base.ntuple), f, n)
245245
a, b = unzip_tuple(ntuple(BackMap(f), n))
246246
a, function (Δ)
247247
(NoTangent(), sum(map(back_apply, b, Δ)), NoTangent())
248248
end
249249
end
250250

251-
function ChainRules.frule(_, ::Type{Vector{T}}, undef::UndefInitializer, dims::Int...) where {T}
251+
function ChainRules.frule(::DiffractorRuleConfig, _, ::Type{Vector{T}}, undef::UndefInitializer, dims::Int...) where {T}
252252
Vector{T}(undef, dims...), zeros(T, dims...)
253253
end
254254

@@ -258,11 +258,11 @@ end
258258
ChainRulesCore.canonicalize(::NoTangent) = NoTangent()
259259

260260
# Disable thunking at higher order (TODO: These should go into ChainRulesCore)
261-
function ChainRulesCore.rrule(::Type{Thunk}, thnk)
261+
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::Type{Thunk}, thnk)
262262
z, ∂z = ∂⃖¹(thnk)
263263
z, Δ->(NoTangent(), ∂z(Δ)...)
264264
end
265265

266-
function ChainRulesCore.rrule(::Type{InplaceableThunk}, add!!, val)
266+
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::Type{InplaceableThunk}, add!!, val)
267267
val, Δ->(NoTangent(), NoTangent(), Δ)
268268
end

src/stage1/broadcast.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,12 @@ unbroadcast(x::AbstractArray, x̄::Nothing) = NoTangent()
6262

6363
const Numeric = Union{Number, AbstractArray{<:Number, N} where N}
6464

65-
function ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(+), xs::Numeric...)
65+
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(broadcasted), ::typeof(+), xs::Numeric...)
6666
broadcast(+, xs...), ȳ -> (NoTangent(), NoTangent(), map(x -> unbroadcast(x, unthunk(ȳ)), xs)...)
6767
end
6868

69-
ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(-), x::Numeric, y::Numeric) = x .- y,
69+
ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(broadcasted), ::typeof(-), x::Numeric, y::Numeric) = x .- y,
7070
Δ -> let Δ=unthunk(Δ); (NoTangent(), NoTangent(), unbroadcast(x, Δ), -unbroadcast(y, Δ)); end
7171

72-
ChainRulesCore.rrule(::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeric) = x.*y,
72+
ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeric) = x.*y,
7373
-> let=unthunk(z̄); (NoTangent(), NoTangent(), unbroadcast(x, z̄ .* conj.(y)), unbroadcast(y, z̄ .* conj.(x))); end

0 commit comments

Comments
 (0)