Skip to content

Commit e496616

Browse files
committed
Specialize kernels for CPU & GPU
1 parent efedbc1 commit e496616

File tree

4 files changed

+150
-78
lines changed

4 files changed

+150
-78
lines changed

src/NNlib.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
module NNlib
22

3-
using Pkg
4-
using Requires
5-
using ChainRulesCore
63
import ChainRulesCore: rrule
4+
75
using Base.Broadcast: broadcasted
86
using Base.Threads
7+
using ChainRulesCore
8+
using KernelAbstractions
9+
using KernelAbstractions: @atomic
10+
using LinearAlgebra
11+
using LinearAlgebra.BLAS: @blasfunc, BlasInt
12+
using LinearAlgebra: AdjOrTransAbsMat, Adjoint, BlasFloat, Transpose
13+
using Pkg
914
using Random
15+
using Requires
1016
using Statistics
1117
using Statistics: mean
12-
using LinearAlgebra
13-
using LinearAlgebra: BlasFloat, Transpose, Adjoint, AdjOrTransAbsMat
14-
using LinearAlgebra.BLAS: BlasInt, @blasfunc
15-
using KernelAbstractions
16-
using KernelAbstractions: @atomic
1718

1819
const libblas = Base.libblas_name
1920

src/upsample.jl

Lines changed: 140 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -355,51 +355,78 @@ upsample_trilinear(x; size, align_corners::Bool = true) = upsample_linear(x; si
355355
function upsample_linear_kernel!(
356356
y::AbstractArray{T, N}, x::AbstractArray{T, N}; align_corners::Bool = true,
357357
) where {T, N}
358-
# ndrange = size(y)[1:N - 2]
359-
ndrange = size(y)[N - 1:end]
358+
backend = KernelAbstractions.get_backend(x)
359+
ndrange = backend isa CPU ?
360+
size(y)[N - 1:end] : # Parallelization along channel x batch.
361+
size(y)[1:N - 2] # Parallelization along WHD.
360362
ratios = align_corners ?
361363
ntuple(i -> real(T)((size(x, i) - 1) / (size(y, i) - 1)), N - 2) :
362364
ntuple(i -> real(T)(size(x, i) / size(y, i)), N - 2)
363-
364-
backend = KernelAbstractions.get_backend(x)
365-
_upsample_linear_kernel!(backend)(y, x, ratios..., Val(align_corners); ndrange)
365+
_upsample_linear_kernel!(backend)(backend, y, x, ratios..., Val(align_corners); ndrange)
366366
return y
367367
end
368368

369369
function ∇upsample_linear_kernel!(
370370
dx::AbstractArray{T, N}, Δ::AbstractArray{T, N}; align_corners::Bool = true,
371371
) where {T, N}
372-
ndrange = size(Δ)[1:N - 2]
372+
backend = KernelAbstractions.get_backend(dx)
373+
ndrange = backend isa CPU ?
374+
size(Δ)[N - 1:end] : # Parallelization along channel x batch.
375+
size(Δ)[1:N - 2] # Parallelization along WHD.
373376
ratios = align_corners ?
374377
ntuple(i -> real(T)((size(dx, i) - 1) / (size(Δ, i) - 1)), N - 2) :
375378
ntuple(i -> real(T)(size(dx, i) / size(Δ, i)), N - 2)
376-
377-
backend = KernelAbstractions.get_backend(dx)
378-
_∇upsample_linear_kernel!(backend)(dx, Δ, ratios..., Val(align_corners); ndrange)
379+
_∇upsample_linear_kernel!(backend)(backend, dx, Δ, ratios..., Val(align_corners); ndrange)
379380
return dx
380381
end
381382

382-
# Linear.
383+
# Linear (CPU): parallelization along channel x batch dimensions.
383384

384-
@kernel function _upsample_linear_kernel!(y::T, x::T, rwidth, align::Val{A}) where {
385-
T <: AbstractArray{<: Any, 3}, A,
385+
@kernel function _upsample_linear_kernel!(::CPU, y::T, x::T, rwidth, align::Val{A}) where {
386+
T <: AbstractArray{<:Any, 3}, A,
386387
}
387388
@uniform in_width::UInt32, channels::UInt32, batch::UInt32 = size(x)
389+
@uniform out_width::UInt32 = size(y, 1)
390+
c::UInt32, n::UInt32 = @index(Global, NTuple)
391+
@inbounds for i in UnitRange{UInt32}(1, out_width)
392+
iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, in_width)
393+
y[i, c, n] = w0lambda * x[iw0, c, n] + w1lambda * x[iw1, c, n]
394+
end
395+
end
396+
397+
@kernel function _∇upsample_linear_kernel!(::CPU, dx::T1, Δ::T2, rwidth, align::Val{A}) where {
398+
T1 <: AbstractArray{<:Any, 3}, T2 <: AbstractArray{<:Any, 3}, A,
399+
}
400+
@uniform in_width::UInt32, channels::UInt32, batch::UInt32 = size(Δ)
401+
@uniform out_width::UInt32 = size(dx, 1)
402+
c::UInt32, n::UInt32 = @index(Global, NTuple)
403+
@inbounds for i in UnitRange{UInt32}(1, in_width)
404+
ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, out_width)
405+
val = Δ[i, c, n]
406+
@atomic dx[ow0, c, n] += w0lambda * val
407+
@atomic dx[ow1, c, n] += w1lambda * val
408+
end
409+
end
410+
411+
# Linear (GPU): parallelization along width dimension.
412+
# TODO replace AbstractArray -> AbstractGPUArray once device arrays subtype it.
388413

414+
@kernel function _upsample_linear_kernel!(::B, y::T, x::T, rwidth, align::Val{A}) where {
415+
B <: GPU, T <: AbstractArray{<:Any, 3}, A,
416+
}
417+
@uniform in_width::UInt32, channels::UInt32, batch::UInt32 = size(x)
389418
i::UInt32 = @index(Global)
390-
iw0, iw1, w0lambda, w1lambda = source_index_and_lambda( rwidth, i - 0x1, align, in_width)
419+
iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, in_width)
391420
@inbounds for n in 1:batch, c in 1:channels
392421
y[i, c, n] = w0lambda * x[iw0, c, n] + w1lambda * x[iw1, c, n]
393422
end
394423
end
395424

396-
@kernel function _∇upsample_linear_kernel!(dx::T1, Δ::T2, rwidth, align::Val{A}) where {
397-
T1 <: AbstractArray{<: Any, 3},
398-
T2 <: AbstractArray{<: Any, 3}, A,
425+
@kernel function _∇upsample_linear_kernel!(::B, dx::T, Δ::T, rwidth, align::Val{A}) where {
426+
B <: GPU, T <: AbstractArray{<:Any, 3}, A,
399427
}
400428
@uniform in_width::UInt32, channels::UInt32, batch::UInt32 = size(Δ)
401429
@uniform out_width::UInt32 = size(dx, 1)
402-
403430
i::UInt32 = @index(Global)
404431
ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, out_width)
405432
@inbounds for n in 1:batch, c in 1:channels
@@ -409,16 +436,14 @@ end
409436
end
410437
end
411438

412-
# Bilinear.
439+
# Bilinear (CPU): parallelization along channel x batch dimensions.
413440

414-
@kernel function _upsample_linear_kernel!(y::T, x::T, rwidth, rheight, align::Val{A}) where {
415-
T <: AbstractArray{<: Any, 4}, A,
441+
@kernel function _upsample_linear_kernel!(::CPU, y::T, x::T, rwidth, rheight, align::Val{A}) where {
442+
T <: AbstractArray{<:Any, 4}, A,
416443
}
417444
@uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(x)
418445
@uniform out_width::UInt32, out_height::UInt32 = size(y)[1:2]
419-
420446
c::UInt32, n::UInt32 = @index(Global, NTuple)
421-
422447
for j in UnitRange{UInt32}(1, out_height)
423448
ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, in_height)
424449
for i in UnitRange{UInt32}(1, out_width)
@@ -428,48 +453,51 @@ end
428453
h1lambda * (w0lambda * x[iw0, ih1, c, n] + w1lambda * x[iw1, ih1, c, n])
429454
end
430455
end
431-
432-
# i::UInt32, j::UInt32 = @index(Global, NTuple)
433-
434-
# iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, in_width)
435-
# ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, in_height)
436-
437-
# @inbounds for n in 1:batch, c in 1:channels
438-
# y[i, j, c, n] =
439-
# h0lambda * (w0lambda * x[iw0, ih0, c, n] + w1lambda * x[iw1, ih0, c, n]) +
440-
# h1lambda * (w0lambda * x[iw0, ih1, c, n] + w1lambda * x[iw1, ih1, c, n])
441-
# end
442456
end
443457

444-
# @kernel function _upsample_linear_kernel!(y::T, x::T, rwidth, rheight, align::Val{A}) where {
445-
# T <: AbstractArray{<: Any, 4}, A,
446-
# }
447-
# @uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(x)
448-
449-
# i::UInt32, j::UInt32 = @index(Global, NTuple)
458+
@kernel function _∇upsample_linear_kernel!(::CPU, dx::T1, Δ::T2, rwidth, rheight, align::Val{A}) where {
459+
T1 <: AbstractArray{<:Any, 4}, T2 <: AbstractArray{<:Any, 4}, A,
460+
}
461+
@uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(Δ)
462+
@uniform out_width::UInt32, out_height::UInt32 = size(dx)[1:2]
463+
c::UInt32, n::UInt32 = @index(Global, NTuple)
464+
for j in UnitRange{UInt32}(1, in_height)
465+
oh0, oh1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, out_height)
466+
for i in UnitRange{UInt32}(1, in_width)
467+
ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, out_width)
468+
val = Δ[i, j, c, n]
469+
@atomic dx[ow0, oh0, c, n] += w0lambda * h0lambda * val
470+
@atomic dx[ow1, oh0, c, n] += w1lambda * h0lambda * val
471+
@atomic dx[ow0, oh1, c, n] += w0lambda * h1lambda * val
472+
@atomic dx[ow1, oh1, c, n] += w1lambda * h1lambda * val
473+
end
474+
end
475+
end
450476

451-
# iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, in_width)
452-
# ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, in_height)
477+
# Bilinear (GPU): parallelization along width, height dimensions.
453478

454-
# @inbounds for n in 1:batch, c in 1:channels
455-
# y[i, j, c, n] =
456-
# h0lambda * (w0lambda * x[iw0, ih0, c, n] + w1lambda * x[iw1, ih0, c, n]) +
457-
# h1lambda * (w0lambda * x[iw0, ih1, c, n] + w1lambda * x[iw1, ih1, c, n])
458-
# end
459-
# end
479+
@kernel function _upsample_linear_kernel!(::B, y::T, x::T, rwidth, rheight, align::Val{A}) where {
480+
B <: GPU, T <: AbstractArray{<:Any, 4}, A,
481+
}
482+
@uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(x)
483+
i::UInt32, j::UInt32 = @index(Global, NTuple)
484+
iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, in_width)
485+
ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, in_height)
486+
@inbounds for n in 1:batch, c in 1:channels
487+
y[i, j, c, n] =
488+
h0lambda * (w0lambda * x[iw0, ih0, c, n] + w1lambda * x[iw1, ih0, c, n]) +
489+
h1lambda * (w0lambda * x[iw0, ih1, c, n] + w1lambda * x[iw1, ih1, c, n])
490+
end
491+
end
460492

461-
@kernel function _∇upsample_linear_kernel!(dx::T1, Δ::T2, rwidth, rheight, align::Val{A}) where {
462-
T1 <: AbstractArray{<: Any, 4},
463-
T2 <: AbstractArray{<: Any, 4}, A,
493+
@kernel function _∇upsample_linear_kernel!(::B, dx::T, Δ::T, rwidth, rheight, align::Val{A}) where {
494+
B <: GPU, T <: AbstractArray{<:Any, 4}, A,
464495
}
465496
@uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(Δ)
466497
@uniform out_width::UInt32, out_height::UInt32 = size(dx)[1:2]
467-
468498
i::UInt32, j::UInt32 = @index(Global, NTuple)
469-
470499
ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, out_width)
471500
oh0, oh1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, out_height)
472-
473501
@inbounds for n in 1:batch, c in 1:channels
474502
val = Δ[i, j, c, n]
475503
@atomic dx[ow0, oh0, c, n] += w0lambda * h0lambda * val
@@ -479,20 +507,72 @@ end
479507
end
480508
end
481509

482-
# Trilinear.
510+
# Trilinear (CPU): parallelization along channel x batch dimensions.
483511

484-
@kernel function _upsample_linear_kernel!(y::T, x::T, rwidth, rheight, rdepth, align::Val{A}) where {
485-
T <: AbstractArray{<: Any, 5}, A,
512+
@kernel function _upsample_linear_kernel!(::CPU, y::T, x::T, rwidth, rheight, rdepth, align::Val{A}) where {
513+
T <: AbstractArray{<:Any, 5}, A,
486514
}
487515
@uniform in_width::UInt32, in_height::UInt32, in_depth::UInt32 = size(x)[1:3]
488516
@uniform channels::UInt32, batch::UInt32 = size(x, 4), size(x, 5)
517+
@uniform out_width::UInt32, out_height::UInt32, out_depth::UInt32 = size(y)[1:3]
518+
c::UInt32, n::UInt32 = @index(Global, NTuple)
519+
for k in UnitRange{UInt32}(1, out_depth)
520+
id0, id1, d0lambda, d1lambda = source_index_and_lambda(rdepth, k - 0x1, align, in_depth)
521+
for j in UnitRange{UInt32}(1, out_height)
522+
ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, in_height)
523+
for i in UnitRange{UInt32}(1, out_width)
524+
iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, in_width)
525+
@inbounds y[i, j, k, c, n] =
526+
d0lambda * (
527+
h0lambda * (w0lambda * x[iw0, ih0, id0, c, n] + w1lambda * x[iw1, ih0, id0, c, n]) +
528+
h1lambda * (w0lambda * x[iw0, ih1, id0, c, n] + w1lambda * x[iw1, ih1, id0, c, n])) +
529+
d1lambda * (
530+
h0lambda * (w0lambda * x[iw0, ih0, id1, c, n] + w1lambda * x[iw1, ih0, id1, c, n]) +
531+
h1lambda * (w0lambda * x[iw0, ih1, id1, c, n] + w1lambda * x[iw1, ih1, id1, c, n]))
532+
end
533+
end
534+
end
535+
end
489536

490-
i::UInt32, j::UInt32, k::UInt32 = @index(Global, NTuple)
537+
@kernel function _∇upsample_linear_kernel!(::CPU, dx::T1, Δ::T2, rwidth, rheight, rdepth, align::Val{A}) where {
538+
T1 <: AbstractArray{<:Any, 5}, T2 <: AbstractArray{<:Any, 5}, A,
539+
}
540+
@uniform in_width::UInt32, in_height::UInt32, in_depth::UInt32 = size(Δ)[1:3]
541+
@uniform channels::UInt32, batch::UInt32 = size(Δ, 4), size(Δ, 5)
542+
@uniform out_width::UInt32, out_height::UInt32, out_depth::UInt32 = size(dx)[1:3]
543+
c::UInt32, n::UInt32 = @index(Global, NTuple)
544+
for k in UnitRange{UInt32}(1, in_depth)
545+
od0, od1, d0lambda, d1lambda = source_index_and_lambda(rdepth, k - 0x1, align, out_depth)
546+
for j in UnitRange{UInt32}(1, in_height)
547+
oh0, oh1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, out_height)
548+
@inbounds for i in UnitRange{UInt32}(1, in_width)
549+
ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, out_width)
550+
val = Δ[i, j, k, c, n]
551+
@atomic dx[ow0, oh0, od0, c, n] += w0lambda * h0lambda * d0lambda * val
552+
@atomic dx[ow1, oh0, od0, c, n] += w1lambda * h0lambda * d0lambda * val
553+
@atomic dx[ow0, oh1, od0, c, n] += w0lambda * h1lambda * d0lambda * val
554+
@atomic dx[ow1, oh1, od0, c, n] += w1lambda * h1lambda * d0lambda * val
555+
556+
@atomic dx[ow0, oh0, od1, c, n] += w0lambda * h0lambda * d1lambda * val
557+
@atomic dx[ow1, oh0, od1, c, n] += w1lambda * h0lambda * d1lambda * val
558+
@atomic dx[ow0, oh1, od1, c, n] += w0lambda * h1lambda * d1lambda * val
559+
@atomic dx[ow1, oh1, od1, c, n] += w1lambda * h1lambda * d1lambda * val
560+
end
561+
end
562+
end
563+
end
564+
565+
# Trilinear (GPU): parallelization along width x height x depth dimensions.
491566

567+
@kernel function _upsample_linear_kernel!(::B, y::T, x::T, rwidth, rheight, rdepth, align::Val{A}) where {
568+
B <: GPU, T <: AbstractArray{<:Any, 5}, A,
569+
}
570+
@uniform in_width::UInt32, in_height::UInt32, in_depth::UInt32 = size(x)[1:3]
571+
@uniform channels::UInt32, batch::UInt32 = size(x, 4), size(x, 5)
572+
i::UInt32, j::UInt32, k::UInt32 = @index(Global, NTuple)
492573
iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, in_width)
493574
ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, in_height)
494575
id0, id1, d0lambda, d1lambda = source_index_and_lambda(rdepth, k - 0x1, align, in_depth)
495-
496576
@inbounds for n in 1:batch, c in 1:channels
497577
y[i, j, k, c, n] =
498578
d0lambda * (
@@ -504,20 +584,16 @@ end
504584
end
505585
end
506586

507-
@kernel function _∇upsample_linear_kernel!(dx::T1, Δ::T2, rwidth, rheight, rdepth, align::Val{A}) where {
508-
T1 <: AbstractArray{<: Any, 5},
509-
T2 <: AbstractArray{<: Any, 5}, A,
587+
@kernel function _∇upsample_linear_kernel!(::B, dx::T, Δ::T, rwidth, rheight, rdepth, align::Val{A}) where {
588+
B <: GPU, T <: AbstractArray{<:Any, 5}, A,
510589
}
511590
@uniform in_width::UInt32, in_height::UInt32, in_depth::UInt32 = size(Δ)[1:3]
512591
@uniform channels::UInt32, batch::UInt32 = size(Δ, 4), size(Δ, 5)
513592
@uniform out_width::UInt32, out_height::UInt32, out_depth::UInt32 = size(dx)[1:3]
514-
515593
i::UInt32, j::UInt32, k::UInt32 = @index(Global, NTuple)
516-
517594
ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - 0x1, align, out_width)
518595
oh0, oh1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - 0x1, align, out_height)
519596
od0, od1, d0lambda, d1lambda = source_index_and_lambda(rdepth, k - 0x1, align, out_depth)
520-
521597
@inbounds for n in 1:batch, c in 1:channels
522598
val = Δ[i, j, k, c, n]
523599
@atomic dx[ow0, oh0, od0, c, n] += w0lambda * h0lambda * d0lambda * val
@@ -545,6 +621,5 @@ end
545621

546622
w1lambda = real_index - iw0
547623
w0lambda = T(1) - w1lambda
548-
549624
return iw0 + 0x1, iw1, w0lambda, w1lambda
550625
end

test/runtests.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,7 @@ end
5656
if get(ENV, "NNLIB_TEST_AMDGPU", "false") == "true"
5757
import Pkg
5858
test_info = Pkg.project()
59-
# Add MIOpen_jll to AMDGPU.
6059
Pkg.develop("AMDGPU")
61-
Pkg.activate(joinpath(Pkg.devdir(), "AMDGPU"))
62-
Pkg.add("MIOpen_jll")
63-
Pkg.update()
6460
# Update test project.
6561
Pkg.activate(test_info.path)
6662
Pkg.update()

test/upsample.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ function upsample_testsuite(Backend)
2727
end
2828

2929
@testset "Linear upsampling (1D)" begin
30-
x = Float64[1,2,3,4]
30+
x = T[1,2,3,4]
3131
x = hcat(x,x,x)[:,:,:]
3232

3333
y = collect(1:1//3:4)

0 commit comments

Comments
 (0)