@@ -359,11 +359,11 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
359
359
360
360
@inbounds if i <= size (A,1 ) && j <= size (B,2 )
361
361
z2 = zero (A[i, 1 ]* B[1 , j] + A[i, 1 ]* B[1 , j])
362
- Ctmp = convert (promote_type (R, typeof (z2)), z2)
362
+ Cij = convert (promote_type (R, typeof (z2)), z2)
363
363
for k in 1 : size (A,2 )
364
- Ctmp += A[i, k]* B[k, j]
364
+ Cij += A[i, k]* B[k, j]
365
365
end
366
- C[i,j] = add (Ctmp , C[i,j])
366
+ C[i,j] = add (Cij , C[i,j])
367
367
end
368
368
369
369
return
@@ -388,7 +388,184 @@ end
388
388
function LinearAlgebra. generic_matmatmul! (C:: AbstractGPUVecOrMat , tA, tB, A:: AbstractGPUVecOrMat , B:: AbstractGPUVecOrMat , a:: Number , b:: Number )
389
389
LinearAlgebra. @stable_muladdmul generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), MulAddMul (a, b))
390
390
end
391
- end
391
+ end
392
+
393
+ function generic_trimatmul! (C:: AbstractGPUVecOrMat{R} , uploc, isunitc, tfun:: Function , A:: AbstractGPUMatrix{T} , B:: AbstractGPUVecOrMat{S} ) where {T,S,R}
394
+ if size (A,2 ) != size (B,1 )
395
+ throw (DimensionMismatch (lazy " matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))" ))
396
+ end
397
+ if size (C,1 ) != size (A,1 ) || size (C,2 ) != size (B,2 )
398
+ throw (DimensionMismatch (lazy " result C has dimensions $(size(C)), needs $((size(A,1),size(B,2)))" ))
399
+ end
400
+ if isempty (A) || isempty (B)
401
+ return fill! (C, zero (R))
402
+ end
403
+
404
+ upper = tfun === identity ? uploc == ' U' : uploc != ' U'
405
+ unit = isunitc == ' U'
406
+
407
+ function trimatmul (ctx, C, A, B)
408
+ idx = @linearidx C
409
+ assume .(size (C) .> 0 )
410
+ i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
411
+ l, m, n = size (A, 1 ), size (B, 1 ), size (B, 2 )
412
+
413
+ @inbounds if i <= l && j <= n
414
+ z2 = zero (A[i,1 ] * B[1 ,j] + A[i,1 ] * B[1 ,j])
415
+ Cij = convert (promote_type (R, typeof (z2)), z2)
416
+ Cij += (unit ? one (Cij) : A[i,i]) * B[i,j]
417
+ for k in (upper ? (i + 1 ) : 1 ): (upper ? m : (i - 1 ))
418
+ Cij += A[i,k] * B[k,j]
419
+ end
420
+ C[i,j] += Cij
421
+ end
422
+
423
+ return
424
+ end
425
+
426
+ function trimatmul_t (ctx, C, A, B)
427
+ idx = @linearidx C
428
+ assume .(size (C) .> 0 )
429
+ i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
430
+ l, m, n = size (A, 1 ), size (B, 1 ), size (B, 2 )
431
+
432
+ @inbounds if i <= l && j <= n
433
+ z2 = zero (A[i,1 ] * B[1 ,j] + A[i,1 ] * B[1 ,j])
434
+ Cij = convert (promote_type (R, typeof (z2)), z2)
435
+ Cij += (unit ? one (Cij) : transpose (A[i,i])) * B[i,j]
436
+ for k in (upper ? (i + 1 ) : 1 ): (upper ? m : (i - 1 ))
437
+ Cij += transpose (A[k,i]) * B[k,j]
438
+ end
439
+ C[i,j] += Cij
440
+ end
441
+
442
+ return
443
+ end
444
+
445
+ function trimatmul_a (ctx, C, A, B)
446
+ idx = @linearidx C
447
+ assume .(size (C) .> 0 )
448
+ i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
449
+ l, m, n = size (A, 1 ), size (B, 1 ), size (B, 2 )
450
+
451
+ @inbounds if i <= l && j <= n
452
+ z2 = zero (A[i,1 ] * B[1 ,j] + A[i,1 ] * B[1 ,j])
453
+ Cij = convert (promote_type (R, typeof (z2)), z2)
454
+ Cij += (unit ? one (Cij) : adjoint (A[i,i])) * B[i,j]
455
+ for k in (upper ? (i + 1 ) : 1 ): (upper ? m : (i - 1 ))
456
+ Cij += adjoint (A[k,i]) * B[k,j]
457
+ end
458
+ C[i,j] += Cij
459
+ end
460
+
461
+ return
462
+ end
463
+
464
+ if tfun === identity
465
+ gpu_call (trimatmul, C, A, B; name= " trimatmul" )
466
+ elseif tfun == transpose
467
+ gpu_call (trimatmul_t, C, A, B; name= " trimatmul_t" )
468
+ elseif tfun === adjoint
469
+ gpu_call (trimatmul_a, C, A, B; name= " trimatmul_a" )
470
+ else
471
+ error (" Not supported" )
472
+ end
473
+
474
+ C
475
+ end
476
+
477
+ function generic_mattrimul! (C:: AbstractGPUVecOrMat{R} , uploc, isunitc, tfun:: Function , A:: AbstractGPUMatrix{T} , B:: AbstractGPUVecOrMat{S} ) where {T,S,R}
478
+ if size (A,2 ) != size (B,1 )
479
+ throw (DimensionMismatch (lazy " matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))" ))
480
+ end
481
+ if size (C,1 ) != size (A,1 ) || size (C,2 ) != size (B,2 )
482
+ throw (DimensionMismatch (lazy " result C has dimensions $(size(C)), needs $((size(A,1),size(B,2)))" ))
483
+ end
484
+ if isempty (A) || isempty (B)
485
+ return fill! (C, zero (R))
486
+ end
487
+
488
+ upper = tfun === identity ? uploc == ' U' : uploc != ' U'
489
+ unit = isunitc == ' U'
490
+
491
+ function mattrimul (ctx, C, A, B)
492
+ idx = @linearidx C
493
+ assume .(size (C) .> 0 )
494
+ i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
495
+ l, m, n = size (A, 1 ), size (B, 1 ), size (B, 2 )
496
+
497
+ @inbounds if i <= l && j <= n
498
+ z2 = zero (A[i,1 ] * B[1 ,j] + A[i,1 ] * B[1 ,j])
499
+ Cij = convert (promote_type (R, typeof (z2)), z2)
500
+ Cij += A[i,j] * (unit ? one (Cij) : B[j,j])
501
+ for k in (upper ? 1 : (j + 1 )): (upper ? (j - 1 ) : m)
502
+ Cij += A[i,k] * B[k,j]
503
+ end
504
+ C[i,j] += Cij
505
+ end
506
+
507
+ return
508
+ end
509
+
510
+ function mattrimul_t (ctx, C, A, B)
511
+ idx = @linearidx C
512
+ assume .(size (C) .> 0 )
513
+ i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
514
+ l, m, n = size (A, 1 ), size (B, 1 ), size (B, 2 )
515
+
516
+ @inbounds if i <= l && j <= n
517
+ z2 = zero (A[i,1 ] * B[1 ,j] + A[i,1 ] * B[1 ,j])
518
+ Cij = convert (promote_type (R, typeof (z2)), z2)
519
+ Cij += A[i,j] * (unit ? one (Cij) : transpose (B[j,j]))
520
+ for k in (upper ? 1 : (j + 1 ) ): (upper ? (j - 1 ) : m)
521
+ Cij += A[i,k] * transpose (B[j,k])
522
+ end
523
+ C[i,j] += Cij
524
+ end
525
+
526
+ return
527
+ end
528
+
529
+ function mattrimul_a (ctx, C, A, B)
530
+ idx = @linearidx C
531
+ assume .(size (C) .> 0 )
532
+ i, j = @inbounds Tuple (CartesianIndices (C)[idx])... , 1
533
+ l, m, n = size (A, 1 ), size (B, 1 ), size (B, 2 )
534
+
535
+ @inbounds if i <= l && j <= n
536
+ z2 = zero (A[i,1 ] * B[1 ,j] + A[i,1 ] * B[1 ,j])
537
+ Cij = convert (promote_type (R, typeof (z2)), z2)
538
+ Cij += A[i,j] * (unit ? one (Cij) : adjoint (B[j,j]))
539
+ for k in (upper ? 1 : (j + 1 )): (upper ? (j - 1 ) : m)
540
+ Cij += A[i,k] * adjoint (B[j,k])
541
+ end
542
+ C[i,j] += Cij
543
+ end
544
+
545
+ return
546
+ end
547
+
548
+ if tfun === identity
549
+ gpu_call (mattrimul, C, A, B; name= " mattrimul" )
550
+ elseif tfun == transpose
551
+ gpu_call (mattrimul_t, C, A, B; name= " mattrimul_t" )
552
+ elseif tfun === adjoint
553
+ gpu_call (mattrimul_a, C, A, B; name= " mattrimul_a" )
554
+ else
555
+ error (" Not supported" )
556
+ end
557
+
558
+ C
559
+ end
560
+
561
+ if VERSION >= v " 1.10-"
562
+ function LinearAlgebra. generic_trimatmul! (C:: AbstractGPUVecOrMat , uploc, isunitc, tfun:: Function , A:: AbstractGPUMatrix , B:: AbstractGPUVecOrMat )
563
+ generic_trimatmul! (C, uploc, isunitc, tfun, A, B)
564
+ end
565
+ function LinearAlgebra. generic_mattrimul! (C:: AbstractGPUMatrix , uploc, isunitc, tfun:: Function , A:: AbstractGPUMatrix , B:: AbstractGPUMatrix )
566
+ generic_mattrimul! (C, uploc, isunitc, tfun, A, B)
567
+ end
568
+ end
392
569
393
570
if VERSION < v " 1.10.0-DEV.1365"
394
571
# catch other functions that are called by LinearAlgebra's mul!
0 commit comments