Skip to content

Commit 0f9e672

Browse files
committed
Warn in restructure pullback as well
1 parent fe8e91f commit 0f9e672

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

src/utils.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -616,13 +616,18 @@ function _restructure(m, xs)
616616
i += length(x)
617617
return x
618618
end
619-
620619
length(xs) == i || @warn "Expected $(i) params, got $(length(xs))"
621620
return
622621
end
623622

624623
@adjoint function _restructure(m, xs)
625-
_restructure(m, xs), dm -> (nothing,destructure(dm)[1])
624+
m̄, numel = _restructure(m, xs), length(xs)
625+
function _restructure_pullback(dm)
626+
xs′ = destructure(dm)[1]
627+
numel == length(xs′) || @warn "Expected $(numel) params, got $(length(xs′))"
628+
return (nothing, xs′)
629+
end
630+
return m̄, _restructure_pullback
626631
end
627632

628633
"""

0 commit comments

Comments
 (0)