Skip to content

Commit cab706e

Browse files
authored
Merge pull request #162 from theabhirath/vit-fix
2 parents edf83e0 + 25a0682 commit cab706e

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

src/layers/embeddings.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ end
6060
ClassTokens(dim::Integer; init = Flux.zeros32) = ClassTokens(init(dim, 1, 1))
6161

6262
function (m::ClassTokens)(x::AbstractArray{T, 3}) where {T}
63-
tokens = m.token .* fill!(similar(x, 1, 1, size(x, 3)), one(T))
63+
tokens = m.token .* fill(one(T), (1, 1, size(x, 3)))
6464
return hcat(tokens, x)
6565
end
6666

src/vit-based/vit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ function vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} =
5353
ViPosEmbedding(embedplanes, npatches + 1),
5454
Dropout(emb_dropout),
5555
transformer_encoder(embedplanes, depth, nheads; mlp_ratio, dropout),
56-
(pool == :class) ? x -> x[:, 1, :] : seconddimmean),
56+
(pool == :class) ? x -> selectdim(x, 2, 1) : seconddimmean),
5757
Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses, tanh_fast)))
5858
end
5959

0 commit comments

Comments
 (0)