Skip to content

Commit a281cd4

Browse files
committed
refactored TabularModel
1 parent 3541938 commit a281cd4

File tree

2 files changed

+28
-22
lines changed

2 files changed

+28
-22
lines changed

src/models/Models.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ include("unet.jl")
1717
include("tabularmodel.jl")
1818

1919

20-
export xresnet18, xresnet50, UNetDynamic, TabularModel,
21-
get_emb_sz
20+
export xresnet18, xresnet50, UNetDynamic,
21+
TabularModel, get_emb_sz, embeddingbackbone, continuousbackbone
2222

2323

2424
end

src/models/tabularmodel.jl

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,48 +18,54 @@ end
1818
# [_one_emb_sz(catdict, catcol, sz_dict) for catcol in cols]
1919
# end
2020

21+
function embeddingbackbone(embedding_sizes, dropoutprob)
22+
embedslist = [Embedding(ni, nf) for (ni, nf) in embedding_sizes]
23+
emb_drop = Dropout(dropoutprob)
24+
Chain(
25+
x -> tuple(eachrow(x)...),
26+
Parallel(vcat, embedslist),
27+
emb_drop
28+
)
29+
end
30+
31+
function continuousbackbone(n_cont)
32+
n_cont > 0 ? BatchNorm(n_cont) : identity
33+
end
34+
2135
function TabularModel(
36+
catbackbone,
37+
contbackbone,
2238
layers;
23-
emb_szs,
39+
n_cat,
2440
n_cont,
2541
out_sz,
2642
ps=0,
27-
embed_p=0.,
2843
use_bn=true,
2944
bn_final=false,
30-
bn_cont=true,
3145
act_cls=Flux.relu,
3246
lin_first=true,
33-
final_activation=identity)
34-
35-
embedslist = [Embedding(ni, nf) for (ni, nf) in emb_szs]
36-
n_emb = sum(size(embedlayer.weight)[1] for embedlayer in embedslist)
37-
# n_emb = first(Flux.outputsize(embeds, (length(emb_szs), 1)))
38-
emb_drop = Dropout(embed_p)
39-
embeds = Chain(
40-
x -> collect(eachrow(x)),
41-
x -> ntuple(i -> x[i], length(x)),
42-
Parallel(vcat, embedslist),
43-
emb_drop
47+
final_activation=identity
4448
)
4549

46-
bn_cont = bn_cont && n_cont>0 ? BatchNorm(n_cont) : identity
47-
50+
tabularbackbone = Parallel(vcat, catbackbone, contbackbone)
51+
52+
catoutsize = first(Flux.outputsize(catbackbone, (n_cat, 1)))
4853
ps = Iterators.cycle(ps)
4954
classifiers = []
5055

5156
first_ps, ps = Iterators.peel(ps)
52-
push!(classifiers, linbndrop(n_emb+n_cont, first(layers); use_bn=use_bn, p=first_ps, lin_first=lin_first, act=act_cls))
57+
push!(classifiers, linbndrop(catoutsize+n_cont, first(layers); use_bn=use_bn, p=first_ps, lin_first=lin_first, act=act_cls))
58+
5359
for (isize, osize, p) in zip(layers[1:(end-1)], layers[2:(end)], ps)
5460
layer = linbndrop(isize, osize; use_bn=use_bn, p=p, act=act_cls, lin_first=lin_first)
5561
push!(classifiers, layer)
5662
end
63+
5764
push!(classifiers, linbndrop(last(layers), out_sz; use_bn=bn_final, lin_first=lin_first))
65+
5866
layers = Chain(
59-
x -> tuple(x...),
60-
Parallel(vcat, embeds, Chain(x -> ndims(x)==1 ? Flux.unsqueeze(x, 2) : x, bn_cont)),
67+
tabularbackbone,
6168
classifiers...,
6269
final_activation
6370
)
64-
layers
6571
end

0 commit comments

Comments
 (0)