Skip to content

Commit e7b50f5

Browse files
committed
Add some custom rrule
1 parent 81770b7 commit e7b50f5

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

src/layers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ If `init` is not passed, it creates a MultiScale Regularized Deep Equilibrium Ne
335335
"""
336336
function MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
337337
post_fuse_layer::Union{Nothing, Tuple}, init::Tuple, solver, scales; kwargs...)
338-
init = Chain(Parallel(nothing, init...), x -> mapreduce(__flatten, vcat, x))
338+
init = Chain(Parallel(nothing, init...), __flatten_vcat)
339339
return MultiScaleDeepEquilibriumNetwork(main_layers, mapping_layers, post_fuse_layer,
340340
solver, scales; init, kwargs...)
341341
end
@@ -393,6 +393,6 @@ end
393393
u, x = z
394394
u_ = __split_and_reshape(u, m.split_idxs, m.scales)
395395
u_res, st = Lux.apply(m.model, ($(inputs...),), ps, st)
396-
return mapreduce(__flatten, vcat, u_res), st
396+
return __flatten_vcat(u_res), st
397397
end
398398
end

src/utils.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,31 @@ end
1111
__split_and_reshape(x::AbstractMatrix, ::Nothing, ::Nothing) = x
1212
__split_and_reshape(x::AbstractArray, ::Nothing, ::Nothing) = x
1313

14+
function __split_and_reshape(y::AbstractMatrix, x)
15+
szs = [prod(size(xᵢ)[1:(end - 1)]) for xᵢ in x]
16+
counters = vcat(0, cumsum(szs)[1:(end - 1)])
17+
return map((sz, c, xᵢ) -> reshape(view(y, (c + 1):(c + sz), :), size(xᵢ)),
18+
szs, counters, x)
19+
end
20+
1421
@inline __flatten(x::AbstractVector) = reshape(x, length(x), 1)
1522
@inline __flatten(x::AbstractMatrix) = x
1623
@inline __flatten(x::AbstractArray) = reshape(x, :, size(x, ndims(x)))
1724

25+
@inline __flatten_vcat(x) = mapreduce(__flatten, vcat, x)
26+
27+
function CRC.rrule(::typeof(__flatten_vcat), x)
28+
y = __flatten_vcat(x)
29+
projects = CRC.ProjectTo.(x)
30+
function ∇__flatten_vcat(∂y)
31+
∂y isa CRC.NoTangent && return (CRC.NoTangent(), CRC.NoTangent())
32+
∂x = __split_and_reshape(∂y, x)
33+
∂x = map((∂xᵢ, project) -> project(∂xᵢ), ∂x, projects)
34+
return CRC.NoTangent(), ∂x
35+
end
36+
return y, ∇__flatten_vcat
37+
end
38+
1839
@inline __check_unrolled_mode(::Val{d}) where {d} = Val(d 1)
1940
@inline __check_unrolled_mode(st::NamedTuple) = __check_unrolled_mode(st.fixed_depth)
2041

0 commit comments

Comments
 (0)