Skip to content

Commit c0b2f26

Browse files
committed
Use @non_differentiable function for fill! in ClassTokens
Should solve at least part of #165
1 parent ff264cb commit c0b2f26

File tree

3 files changed

+7
-2
lines changed

3 files changed

+7
-2
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.7.1"
55
[deps]
66
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
77
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
8+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
89
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
910
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1011
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"

src/layers/Layers.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using Flux: outputsize, Zygote
55
using Functors
66
using Statistics
77
using MLUtils
8+
using ChainRulesCore
89

910
include("../utilities.jl")
1011

src/layers/embeddings.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,12 @@ end
5959

6060
ClassTokens(dim::Integer; init = Flux.zeros32) = ClassTokens(init(dim, 1, 1))
6161

62+
_fill_like(y::AbstractArray{T, 3}) where {T} = fill!(similar(y, 1, 1, size(y, 3)), one(T))
63+
ChainRulesCore.@non_differentiable _fill_like(y)
64+
6265
function (m::ClassTokens)(x::AbstractArray{T, 3}) where {T}
63-
tokens = m.token .* fill(one(T), (1, 1, size(x, 3)))
64-
return hcat(tokens, x)
66+
tokens = m.token .* _fill_like(x)
67+
return hcat(tokens, x)
6568
end
6669

6770
@functor ClassTokens

0 commit comments

Comments
 (0)