Skip to content

Commit c498d9e

Browse files
committed
fixed batchnorm in tabular model
1 parent 323d99e commit c498d9e

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/models/tabularmodel.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function TabularModel(
3030
end
3131
embedslist = [Embedding(ni, nf) for (ni, nf) in emb_szs]
3232
emb_drop = Dropout(embed_p)
33-
bn_cont = bn_cont && BatchNorm(n_cont)
33+
bn_cont = bn_cont ? BatchNorm(n_cont) : false
3434
n_emb = sum(size(embedlayer.weight)[1] for embedlayer in embedslist)
3535
sizes = append!(zeros(0), [n_emb+n_cont], layers, [out_sz])
3636
actns = append!([], [act_cls for i in 1:(length(sizes)-1)], [nothing])
@@ -50,7 +50,7 @@ function (tm::TabularModel)(x)
5050
x = tm.emb_drop(x)
5151
end
5252
if tm.n_cont != 0
53-
if !isnothing(tm.bn_cont)
53+
if (tm.bn_cont != false)
5454
x_cont = tm.bn_cont(x_cont)
5555
end
5656
x = tm.n_emb!=0 ? vcat(x, x_cont) : x_cont

0 commit comments

Comments
 (0)