Skip to content

Commit 57e64ff

Browse files
committed
Fix #38 - UnApply for Vector
1 parent 1149a4c commit 57e64ff

File tree

4 files changed

+25
-7
lines changed

4 files changed

+25
-7
lines changed

src/extra_rules.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,3 +256,13 @@ end
256256
@ChainRules.non_differentiable Base.throw(err)
257257
@ChainRules.non_differentiable Core.Compiler.return_type(args...)
258258
ChainRulesCore.canonicalize(::NoTangent) = NoTangent()
259+
260+
# Disable thunking at higher order (TODO: These should go into ChainRulesCore)
261+
function ChainRulesCore.rrule(::Type{Thunk}, thnk)
262+
z, ∂z = ∂⃖¹(thnk)
263+
z, Δ->(NoTangent(), ∂z(Δ)...)
264+
end
265+
266+
function ChainRulesCore.rrule(::Type{InplaceableThunk}, add!!, val)
267+
val, Δ->(NoTangent(), NoTangent(), Δ)
268+
end

src/stage1/generated.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -336,12 +336,16 @@ function (::∂⃖{N})(::typeof(Core.tuple), args::Vararg{Any, M}) where {N, M}
336336
)
337337
end
338338

339-
struct UnApply{Spec}; end
340-
@generated function (::UnApply{Spec})(Δ) where Spec
339+
struct UnApply{Spec, Types}; end
340+
@generated function (::UnApply{Spec, Types})(Δ) where {Spec, Types}
341341
args = Any[NoTangent(), NoTangent(), :(Δ[1])]
342342
start = 2
343-
for l in Spec
344-
push!(args, :(Δ[$(start:(start+l-1))]))
343+
for (l, T) in zip(Spec, Types.parameters)
344+
if T <: Array
345+
push!(args, :([Δ[$(start:(start+l-1))]...]))
346+
else
347+
push!(args, :(Δ[$(start:(start+l-1))]))
348+
end
345349
start += l
346350
end
347351
:(Core.tuple($(args...)))
@@ -362,10 +366,10 @@ end
362366
a.u(r)
363367
end
364368

365-
function (this::∂⃖{N})(::typeof(Core._apply_iterate), iterate, f, args::Union{Tuple, NamedTuple}...) where {N}
369+
function (this::∂⃖{N})(::typeof(Core._apply_iterate), iterate, f, args::Union{Tuple, Vector, NamedTuple}...) where {N}
366370
@assert iterate === Base.iterate
367371
x, ∂⃖f = Core._apply_iterate(iterate, this, (f,), args...)
368-
return x, ApplyOdd{1, c_order(N)}(UnApply{map(length, args)}(), ∂⃖f)
372+
return x, ApplyOdd{1, c_order(N)}(UnApply{map(length, args), typeof(args)}(), ∂⃖f)
369373
end
370374

371375

src/stage1/recurse.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ function transform!(ci, meth, nargs, sparams, N)
571571
if is_accumable(stmt.args[2])
572572
accum!(stmt.args[2], nt)
573573
end
574-
elseif isa(stmt, GlobalRef) || isexpr(stmt, :static_parameter) || isexpr(stmt, :throw_undef_if_not)
574+
elseif isa(stmt, GlobalRef) || isexpr(stmt, :static_parameter) || isexpr(stmt, :throw_undef_if_not) || isexpr(stmt, :loopinfo)
575575
# We drop gradients for globals and static parameters
576576
elseif isexpr(stmt, :inbounds)
577577
# Nothing to do

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,4 +188,8 @@ let var"'" = bwd
188188
@test (x->x^5)'''(1.0) == 60.
189189
end
190190

191+
# Issue #38 - Splatting arrays
192+
@test gradient(x -> max(x...), (1,2,3))[1] == (0.0, 0.0, 1.0)
193+
@test gradient(x -> max(x...), [1,2,3])[1] == [0.0, 0.0, 1.0]
194+
191195
include("pinn.jl")

0 commit comments

Comments
 (0)