Closed
Description
Hello,
I have created custom BiLSTM layer like this:
using Flux
mutable struct BiLSTM{L<:Flux.Recur{T} where T<:Flux.LSTMCell}
fw::L
bw::L
end
function BiLSTM(in::Int, out::Int)
if out % 2 != 0
throw(ArgumentError("Output size must be divisible by 2."))
end
return BiLSTM(
LSTM(in, out ÷ 2),
LSTM(in, out ÷ 2)
)
end
Flux.@functor BiLSTM
Flux.trainable(b::BiLSTM) = (b.fw, b.bw)
Base.show(io::IO, ::MIME"text/plain", b::BiLSTM) =
print(io, "BiLSTM($(size(b.fw.cell.Wi, 2)), $(size(b.fw.cell.Wi, 1)÷2))")
# ...
When I instantiate BiLSTM
cell using BiLSTM
function the Base.show
function is used as expected.
julia> BiLSTM(12, 16)
BiLSTM(12, 16)
But if I create a Chain
of BiLSTM cells my custom Base.show
function is ignored.
julia> Chain(BiLSTM(12, 16), BiLSTM(16, 16), BiLSTM(16, 16))
Chain(
BiLSTM(
Recur(
LSTMCell(12, 8), # 688 parameters
),
Recur(
LSTMCell(12, 8), # 688 parameters
),
),
BiLSTM(
Recur(
LSTMCell(16, 8), # 816 parameters
),
Recur(
LSTMCell(16, 8), # 816 parameters
),
),
BiLSTM(
Recur(
LSTMCell(16, 8), # 816 parameters
),
Recur(
LSTMCell(16, 8), # 816 parameters
),
),
) # Total: 30 trainable arrays, 4_640 parameters,
# plus 12 non-trainable, 96 parameters, summarysize 20.141 KiB.
Is there any way to make this less verbose? Something like this would be nice.
julia> Chain(BiLSTM(12, 16), BiLSTM(16, 16), BiLSTM(16, 16))
Chain(
BiLSTM(12, 16), # 1376 parameters
BiLSTM(16, 16), # 1632 parameters
BiLSTM(16, 16), # 1632 parameters
) # Total: 30 trainable arrays, 4_640 parameters,
# plus 12 non-trainable, 96 parameters, summarysize 20.141 KiB.
Thank you for help!
Metadata
Metadata
Assignees
Labels
No labels