Skip to content

Commit 3541938

Browse files
committed
updated tabular model
1 parent 75f4821 commit 3541938

File tree

1 file changed

+23
-14
lines changed

1 file changed

+23
-14
lines changed

src/models/tabularmodel.jl

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,24 +21,29 @@ end
2121
function TabularModel(
2222
layers;
2323
emb_szs,
24-
n_cont::Int64,
24+
n_cont,
2525
out_sz,
26-
ps::Union{Tuple, Vector, Number}=0,
27-
embed_p::Float64=0.,
28-
y_range=nothing,
29-
use_bn::Bool=true,
30-
bn_final::Bool=false,
31-
bn_cont::Bool=true,
26+
ps=0,
27+
embed_p=0.,
28+
use_bn=true,
29+
bn_final=false,
30+
bn_cont=true,
3231
act_cls=Flux.relu,
33-
lin_first::Bool=true)
32+
lin_first=true,
33+
final_activation=identity)
3434

3535
embedslist = [Embedding(ni, nf) for (ni, nf) in emb_szs]
36+
n_emb = sum(size(embedlayer.weight)[1] for embedlayer in embedslist)
37+
# n_emb = first(Flux.outputsize(embeds, (length(emb_szs), 1)))
3638
emb_drop = Dropout(embed_p)
37-
embeds = Chain(x -> ntuple(i -> x[i, :], length(emb_szs)), Parallel(vcat, embedslist...), emb_drop)
38-
39-
bn_cont = bn_cont ? BatchNorm(n_cont) : identity
39+
embeds = Chain(
40+
x -> collect(eachrow(x)),
41+
x -> ntuple(i -> x[i], length(x)),
42+
Parallel(vcat, embedslist),
43+
emb_drop
44+
)
4045

41-
n_emb = sum(size(embedlayer.weight)[1] for embedlayer in embedslist)
46+
bn_cont = bn_cont && n_cont>0 ? BatchNorm(n_cont) : identity
4247

4348
ps = Iterators.cycle(ps)
4449
classifiers = []
@@ -50,7 +55,11 @@ function TabularModel(
5055
push!(classifiers, layer)
5156
end
5257
push!(classifiers, linbndrop(last(layers), out_sz; use_bn=bn_final, lin_first=lin_first))
53-
54-
layers = isnothing(y_range) ? Chain(Parallel(vcat, embeds, bn_cont), classifiers...) : Chain(Parallel(vcat, embeds, bn_cont), classifiers..., @. x->Flux.sigmoid(x) * (y_range[2] - y_range[1]) + y_range[1])
58+
layers = Chain(
59+
x -> tuple(x...),
60+
Parallel(vcat, embeds, Chain(x -> ndims(x)==1 ? Flux.unsqueeze(x, 2) : x, bn_cont)),
61+
classifiers...,
62+
final_activation
63+
)
5564
layers
5665
end

0 commit comments

Comments
 (0)