-
-
Notifications
You must be signed in to change notification settings - Fork 50
Add tabular model #124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
lorenzoh
merged 24 commits into
FluxML:master
from
manikyabard:manikyabard/tabularmodel
Aug 22, 2021
Merged
Add tabular model #124
Changes from 22 commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
2c85ed4
added tabular model
manikyabard 97546c7
fixed batchnorm in tabular model
manikyabard c1bd73b
added function for calculating embedding dimensions
manikyabard 2551fbb
updated tabular model
manikyabard ef11450
Apply suggestions from code review
manikyabard c0b2922
simplified tabular model
manikyabard c2c95d5
updated tabular model
manikyabard a081616
refactored TabularModel
manikyabard f565675
updated tabular model, and added tests
manikyabard 04d27d4
added classifierbackbone
manikyabard e1c2263
update tablemodel tests
manikyabard b4d7149
export classifierbackbone
manikyabard bc250a1
refactored TabularModel methods
manikyabard 506f889
updated tabular model tests
manikyabard 725c6dd
add TabularModel docstring
manikyabard 59eb66a
renamed args
manikyabard d4fded0
updated docstrings and embed dims calculation, made args usage consis…
manikyabard 96564f0
docstring fixes
manikyabard ddb4d62
made methods concise
manikyabard 825d146
updated docstrings and get_emb_sz
manikyabard 979c9ba
updated model test
manikyabard 8f8c65a
undo unintentional comments
manikyabard f2042a4
Docstring updates
manikyabard 928509b
minor docstring fix
manikyabard File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
""" | ||
emb_sz_rule(n_cat) | ||
|
||
Returns an embedding size corresponding to the number of classes for a | ||
categorical variable using the rule of thumb present in python fastai. | ||
(see https://github.com/fastai/fastai/blob/2742fe844573d06e700f869839fb9ec5f3a9bca9/fastai/tabular/model.py#L12) | ||
""" | ||
emb_sz_rule(n_cat) = min(600, round(Int, 1.6 * n_cat^0.56)) | ||
|
||
""" | ||
get_emb_sz(cardinalities, [size_overrides]) | ||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Returns a collection of tuples containing embedding dimensions corresponding to | ||
number of classes in categorical columns present in `cardinalities` and adjusting for NaNs. | ||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
## Keyword arguments | ||
|
||
- `size_overrides`: A collection of Integers and `nothing` where the integer present at any index | ||
will be used to override the rule of thumb for getting embedding sizes. | ||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
|
||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
get_emb_sz(cardinalities::AbstractVector{<:Integer}, size_overrides=fill(nothing, length(cardinalities))) = | ||
map(zip(cardinalities, size_overrides)) do (cardinality, override) | ||
emb_dim = isnothing(override) ? emb_sz_rule(cardinality + 1) : Int64(override) | ||
return (cardinality + 1, emb_dim) | ||
end | ||
|
||
""" | ||
get_emb_sz(cardinalities, categorical_cols, [size_overrides]) | ||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Returns a collection of tuples containing embedding dimensions corresponding to | ||
number of classes in categorical columns present in `cardinalities` and adjusting for NaNs. | ||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
## Keyword arguments | ||
|
||
- `size_overrides`: An indexable collection with column name as key and size | ||
to override it with as the value. | ||
- `categorical_cols`: A collection of categorical column names. | ||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
|
||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
function get_emb_sz(cardinalities::AbstractVector{<:Integer}, categorical_cols::Tuple, size_overrides=Dict()) | ||
keylist = keys(size_overrides) | ||
overrides = collect(map(categorical_cols) do col | ||
col in keylist ? size_overrides[col] : nothing | ||
end) | ||
get_emb_sz(cardinalities, overrides) | ||
end | ||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
sigmoidrange(x, low, high) = @. Flux.sigmoid(x) * (high - low) + low | ||
|
||
function tabular_embedding_backbone(embedding_sizes, dropout_rate=0.) | ||
embedslist = [Flux.Embedding(ni, nf) for (ni, nf) in embedding_sizes] | ||
emb_drop = iszero(dropout_rate) ? identity : Dropout(dropout_rate) | ||
Chain( | ||
x -> tuple(eachrow(x)...), | ||
Parallel(vcat, embedslist), | ||
emb_drop | ||
) | ||
end | ||
|
||
tabular_continuous_backbone(n_cont) = BatchNorm(n_cont) | ||
|
||
""" | ||
TabularModel(catbackbone, contbackbone, [finalclassifier]; kwargs...) | ||
|
||
Create a tabular model which takes in a tuple of categorical values | ||
(label or one-hot encoded) and continuous values. The default categorical backbone or `catbackbone` is | ||
a Parallel of Embedding layers corresponding to each categorical variable, and continuous | ||
variables are just BatchNormed using `contbackbone`. The output from these backbones is then passed through | ||
a `finalclassifier` block. | ||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
## Keyword arguments | ||
|
||
- `outsize`: The output size of the final classifier block. For single classification tasks, | ||
this would just be the number of classes and for regression tasks, this could be the | ||
number of target continuous variables. | ||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
- `layersizes`: The sizes of the hidden layers in the classifier block. | ||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
- `dropout_rates`: Dropout probability. This could either be a single number which would be | ||
used for for all the classifier layers, or a collection of numbers which are cycled through | ||
for each layer. | ||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
- `batchnorm`: Boolean variable which controls whether to use batch normalization in the classifier. | ||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
- `activation`: The activation function to use in the classifier layers. | ||
- `linear_first`: Controls if the linear layer comes before or after BatchNorm and Dropout. | ||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
|
||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
function TabularModel( | ||
catbackbone, | ||
contbackbone; | ||
outsize, | ||
layersizes=(200, 100), | ||
kwargs...) | ||
TabularModel(catbackbone, contbackbone, Dense(layersizes[end], outsize); layersizes=layersizes, kwargs...) | ||
end | ||
|
||
function TabularModel( | ||
catbackbone, | ||
contbackbone, | ||
finalclassifier; | ||
layersizes=[200, 100], | ||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dropout_rates=0., | ||
batchnorm=true, | ||
activation=Flux.relu, | ||
linear_first=true) | ||
|
||
tabularbackbone = Parallel(vcat, catbackbone, contbackbone) | ||
|
||
classifierin = mapreduce(layer -> size(layer.weight)[1], +, catbackbone[2].layers; | ||
init = contbackbone.chs) | ||
dropout_rates = Iterators.cycle(dropout_rates) | ||
classifiers = [] | ||
|
||
first_ps, dropout_rates = Iterators.peel(dropout_rates) | ||
push!(classifiers, linbndrop(classifierin, first(layersizes); | ||
use_bn=batchnorm, p=first_ps, lin_first=linear_first, act=activation)) | ||
|
||
for (isize, osize, p) in zip(layersizes[1:(end-1)], layersizes[2:end], dropout_rates) | ||
layer = linbndrop(isize, osize; use_bn=batchnorm, p=p, act=activation, lin_first=linear_first) | ||
push!(classifiers, layer) | ||
end | ||
|
||
Chain( | ||
tabularbackbone, | ||
classifiers..., | ||
finalclassifier | ||
) | ||
end | ||
|
||
""" | ||
TabularModel(n_cont, outsize, [layersizes; kwargs...]) | ||
|
||
Create a tabular model which takes in a tuple of categorical values | ||
(label or one-hot encoded) and continuous values. The default categorical backbone is | ||
a Parallel of Embedding layers corresponding to each categorical variable, and continuous | ||
variables are just BatchNormed. The output from these backbones is then passed through | ||
a final classifier block. Uses `n_cont` the number of continuous columns, `outsize` which | ||
is the output size of the final classifier block, and `layersizes` which is a collection of | ||
classifier layer sizes, to create the model. | ||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
## Keyword arguments | ||
|
||
- `cardinalities`: A collection of sizes (number of classes) for each categorical column. | ||
- `size_overrides`: An optional argument which corresponds to a collection containing | ||
embedding sizes to override the value returned by the "rule of thumb" for a particular index | ||
corresponding to `cardinalities`, or `nothing`. | ||
""" | ||
|
||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
function TabularModel( | ||
n_cont::Number, | ||
outsize::Number, | ||
layersizes=(200, 100); | ||
cardinalities, | ||
size_overrides=fill(nothing, length(cardinalities))) | ||
embedszs = get_emb_sz(cardinalities, size_overrides) | ||
catback = tabular_embedding_backbone(embedszs) | ||
contback = tabular_continuous_backbone(n_cont) | ||
|
||
TabularModel(catback, contback; layersizes=layersizes, outsize=outsize) | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
include("../imports.jl") | ||
|
||
@testset ExtendedTestSet "TabularModel Components" begin | ||
@testset ExtendedTestSet "embeddingbackbone" begin | ||
embed_szs = [(5, 10), (100, 30), (2, 30)] | ||
embeds = FastAI.Models.tabular_embedding_backbone(embed_szs, 0.) | ||
x = [rand(1:n) for (n, _) in embed_szs] | ||
|
||
@test size(embeds(x)) == (70, 1) | ||
end | ||
|
||
@testset ExtendedTestSet "continuousbackbone" begin | ||
n = 5 | ||
contback = FastAI.Models.tabular_continuous_backbone(n) | ||
x = rand(5, 1) | ||
@test size(contback(x)) == (5, 1) | ||
end | ||
|
||
@testset ExtendedTestSet "TabularModel" begin | ||
n = 5 | ||
embed_szs = [(5, 10), (100, 30), (2, 30)] | ||
|
||
embeds = FastAI.Models.tabular_embedding_backbone(embed_szs, 0.) | ||
contback = FastAI.Models.tabular_continuous_backbone(n) | ||
|
||
x = ([rand(1:n) for (n, _) in embed_szs], rand(5, 1)) | ||
|
||
tm = TabularModel(embeds, contback; outsize=4) | ||
@test size(tm(x)) == (4, 1) | ||
|
||
tm2 = TabularModel(embeds, contback, Chain(Dense(100, 4), x->FastAI.Models.sigmoidrange(x, 2, 5))) | ||
y2 = tm2(x) | ||
@test all(y2.> 2) && all(y2.<5) | ||
|
||
cardinalities = [4, 99, 1] | ||
tm3 = TabularModel(n, 4, [200, 100], cardinalities = cardinalities, size_overrides = (10, 30, 30)) | ||
@test size(tm3(x)) == (4, 1) | ||
end | ||
end | ||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.