We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent fe8e91f commit 0f9e672Copy full SHA for 0f9e672
src/utils.jl
@@ -616,13 +616,18 @@ function _restructure(m, xs)
616
i += length(x)
617
return x
618
end
619
-
620
length(xs) == i || @warn "Expected $(i) params, got $(length(xs))"
621
return m̄
622
623
624
@adjoint function _restructure(m, xs)
625
- _restructure(m, xs), dm -> (nothing,destructure(dm)[1])
+ m̄, numel = _restructure(m, xs), length(xs)
+ function _restructure_pullback(dm)
626
+ xs′ = destructure(dm)[1]
627
+ numel == length(xs′) || @warn "Expected $(numel) params, got $(length(xs′))"
628
+ return (nothing, xs′)
629
+ end
630
+ return m̄, _restructure_pullback
631
632
633
"""
0 commit comments