Skip to content

Commit f97a61d

Browse files
authored
Merge pull request #166 from theabhirath/clstoken-fix
2 parents ff264cb + 0a864d5 commit f97a61d

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/layers/embeddings.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ end
6060
ClassTokens(dim::Integer; init = Flux.zeros32) = ClassTokens(init(dim, 1, 1))
6161

6262
function (m::ClassTokens)(x::AbstractArray{T, 3}) where {T}
63-
tokens = m.token .* fill(one(T), (1, 1, size(x, 3)))
64-
return hcat(tokens, x)
63+
tokens = m.token .* MLUtils.ones_like(x, T, (1, 1, size(x, 3)))
64+
return hcat(tokens, x)
6565
end
6666

6767
@functor ClassTokens

src/vit-based/vit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ function vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} =
5353
ViPosEmbedding(embedplanes, npatches + 1),
5454
Dropout(emb_dropout),
5555
transformer_encoder(embedplanes, depth, nheads; mlp_ratio, dropout),
56-
(pool == :class) ? x -> selectdim(x, 2, 1) : seconddimmean),
56+
(pool == :class) ? x -> x[:, 1, :] : seconddimmean),
5757
Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses, tanh_fast)))
5858
end
5959

0 commit comments

Comments
 (0)