Skip to content

Commit b306865

Browse files
Add tabular model (#124)
* added tabular model * fixed batchnorm in tabular model * added function for calculating embedding dimensions * updated tabular model * Apply suggestions from code review Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com> * simplified tabular model * updated tabular model * refactored TabularModel * updated tabular model, and added tests * added classifierbackbone * update tablemodel tests * export classifierbackbone * refactored TabularModel methods * updated tabular model tests * add TabularModel docstring * renamed args Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com> * updated docstrings and embed dims calculation, made args usage consistent * docstring fixes * made methods concise Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com> * updated docstrings and get_emb_sz * updated model test * undo unintentional comments * Docstring updates Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com> * minor docstring fix Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
1 parent 9fb14af commit b306865

File tree

5 files changed

+208
-2
lines changed

5 files changed

+208
-2
lines changed

src/models/Models.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module Models
22

3+
using Base: Bool, Symbol
34
using ..FastAI
45

56
using BSON
@@ -13,9 +14,9 @@ include("blocks.jl")
1314

1415
include("xresnet.jl")
1516
include("unet.jl")
17+
include("tabularmodel.jl")
1618

1719

18-
export xresnet18, xresnet50, UNetDynamic
19-
20+
export xresnet18, xresnet50, UNetDynamic, TabularModel
2021

2122
end

src/models/tabularmodel.jl

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
"""
2+
emb_sz_rule(n_cat)
3+
4+
Compute an embedding size corresponding to the number of classes for a
5+
categorical variable using the rule of thumb present in python fastai.
6+
(see https://github.com/fastai/fastai/blob/2742fe844573d06e700f869839fb9ec5f3a9bca9/fastai/tabular/model.py#L12)
7+
"""
8+
emb_sz_rule(n_cat) = min(600, round(Int, 1.6 * n_cat^0.56))
9+
10+
"""
11+
get_emb_sz(cardinalities::AbstractVector, [size_overrides::AbstractVector])
12+
13+
Given a vector of `cardinalities` of each categorical column
14+
(i.e. each element of `cardinalities` is the number of classes in that categorical column),
15+
compute the output embedding size according to [`emb_sz_rule`](#).
16+
Return a vector of tuples where each element is `(in_size, out_size)` for an embedding layer.
17+
18+
## Keyword arguments
19+
20+
- `size_overrides`: A collection of integers (or `nothing` to skip override) where the value present at any index
21+
will be used to as the output embedding size for that column.
22+
"""
23+
get_emb_sz(cardinalities::AbstractVector{<:Integer}, size_overrides=fill(nothing, length(cardinalities))) =
24+
map(zip(cardinalities, size_overrides)) do (cardinality, override)
25+
emb_dim = isnothing(override) ? emb_sz_rule(cardinality + 1) : Int64(override)
26+
return (cardinality + 1, emb_dim)
27+
end
28+
29+
"""
30+
get_emb_sz(cardinalities::Dict, [size_overrides::Dict])
31+
32+
Given a map from columns to `cardinalities`, compute the output embedding size according to [`emb_sz_rule`](#).
33+
Return a vector of tuples where each element is `(in_size, out_size)` for an embedding layer.
34+
35+
## Keyword arguments
36+
37+
- `size_overrides`: A map of output embedding size overrides
38+
(i.e. `size_overrides[col]` is the output embedding size for `col`).
39+
"""
40+
function get_emb_sz(cardinalities::Dict{<:Any, <:Integer}, size_overrides=Dict())
41+
values_and_overrides = map(pairs(cardinalities)) do (col, cardinality)
42+
cardinality, get(size_overrides, col, nothing)
43+
end
44+
get_emb_sz(first.(values_and_overrides), last.(values_and_overrides))
45+
end
46+
47+
sigmoidrange(x, low, high) = @. Flux.sigmoid(x) * (high - low) + low
48+
49+
function tabular_embedding_backbone(embedding_sizes, dropout_rate=0.)
50+
embedslist = [Flux.Embedding(ni, nf) for (ni, nf) in embedding_sizes]
51+
emb_drop = iszero(dropout_rate) ? identity : Dropout(dropout_rate)
52+
Chain(
53+
x -> tuple(eachrow(x)...),
54+
Parallel(vcat, embedslist),
55+
emb_drop
56+
)
57+
end
58+
59+
tabular_continuous_backbone(n_cont) = BatchNorm(n_cont)
60+
61+
"""
62+
TabularModel(catbackbone, contbackbone, [finalclassifier]; kwargs...)
63+
64+
Create a tabular model which operates on a tuple of categorical values
65+
(label or one-hot encoded) and continuous values.
66+
The categorical backbones (`catbackbone`) and continuous backbone (`contbackbone`) operate on each element of the input tuple.
67+
The output from these backbones is then passed through a series of linear-batch norm-dropout layers before a `finalclassifier` block.
68+
69+
## Keyword arguments
70+
71+
- `outsize`: The output size of the final classifier block. For single classification tasks,
72+
this would be the number of classes, and for regression tasks, this would be the
73+
number of target continuous variables.
74+
- `layersizes`: A vector of sizes for each hidden layer in the sequence of linear layers.
75+
- `dropout_rates`: Dropout probabilities for the linear-batch norm-dropout layers.
76+
This could either be a single number which would be used for for all the layers,
77+
or a collection of numbers which are cycled through for each layer.
78+
- `batchnorm`: Set to `false` to skip each batch norm in the linear-batch norm-dropout sequence.
79+
- `activation`: The activation function to use in the classifier layers.
80+
- `linear_first`: Controls if the linear layer comes before or after batch norm and dropout.
81+
"""
82+
function TabularModel(
83+
catbackbone,
84+
contbackbone;
85+
outsize,
86+
layersizes=(200, 100),
87+
kwargs...)
88+
TabularModel(catbackbone, contbackbone, Dense(layersizes[end], outsize); layersizes=layersizes, kwargs...)
89+
end
90+
91+
function TabularModel(
92+
catbackbone,
93+
contbackbone,
94+
finalclassifier;
95+
layersizes=(200, 100),
96+
dropout_rates=0.,
97+
batchnorm=true,
98+
activation=Flux.relu,
99+
linear_first=true)
100+
101+
tabularbackbone = Parallel(vcat, catbackbone, contbackbone)
102+
103+
classifierin = mapreduce(layer -> size(layer.weight)[1], +, catbackbone[2].layers;
104+
init = contbackbone.chs)
105+
dropout_rates = Iterators.cycle(dropout_rates)
106+
classifiers = []
107+
108+
first_ps, dropout_rates = Iterators.peel(dropout_rates)
109+
push!(classifiers, linbndrop(classifierin, first(layersizes);
110+
use_bn=batchnorm, p=first_ps, lin_first=linear_first, act=activation))
111+
112+
for (isize, osize, p) in zip(layersizes[1:(end-1)], layersizes[2:end], dropout_rates)
113+
layer = linbndrop(isize, osize; use_bn=batchnorm, p=p, act=activation, lin_first=linear_first)
114+
push!(classifiers, layer)
115+
end
116+
117+
Chain(
118+
tabularbackbone,
119+
classifiers...,
120+
finalclassifier
121+
)
122+
end
123+
124+
"""
125+
TabularModel(n_cont, outsize, [layersizes; kwargs...])
126+
127+
Create a tabular model which operates on a tuple of categorical values
128+
(label or one-hot encoded) and continuous values. The default categorical backbone (`catbackbone`) is
129+
a [`Flux.Parallel`](https://fluxml.ai/Flux.jl/stable/models/layers/#Flux.Parallel) set of `Flux.Embedding` layers corresponding to each categorical variable.
130+
The default continuous backbone (`contbackbone`) is a single [`Flux.BatchNorm`](https://fluxml.ai/Flux.jl/stable/models/layers/#Flux.BatchNorm).
131+
The output from these backbones is concatenated then passed through a series of linear-batch norm-dropout layers before a `finalclassifier` block.
132+
133+
## Arguments
134+
135+
- `n_cont`: The number of continuous columns.
136+
- `outsize`: The output size of the model.
137+
- `layersizes`: A vector of sizes for each hidden layer in the sequence of linear layers.
138+
139+
## Keyword arguments
140+
141+
- `cardinalities`: A collection of sizes (number of classes) for each categorical column.
142+
- `size_overrides`: An optional argument which corresponds to a collection containing
143+
embedding sizes to override the value returned by the "rule of thumb" for a particular index
144+
corresponding to `cardinalities`, or `nothing`.
145+
"""
146+
function TabularModel(
147+
n_cont::Number,
148+
outsize::Number,
149+
layersizes=(200, 100);
150+
cardinalities,
151+
size_overrides=fill(nothing, length(cardinalities)))
152+
embedszs = get_emb_sz(cardinalities, size_overrides)
153+
catback = tabular_embedding_backbone(embedszs)
154+
contback = tabular_continuous_backbone(n_cont)
155+
156+
TabularModel(catback, contback; layersizes=layersizes, outsize=outsize)
157+
end

test/imports.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import FastAI: Image, Keypoints, Mask, testencoding, Label, OneHot, ProjectiveTr
66
encodedblock, decodedblock, encode, decode, mockblock, checkblock, Block, Encoding
77
using FilePathsBase
88
using FastAI.Datasets
9+
using FastAI.Models
910
using DLPipelines
1011
import DataAugmentation
1112
import DataAugmentation: getbounds

test/models/tabularmodel.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
include("../imports.jl")
2+
3+
@testset ExtendedTestSet "TabularModel Components" begin
4+
@testset ExtendedTestSet "embeddingbackbone" begin
5+
embed_szs = [(5, 10), (100, 30), (2, 30)]
6+
embeds = FastAI.Models.tabular_embedding_backbone(embed_szs, 0.)
7+
x = [rand(1:n) for (n, _) in embed_szs]
8+
9+
@test size(embeds(x)) == (70, 1)
10+
end
11+
12+
@testset ExtendedTestSet "continuousbackbone" begin
13+
n = 5
14+
contback = FastAI.Models.tabular_continuous_backbone(n)
15+
x = rand(5, 1)
16+
@test size(contback(x)) == (5, 1)
17+
end
18+
19+
@testset ExtendedTestSet "TabularModel" begin
20+
n = 5
21+
embed_szs = [(5, 10), (100, 30), (2, 30)]
22+
23+
embeds = FastAI.Models.tabular_embedding_backbone(embed_szs, 0.)
24+
contback = FastAI.Models.tabular_continuous_backbone(n)
25+
26+
x = ([rand(1:n) for (n, _) in embed_szs], rand(5, 1))
27+
28+
tm = TabularModel(embeds, contback; outsize=4)
29+
@test size(tm(x)) == (4, 1)
30+
31+
tm2 = TabularModel(embeds, contback, Chain(Dense(100, 4), x->FastAI.Models.sigmoidrange(x, 2, 5)))
32+
y2 = tm2(x)
33+
@test all(y2.> 2) && all(y2.<5)
34+
35+
cardinalities = [4, 99, 1]
36+
tm3 = TabularModel(n, 4, [200, 100], cardinalities = cardinalities, size_overrides = (10, 30, 30))
37+
@test size(tm3(x)) == (4, 1)
38+
end
39+
end
40+
41+

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,10 @@ include("imports.jl")
7272
end
7373
# TODO: test learning rate finder
7474
end
75+
76+
@testset ExtendedTestSet "models/" begin
77+
@testset ExtendedTestSet "tabularmodel.jl" begin
78+
include("models/tabularmodel.jl")
79+
end
80+
end
7581
end

0 commit comments

Comments
 (0)