Skip to content

Commit 33f99ef

Browse files
bors[bot]darsnack
andauthored
Merge #1462
1462: Add Parallel layer r=DhairyaLGandhi a=darsnack Since #1289 stalled, I have added an implementation of `Parallel` with some of the changes we discussed during ML calls. This version excludes most of the structural layers in #1289 like `Join`, `Split`, and `Nop`. I also added the ability for the user to specify the reduction operator. If it is acceptable, I would like to remap `SkipConnection` to `Parallel` (not a deprecation exactly). The reason for submitting this PR now is because I am creating pre-trained weights for the networks in FluxML/Metalhead.jl#70, and there is a lot of code that can be replaced with a `Parallel`. So, I'd like to have `Parallel` in Flux before continuing with training to make the process easier. ### PR Checklist - [x] Tests are added - [x] Entry in NEWS.md - [x] Documentation, if applicable - [x] Final review from @DhairyaLGandhi (for API changes). cc @CarloLucibello Co-authored-by: Kyle Daruwalla <daruwalla@wisc.edu> Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
2 parents 40db137 + 377c977 commit 33f99ef

File tree

7 files changed

+210
-5
lines changed

7 files changed

+210
-5
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
* Removed kwarg only constructors for [`convolutional layers`](https://github.com/FluxML/Flux.jl/pull/1379).
99
* Add [sparse initialization](https://github.com/FluxML/Flux.jl/pull/1454) as described in [Deep learning via Hessian-free optimization](https://dl.acm.org/doi/abs/10.5555/3104322.3104416).
1010
* Moved GPU CI to use buildkite instead of GitLab
11+
* New [`Parallel` layer](https://github.com/FluxML/Flux.jl/pull/1462) adds inception module-like building blocks.
1112
* Other new features and bug fixes (see GitHub releases page)
1213

1314
## v0.11.2

docs/src/models/advanced.md

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,136 @@ by simply deleting it from `ps`:
7070
ps = params(m)
7171
delete!(ps, m[2].b)
7272
```
73+
74+
## Custom multiple input or output layer
75+
76+
Sometimes a model needs to receive several separate inputs at once or produce several separate outputs at once. In other words, there multiple paths within this high-level layer, each processing a different input or producing a different output. A simple example of this in machine learning literature is the [inception module](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Szegedy_Rethinking_the_Inception_CVPR_2016_paper.pdf).
77+
78+
Naively, we could have a struct that stores the weights of along each path and implement the joining/splitting in the forward pass function. But that would mean a new struct any time the operations along each path changes. Instead, this guide will show you how to construct a high-level layer (like [`Chain`](@ref)) that is made of multiple sub-layers for each path.
79+
80+
### Multiple inputs: a custom `Join` layer
81+
82+
Our custom `Join` layer will accept multiple inputs at once, pass each input through a separate path, then combine the results together. Note that this layer can already be constructed using [`Parallel`](@ref), but we will first walk through how do this manually.
83+
84+
We start by defining a new struct, `Join`, that stores the different paths and a combine operation as its fields.
85+
```julia
86+
using Flux
87+
using CUDA
88+
89+
# custom join layer
90+
struct Join{T, F}
91+
combine::F
92+
paths::T
93+
end
94+
95+
# allow Join(op, m1, m2, ...) as a constructor
96+
Join(combine, paths...) = Join(combine, paths)
97+
```
98+
Notice that we parameterized the type of the `paths` field. This is necessary for fast Julia code; in general, `T` might be a `Tuple` or `Vector`, but we don't need to pay attention to what it specifically is. The same goes for the `combine` field.
99+
100+
The next step is to use [`Flux.@functor`](@ref) to make our struct behave like a Flux layer. This is important so that calling `params` on a `Join` returns the underlying weight arrays on each path.
101+
```julia
102+
Flux.@functor Join
103+
```
104+
105+
Finally, we define the forward pass. For `Join`, this means applying each `path` in `paths` to each input array, then using `combine` to merge the results.
106+
```julia
107+
(m::Join)(xs::Tuple) = m.combine(map((f, x) -> f(x), m.paths, xs))
108+
(m::Join)(xs...) = m(xs)
109+
```
110+
111+
Lastly, we can test our new layer. Thanks to the proper abstractions in Julia, our layer works on GPU arrays out of the box!
112+
```julia
113+
model = Chain(
114+
Join(vcat,
115+
Chain(
116+
Dense(1, 5),
117+
Dense(5, 1)
118+
),
119+
Dense(1, 2),
120+
Dense(1, 1),
121+
),
122+
Dense(4, 1)
123+
) |> gpu
124+
125+
xs = map(gpu, (rand(1), rand(1), rand(1)))
126+
127+
model(xs)
128+
# returns a single float vector with one value
129+
```
130+
131+
#### Using `Parallel`
132+
133+
Flux already provides [`Parallel`](@ref) that can offer the same functionality. In this case, `Join` is going to just be syntactic sugar for `Parallel`.
134+
```julia
135+
Join(combine, paths) = Parallel(combine, paths)
136+
Join(combine, paths...) = Join(combine, paths)
137+
138+
# use vararg/tuple version of Parallel forward pass
139+
model = Chain(
140+
Join(vcat,
141+
Chain(
142+
Dense(1, 5),
143+
Dense(5, 1)
144+
),
145+
Dense(1, 2),
146+
Dense(1, 1),
147+
),
148+
Dense(4, 1)
149+
) |> gpu
150+
151+
xs = map(gpu, (rand(1), rand(1), rand(1)))
152+
153+
model(xs)
154+
# returns a single float vector with one value
155+
```
156+
157+
### Multiple outputs: a custom `Split` layer
158+
159+
Our custom `Split` layer will accept a single input, then pass the input through a separate path to produce multiple outputs.
160+
161+
We start by following the same steps as the `Join` layer: define a struct, use [`Flux.@functor`](@ref), and define the forward pass.
162+
```julia
163+
using Flux
164+
using CUDA
165+
166+
# custom split layer
167+
struct Split{T}
168+
paths::T
169+
end
170+
171+
Split(paths...) = Split(paths)
172+
173+
Flux.@functor Split
174+
175+
(m::Split)(x::AbstractArray) = tuple(map(f -> f(x), m.paths))
176+
```
177+
178+
Now we can test to see that our `Split` does indeed produce multiple outputs.
179+
```julia
180+
model = Chain(
181+
Dense(10, 5),
182+
CustomSplit(
183+
Dense(5, 1),
184+
Dense(5, 3),
185+
Dense(5, 2)
186+
)
187+
) |> gpu
188+
189+
model(gpu(rand(10)))
190+
# returns a tuple with three float vectors
191+
```
192+
193+
A custom loss function for the multiple outputs may look like this:
194+
```julia
195+
using Statistics
196+
197+
# assuming model returns the output of a Split
198+
# x is a single input
199+
# ys is a tuple of outputs
200+
function loss(x, ys, model)
201+
# rms over all the mse
202+
ŷs = model(x)
203+
return sqrt(mean(Flux.mse(y, ŷ) for (y, ŷ) in zip(ys, ŷs)))
204+
end
205+
```

docs/src/models/layers.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ But in contrast to the layers described in the other sections are not readily gr
4949
```@docs
5050
Maxout
5151
SkipConnection
52+
Parallel
5253
```
5354

5455
## Normalisation & Regularisation

src/Flux.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@ using Zygote: Params, @adjoint, gradient, pullback, @nograd
1111

1212
export gradient
1313

14-
export Chain, Dense, Maxout, RNN, LSTM, GRU, SamePad, Conv, CrossCor, ConvTranspose,
15-
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool,
16-
MeanPool, flatten, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm,
17-
InstanceNorm, GroupNorm, SkipConnection, params, fmap, cpu, gpu, f32, f64,
14+
export Chain, Dense, Maxout, SkipConnection, Parallel, flatten,
15+
RNN, LSTM, GRU,
16+
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
17+
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
18+
Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
19+
params, fmap, cpu, gpu, f32, f64,
1820
testmode!, trainmode!
1921

2022
include("optimise/Optimise.jl")

src/layers/basic.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,3 +253,51 @@ end
253253
function Base.show(io::IO, b::SkipConnection)
254254
print(io, "SkipConnection(", b.layers, ", ", b.connection, ")")
255255
end
256+
257+
"""
258+
Parallel(connection, layers...)
259+
260+
Create a 'Parallel' layer that passes an input array to each path in
261+
`layers`, reducing the output with `connection`.
262+
263+
Called with one input `x`, this is equivalent to `reduce(connection, [l(x) for l in layers])`.
264+
If called with multiple inputs, they are `zip`ped with the layers, thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`.
265+
266+
# Examples
267+
268+
```jldoctest
269+
julia> model = Chain(Dense(3, 5),
270+
Parallel(vcat, Dense(5, 4), Chain(Dense(5, 7), Dense(7, 4))),
271+
Dense(8, 17));
272+
273+
julia> size(model(rand(3)))
274+
(17,)
275+
276+
julia> model = Parallel(+, Dense(10, 2), Dense(5, 2))
277+
Parallel(+, Dense(10, 2), Dense(5, 2))
278+
279+
julia> size(model(rand(10), rand(5)))
280+
(2,)
281+
```
282+
"""
283+
struct Parallel{F, T}
284+
connection::F
285+
layers::T
286+
end
287+
288+
Parallel(connection, layers...) = Parallel(connection, layers)
289+
290+
@functor Parallel
291+
292+
(m::Parallel)(x::AbstractArray) = mapreduce(f -> f(x), m.connection, m.layers)
293+
(m::Parallel)(xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> f(x), m.connection, m.layers, xs)
294+
(m::Parallel)(xs::Tuple) = m(xs...)
295+
296+
Base.getindex(m::Parallel, i::Integer) = m.layers[i]
297+
Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]...)
298+
299+
function Base.show(io::IO, m::Parallel)
300+
print(io, "Parallel(", m.connection, ", ")
301+
join(io, m.layers, ", ")
302+
print(io, ")")
303+
end

test/layers/basic.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,4 +106,21 @@ import Flux: activations
106106
@test size(SkipConnection(Dense(10,10), (a,b) -> cat(a, b, dims = 2))(input)) == (10,4)
107107
end
108108
end
109-
end
109+
110+
@testset "Parallel" begin
111+
@testset "zero sum" begin
112+
input = randn(10, 10, 10, 10)
113+
@test Parallel(+, x -> zeros(size(x)), identity)(input) == input
114+
end
115+
116+
@testset "concat size" begin
117+
input = randn(10, 2)
118+
@test size(Parallel((a, b) -> cat(a, b; dims=2), Dense(10, 10), identity)(input)) == (10, 4)
119+
end
120+
121+
@testset "vararg input" begin
122+
inputs = randn(10), randn(5), randn(4)
123+
@test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,)
124+
end
125+
end
126+
end

test/outputsize.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828

2929
m = SkipConnection(Conv((3, 3), 3 => 16; pad = 1), (mx, x) -> cat(mx, x; dims = 3))
3030
@test outputsize(m, (10, 10, 3, 1)) == (10, 10, 19, 1)
31+
32+
m = Parallel((mx, x) -> cat(mx, x; dims = 3), Conv((3, 3), 3 => 16; pad = 1), identity)
33+
@test outputsize(m, (10, 10, 3, 1)) == (10, 10, 19, 1)
3134
end
3235

3336
@testset "activations" begin

0 commit comments

Comments
 (0)