-
-
Notifications
You must be signed in to change notification settings - Fork 611
Description
Motivation and description
Given #2185 and other issues caused by the current mutability of the recur interface, we should move to a more standard blocked (i.e. 3D for simple RNN) interface. This has the benefits of:
- cleaning the recurrent interface so it is more easily used by people coming from other packages,
- more easily enable workflows using convRNNs, and
- potentially enable some optimizations we can handle on the Flux side (see Lux's Recurrence return_sequence=true vs false)
I have not tested how we might fix the gradients by moving to this restricted interface. But if we decide to remove the statefulness (see below) we can fix gradients as seen in FluxML/Fluxperimental.jl#7.
Possible Implementation
I see two ways we can do this change, one which is a wider change of the Flux chain interface and another which tries to only fix Recur. In either case, the implementation would assume the final dimension of your multi-dimensional array is the time index. For a simple RNN it would assume the dimensions of the incoming array as: Features x Batch x Time. It will produce an error if a 2d array or 1d array is passed to recur, to avoid ambiguities.
One possible implementation is to go ahead and do the full change over to removing state from the network generally. See FluxML/Fluxperimental.jl#7. This would overhaul large parts of the interface into chain, and could be targeted at 0.14. See the implementation done in the above PR and FluxML/Fluxperimental.jl#5 for details.
The second possible approach is to just first remove the loop over timesteps interface and replace with the 3d interface. This initial change restricts the interface to be 3d, but I haven't tested how we could fix gradients while maintaining mutability and statefulness in Recur. The interface/impl would likely look much like:
Flux.jl/src/layers/recurrent.jl
Lines 184 to 188 in c9c262d
function (m::Recur)(x::AbstractArray{T, 3}) where T | |
h = [m(x_t) for x_t in eachlastdim(x)] | |
sze = size(h[1]) | |
reshape(reduce(hcat, h), sze[1], sze[2], length(h)) | |
end |