1
+ struct TabularModel
2
+ embeds
3
+ emb_drop
4
+ bn_cont
5
+ n_emb
6
+ n_cont
7
+ layers
8
+ end
9
+
10
+ function TabularModel (
11
+ layers;
12
+ emb_szs,
13
+ n_cont,
14
+ out_sz,
15
+ ps:: Union{Tuple, Vector, Number, Nothing} = nothing ,
16
+ embed_p:: Float64 = 0. ,
17
+ y_range= nothing ,
18
+ use_bn:: Bool = true ,
19
+ bn_final:: Bool = false ,
20
+ bn_cont:: Bool = true ,
21
+ act_cls= Flux. relu,
22
+ lin_first:: Bool = true )
23
+
24
+ n_cont = Int64 (n_cont)
25
+ if isnothing (ps)
26
+ ps = zeros (length (layers))
27
+ end
28
+ if ps isa Number
29
+ ps = fill (ps, length (layers))
30
+ end
31
+ embedslist = [Embedding (ni, nf) for (ni, nf) in emb_szs]
32
+ emb_drop = Dropout (embed_p)
33
+ bn_cont = bn_cont && BatchNorm (n_cont)
34
+ n_emb = sum (size (embedlayer. weight)[1 ] for embedlayer in embedslist)
35
+ sizes = append! (zeros (0 ), [n_emb+ n_cont], layers, [out_sz])
36
+ actns = append! ([], [act_cls for i in 1 : (length (sizes)- 1 )], [nothing ])
37
+ _layers = [linbndrop (Int64 (sizes[i]), Int64 (sizes[i+ 1 ]), use_bn= (use_bn && ((i!= (length (actns)- 1 )) || bn_final)), p= p, act= a, lin_first= lin_first) for (i, (p, a)) in enumerate (zip (push! (ps, 0. ), actns))]
38
+ if ! isnothing (y_range)
39
+ push! (_layers, Chain (@. x-> Flux. sigmoid (x) * (y_range[2 ] - y_range[1 ]) + y_range[1 ]))
40
+ end
41
+ layers = Chain (_layers... )
42
+ TabularModel (embedslist, emb_drop, bn_cont, n_emb, n_cont, layers)
43
+ end
44
+
45
+ function (tm:: TabularModel )(x)
46
+ x_cat, x_cont = x
47
+ if tm. n_emb != 0
48
+ x = [e (x_cat[i, :]) for (i, e) in enumerate (tm. embeds)]
49
+ x = vcat (x... )
50
+ x = tm. emb_drop (x)
51
+ end
52
+ if tm. n_cont != 0
53
+ if ! isnothing (tm. bn_cont)
54
+ x_cont = tm. bn_cont (x_cont)
55
+ end
56
+ x = tm. n_emb!= 0 ? vcat (x, x_cont) : x_cont
57
+ end
58
+ tm. layers (x)
59
+ end
0 commit comments