|
18 | 18 | # [_one_emb_sz(catdict, catcol, sz_dict) for catcol in cols]
|
19 | 19 | # end
|
20 | 20 |
|
| 21 | +function embeddingbackbone(embedding_sizes, dropoutprob) |
| 22 | + embedslist = [Embedding(ni, nf) for (ni, nf) in embedding_sizes] |
| 23 | + emb_drop = Dropout(dropoutprob) |
| 24 | + Chain( |
| 25 | + x -> tuple(eachrow(x)...), |
| 26 | + Parallel(vcat, embedslist), |
| 27 | + emb_drop |
| 28 | + ) |
| 29 | +end |
| 30 | + |
| 31 | +function continuousbackbone(n_cont) |
| 32 | + n_cont > 0 ? BatchNorm(n_cont) : identity |
| 33 | +end |
| 34 | + |
21 | 35 | function TabularModel(
|
| 36 | + catbackbone, |
| 37 | + contbackbone, |
22 | 38 | layers;
|
23 |
| - emb_szs, |
| 39 | + n_cat, |
24 | 40 | n_cont,
|
25 | 41 | out_sz,
|
26 | 42 | ps=0,
|
27 |
| - embed_p=0., |
28 | 43 | use_bn=true,
|
29 | 44 | bn_final=false,
|
30 |
| - bn_cont=true, |
31 | 45 | act_cls=Flux.relu,
|
32 | 46 | lin_first=true,
|
33 |
| - final_activation=identity) |
34 |
| - |
35 |
| - embedslist = [Embedding(ni, nf) for (ni, nf) in emb_szs] |
36 |
| - n_emb = sum(size(embedlayer.weight)[1] for embedlayer in embedslist) |
37 |
| - # n_emb = first(Flux.outputsize(embeds, (length(emb_szs), 1))) |
38 |
| - emb_drop = Dropout(embed_p) |
39 |
| - embeds = Chain( |
40 |
| - x -> collect(eachrow(x)), |
41 |
| - x -> ntuple(i -> x[i], length(x)), |
42 |
| - Parallel(vcat, embedslist), |
43 |
| - emb_drop |
| 47 | + final_activation=identity |
44 | 48 | )
|
45 | 49 |
|
46 |
| - bn_cont = bn_cont && n_cont>0 ? BatchNorm(n_cont) : identity |
47 |
| - |
| 50 | + tabularbackbone = Parallel(vcat, catbackbone, contbackbone) |
| 51 | + |
| 52 | + catoutsize = first(Flux.outputsize(catbackbone, (n_cat, 1))) |
48 | 53 | ps = Iterators.cycle(ps)
|
49 | 54 | classifiers = []
|
50 | 55 |
|
51 | 56 | first_ps, ps = Iterators.peel(ps)
|
52 |
| - push!(classifiers, linbndrop(n_emb+n_cont, first(layers); use_bn=use_bn, p=first_ps, lin_first=lin_first, act=act_cls)) |
| 57 | + push!(classifiers, linbndrop(catoutsize+n_cont, first(layers); use_bn=use_bn, p=first_ps, lin_first=lin_first, act=act_cls)) |
| 58 | + |
53 | 59 | for (isize, osize, p) in zip(layers[1:(end-1)], layers[2:(end)], ps)
|
54 | 60 | layer = linbndrop(isize, osize; use_bn=use_bn, p=p, act=act_cls, lin_first=lin_first)
|
55 | 61 | push!(classifiers, layer)
|
56 | 62 | end
|
| 63 | + |
57 | 64 | push!(classifiers, linbndrop(last(layers), out_sz; use_bn=bn_final, lin_first=lin_first))
|
| 65 | + |
58 | 66 | layers = Chain(
|
59 |
| - x -> tuple(x...), |
60 |
| - Parallel(vcat, embeds, Chain(x -> ndims(x)==1 ? Flux.unsqueeze(x, 2) : x, bn_cont)), |
| 67 | + tabularbackbone, |
61 | 68 | classifiers...,
|
62 | 69 | final_activation
|
63 | 70 | )
|
64 |
| - layers |
65 | 71 | end
|
0 commit comments