Skip to content

Commit 0a864d5

Browse files
committed
Use MLUtils.ones_like
Also go back to indexing instead of `selectdim` to prevent scalar indexing on the GPU
1 parent c0b2f26 commit 0a864d5

File tree

4 files changed

+2
-7
lines changed

4 files changed

+2
-7
lines changed

Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ version = "0.7.1"
55
[deps]
66
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
77
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
8-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
98
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
109
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1110
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"

src/layers/Layers.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ using Flux: outputsize, Zygote
55
using Functors
66
using Statistics
77
using MLUtils
8-
using ChainRulesCore
98

109
include("../utilities.jl")
1110

src/layers/embeddings.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,8 @@ end
5959

6060
ClassTokens(dim::Integer; init = Flux.zeros32) = ClassTokens(init(dim, 1, 1))
6161

62-
_fill_like(y::AbstractArray{T, 3}) where {T} = fill!(similar(y, 1, 1, size(y, 3)), one(T))
63-
ChainRulesCore.@non_differentiable _fill_like(y)
64-
6562
function (m::ClassTokens)(x::AbstractArray{T, 3}) where {T}
66-
tokens = m.token .* _fill_like(x)
63+
tokens = m.token .* MLUtils.ones_like(x, T, (1, 1, size(x, 3)))
6764
return hcat(tokens, x)
6865
end
6966

src/vit-based/vit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ function vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} =
5353
ViPosEmbedding(embedplanes, npatches + 1),
5454
Dropout(emb_dropout),
5555
transformer_encoder(embedplanes, depth, nheads; mlp_ratio, dropout),
56-
(pool == :class) ? x -> selectdim(x, 2, 1) : seconddimmean),
56+
(pool == :class) ? x -> x[:, 1, :] : seconddimmean),
5757
Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses, tanh_fast)))
5858
end
5959

0 commit comments

Comments
 (0)