@@ -249,7 +249,83 @@ function Base.show(io::IO, b::SkipConnection)
249
249
end
250
250
251
251
"""
252
- Parallel(connection, layers...)
252
+ Bilinear(in1, in2, out)
253
+
254
+ Creates a Bilinear layer, which operates on two inputs at the same time.
255
+ It has parameters `W` and `b`, and its output given vectors `x`, `y` is of the form
256
+
257
+ z[i] = σ.(x' * W[i,:,:] * y .+ b[i])
258
+
259
+ If `x` and `y` are matrices, then each column of the output `z = B(x, y)` is of this form,
260
+ given that `B` is a Bilinear layer of appropriate size.
261
+
262
+ If `y` is not given, it is taken to be equal to `x`, i.e. `B(x) == B(x, x)`
263
+ The two inputs may also be provided as a tuple, `B((x, y)) == B(x, y)`,
264
+ which is accepted as the input to a `Chain`.
265
+
266
+ ```julia
267
+ # using Bilinear to generate interactions, on one input
268
+ x = randn(Float32, 11, 7)
269
+ B = Bilinear(11, 11, 3)
270
+ size(B(x)) == (3, 7)
271
+
272
+ # using Bilinear on two data streams at once, as a tuple
273
+ x = randn(Float32, 10, 9)
274
+ y = randn(Float32, 2, 9)
275
+ m = Chain(Bilinear(10, 2, 3), Dense(3, 1))
276
+ size(m((x, y))) == (1, 9)
277
+
278
+ # using Bilinear as the recombinator in a SkipConnection
279
+ x = randn(Float32, 10, 9)
280
+ sc = SkipConnection(Dense(10, 10), Bilinear(10, 10, 5))
281
+ size(sc(x)) == (5, 9)
282
+ ```
283
+ """
284
+ struct Bilinear{A,B,S}
285
+ W:: A
286
+ b:: B
287
+ σ:: S
288
+ end
289
+
290
+ @functor Bilinear
291
+
292
+ Bilinear (W, b) = Bilinear (W, b, identity)
293
+
294
+ function Bilinear (in1:: Integer , in2:: Integer , out:: Integer , σ = identity;
295
+ initW = glorot_uniform, initb = zeros)
296
+ return Bilinear (initW (out, in1, in2), initb (out), σ)
297
+ end
298
+
299
+ function (a:: Bilinear )(x:: AbstractMatrix , y:: AbstractMatrix )
300
+ W, b, σ = a. W, a. b, a. σ
301
+
302
+ d_z, d_x, d_y = size (W)
303
+ d_x == size (x,1 ) && d_y == size (y,1 ) || throw (DimensionMismatch (" number of rows in data must match W" ))
304
+ size (x,2 ) == size (y,2 ) || throw (DimensionMismatch (" Data inputs must agree on number of columns, got $(size (x,2 )) and $(size (y,2 )) " ))
305
+
306
+ # @einsum Wy[o,i,s] := W[o,i,j] * y[j,s]
307
+ Wy = reshape (reshape (W, (:, d_y)) * y, (d_z, d_x, :))
308
+
309
+ # @einsum Z[o,s] := Wy[o,i,s] * x[i,s]
310
+ Wyx = batched_mul (Wy, reshape (x, (d_x, 1 , :)))
311
+ Z = reshape (Wyx, (d_z, :))
312
+
313
+ # @einsum out[o,s] := σ(Z[o,i] + b[o])
314
+ σ .(Z .+ b)
315
+ end
316
+
317
+ (a:: Bilinear )(x:: AbstractVecOrMat ) = a (x, x)
318
+ (a:: Bilinear )(x:: AbstractVector , y:: AbstractVector ) = vec (a (reshape (x, :,1 ), reshape (y, :,1 )))
319
+ (a:: Bilinear )(x:: NTuple{2, AbstractArray} ) = a (x[1 ], x[2 ])
320
+
321
+ function Base. show (io:: IO , l:: Bilinear )
322
+ print (io, " Bilinear(" , size (l. W, 2 ), " , " , size (l. W, 3 ), " , " , size (l. W, 1 ))
323
+ l. σ == identity || print (io, " , " , l. σ)
324
+ print (io, " )" )
325
+ end
326
+
327
+ """
328
+ Parallel(connection, layers...)
253
329
254
330
Create a 'Parallel' layer that passes an input array to each path in
255
331
`layers`, reducing the output with `connection`.
0 commit comments