|
19 | 19 | # end
|
20 | 20 |
|
21 | 21 | function TabularModel(
|
22 |
| - layers; |
23 |
| - emb_szs, |
24 |
| - n_cont::Int64, |
25 |
| - out_sz, |
26 |
| - ps::Union{Tuple, Vector, Number}=0, |
27 |
| - embed_p::Float64=0., |
28 |
| - y_range=nothing, |
29 |
| - use_bn::Bool=true, |
30 |
| - bn_final::Bool=false, |
31 |
| - bn_cont::Bool=true, |
32 |
| - act_cls=Flux.relu, |
33 |
| - lin_first::Bool=true) |
34 |
| - |
35 |
| - embedslist = [Embedding(ni, nf) for (ni, nf) in emb_szs] |
36 |
| - emb_drop = Dropout(embed_p) |
37 |
| - embeds = Chain(x -> ntuple(i -> x[i, :], length(emb_szs)), Parallel(vcat, embedslist...), emb_drop) |
38 |
| - |
39 |
| - bn_cont = bn_cont ? BatchNorm(n_cont) : identity |
40 |
| - |
41 |
| - n_emb = sum(size(embedlayer.weight)[1] for embedlayer in embedslist) |
42 |
| - sizes = append!(zeros(0), [n_emb+n_cont], layers) |
43 |
| - |
44 |
| - _layers = [] |
45 |
| - for (i, (p, a)) in enumerate(zip(Iterators.cycle(ps), actns)) |
46 |
| - layer = linbndrop(Int64(sizes[i]), Int64(sizes[i+1]), use_bn=use_bn, p=p, act=act_cls, lin_first=lin_first) |
47 |
| - push!(_layers, layer) |
48 |
| - end |
49 |
| - push!(_layers, linbndrop(Int64(last(sizes)), Int64(out_sz), use_bn=bn_final, lin_first=lin_first)) |
50 |
| - layers = isnothing(y_range) ? Chain(Parallel(vcat, embeds, bn_cont), _layers...) : Chain(Parallel(vcat, embeds, bn_cont), _layers..., @. x->Flux.sigmoid(x) * (y_range[2] - y_range[1]) + y_range[1]) |
51 |
| - layers |
| 22 | + layers; |
| 23 | + emb_szs, |
| 24 | + n_cont::Int64, |
| 25 | + out_sz, |
| 26 | + ps::Union{Tuple, Vector, Number}=0, |
| 27 | + embed_p::Float64=0., |
| 28 | + y_range=nothing, |
| 29 | + use_bn::Bool=true, |
| 30 | + bn_final::Bool=false, |
| 31 | + bn_cont::Bool=true, |
| 32 | + act_cls=Flux.relu, |
| 33 | + lin_first::Bool=true) |
| 34 | + |
| 35 | + embedslist = [Embedding(ni, nf) for (ni, nf) in emb_szs] |
| 36 | + emb_drop = Dropout(embed_p) |
| 37 | + embeds = Chain(x -> ntuple(i -> x[i, :], length(emb_szs)), Parallel(vcat, embedslist...), emb_drop) |
| 38 | + |
| 39 | + bn_cont = bn_cont ? BatchNorm(n_cont) : identity |
| 40 | + |
| 41 | + n_emb = sum(size(embedlayer.weight)[1] for embedlayer in embedslist) |
| 42 | + |
| 43 | + ps = Iterators.cycle(ps) |
| 44 | + classifiers = [] |
| 45 | + |
| 46 | + first_ps, ps = Iterators.peel(ps) |
| 47 | + push!(classifiers, linbndrop(n_emb+n_cont, first(layers); use_bn=use_bn, p=first_ps, lin_first=lin_first, act=act_cls)) |
| 48 | + for (isize, osize, p) in zip(layers[1:(end-1)], layers[2:(end)], ps) |
| 49 | + layer = linbndrop(isize, osize; use_bn=use_bn, p=p, act=act_cls, lin_first=lin_first) |
| 50 | + push!(classifiers, layer) |
| 51 | + end |
| 52 | + push!(classifiers, linbndrop(last(layers), out_sz; use_bn=bn_final, lin_first=lin_first)) |
| 53 | + |
| 54 | + layers = isnothing(y_range) ? Chain(Parallel(vcat, embeds, bn_cont), classifiers...) : Chain(Parallel(vcat, embeds, bn_cont), classifiers..., @. x->Flux.sigmoid(x) * (y_range[2] - y_range[1]) + y_range[1]) |
| 55 | + layers |
52 | 56 | end
|
0 commit comments