Skip to content

Commit 377c977

Browse files
committed
Add tests for vararg Parallel and updated docs
1 parent 49a05b4 commit 377c977

File tree

3 files changed

+33
-1
lines changed

3 files changed

+33
-1
lines changed

docs/src/models/advanced.md

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ Naively, we could have a struct that stores the weights of along each path and i
7979

8080
### Multiple inputs: a custom `Join` layer
8181

82-
Our custom `Join` layer will accept multiple inputs at once, pass each input through a separate path, then combine the results together.
82+
Our custom `Join` layer will accept multiple inputs at once, pass each input through a separate path, then combine the results together. Note that this layer can already be constructed using [`Parallel`](@ref), but we will first walk through how do this manually.
8383

8484
We start by defining a new struct, `Join`, that stores the different paths and a combine operation as its fields.
8585
```julia
@@ -128,6 +128,32 @@ model(xs)
128128
# returns a single float vector with one value
129129
```
130130

131+
#### Using `Parallel`
132+
133+
Flux already provides [`Parallel`](@ref) that can offer the same functionality. In this case, `Join` is going to just be syntactic sugar for `Parallel`.
134+
```julia
135+
Join(combine, paths) = Parallel(combine, paths)
136+
Join(combine, paths...) = Join(combine, paths)
137+
138+
# use vararg/tuple version of Parallel forward pass
139+
model = Chain(
140+
Join(vcat,
141+
Chain(
142+
Dense(1, 5),
143+
Dense(5, 1)
144+
),
145+
Dense(1, 2),
146+
Dense(1, 1),
147+
),
148+
Dense(4, 1)
149+
) |> gpu
150+
151+
xs = map(gpu, (rand(1), rand(1), rand(1)))
152+
153+
model(xs)
154+
# returns a single float vector with one value
155+
```
156+
131157
### Multiple outputs: a custom `Split` layer
132158

133159
Our custom `Split` layer will accept a single input, then pass the input through a separate path to produce multiple outputs.

src/layers/basic.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ Parallel(connection, layers...) = Parallel(connection, layers)
291291

292292
(m::Parallel)(x::AbstractArray) = mapreduce(f -> f(x), m.connection, m.layers)
293293
(m::Parallel)(xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> f(x), m.connection, m.layers, xs)
294+
(m::Parallel)(xs::Tuple) = m(xs...)
294295

295296
Base.getindex(m::Parallel, i::Integer) = m.layers[i]
296297
Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]...)

test/layers/basic.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,5 +117,10 @@ import Flux: activations
117117
input = randn(10, 2)
118118
@test size(Parallel((a, b) -> cat(a, b; dims=2), Dense(10, 10), identity)(input)) == (10, 4)
119119
end
120+
121+
@testset "vararg input" begin
122+
inputs = randn(10), randn(5), randn(4)
123+
@test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,)
124+
end
120125
end
121126
end

0 commit comments

Comments
 (0)