@@ -397,9 +397,10 @@ upsample_trilinear_whdcn!(y, x) = upsample_linear_kernel!(y, x)
397
397
@uniform in_width:: UInt32 , channels:: UInt32 , batch:: UInt32 = size (x)
398
398
@uniform out_width:: UInt32 = size (y, 1 )
399
399
c:: UInt32 , n:: UInt32 = @index (Global, NTuple)
400
+ yv, xv = @view (y[:, c, n]), @view (x[:, c, n])
400
401
@inbounds for i in UnitRange {UInt32} (one (UInt32), out_width)
401
- iw0, iw1, w0lambda, w1lambda = source_index_and_lambda (rwidth, i - one (UInt32), align, in_width)
402
- y[i, c, n ] = w0lambda * x [iw0, c, n ] + w1lambda * x [iw1, c, n ]
402
+ iw0, iw1, w0λ, w1λ = source_idx_and_λ (rwidth, i - one (UInt32), align, in_width)
403
+ yv[i ] = w0λ * xv [iw0] + w1λ * xv [iw1]
403
404
end
404
405
end
405
406
@@ -409,11 +410,12 @@ end
409
410
@uniform in_width:: UInt32 , channels:: UInt32 , batch:: UInt32 = size (Δ)
410
411
@uniform out_width:: UInt32 = size (dx, 1 )
411
412
c:: UInt32 , n:: UInt32 = @index (Global, NTuple)
413
+ Δv, dxv = @view (Δ[:, c, n]), @view (dx[:, c, n])
412
414
@inbounds for i in UnitRange {UInt32} (one (UInt32), in_width)
413
- ow0, ow1, w0lambda, w1lambda = source_index_and_lambda (rwidth, i - one (UInt32), align, out_width)
414
- val = Δ[i, c, n ]
415
- dx [ow0, c, n ] += w0lambda * val
416
- dx [ow1, c, n ] += w1lambda * val
415
+ ow0, ow1, w0λ, w1λ = source_idx_and_λ (rwidth, i - one (UInt32), align, out_width)
416
+ val = Δv[i ]
417
+ dxv [ow0] += w0λ * val
418
+ dxv [ow1] += w1λ * val
417
419
end
418
420
end
419
421
425
427
}
426
428
@uniform in_width:: UInt32 , channels:: UInt32 , batch:: UInt32 = size (x)
427
429
i:: UInt32 = @index (Global)
428
- iw0, iw1, w0lambda, w1lambda = source_index_and_lambda (rwidth, i - one (UInt32), align, in_width)
430
+ iw0, iw1, w0λ, w1λ = source_idx_and_λ (rwidth, i - one (UInt32), align, in_width)
429
431
@inbounds for n in UnitRange {UInt32} (one (UInt32), batch), c in UnitRange {UInt32} (one (UInt32), channels)
430
- y[i, c, n] = w0lambda * x[iw0, c, n] + w1lambda * x[iw1, c, n]
432
+ y[i, c, n] = w0λ * x[iw0, c, n] + w1λ * x[iw1, c, n]
431
433
end
432
434
end
433
435
@@ -437,11 +439,11 @@ end
437
439
@uniform in_width:: UInt32 , channels:: UInt32 , batch:: UInt32 = size (Δ)
438
440
@uniform out_width:: UInt32 = size (dx, 1 )
439
441
i:: UInt32 = @index (Global)
440
- ow0, ow1, w0lambda, w1lambda = source_index_and_lambda (rwidth, i - one (UInt32), align, out_width)
442
+ ow0, ow1, w0λ, w1λ = source_idx_and_λ (rwidth, i - one (UInt32), align, out_width)
441
443
@inbounds for n in UnitRange {UInt32} (one (UInt32), batch), c in UnitRange {UInt32} (one (UInt32), channels)
442
444
val = Δ[i, c, n]
443
- @atomic dx[ow0, c, n] += w0lambda * val
444
- @atomic dx[ow1, c, n] += w1lambda * val
445
+ @atomic dx[ow0, c, n] += w0λ * val
446
+ @atomic dx[ow1, c, n] += w1λ * val
445
447
end
446
448
end
447
449
@@ -453,13 +455,14 @@ end
453
455
@uniform in_width:: UInt32 , in_height:: UInt32 , channels:: UInt32 , batch:: UInt32 = size (x)
454
456
@uniform out_width:: UInt32 , out_height:: UInt32 = size (y)[1 : 2 ]
455
457
c:: UInt32 , n:: UInt32 = @index (Global, NTuple)
458
+ yv, xv = @view (y[:, :, c, n]), @view (x[:, :, c, n])
456
459
for j in UnitRange {UInt32} (one (UInt32), out_height)
457
- ih0, ih1, h0lambda, h1lambda = source_index_and_lambda (rheight, j - one (UInt32), align, in_height)
460
+ ih0, ih1, h0λ, h1λ = source_idx_and_λ (rheight, j - one (UInt32), align, in_height)
458
461
for i in UnitRange {UInt32} (one (UInt32), out_width)
459
- iw0, iw1, w0lambda, w1lambda = source_index_and_lambda (rwidth, i - one (UInt32), align, in_width)
460
- @inbounds y [i, j, c, n ] =
461
- h0lambda * (w0lambda * x [iw0, ih0, c, n ] + w1lambda * x [iw1, ih0, c, n ]) +
462
- h1lambda * (w0lambda * x [iw0, ih1, c, n ] + w1lambda * x [iw1, ih1, c, n ])
462
+ iw0, iw1, w0λ, w1λ = source_idx_and_λ (rwidth, i - one (UInt32), align, in_width)
463
+ @inbounds yv [i, j] =
464
+ h0λ * (w0λ * xv [iw0, ih0] + w1λ * xv [iw1, ih0]) +
465
+ h1λ * (w0λ * xv [iw0, ih1] + w1λ * xv [iw1, ih1])
463
466
end
464
467
end
465
468
end
@@ -470,15 +473,16 @@ end
470
473
@uniform in_width:: UInt32 , in_height:: UInt32 , channels:: UInt32 , batch:: UInt32 = size (Δ)
471
474
@uniform out_width:: UInt32 , out_height:: UInt32 = size (dx)[1 : 2 ]
472
475
c:: UInt32 , n:: UInt32 = @index (Global, NTuple)
476
+ Δv, dxv = @view (Δ[:, :, c, n]), @view (dx[:, :, c, n])
473
477
for j in UnitRange {UInt32} (one (UInt32), in_height)
474
- oh0, oh1, h0lambda, h1lambda = source_index_and_lambda (rheight, j - one (UInt32), align, out_height)
475
- for i in UnitRange {UInt32} (one (UInt32), in_width)
476
- ow0, ow1, w0lambda, w1lambda = source_index_and_lambda (rwidth, i - one (UInt32), align, out_width)
477
- val = Δ [i, j, c, n ]
478
- dx [ow0, oh0, c, n ] += w0lambda * h0lambda * val
479
- dx [ow1, oh0, c, n ] += w1lambda * h0lambda * val
480
- dx [ow0, oh1, c, n ] += w0lambda * h1lambda * val
481
- dx [ow1, oh1, c, n ] += w1lambda * h1lambda * val
478
+ oh0, oh1, h0λ, h1λ = source_idx_and_λ (rheight, j - one (UInt32), align, out_height)
479
+ @inbounds for i in UnitRange {UInt32} (one (UInt32), in_width)
480
+ ow0, ow1, w0λ, w1λ = source_idx_and_λ (rwidth, i - one (UInt32), align, out_width)
481
+ val = Δv [i, j]
482
+ dxv [ow0, oh0] += w0λ * h0λ * val
483
+ dxv [ow1, oh0] += w1λ * h0λ * val
484
+ dxv [ow0, oh1] += w0λ * h1λ * val
485
+ dxv [ow1, oh1] += w1λ * h1λ * val
482
486
end
483
487
end
484
488
end
@@ -490,12 +494,12 @@ end
490
494
}
491
495
@uniform in_width:: UInt32 , in_height:: UInt32 , channels:: UInt32 , batch:: UInt32 = size (x)
492
496
i:: UInt32 , j:: UInt32 = @index (Global, NTuple)
493
- iw0, iw1, w0lambda, w1lambda = source_index_and_lambda (rwidth, i - one (UInt32), align, in_width)
494
- ih0, ih1, h0lambda, h1lambda = source_index_and_lambda (rheight, j - one (UInt32), align, in_height)
497
+ iw0, iw1, w0λ, w1λ = source_idx_and_λ (rwidth, i - one (UInt32), align, in_width)
498
+ ih0, ih1, h0λ, h1λ = source_idx_and_λ (rheight, j - one (UInt32), align, in_height)
495
499
@inbounds for n in UnitRange {UInt32} (one (UInt32), batch), c in UnitRange {UInt32} (one (UInt32), channels)
496
500
y[i, j, c, n] =
497
- h0lambda * (w0lambda * x[iw0, ih0, c, n] + w1lambda * x[iw1, ih0, c, n]) +
498
- h1lambda * (w0lambda * x[iw0, ih1, c, n] + w1lambda * x[iw1, ih1, c, n])
501
+ h0λ * (w0λ * x[iw0, ih0, c, n] + w1λ * x[iw1, ih0, c, n]) +
502
+ h1λ * (w0λ * x[iw0, ih1, c, n] + w1λ * x[iw1, ih1, c, n])
499
503
end
500
504
end
501
505
@@ -505,14 +509,14 @@ end
505
509
@uniform in_width:: UInt32 , in_height:: UInt32 , channels:: UInt32 , batch:: UInt32 = size (Δ)
506
510
@uniform out_width:: UInt32 , out_height:: UInt32 = size (dx)[1 : 2 ]
507
511
i:: UInt32 , j:: UInt32 = @index (Global, NTuple)
508
- ow0, ow1, w0lambda, w1lambda = source_index_and_lambda (rwidth, i - one (UInt32), align, out_width)
509
- oh0, oh1, h0lambda, h1lambda = source_index_and_lambda (rheight, j - one (UInt32), align, out_height)
512
+ ow0, ow1, w0λ, w1λ = source_idx_and_λ (rwidth, i - one (UInt32), align, out_width)
513
+ oh0, oh1, h0λ, h1λ = source_idx_and_λ (rheight, j - one (UInt32), align, out_height)
510
514
@inbounds for n in UnitRange {UInt32} (one (UInt32), batch), c in UnitRange {UInt32} (one (UInt32), channels)
511
515
val = Δ[i, j, c, n]
512
- @atomic dx[ow0, oh0, c, n] += w0lambda * h0lambda * val
513
- @atomic dx[ow1, oh0, c, n] += w1lambda * h0lambda * val
514
- @atomic dx[ow0, oh1, c, n] += w0lambda * h1lambda * val
515
- @atomic dx[ow1, oh1, c, n] += w1lambda * h1lambda * val
516
+ @atomic dx[ow0, oh0, c, n] += w0λ * h0λ * val
517
+ @atomic dx[ow1, oh0, c, n] += w1λ * h0λ * val
518
+ @atomic dx[ow0, oh1, c, n] += w0λ * h1λ * val
519
+ @atomic dx[ow1, oh1, c, n] += w1λ * h1λ * val
516
520
end
517
521
end
518
522
@@ -525,19 +529,20 @@ end
525
529
@uniform channels:: UInt32 , batch:: UInt32 = size (x, 4 ), size (x, 5 )
526
530
@uniform out_width:: UInt32 , out_height:: UInt32 , out_depth:: UInt32 = size (y)[1 : 3 ]
527
531
c:: UInt32 , n:: UInt32 = @index (Global, NTuple)
532
+ yv, xv = @view (y[:, :, :, c, n]), @view (x[:, :, :, c, n])
528
533
for k in UnitRange {UInt32} (one (UInt32), out_depth)
529
- id0, id1, d0lambda, d1lambda = source_index_and_lambda (rdepth, k - one (UInt32), align, in_depth)
534
+ id0, id1, d0λ, d1λ = source_idx_and_λ (rdepth, k - one (UInt32), align, in_depth)
530
535
for j in UnitRange {UInt32} (one (UInt32), out_height)
531
- ih0, ih1, h0lambda, h1lambda = source_index_and_lambda (rheight, j - one (UInt32), align, in_height)
536
+ ih0, ih1, h0λ, h1λ = source_idx_and_λ (rheight, j - one (UInt32), align, in_height)
532
537
for i in UnitRange {UInt32} (one (UInt32), out_width)
533
- iw0, iw1, w0lambda, w1lambda = source_index_and_lambda (rwidth, i - one (UInt32), align, in_width)
534
- @inbounds y [i, j, k, c, n ] =
535
- d0lambda * (
536
- h0lambda * (w0lambda * x [iw0, ih0, id0, c, n ] + w1lambda * x [iw1, ih0, id0, c, n ]) +
537
- h1lambda * (w0lambda * x [iw0, ih1, id0, c, n ] + w1lambda * x [iw1, ih1, id0, c, n ])) +
538
- d1lambda * (
539
- h0lambda * (w0lambda * x [iw0, ih0, id1, c, n ] + w1lambda * x [iw1, ih0, id1, c, n ]) +
540
- h1lambda * (w0lambda * x [iw0, ih1, id1, c, n ] + w1lambda * x [iw1, ih1, id1, c, n ]))
538
+ iw0, iw1, w0λ, w1λ = source_idx_and_λ (rwidth, i - one (UInt32), align, in_width)
539
+ @inbounds yv [i, j, k] =
540
+ d0λ * (
541
+ h0λ * (w0λ * xv [iw0, ih0, id0] + w1λ * xv [iw1, ih0, id0]) +
542
+ h1λ * (w0λ * xv [iw0, ih1, id0] + w1λ * xv [iw1, ih1, id0])) +
543
+ d1λ * (
544
+ h0λ * (w0λ * xv [iw0, ih0, id1] + w1λ * xv [iw1, ih0, id1]) +
545
+ h1λ * (w0λ * xv [iw0, ih1, id1] + w1λ * xv [iw1, ih1, id1]))
541
546
end
542
547
end
543
548
end
@@ -550,22 +555,23 @@ end
550
555
@uniform channels:: UInt32 , batch:: UInt32 = size (Δ, 4 ), size (Δ, 5 )
551
556
@uniform out_width:: UInt32 , out_height:: UInt32 , out_depth:: UInt32 = size (dx)[1 : 3 ]
552
557
c:: UInt32 , n:: UInt32 = @index (Global, NTuple)
558
+ Δv, dxv = @view (Δ[:, :, :, c, n]), @view (dx[:, :, :, c, n])
553
559
for k in UnitRange {UInt32} (one (UInt32), in_depth)
554
- od0, od1, d0lambda, d1lambda = source_index_and_lambda (rdepth, k - one (UInt32), align, out_depth)
560
+ od0, od1, d0λ, d1λ = source_idx_and_λ (rdepth, k - one (UInt32), align, out_depth)
555
561
for j in UnitRange {UInt32} (one (UInt32), in_height)
556
- oh0, oh1, h0lambda, h1lambda = source_index_and_lambda (rheight, j - one (UInt32), align, out_height)
562
+ oh0, oh1, h0λ, h1λ = source_idx_and_λ (rheight, j - one (UInt32), align, out_height)
557
563
@inbounds for i in UnitRange {UInt32} (one (UInt32), in_width)
558
- ow0, ow1, w0lambda, w1lambda = source_index_and_lambda (rwidth, i - one (UInt32), align, out_width)
559
- val = Δ [i, j, k, c, n ]
560
- dx [ow0, oh0, od0, c, n ] += w0lambda * h0lambda * d0lambda * val
561
- dx [ow1, oh0, od0, c, n ] += w1lambda * h0lambda * d0lambda * val
562
- dx [ow0, oh1, od0, c, n ] += w0lambda * h1lambda * d0lambda * val
563
- dx [ow1, oh1, od0, c, n ] += w1lambda * h1lambda * d0lambda * val
564
-
565
- dx [ow0, oh0, od1, c, n ] += w0lambda * h0lambda * d1lambda * val
566
- dx [ow1, oh0, od1, c, n ] += w1lambda * h0lambda * d1lambda * val
567
- dx [ow0, oh1, od1, c, n ] += w0lambda * h1lambda * d1lambda * val
568
- dx [ow1, oh1, od1, c, n ] += w1lambda * h1lambda * d1lambda * val
564
+ ow0, ow1, w0λ, w1λ = source_idx_and_λ (rwidth, i - one (UInt32), align, out_width)
565
+ val = Δv [i, j, k]
566
+ dxv [ow0, oh0, od0] += w0λ * h0λ * d0λ * val
567
+ dxv [ow1, oh0, od0] += w1λ * h0λ * d0λ * val
568
+ dxv [ow0, oh1, od0] += w0λ * h1λ * d0λ * val
569
+ dxv [ow1, oh1, od0] += w1λ * h1λ * d0λ * val
570
+
571
+ dxv [ow0, oh0, od1] += w0λ * h0λ * d1λ * val
572
+ dxv [ow1, oh0, od1] += w1λ * h0λ * d1λ * val
573
+ dxv [ow0, oh1, od1] += w0λ * h1λ * d1λ * val
574
+ dxv [ow1, oh1, od1] += w1λ * h1λ * d1λ * val
569
575
end
570
576
end
571
577
end
@@ -579,17 +585,17 @@ end
579
585
@uniform in_width:: UInt32 , in_height:: UInt32 , in_depth:: UInt32 = size (x)[1 : 3 ]
580
586
@uniform channels:: UInt32 , batch:: UInt32 = size (x, 4 ), size (x, 5 )
581
587
i:: UInt32 , j:: UInt32 , k:: UInt32 = @index (Global, NTuple)
582
- iw0, iw1, w0lambda, w1lambda = source_index_and_lambda (rwidth, i - one (UInt32), align, in_width)
583
- ih0, ih1, h0lambda, h1lambda = source_index_and_lambda (rheight, j - one (UInt32), align, in_height)
584
- id0, id1, d0lambda, d1lambda = source_index_and_lambda (rdepth, k - one (UInt32), align, in_depth)
588
+ iw0, iw1, w0λ, w1λ = source_idx_and_λ (rwidth, i - one (UInt32), align, in_width)
589
+ ih0, ih1, h0λ, h1λ = source_idx_and_λ (rheight, j - one (UInt32), align, in_height)
590
+ id0, id1, d0λ, d1λ = source_idx_and_λ (rdepth, k - one (UInt32), align, in_depth)
585
591
@inbounds for n in UnitRange {UInt32} (one (UInt32), batch), c in UnitRange {UInt32} (one (UInt32), channels)
586
592
y[i, j, k, c, n] =
587
- d0lambda * (
588
- h0lambda * (w0lambda * x[iw0, ih0, id0, c, n] + w1lambda * x[iw1, ih0, id0, c, n]) +
589
- h1lambda * (w0lambda * x[iw0, ih1, id0, c, n] + w1lambda * x[iw1, ih1, id0, c, n])) +
590
- d1lambda * (
591
- h0lambda * (w0lambda * x[iw0, ih0, id1, c, n] + w1lambda * x[iw1, ih0, id1, c, n]) +
592
- h1lambda * (w0lambda * x[iw0, ih1, id1, c, n] + w1lambda * x[iw1, ih1, id1, c, n]))
593
+ d0λ * (
594
+ h0λ * (w0λ * x[iw0, ih0, id0, c, n] + w1λ * x[iw1, ih0, id0, c, n]) +
595
+ h1λ * (w0λ * x[iw0, ih1, id0, c, n] + w1λ * x[iw1, ih1, id0, c, n])) +
596
+ d1λ * (
597
+ h0λ * (w0λ * x[iw0, ih0, id1, c, n] + w1λ * x[iw1, ih0, id1, c, n]) +
598
+ h1λ * (w0λ * x[iw0, ih1, id1, c, n] + w1λ * x[iw1, ih1, id1, c, n]))
593
599
end
594
600
end
595
601
@@ -600,24 +606,24 @@ end
600
606
@uniform channels:: UInt32 , batch:: UInt32 = size (Δ, 4 ), size (Δ, 5 )
601
607
@uniform out_width:: UInt32 , out_height:: UInt32 , out_depth:: UInt32 = size (dx)[1 : 3 ]
602
608
i:: UInt32 , j:: UInt32 , k:: UInt32 = @index (Global, NTuple)
603
- ow0, ow1, w0lambda, w1lambda = source_index_and_lambda (rwidth, i - one (UInt32), align, out_width)
604
- oh0, oh1, h0lambda, h1lambda = source_index_and_lambda (rheight, j - one (UInt32), align, out_height)
605
- od0, od1, d0lambda, d1lambda = source_index_and_lambda (rdepth, k - one (UInt32), align, out_depth)
609
+ ow0, ow1, w0λ, w1λ = source_idx_and_λ (rwidth, i - one (UInt32), align, out_width)
610
+ oh0, oh1, h0λ, h1λ = source_idx_and_λ (rheight, j - one (UInt32), align, out_height)
611
+ od0, od1, d0λ, d1λ = source_idx_and_λ (rdepth, k - one (UInt32), align, out_depth)
606
612
@inbounds for n in UnitRange {UInt32} (one (UInt32), batch), c in UnitRange {UInt32} (one (UInt32), channels)
607
613
val = Δ[i, j, k, c, n]
608
- @atomic dx[ow0, oh0, od0, c, n] += w0lambda * h0lambda * d0lambda * val
609
- @atomic dx[ow1, oh0, od0, c, n] += w1lambda * h0lambda * d0lambda * val
610
- @atomic dx[ow0, oh1, od0, c, n] += w0lambda * h1lambda * d0lambda * val
611
- @atomic dx[ow1, oh1, od0, c, n] += w1lambda * h1lambda * d0lambda * val
612
-
613
- @atomic dx[ow0, oh0, od1, c, n] += w0lambda * h0lambda * d1lambda * val
614
- @atomic dx[ow1, oh0, od1, c, n] += w1lambda * h0lambda * d1lambda * val
615
- @atomic dx[ow0, oh1, od1, c, n] += w0lambda * h1lambda * d1lambda * val
616
- @atomic dx[ow1, oh1, od1, c, n] += w1lambda * h1lambda * d1lambda * val
614
+ @atomic dx[ow0, oh0, od0, c, n] += w0λ * h0λ * d0λ * val
615
+ @atomic dx[ow1, oh0, od0, c, n] += w1λ * h0λ * d0λ * val
616
+ @atomic dx[ow0, oh1, od0, c, n] += w0λ * h1λ * d0λ * val
617
+ @atomic dx[ow1, oh1, od0, c, n] += w1λ * h1λ * d0λ * val
618
+
619
+ @atomic dx[ow0, oh0, od1, c, n] += w0λ * h0λ * d1λ * val
620
+ @atomic dx[ow1, oh0, od1, c, n] += w1λ * h0λ * d1λ * val
621
+ @atomic dx[ow0, oh1, od1, c, n] += w0λ * h1λ * d1λ * val
622
+ @atomic dx[ow1, oh1, od1, c, n] += w1λ * h1λ * d1λ * val
617
623
end
618
624
end
619
625
620
- @inline function source_index_and_lambda (
626
+ @inline function source_idx_and_λ (
621
627
ratio:: T , out_idx:: UInt32 , :: Val{align} , in_width:: UInt32 ,
622
628
) where {T, align}
623
629
real_index = align ?
629
635
iw1 = iw0 + offset + one (UInt32)
630
636
631
637
w1lambda = real_index - iw0
632
- w0lambda = T ( 1 ) - w1lambda
638
+ w0lambda = one (T ) - w1lambda
633
639
return iw0 + one (UInt32), iw1, w0lambda, w1lambda
634
640
end
0 commit comments