Skip to content

Commit 75f4821

Browse files
committed
simplified tabular model
1 parent 4ecdf59 commit 75f4821

File tree

1 file changed

+34
-30
lines changed

1 file changed

+34
-30
lines changed

src/models/tabularmodel.jl

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,34 +19,38 @@ end
1919
# end
2020

2121
function TabularModel(
22-
layers;
23-
emb_szs,
24-
n_cont::Int64,
25-
out_sz,
26-
ps::Union{Tuple, Vector, Number}=0,
27-
embed_p::Float64=0.,
28-
y_range=nothing,
29-
use_bn::Bool=true,
30-
bn_final::Bool=false,
31-
bn_cont::Bool=true,
32-
act_cls=Flux.relu,
33-
lin_first::Bool=true)
34-
35-
embedslist = [Embedding(ni, nf) for (ni, nf) in emb_szs]
36-
emb_drop = Dropout(embed_p)
37-
embeds = Chain(x -> ntuple(i -> x[i, :], length(emb_szs)), Parallel(vcat, embedslist...), emb_drop)
38-
39-
bn_cont = bn_cont ? BatchNorm(n_cont) : identity
40-
41-
n_emb = sum(size(embedlayer.weight)[1] for embedlayer in embedslist)
42-
sizes = append!(zeros(0), [n_emb+n_cont], layers)
43-
44-
_layers = []
45-
for (i, (p, a)) in enumerate(zip(Iterators.cycle(ps), actns))
46-
layer = linbndrop(Int64(sizes[i]), Int64(sizes[i+1]), use_bn=use_bn, p=p, act=act_cls, lin_first=lin_first)
47-
push!(_layers, layer)
48-
end
49-
push!(_layers, linbndrop(Int64(last(sizes)), Int64(out_sz), use_bn=bn_final, lin_first=lin_first))
50-
layers = isnothing(y_range) ? Chain(Parallel(vcat, embeds, bn_cont), _layers...) : Chain(Parallel(vcat, embeds, bn_cont), _layers..., @. x->Flux.sigmoid(x) * (y_range[2] - y_range[1]) + y_range[1])
51-
layers
22+
layers;
23+
emb_szs,
24+
n_cont::Int64,
25+
out_sz,
26+
ps::Union{Tuple, Vector, Number}=0,
27+
embed_p::Float64=0.,
28+
y_range=nothing,
29+
use_bn::Bool=true,
30+
bn_final::Bool=false,
31+
bn_cont::Bool=true,
32+
act_cls=Flux.relu,
33+
lin_first::Bool=true)
34+
35+
embedslist = [Embedding(ni, nf) for (ni, nf) in emb_szs]
36+
emb_drop = Dropout(embed_p)
37+
embeds = Chain(x -> ntuple(i -> x[i, :], length(emb_szs)), Parallel(vcat, embedslist...), emb_drop)
38+
39+
bn_cont = bn_cont ? BatchNorm(n_cont) : identity
40+
41+
n_emb = sum(size(embedlayer.weight)[1] for embedlayer in embedslist)
42+
43+
ps = Iterators.cycle(ps)
44+
classifiers = []
45+
46+
first_ps, ps = Iterators.peel(ps)
47+
push!(classifiers, linbndrop(n_emb+n_cont, first(layers); use_bn=use_bn, p=first_ps, lin_first=lin_first, act=act_cls))
48+
for (isize, osize, p) in zip(layers[1:(end-1)], layers[2:(end)], ps)
49+
layer = linbndrop(isize, osize; use_bn=use_bn, p=p, act=act_cls, lin_first=lin_first)
50+
push!(classifiers, layer)
51+
end
52+
push!(classifiers, linbndrop(last(layers), out_sz; use_bn=bn_final, lin_first=lin_first))
53+
54+
layers = isnothing(y_range) ? Chain(Parallel(vcat, embeds, bn_cont), classifiers...) : Chain(Parallel(vcat, embeds, bn_cont), classifiers..., @. x->Flux.sigmoid(x) * (y_range[2] - y_range[1]) + y_range[1])
55+
layers
5256
end

0 commit comments

Comments
 (0)