21
21
function TabularModel (
22
22
layers;
23
23
emb_szs,
24
- n_cont:: Int64 ,
24
+ n_cont,
25
25
out_sz,
26
- ps:: Union{Tuple, Vector, Number} = 0 ,
27
- embed_p:: Float64 = 0. ,
28
- y_range= nothing ,
29
- use_bn:: Bool = true ,
30
- bn_final:: Bool = false ,
31
- bn_cont:: Bool = true ,
26
+ ps= 0 ,
27
+ embed_p= 0. ,
28
+ use_bn= true ,
29
+ bn_final= false ,
30
+ bn_cont= true ,
32
31
act_cls= Flux. relu,
33
- lin_first:: Bool = true )
32
+ lin_first= true ,
33
+ final_activation= identity)
34
34
35
35
embedslist = [Embedding (ni, nf) for (ni, nf) in emb_szs]
36
+ n_emb = sum (size (embedlayer. weight)[1 ] for embedlayer in embedslist)
37
+ # n_emb = first(Flux.outputsize(embeds, (length(emb_szs), 1)))
36
38
emb_drop = Dropout (embed_p)
37
- embeds = Chain (x -> ntuple (i -> x[i, :], length (emb_szs)), Parallel (vcat, embedslist... ), emb_drop)
38
-
39
- bn_cont = bn_cont ? BatchNorm (n_cont) : identity
39
+ embeds = Chain (
40
+ x -> collect (eachrow (x)),
41
+ x -> ntuple (i -> x[i], length (x)),
42
+ Parallel (vcat, embedslist),
43
+ emb_drop
44
+ )
40
45
41
- n_emb = sum ( size (embedlayer . weight)[ 1 ] for embedlayer in embedslist)
46
+ bn_cont = bn_cont && n_cont > 0 ? BatchNorm (n_cont) : identity
42
47
43
48
ps = Iterators. cycle (ps)
44
49
classifiers = []
@@ -50,7 +55,11 @@ function TabularModel(
50
55
push! (classifiers, layer)
51
56
end
52
57
push! (classifiers, linbndrop (last (layers), out_sz; use_bn= bn_final, lin_first= lin_first))
53
-
54
- layers = isnothing (y_range) ? Chain (Parallel (vcat, embeds, bn_cont), classifiers... ) : Chain (Parallel (vcat, embeds, bn_cont), classifiers... , @. x-> Flux. sigmoid (x) * (y_range[2 ] - y_range[1 ]) + y_range[1 ])
58
+ layers = Chain (
59
+ x -> tuple (x... ),
60
+ Parallel (vcat, embeds, Chain (x -> ndims (x)== 1 ? Flux. unsqueeze (x, 2 ) : x, bn_cont)),
61
+ classifiers... ,
62
+ final_activation
63
+ )
55
64
layers
56
65
end
0 commit comments