Skip to content

Commit 0be6765

Browse files
fix ctc_gpu
1 parent f28343f commit 0be6765

File tree

3 files changed

+10
-7
lines changed

3 files changed

+10
-7
lines changed

src/losses/ctc-gpu.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -204,14 +204,15 @@ end
204204
function ctc_alpha(ŷ::CuArray, y)
205205
= logsoftmax(ŷ)
206206
blank = size(ŷ, 1)
207-
z′ = fill(blank, 2 * length(y) + 1)
208-
z′[eachindex(y) .* 2] = y
207+
ycu = cu(y)
208+
z′ = CUDA.fill(blank, 2 * length(y) + 1)
209+
z′[eachindex(y) .* 2] .= ycu
209210
T = size(ŷ, 2)
210211
U′ = 2*length(y) + 1
211-
alphas = CUDA.fill(log(zero(ŷ[1])), U′,T)
212-
nRepeats = count_repeats(y)
212+
alphas = CUDA.fill(log(zero(eltype(ŷ))), U′,T)
213+
nRepeats = count_repeats(cpu(y))
213214
nThreads = min(U′, MAX_THREADS)
214-
@cuda blocks=1 threads=nThreads compute_alpha_kernel(ŷ, length(y), T, nRepeats, CuArray(y), CuArray(z′), alphas, blank)
215+
@cuda blocks=1 threads=nThreads compute_alpha_kernel(ŷ, length(y), T, nRepeats, ycu, z′, alphas, blank)
215216
return (loss=-1 * logsumexp(alphas[end-1:end]), alpha=alphas, z′=z′, yhat=ŷ, nRepeats=nRepeats)
216217
end
217218

@@ -221,7 +222,7 @@ function ∇ctc_loss(ŷ::CuArray, y, out)
221222
loss, alphas, z′, ŷ, nRepeats = out
222223
U′, T = size(alphas)
223224
blank = size(ŷ, 1)
224-
typed_zero = zero(first(ŷ))
225+
typed_zero = zero(eltype(ŷ))
225226
betas = CUDA.fill(log(typed_zero), U′, T)
226227
output = CUDA.fill(log(typed_zero), U′, T)
227228
nThreads = min(U′, MAX_THREADS)

test/outputsize.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
@test outputsize(m, (10,); padbatch=true) == (2, 1)
1111
@test outputsize(m, (10, 30)) == (2, 30)
1212

13+
@info "Don't mind the following error, it's for testing purpose."
1314
m = Chain(Dense(10, 8, σ), Dense(8, 4), Dense(5, 2))
1415
@test_throws DimensionMismatch outputsize(m, (10,))
1516

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Flux
22
using Flux.Data
3+
using Flux: OneHotArray, OneHotMatrix, OneHotVector
34
using Test
45
using Random, Statistics, LinearAlgebra
56
using IterTools: ncycle
@@ -50,7 +51,7 @@ end
5051
end
5152
end
5253

53-
@static if VERSION == v"1.5"
54+
@static if VERSION == v"1.6"
5455
using Documenter
5556
@testset "Docs" begin
5657
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)

0 commit comments

Comments
 (0)