@@ -111,16 +111,12 @@ Finally, we define the forward pass. For `Join`, this means applying each `path`
111
111
Lastly, we can test our new layer. Thanks to the proper abstractions in Julia, our layer works on GPU arrays out of the box!
112
112
``` julia
113
113
model = Chain (
114
- Join (vcat,
115
- Chain (
116
- Dense (1 , 5 ),
117
- Dense (5 , 1 )
118
- ),
119
- Dense (1 , 2 ),
120
- Dense (1 , 1 ),
121
- ),
122
- Dense (4 , 1 )
123
- ) |> gpu
114
+ Join (vcat,
115
+ Chain (Dense (1 , 5 ),Dense (5 , 1 )), # branch 1
116
+ Dense (1 , 2 ), # branch 2
117
+ Dense (1 , 1 )), # branch 3
118
+ Dense (4 , 1 )
119
+ ) |> gpu
124
120
125
121
xs = map (gpu, (rand (1 ), rand (1 ), rand (1 )))
126
122
@@ -137,16 +133,13 @@ Join(combine, paths...) = Join(combine, paths)
137
133
138
134
# use vararg/tuple version of Parallel forward pass
139
135
model = Chain (
140
- Join (vcat,
141
- Chain (
142
- Dense (1 , 5 ),
143
- Dense (5 , 1 )
144
- ),
145
- Dense (1 , 2 ),
146
- Dense (1 , 1 ),
147
- ),
148
- Dense (4 , 1 )
149
- ) |> gpu
136
+ Join (vcat,
137
+ Chain (Dense (1 , 5 ),Dense (5 , 1 )),
138
+ Dense (1 , 2 ),
139
+ Dense (1 , 1 )
140
+ ),
141
+ Dense (4 , 1 )
142
+ ) |> gpu
150
143
151
144
xs = map (gpu, (rand (1 ), rand (1 ), rand (1 )))
152
145
@@ -178,13 +171,9 @@ Flux.@functor Split
178
171
Now we can test to see that our ` Split ` does indeed produce multiple outputs.
179
172
``` julia
180
173
model = Chain (
181
- Dense (10 , 5 ),
182
- CustomSplit (
183
- Dense (5 , 1 ),
184
- Dense (5 , 3 ),
185
- Dense (5 , 2 )
186
- )
187
- ) |> gpu
174
+ Dense (10 , 5 ),
175
+ CustomSplit (Dense (5 , 1 ),Dense (5 , 3 ),Dense (5 , 2 ))
176
+ ) |> gpu
188
177
189
178
model (gpu (rand (10 )))
190
179
# returns a tuple with three float vectors
0 commit comments