Skip to content

Commit bf4d6cf

Browse files
committed
Optimize CPU kernels
1 parent 5fbf12f commit bf4d6cf

File tree

1 file changed

+86
-80
lines changed

1 file changed

+86
-80
lines changed

src/upsample.jl

Lines changed: 86 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -397,9 +397,10 @@ upsample_trilinear_whdcn!(y, x) = upsample_linear_kernel!(y, x)
397397
@uniform in_width::UInt32, channels::UInt32, batch::UInt32 = size(x)
398398
@uniform out_width::UInt32 = size(y, 1)
399399
c::UInt32, n::UInt32 = @index(Global, NTuple)
400+
yv, xv = @view(y[:, c, n]), @view(x[:, c, n])
400401
@inbounds for i in UnitRange{UInt32}(one(UInt32), out_width)
401-
iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, in_width)
402-
y[i, c, n] = w0lambda * x[iw0, c, n] + w1lambda * x[iw1, c, n]
402+
iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, in_width)
403+
yv[i] = w0λ * xv[iw0] + w1λ * xv[iw1]
403404
end
404405
end
405406

@@ -409,11 +410,12 @@ end
409410
@uniform in_width::UInt32, channels::UInt32, batch::UInt32 = size(Δ)
410411
@uniform out_width::UInt32 = size(dx, 1)
411412
c::UInt32, n::UInt32 = @index(Global, NTuple)
413+
Δv, dxv = @view(Δ[:, c, n]), @view(dx[:, c, n])
412414
@inbounds for i in UnitRange{UInt32}(one(UInt32), in_width)
413-
ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, out_width)
414-
val = Δ[i, c, n]
415-
dx[ow0, c, n] += w0lambda * val
416-
dx[ow1, c, n] += w1lambda * val
415+
ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, out_width)
416+
val = Δv[i]
417+
dxv[ow0] += w0λ * val
418+
dxv[ow1] += w1λ * val
417419
end
418420
end
419421

@@ -425,9 +427,9 @@ end
425427
}
426428
@uniform in_width::UInt32, channels::UInt32, batch::UInt32 = size(x)
427429
i::UInt32 = @index(Global)
428-
iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, in_width)
430+
iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, in_width)
429431
@inbounds for n in UnitRange{UInt32}(one(UInt32), batch), c in UnitRange{UInt32}(one(UInt32), channels)
430-
y[i, c, n] = w0lambda * x[iw0, c, n] + w1lambda * x[iw1, c, n]
432+
y[i, c, n] = w0λ * x[iw0, c, n] + w1λ * x[iw1, c, n]
431433
end
432434
end
433435

@@ -437,11 +439,11 @@ end
437439
@uniform in_width::UInt32, channels::UInt32, batch::UInt32 = size(Δ)
438440
@uniform out_width::UInt32 = size(dx, 1)
439441
i::UInt32 = @index(Global)
440-
ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, out_width)
442+
ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, out_width)
441443
@inbounds for n in UnitRange{UInt32}(one(UInt32), batch), c in UnitRange{UInt32}(one(UInt32), channels)
442444
val = Δ[i, c, n]
443-
@atomic dx[ow0, c, n] += w0lambda * val
444-
@atomic dx[ow1, c, n] += w1lambda * val
445+
@atomic dx[ow0, c, n] += w0λ * val
446+
@atomic dx[ow1, c, n] += w1λ * val
445447
end
446448
end
447449

@@ -453,13 +455,14 @@ end
453455
@uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(x)
454456
@uniform out_width::UInt32, out_height::UInt32 = size(y)[1:2]
455457
c::UInt32, n::UInt32 = @index(Global, NTuple)
458+
yv, xv = @view(y[:, :, c, n]), @view(x[:, :, c, n])
456459
for j in UnitRange{UInt32}(one(UInt32), out_height)
457-
ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - one(UInt32), align, in_height)
460+
ih0, ih1, h0λ, h1λ = source_idx_and_λ(rheight, j - one(UInt32), align, in_height)
458461
for i in UnitRange{UInt32}(one(UInt32), out_width)
459-
iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, in_width)
460-
@inbounds y[i, j, c, n] =
461-
h0lambda * (w0lambda * x[iw0, ih0, c, n] + w1lambda * x[iw1, ih0, c, n]) +
462-
h1lambda * (w0lambda * x[iw0, ih1, c, n] + w1lambda * x[iw1, ih1, c, n])
462+
iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, in_width)
463+
@inbounds yv[i, j] =
464+
h0λ * (w0λ * xv[iw0, ih0] + w1λ * xv[iw1, ih0]) +
465+
h1λ * (w0λ * xv[iw0, ih1] + w1λ * xv[iw1, ih1])
463466
end
464467
end
465468
end
@@ -470,15 +473,16 @@ end
470473
@uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(Δ)
471474
@uniform out_width::UInt32, out_height::UInt32 = size(dx)[1:2]
472475
c::UInt32, n::UInt32 = @index(Global, NTuple)
476+
Δv, dxv = @view(Δ[:, :, c, n]), @view(dx[:, :, c, n])
473477
for j in UnitRange{UInt32}(one(UInt32), in_height)
474-
oh0, oh1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - one(UInt32), align, out_height)
475-
for i in UnitRange{UInt32}(one(UInt32), in_width)
476-
ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, out_width)
477-
val = Δ[i, j, c, n]
478-
dx[ow0, oh0, c, n] += w0lambda * h0lambda * val
479-
dx[ow1, oh0, c, n] += w1lambda * h0lambda * val
480-
dx[ow0, oh1, c, n] += w0lambda * h1lambda * val
481-
dx[ow1, oh1, c, n] += w1lambda * h1lambda * val
478+
oh0, oh1, h0λ, h1λ = source_idx_and_λ(rheight, j - one(UInt32), align, out_height)
479+
@inbounds for i in UnitRange{UInt32}(one(UInt32), in_width)
480+
ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, out_width)
481+
val = Δv[i, j]
482+
dxv[ow0, oh0] += w0λ * h0λ * val
483+
dxv[ow1, oh0] += w1λ * h0λ * val
484+
dxv[ow0, oh1] += w0λ * h1λ * val
485+
dxv[ow1, oh1] += w1λ * h1λ * val
482486
end
483487
end
484488
end
@@ -490,12 +494,12 @@ end
490494
}
491495
@uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(x)
492496
i::UInt32, j::UInt32 = @index(Global, NTuple)
493-
iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, in_width)
494-
ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - one(UInt32), align, in_height)
497+
iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, in_width)
498+
ih0, ih1, h0λ, h1λ = source_idx_and_λ(rheight, j - one(UInt32), align, in_height)
495499
@inbounds for n in UnitRange{UInt32}(one(UInt32), batch), c in UnitRange{UInt32}(one(UInt32), channels)
496500
y[i, j, c, n] =
497-
h0lambda * (w0lambda * x[iw0, ih0, c, n] + w1lambda * x[iw1, ih0, c, n]) +
498-
h1lambda * (w0lambda * x[iw0, ih1, c, n] + w1lambda * x[iw1, ih1, c, n])
501+
h0λ * (w0λ * x[iw0, ih0, c, n] + w1λ * x[iw1, ih0, c, n]) +
502+
h1λ * (w0λ * x[iw0, ih1, c, n] + w1λ * x[iw1, ih1, c, n])
499503
end
500504
end
501505

@@ -505,14 +509,14 @@ end
505509
@uniform in_width::UInt32, in_height::UInt32, channels::UInt32, batch::UInt32 = size(Δ)
506510
@uniform out_width::UInt32, out_height::UInt32 = size(dx)[1:2]
507511
i::UInt32, j::UInt32 = @index(Global, NTuple)
508-
ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, out_width)
509-
oh0, oh1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - one(UInt32), align, out_height)
512+
ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, out_width)
513+
oh0, oh1, h0λ, h1λ = source_idx_and_λ(rheight, j - one(UInt32), align, out_height)
510514
@inbounds for n in UnitRange{UInt32}(one(UInt32), batch), c in UnitRange{UInt32}(one(UInt32), channels)
511515
val = Δ[i, j, c, n]
512-
@atomic dx[ow0, oh0, c, n] += w0lambda * h0lambda * val
513-
@atomic dx[ow1, oh0, c, n] += w1lambda * h0lambda * val
514-
@atomic dx[ow0, oh1, c, n] += w0lambda * h1lambda * val
515-
@atomic dx[ow1, oh1, c, n] += w1lambda * h1lambda * val
516+
@atomic dx[ow0, oh0, c, n] += w0λ * h0λ * val
517+
@atomic dx[ow1, oh0, c, n] += w1λ * h0λ * val
518+
@atomic dx[ow0, oh1, c, n] += w0λ * h1λ * val
519+
@atomic dx[ow1, oh1, c, n] += w1λ * h1λ * val
516520
end
517521
end
518522

@@ -525,19 +529,20 @@ end
525529
@uniform channels::UInt32, batch::UInt32 = size(x, 4), size(x, 5)
526530
@uniform out_width::UInt32, out_height::UInt32, out_depth::UInt32 = size(y)[1:3]
527531
c::UInt32, n::UInt32 = @index(Global, NTuple)
532+
yv, xv = @view(y[:, :, :, c, n]), @view(x[:, :, :, c, n])
528533
for k in UnitRange{UInt32}(one(UInt32), out_depth)
529-
id0, id1, d0lambda, d1lambda = source_index_and_lambda(rdepth, k - one(UInt32), align, in_depth)
534+
id0, id1, d0λ, d1λ = source_idx_and_λ(rdepth, k - one(UInt32), align, in_depth)
530535
for j in UnitRange{UInt32}(one(UInt32), out_height)
531-
ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - one(UInt32), align, in_height)
536+
ih0, ih1, h0λ, h1λ = source_idx_and_λ(rheight, j - one(UInt32), align, in_height)
532537
for i in UnitRange{UInt32}(one(UInt32), out_width)
533-
iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, in_width)
534-
@inbounds y[i, j, k, c, n] =
535-
d0lambda * (
536-
h0lambda * (w0lambda * x[iw0, ih0, id0, c, n] + w1lambda * x[iw1, ih0, id0, c, n]) +
537-
h1lambda * (w0lambda * x[iw0, ih1, id0, c, n] + w1lambda * x[iw1, ih1, id0, c, n])) +
538-
d1lambda * (
539-
h0lambda * (w0lambda * x[iw0, ih0, id1, c, n] + w1lambda * x[iw1, ih0, id1, c, n]) +
540-
h1lambda * (w0lambda * x[iw0, ih1, id1, c, n] + w1lambda * x[iw1, ih1, id1, c, n]))
538+
iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, in_width)
539+
@inbounds yv[i, j, k] =
540+
d0λ * (
541+
h0λ * (w0λ * xv[iw0, ih0, id0] + w1λ * xv[iw1, ih0, id0]) +
542+
h1λ * (w0λ * xv[iw0, ih1, id0] + w1λ * xv[iw1, ih1, id0])) +
543+
d1λ * (
544+
h0λ * (w0λ * xv[iw0, ih0, id1] + w1λ * xv[iw1, ih0, id1]) +
545+
h1λ * (w0λ * xv[iw0, ih1, id1] + w1λ * xv[iw1, ih1, id1]))
541546
end
542547
end
543548
end
@@ -550,22 +555,23 @@ end
550555
@uniform channels::UInt32, batch::UInt32 = size(Δ, 4), size(Δ, 5)
551556
@uniform out_width::UInt32, out_height::UInt32, out_depth::UInt32 = size(dx)[1:3]
552557
c::UInt32, n::UInt32 = @index(Global, NTuple)
558+
Δv, dxv = @view(Δ[:, :, :, c, n]), @view(dx[:, :, :, c, n])
553559
for k in UnitRange{UInt32}(one(UInt32), in_depth)
554-
od0, od1, d0lambda, d1lambda = source_index_and_lambda(rdepth, k - one(UInt32), align, out_depth)
560+
od0, od1, d0λ, d1λ = source_idx_and_λ(rdepth, k - one(UInt32), align, out_depth)
555561
for j in UnitRange{UInt32}(one(UInt32), in_height)
556-
oh0, oh1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - one(UInt32), align, out_height)
562+
oh0, oh1, h0λ, h1λ = source_idx_and_λ(rheight, j - one(UInt32), align, out_height)
557563
@inbounds for i in UnitRange{UInt32}(one(UInt32), in_width)
558-
ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, out_width)
559-
val = Δ[i, j, k, c, n]
560-
dx[ow0, oh0, od0, c, n] += w0lambda * h0lambda * d0lambda * val
561-
dx[ow1, oh0, od0, c, n] += w1lambda * h0lambda * d0lambda * val
562-
dx[ow0, oh1, od0, c, n] += w0lambda * h1lambda * d0lambda * val
563-
dx[ow1, oh1, od0, c, n] += w1lambda * h1lambda * d0lambda * val
564-
565-
dx[ow0, oh0, od1, c, n] += w0lambda * h0lambda * d1lambda * val
566-
dx[ow1, oh0, od1, c, n] += w1lambda * h0lambda * d1lambda * val
567-
dx[ow0, oh1, od1, c, n] += w0lambda * h1lambda * d1lambda * val
568-
dx[ow1, oh1, od1, c, n] += w1lambda * h1lambda * d1lambda * val
564+
ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, out_width)
565+
val = Δv[i, j, k]
566+
dxv[ow0, oh0, od0] += w0λ * h0λ * d0λ * val
567+
dxv[ow1, oh0, od0] += w1λ * h0λ * d0λ * val
568+
dxv[ow0, oh1, od0] += w0λ * h1λ * d0λ * val
569+
dxv[ow1, oh1, od0] += w1λ * h1λ * d0λ * val
570+
571+
dxv[ow0, oh0, od1] += w0λ * h0λ * d1λ * val
572+
dxv[ow1, oh0, od1] += w1λ * h0λ * d1λ * val
573+
dxv[ow0, oh1, od1] += w0λ * h1λ * d1λ * val
574+
dxv[ow1, oh1, od1] += w1λ * h1λ * d1λ * val
569575
end
570576
end
571577
end
@@ -579,17 +585,17 @@ end
579585
@uniform in_width::UInt32, in_height::UInt32, in_depth::UInt32 = size(x)[1:3]
580586
@uniform channels::UInt32, batch::UInt32 = size(x, 4), size(x, 5)
581587
i::UInt32, j::UInt32, k::UInt32 = @index(Global, NTuple)
582-
iw0, iw1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, in_width)
583-
ih0, ih1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - one(UInt32), align, in_height)
584-
id0, id1, d0lambda, d1lambda = source_index_and_lambda(rdepth, k - one(UInt32), align, in_depth)
588+
iw0, iw1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, in_width)
589+
ih0, ih1, h0λ, h1λ = source_idx_and_λ(rheight, j - one(UInt32), align, in_height)
590+
id0, id1, d0λ, d1λ = source_idx_and_λ(rdepth, k - one(UInt32), align, in_depth)
585591
@inbounds for n in UnitRange{UInt32}(one(UInt32), batch), c in UnitRange{UInt32}(one(UInt32), channels)
586592
y[i, j, k, c, n] =
587-
d0lambda * (
588-
h0lambda * (w0lambda * x[iw0, ih0, id0, c, n] + w1lambda * x[iw1, ih0, id0, c, n]) +
589-
h1lambda * (w0lambda * x[iw0, ih1, id0, c, n] + w1lambda * x[iw1, ih1, id0, c, n])) +
590-
d1lambda * (
591-
h0lambda * (w0lambda * x[iw0, ih0, id1, c, n] + w1lambda * x[iw1, ih0, id1, c, n]) +
592-
h1lambda * (w0lambda * x[iw0, ih1, id1, c, n] + w1lambda * x[iw1, ih1, id1, c, n]))
593+
d0λ * (
594+
h0λ * (w0λ * x[iw0, ih0, id0, c, n] + w1λ * x[iw1, ih0, id0, c, n]) +
595+
h1λ * (w0λ * x[iw0, ih1, id0, c, n] + w1λ * x[iw1, ih1, id0, c, n])) +
596+
d1λ * (
597+
h0λ * (w0λ * x[iw0, ih0, id1, c, n] + w1λ * x[iw1, ih0, id1, c, n]) +
598+
h1λ * (w0λ * x[iw0, ih1, id1, c, n] + w1λ * x[iw1, ih1, id1, c, n]))
593599
end
594600
end
595601

@@ -600,24 +606,24 @@ end
600606
@uniform channels::UInt32, batch::UInt32 = size(Δ, 4), size(Δ, 5)
601607
@uniform out_width::UInt32, out_height::UInt32, out_depth::UInt32 = size(dx)[1:3]
602608
i::UInt32, j::UInt32, k::UInt32 = @index(Global, NTuple)
603-
ow0, ow1, w0lambda, w1lambda = source_index_and_lambda(rwidth, i - one(UInt32), align, out_width)
604-
oh0, oh1, h0lambda, h1lambda = source_index_and_lambda(rheight, j - one(UInt32), align, out_height)
605-
od0, od1, d0lambda, d1lambda = source_index_and_lambda(rdepth, k - one(UInt32), align, out_depth)
609+
ow0, ow1, w0λ, w1λ = source_idx_and_λ(rwidth, i - one(UInt32), align, out_width)
610+
oh0, oh1, h0λ, h1λ = source_idx_and_λ(rheight, j - one(UInt32), align, out_height)
611+
od0, od1, d0λ, d1λ = source_idx_and_λ(rdepth, k - one(UInt32), align, out_depth)
606612
@inbounds for n in UnitRange{UInt32}(one(UInt32), batch), c in UnitRange{UInt32}(one(UInt32), channels)
607613
val = Δ[i, j, k, c, n]
608-
@atomic dx[ow0, oh0, od0, c, n] += w0lambda * h0lambda * d0lambda * val
609-
@atomic dx[ow1, oh0, od0, c, n] += w1lambda * h0lambda * d0lambda * val
610-
@atomic dx[ow0, oh1, od0, c, n] += w0lambda * h1lambda * d0lambda * val
611-
@atomic dx[ow1, oh1, od0, c, n] += w1lambda * h1lambda * d0lambda * val
612-
613-
@atomic dx[ow0, oh0, od1, c, n] += w0lambda * h0lambda * d1lambda * val
614-
@atomic dx[ow1, oh0, od1, c, n] += w1lambda * h0lambda * d1lambda * val
615-
@atomic dx[ow0, oh1, od1, c, n] += w0lambda * h1lambda * d1lambda * val
616-
@atomic dx[ow1, oh1, od1, c, n] += w1lambda * h1lambda * d1lambda * val
614+
@atomic dx[ow0, oh0, od0, c, n] += w0λ * h0λ * d0λ * val
615+
@atomic dx[ow1, oh0, od0, c, n] += w1λ * h0λ * d0λ * val
616+
@atomic dx[ow0, oh1, od0, c, n] += w0λ * h1λ * d0λ * val
617+
@atomic dx[ow1, oh1, od0, c, n] += w1λ * h1λ * d0λ * val
618+
619+
@atomic dx[ow0, oh0, od1, c, n] += w0λ * h0λ * d1λ * val
620+
@atomic dx[ow1, oh0, od1, c, n] += w1λ * h0λ * d1λ * val
621+
@atomic dx[ow0, oh1, od1, c, n] += w0λ * h1λ * d1λ * val
622+
@atomic dx[ow1, oh1, od1, c, n] += w1λ * h1λ * d1λ * val
617623
end
618624
end
619625

620-
@inline function source_index_and_lambda(
626+
@inline function source_idx_and_λ(
621627
ratio::T, out_idx::UInt32, ::Val{align}, in_width::UInt32,
622628
) where {T, align}
623629
real_index = align ?
@@ -629,6 +635,6 @@ end
629635
iw1 = iw0 + offset + one(UInt32)
630636

631637
w1lambda = real_index - iw0
632-
w0lambda = T(1) - w1lambda
638+
w0lambda = one(T) - w1lambda
633639
return iw0 + one(UInt32), iw1, w0lambda, w1lambda
634640
end

0 commit comments

Comments
 (0)