Skip to content

Commit 7eacff7

Browse files
committed
added tabular model
1 parent 2cf84e2 commit 7eacff7

File tree

2 files changed

+62
-1
lines changed

2 files changed

+62
-1
lines changed

src/models/Models.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module Models
22

3+
using Base: Bool
34
using ..FastAI
45

56
using BSON
@@ -13,9 +14,10 @@ include("blocks.jl")
1314

1415
include("xresnet.jl")
1516
include("unet.jl")
17+
include("tabularmodel.jl")
1618

1719

18-
export xresnet18, xresnet50, UNetDynamic
20+
export xresnet18, xresnet50, UNetDynamic, TabularModel
1921

2022

2123
end

src/models/tabularmodel.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
struct TabularModel
2+
embeds
3+
emb_drop
4+
bn_cont
5+
n_emb
6+
n_cont
7+
layers
8+
end
9+
10+
function TabularModel(
11+
layers;
12+
emb_szs,
13+
n_cont,
14+
out_sz,
15+
ps::Union{Tuple, Vector, Number, Nothing}=nothing,
16+
embed_p::Float64=0.,
17+
y_range=nothing,
18+
use_bn::Bool=true,
19+
bn_final::Bool=false,
20+
bn_cont::Bool=true,
21+
act_cls=Flux.relu,
22+
lin_first::Bool=true)
23+
24+
n_cont = Int64(n_cont)
25+
if isnothing(ps)
26+
ps = zeros(length(layers))
27+
end
28+
if ps isa Number
29+
ps = fill(ps, length(layers))
30+
end
31+
embedslist = [Embedding(ni, nf) for (ni, nf) in emb_szs]
32+
emb_drop = Dropout(embed_p)
33+
bn_cont = bn_cont && BatchNorm(n_cont)
34+
n_emb = sum(size(embedlayer.weight)[1] for embedlayer in embedslist)
35+
sizes = append!(zeros(0), [n_emb+n_cont], layers, [out_sz])
36+
actns = append!([], [act_cls for i in 1:(length(sizes)-1)], [nothing])
37+
_layers = [linbndrop(Int64(sizes[i]), Int64(sizes[i+1]), use_bn=(use_bn && ((i!=(length(actns)-1)) || bn_final)), p=p, act=a, lin_first=lin_first) for (i, (p, a)) in enumerate(zip(push!(ps, 0.), actns))]
38+
if !isnothing(y_range)
39+
push!(_layers, Chain(@. x->Flux.sigmoid(x) * (y_range[2] - y_range[1]) + y_range[1]))
40+
end
41+
layers = Chain(_layers...)
42+
TabularModel(embedslist, emb_drop, bn_cont, n_emb, n_cont, layers)
43+
end
44+
45+
function (tm::TabularModel)(x)
46+
x_cat, x_cont = x
47+
if tm.n_emb != 0
48+
x = [e(x_cat[i, :]) for (i, e) in enumerate(tm.embeds)]
49+
x = vcat(x...)
50+
x = tm.emb_drop(x)
51+
end
52+
if tm.n_cont != 0
53+
if !isnothing(tm.bn_cont)
54+
x_cont = tm.bn_cont(x_cont)
55+
end
56+
x = tm.n_emb!=0 ? vcat(x, x_cont) : x_cont
57+
end
58+
tm.layers(x)
59+
end

0 commit comments

Comments
 (0)