Skip to content

Commit 27c4c77

Browse files
bors[bot]ToucheSir
andauthored
Merge #1616
1616: Warn on reconstruct length mismatch r=CarloLucibello a=ToucheSir Ref. #1601. This is kept as a plain warning for backwards compat, but perhaps we want to consider it a bugfix and error/depwarn instead? ### PR Checklist - [x] Tests are added - [ ] Entry in NEWS.md - [ ] Documentation, if applicable - [ ] API changes require approval from a committer (different from the author, if applicable) Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
2 parents 335286a + 0f9e672 commit 27c4c77

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

src/utils.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -610,16 +610,24 @@ end
610610

611611
function _restructure(m, xs)
612612
i = 0
613-
fmap(m) do x
613+
= fmap(m) do x
614614
x isa AbstractArray || return x
615615
x = reshape(xs[i.+(1:length(x))], size(x))
616616
i += length(x)
617617
return x
618618
end
619+
length(xs) == i || @warn "Expected $(i) params, got $(length(xs))"
620+
return
619621
end
620622

621623
@adjoint function _restructure(m, xs)
622-
_restructure(m, xs), dm -> (nothing,destructure(dm)[1])
624+
m̄, numel = _restructure(m, xs), length(xs)
625+
function _restructure_pullback(dm)
626+
xs′ = destructure(dm)[1]
627+
numel == length(xs′) || @warn "Expected $(numel) params, got $(length(xs′))"
628+
return (nothing, xs′)
629+
end
630+
return m̄, _restructure_pullback
623631
end
624632

625633
"""

test/utils.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,15 @@ end
378378
p, re = destructure(m)
379379
testdense(re(p), bt)
380380
end
381+
382+
@testset "restructure in gradient" begin
383+
x = rand(Float32, 3, 1)
384+
m = dm(zeros)
385+
∇m = gradient(m -> sum(m(x)), m)[1]
386+
p, re = destructure(m)
387+
∇p = gradient-> sum(re(θ)(x)), p)[1]
388+
@test ∇p destructure(∇m)[1]
389+
end
381390
end
382391
end
383392

0 commit comments

Comments
 (0)