|
11 | 11 | __split_and_reshape(x::AbstractMatrix, ::Nothing, ::Nothing) = x
|
12 | 12 | __split_and_reshape(x::AbstractArray, ::Nothing, ::Nothing) = x
|
13 | 13 |
|
| 14 | +function __split_and_reshape(y::AbstractMatrix, x) |
| 15 | + szs = [prod(size(xᵢ)[1:(end - 1)]) for xᵢ in x] |
| 16 | + counters = vcat(0, cumsum(szs)[1:(end - 1)]) |
| 17 | + return map((sz, c, xᵢ) -> reshape(view(y, (c + 1):(c + sz), :), size(xᵢ)), |
| 18 | + szs, counters, x) |
| 19 | +end |
| 20 | + |
14 | 21 | @inline __flatten(x::AbstractVector) = reshape(x, length(x), 1)
|
15 | 22 | @inline __flatten(x::AbstractMatrix) = x
|
16 | 23 | @inline __flatten(x::AbstractArray) = reshape(x, :, size(x, ndims(x)))
|
17 | 24 |
|
| 25 | +@inline __flatten_vcat(x) = mapreduce(__flatten, vcat, x) |
| 26 | + |
| 27 | +function CRC.rrule(::typeof(__flatten_vcat), x) |
| 28 | + y = __flatten_vcat(x) |
| 29 | + projects = CRC.ProjectTo.(x) |
| 30 | + function ∇__flatten_vcat(∂y) |
| 31 | + ∂y isa CRC.NoTangent && return (CRC.NoTangent(), CRC.NoTangent()) |
| 32 | + ∂x = __split_and_reshape(∂y, x) |
| 33 | + ∂x = map((∂xᵢ, project) -> project(∂xᵢ), ∂x, projects) |
| 34 | + return CRC.NoTangent(), ∂x |
| 35 | + end |
| 36 | + return y, ∇__flatten_vcat |
| 37 | +end |
| 38 | + |
18 | 39 | @inline __check_unrolled_mode(::Val{d}) where {d} = Val(d ≥ 1)
|
19 | 40 | @inline __check_unrolled_mode(st::NamedTuple) = __check_unrolled_mode(st.fixed_depth)
|
20 | 41 |
|
|
0 commit comments