Skip to content

Commit 3bc42f2

Browse files
Merge pull request #1009 from bhvieira/billinear
Added Bilinear layer
2 parents ddb5e9c + f4a60c7 commit 3bc42f2

File tree

4 files changed

+117
-1
lines changed

4 files changed

+117
-1
lines changed

docs/src/models/layers.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ But in contrast to the layers described in the other sections are not readily gr
5757
Maxout
5858
SkipConnection
5959
Parallel
60+
Bilinear
6061
```
6162

6263
## Normalisation & Regularisation

src/layers/basic.jl

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,83 @@ function Base.show(io::IO, b::SkipConnection)
249249
end
250250

251251
"""
252-
Parallel(connection, layers...)
252+
Bilinear(in1, in2, out)
253+
254+
Creates a Bilinear layer, which operates on two inputs at the same time.
255+
It has parameters `W` and `b`, and its output given vectors `x`, `y` is of the form
256+
257+
z[i] = σ.(x' * W[i,:,:] * y .+ b[i])
258+
259+
If `x` and `y` are matrices, then each column of the output `z = B(x, y)` is of this form,
260+
given that `B` is a Bilinear layer of appropriate size.
261+
262+
If `y` is not given, it is taken to be equal to `x`, i.e. `B(x) == B(x, x)`
263+
The two inputs may also be provided as a tuple, `B((x, y)) == B(x, y)`,
264+
which is accepted as the input to a `Chain`.
265+
266+
```julia
267+
# using Bilinear to generate interactions, on one input
268+
x = randn(Float32, 11, 7)
269+
B = Bilinear(11, 11, 3)
270+
size(B(x)) == (3, 7)
271+
272+
# using Bilinear on two data streams at once, as a tuple
273+
x = randn(Float32, 10, 9)
274+
y = randn(Float32, 2, 9)
275+
m = Chain(Bilinear(10, 2, 3), Dense(3, 1))
276+
size(m((x, y))) == (1, 9)
277+
278+
# using Bilinear as the recombinator in a SkipConnection
279+
x = randn(Float32, 10, 9)
280+
sc = SkipConnection(Dense(10, 10), Bilinear(10, 10, 5))
281+
size(sc(x)) == (5, 9)
282+
```
283+
"""
284+
struct Bilinear{A,B,S}
285+
W::A
286+
b::B
287+
σ::S
288+
end
289+
290+
@functor Bilinear
291+
292+
Bilinear(W, b) = Bilinear(W, b, identity)
293+
294+
function Bilinear(in1::Integer, in2::Integer, out::Integer, σ = identity;
295+
initW = glorot_uniform, initb = zeros)
296+
return Bilinear(initW(out, in1, in2), initb(out), σ)
297+
end
298+
299+
function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix)
300+
W, b, σ = a.W, a.b, a.σ
301+
302+
d_z, d_x, d_y = size(W)
303+
d_x == size(x,1) && d_y == size(y,1) || throw(DimensionMismatch("number of rows in data must match W"))
304+
size(x,2) == size(y,2) || throw(DimensionMismatch("Data inputs must agree on number of columns, got $(size(x,2)) and $(size(y,2))"))
305+
306+
# @einsum Wy[o,i,s] := W[o,i,j] * y[j,s]
307+
Wy = reshape(reshape(W, (:, d_y)) * y, (d_z, d_x, :))
308+
309+
# @einsum Z[o,s] := Wy[o,i,s] * x[i,s]
310+
Wyx = batched_mul(Wy, reshape(x, (d_x, 1, :)))
311+
Z = reshape(Wyx, (d_z, :))
312+
313+
# @einsum out[o,s] := σ(Z[o,i] + b[o])
314+
σ.(Z .+ b)
315+
end
316+
317+
(a::Bilinear)(x::AbstractVecOrMat) = a(x, x)
318+
(a::Bilinear)(x::AbstractVector, y::AbstractVector) = vec(a(reshape(x, :,1), reshape(y, :,1)))
319+
(a::Bilinear)(x::NTuple{2, AbstractArray}) = a(x[1], x[2])
320+
321+
function Base.show(io::IO, l::Bilinear)
322+
print(io, "Bilinear(", size(l.W, 2), ", ", size(l.W, 3), ", ", size(l.W, 1))
323+
l.σ == identity || print(io, ", ", l.σ)
324+
print(io, ")")
325+
end
326+
327+
"""
328+
Parallel(connection, layers...)
253329
254330
Create a 'Parallel' layer that passes an input array to each path in
255331
`layers`, reducing the output with `connection`.

test/cuda/layers.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,17 @@ end
178178
gs = gradient(() -> sum(l(ip)), Flux.params(l))
179179
@test l.b gs.params
180180
end
181+
182+
@testset "Two-streams Bilinear" begin
183+
x = zeros(Float32,10,9) |> gpu
184+
y = zeros(Float32,2,9) |> gpu
185+
b = Flux.Bilinear(10, 2, 3) |> gpu
186+
@test size(b(x,y)) == (3,9)
187+
@test sum(abs2, b(x,y)) 0f0
188+
gs_gpu = gradient(() -> sum(abs2.(b(x, y))), params(b))
189+
b_cpu, x_cpu, y_cpu = b |> cpu, x |> cpu, y |> cpu
190+
gs_cpu = gradient(() -> sum(abs2.(b_cpu(x_cpu, y_cpu))), params(b_cpu))
191+
for (pgpu, pcpu) in zip(params(b), params(b_cpu))
192+
@test gs_cpu[pcpu] Array(gs_gpu[pgpu])
193+
end
194+
end

test/layers/basic.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,31 @@ import Flux: activations
112112
end
113113
end
114114

115+
@testset "Bilinear" begin
116+
@testset "SkipConnection recombinator" begin
117+
d = Dense(10, 10)
118+
b = Flux.Bilinear(10, 10, 5)
119+
x = randn(Float32,10,9)
120+
sc = SkipConnection(d, b)
121+
@test size(sc(x)) == (5,9)
122+
end
123+
124+
@testset "Two-streams zero sum" begin
125+
x = zeros(Float32,10,9)
126+
y = zeros(Float32,2,9)
127+
b = Flux.Bilinear(10, 2, 3)
128+
@test size(b(x,y)) == (3,9)
129+
@test sum(abs2, b(x,y)) == 0f0
130+
end
131+
132+
@testset "Inner interactions" begin
133+
x = randn(Float32,11,7)
134+
b = Flux.Bilinear(11, 11, 3)
135+
@test size(b(x)) == (3,7)
136+
@test_nowarn gs = gradient(() -> sum(abs2.(b(x))), params(b))
137+
end
138+
end
139+
115140
@testset "Parallel" begin
116141
@testset "zero sum" begin
117142
input = randn(10, 10, 10, 10)

0 commit comments

Comments
 (0)