|
| 1 | +""" |
| 2 | + emb_sz_rule(n_cat) |
| 3 | +
|
| 4 | +Compute an embedding size corresponding to the number of classes for a |
| 5 | +categorical variable using the rule of thumb present in python fastai. |
| 6 | +(see https://github.com/fastai/fastai/blob/2742fe844573d06e700f869839fb9ec5f3a9bca9/fastai/tabular/model.py#L12) |
| 7 | +""" |
| 8 | +emb_sz_rule(n_cat) = min(600, round(Int, 1.6 * n_cat^0.56)) |
| 9 | + |
| 10 | +""" |
| 11 | + get_emb_sz(cardinalities::AbstractVector, [size_overrides::AbstractVector]) |
| 12 | +
|
| 13 | +Given a vector of `cardinalities` of each categorical column |
| 14 | +(i.e. each element of `cardinalities` is the number of classes in that categorical column), |
| 15 | +compute the output embedding size according to [`emb_sz_rule`](#). |
| 16 | +Return a vector of tuples where each element is `(in_size, out_size)` for an embedding layer. |
| 17 | +
|
| 18 | +## Keyword arguments |
| 19 | +
|
| 20 | +- `size_overrides`: A collection of integers (or `nothing` to skip override) where the value present at any index |
| 21 | + will be used to as the output embedding size for that column. |
| 22 | +""" |
| 23 | +get_emb_sz(cardinalities::AbstractVector{<:Integer}, size_overrides=fill(nothing, length(cardinalities))) = |
| 24 | + map(zip(cardinalities, size_overrides)) do (cardinality, override) |
| 25 | + emb_dim = isnothing(override) ? emb_sz_rule(cardinality + 1) : Int64(override) |
| 26 | + return (cardinality + 1, emb_dim) |
| 27 | + end |
| 28 | + |
| 29 | +""" |
| 30 | + get_emb_sz(cardinalities::Dict, [size_overrides::Dict]) |
| 31 | +
|
| 32 | +Given a map from columns to `cardinalities`, compute the output embedding size according to [`emb_sz_rule`](#). |
| 33 | +Return a vector of tuples where each element is `(in_size, out_size)` for an embedding layer. |
| 34 | +
|
| 35 | +## Keyword arguments |
| 36 | +
|
| 37 | +- `size_overrides`: A map of output embedding size overrides |
| 38 | + (i.e. `size_overrides[col]` is the output embedding size for `col`). |
| 39 | +""" |
| 40 | +function get_emb_sz(cardinalities::Dict{<:Any, <:Integer}, size_overrides=Dict()) |
| 41 | + values_and_overrides = map(pairs(cardinalities)) do (col, cardinality) |
| 42 | + cardinality, get(size_overrides, col, nothing) |
| 43 | + end |
| 44 | + get_emb_sz(first.(values_and_overrides), last.(values_and_overrides)) |
| 45 | +end |
| 46 | + |
| 47 | +sigmoidrange(x, low, high) = @. Flux.sigmoid(x) * (high - low) + low |
| 48 | + |
| 49 | +function tabular_embedding_backbone(embedding_sizes, dropout_rate=0.) |
| 50 | + embedslist = [Flux.Embedding(ni, nf) for (ni, nf) in embedding_sizes] |
| 51 | + emb_drop = iszero(dropout_rate) ? identity : Dropout(dropout_rate) |
| 52 | + Chain( |
| 53 | + x -> tuple(eachrow(x)...), |
| 54 | + Parallel(vcat, embedslist), |
| 55 | + emb_drop |
| 56 | + ) |
| 57 | +end |
| 58 | + |
| 59 | +tabular_continuous_backbone(n_cont) = BatchNorm(n_cont) |
| 60 | + |
| 61 | +""" |
| 62 | + TabularModel(catbackbone, contbackbone, [finalclassifier]; kwargs...) |
| 63 | +
|
| 64 | +Create a tabular model which operates on a tuple of categorical values |
| 65 | +(label or one-hot encoded) and continuous values. |
| 66 | +The categorical backbones (`catbackbone`) and continuous backbone (`contbackbone`) operate on each element of the input tuple. |
| 67 | +The output from these backbones is then passed through a series of linear-batch norm-dropout layers before a `finalclassifier` block. |
| 68 | +
|
| 69 | +## Keyword arguments |
| 70 | +
|
| 71 | +- `outsize`: The output size of the final classifier block. For single classification tasks, |
| 72 | + this would be the number of classes, and for regression tasks, this would be the |
| 73 | + number of target continuous variables. |
| 74 | +- `layersizes`: A vector of sizes for each hidden layer in the sequence of linear layers. |
| 75 | +- `dropout_rates`: Dropout probabilities for the linear-batch norm-dropout layers. |
| 76 | + This could either be a single number which would be used for for all the layers, |
| 77 | + or a collection of numbers which are cycled through for each layer. |
| 78 | +- `batchnorm`: Set to `false` to skip each batch norm in the linear-batch norm-dropout sequence. |
| 79 | +- `activation`: The activation function to use in the classifier layers. |
| 80 | +- `linear_first`: Controls if the linear layer comes before or after batch norm and dropout. |
| 81 | +""" |
| 82 | +function TabularModel( |
| 83 | + catbackbone, |
| 84 | + contbackbone; |
| 85 | + outsize, |
| 86 | + layersizes=(200, 100), |
| 87 | + kwargs...) |
| 88 | + TabularModel(catbackbone, contbackbone, Dense(layersizes[end], outsize); layersizes=layersizes, kwargs...) |
| 89 | +end |
| 90 | + |
| 91 | +function TabularModel( |
| 92 | + catbackbone, |
| 93 | + contbackbone, |
| 94 | + finalclassifier; |
| 95 | + layersizes=(200, 100), |
| 96 | + dropout_rates=0., |
| 97 | + batchnorm=true, |
| 98 | + activation=Flux.relu, |
| 99 | + linear_first=true) |
| 100 | + |
| 101 | + tabularbackbone = Parallel(vcat, catbackbone, contbackbone) |
| 102 | + |
| 103 | + classifierin = mapreduce(layer -> size(layer.weight)[1], +, catbackbone[2].layers; |
| 104 | + init = contbackbone.chs) |
| 105 | + dropout_rates = Iterators.cycle(dropout_rates) |
| 106 | + classifiers = [] |
| 107 | + |
| 108 | + first_ps, dropout_rates = Iterators.peel(dropout_rates) |
| 109 | + push!(classifiers, linbndrop(classifierin, first(layersizes); |
| 110 | + use_bn=batchnorm, p=first_ps, lin_first=linear_first, act=activation)) |
| 111 | + |
| 112 | + for (isize, osize, p) in zip(layersizes[1:(end-1)], layersizes[2:end], dropout_rates) |
| 113 | + layer = linbndrop(isize, osize; use_bn=batchnorm, p=p, act=activation, lin_first=linear_first) |
| 114 | + push!(classifiers, layer) |
| 115 | + end |
| 116 | + |
| 117 | + Chain( |
| 118 | + tabularbackbone, |
| 119 | + classifiers..., |
| 120 | + finalclassifier |
| 121 | + ) |
| 122 | +end |
| 123 | + |
| 124 | +""" |
| 125 | + TabularModel(n_cont, outsize, [layersizes; kwargs...]) |
| 126 | +
|
| 127 | +Create a tabular model which operates on a tuple of categorical values |
| 128 | +(label or one-hot encoded) and continuous values. The default categorical backbone (`catbackbone`) is |
| 129 | +a [`Flux.Parallel`](https://fluxml.ai/Flux.jl/stable/models/layers/#Flux.Parallel) set of `Flux.Embedding` layers corresponding to each categorical variable. |
| 130 | +The default continuous backbone (`contbackbone`) is a single [`Flux.BatchNorm`](https://fluxml.ai/Flux.jl/stable/models/layers/#Flux.BatchNorm). |
| 131 | +The output from these backbones is concatenated then passed through a series of linear-batch norm-dropout layers before a `finalclassifier` block. |
| 132 | +
|
| 133 | +## Arguments |
| 134 | +
|
| 135 | +- `n_cont`: The number of continuous columns. |
| 136 | +- `outsize`: The output size of the model. |
| 137 | +- `layersizes`: A vector of sizes for each hidden layer in the sequence of linear layers. |
| 138 | +
|
| 139 | +## Keyword arguments |
| 140 | +
|
| 141 | +- `cardinalities`: A collection of sizes (number of classes) for each categorical column. |
| 142 | +- `size_overrides`: An optional argument which corresponds to a collection containing |
| 143 | + embedding sizes to override the value returned by the "rule of thumb" for a particular index |
| 144 | + corresponding to `cardinalities`, or `nothing`. |
| 145 | +""" |
| 146 | +function TabularModel( |
| 147 | + n_cont::Number, |
| 148 | + outsize::Number, |
| 149 | + layersizes=(200, 100); |
| 150 | + cardinalities, |
| 151 | + size_overrides=fill(nothing, length(cardinalities))) |
| 152 | + embedszs = get_emb_sz(cardinalities, size_overrides) |
| 153 | + catback = tabular_embedding_backbone(embedszs) |
| 154 | + contback = tabular_continuous_backbone(n_cont) |
| 155 | + |
| 156 | + TabularModel(catback, contback; layersizes=layersizes, outsize=outsize) |
| 157 | +end |
0 commit comments