Skip to content

Commit fa9279c

Browse files
committed
move code
1 parent 2c270a9 commit fa9279c

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/layers/basic.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -698,9 +698,6 @@ Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(ini
698698
(m::Embedding)(x::AbstractVector{<:Integer}) = NNlib.gather(m.weight, x)
699699
(m::Embedding)(x::AbstractArray{<:Integer}) = reshape(m(vec(x)), :, size(x)...)
700700

701-
(m::Embedding)(x::Nil) = similar(m.weight, Nil, size(m.weight, 1))
702-
(m::Embedding)(x::AbstractArray{Nil}) = similar(m.weight, Nil, size(m.weight, 1), size(x)...)
703-
704701
(m::Embedding)(x::AbstractVector{Bool}) = m.weight * x # usually OneHotVector
705702
(m::Embedding)(x::AbstractMatrix{Bool}) = m.weight * x # usually OneHotMatrix
706703
(m::Embedding)(x::AbstractArray{Bool}) = reshape(m(reshape(x, size(x,1), :)), :, size(x)[2:end]...)

src/outputsize.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ end
163163

164164
## fixes for layers that don't work out of the box
165165

166+
(m::Embedding)(x::Nil) = similar(m.weight, Nil, size(m.weight, 1))
167+
(m::Embedding)(x::AbstractArray{Nil}) = similar(m.weight, Nil, size(m.weight, 1), size(x)...)
168+
166169
for (fn, Dims) in ((:conv, DenseConvDims),)
167170
@eval begin
168171
function NNlib.$fn(a::AbstractArray{Nil}, b::AbstractArray{Nil}, dims::$Dims)

0 commit comments

Comments
 (0)