Skip to content

Commit 4852dcf

Browse files
committed
added function for calculating embedding dimensions
1 parent c498d9e commit 4852dcf

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

src/models/Models.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module Models
22

3-
using Base: Bool
3+
using Base: Bool, Symbol
44
using ..FastAI
55

66
using BSON
@@ -17,7 +17,8 @@ include("unet.jl")
1717
include("tabularmodel.jl")
1818

1919

20-
export xresnet18, xresnet50, UNetDynamic, TabularModel
20+
export xresnet18, xresnet50, UNetDynamic, TabularModel,
21+
get_emb_sz
2122

2223

2324
end

src/models/tabularmodel.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,23 @@
1+
function emb_sz_rule(n_cat)
2+
min(600, round(1.6 * n_cat^0.56))
3+
end
4+
5+
function _one_emb_sz(catdict, catcol::Symbol, sz_dict=nothing)
6+
sz_dict = isnothing(sz_dict) ? Dict() : sz_dict
7+
n_cat = length(catdict[catcol])
8+
sz = catcol in keys(sz_dict) ? sz_dict[catcol] : emb_sz_rule(n_cat)
9+
n_cat, sz
10+
end
11+
12+
function get_emb_sz(catdict, cols, sz_dict=nothing)
13+
[_one_emb_sz(catdict, catcol, sz_dict) for catcol in cols]
14+
end
15+
16+
# function get_emb_sz(td::TableDataset, sz_dict=nothing)
17+
# cols = Tables.columnaccess(td.table) ? Tables.columnnames(td.table) : Tables.columnnames(Tables.rows(td.table)[1])
18+
# [_one_emb_sz(catdict, catcol, sz_dict) for catcol in cols]
19+
# end
20+
121
struct TabularModel
222
embeds
323
emb_drop

0 commit comments

Comments
 (0)