Skip to content

Commit b1ba771

Browse files
authored
Queue operations on the current device, not the one owning the array. (#460)
When using host or shared buffer types, there isn't a clear owner, so let's use the CUDA semantics of executing on the currently-active device.
1 parent 260a4dd commit b1ba771

File tree

4 files changed

+80
-76
lines changed

4 files changed

+80
-76
lines changed

lib/mkl/wrappers_blas.jl

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ for (fname, elty) in ((:onemklSsymm, :Float32),
153153
lda = max(1,stride(A,2))
154154
ldb = max(1,stride(B,2))
155155
ldc = max(1,stride(C,2))
156-
queue = global_queue(context(A), device(A))
156+
queue = global_queue(context(A), device())
157157
$fname(sycl_queue(queue), side, uplo, m, n, alpha, A, lda, B, ldb,
158158
beta, C, ldc)
159159
C
@@ -193,7 +193,7 @@ for (fname, elty) in ((:onemklSsyrk, :Float32),
193193
k = size(A, trans == 'N' ? 2 : 1)
194194
lda = max(1,stride(A,2))
195195
ldc = max(1,stride(C,2))
196-
queue = global_queue(context(A), device(A))
196+
queue = global_queue(context(A), device())
197197
$fname(sycl_queue(queue), uplo, trans, n, k, alpha, A, lda, beta, C, ldc)
198198
C
199199
end
@@ -234,7 +234,7 @@ for (fname, elty) in ((:onemklDsyr2k,:Float64),
234234
lda = max(1,stride(A,2))
235235
ldb = max(1,stride(B,2))
236236
ldc = max(1,stride(C,2))
237-
queue = global_queue(context(A), device(A))
237+
queue = global_queue(context(A), device())
238238
$fname(sycl_queue(queue), uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc)
239239
C
240240
end
@@ -268,7 +268,7 @@ for (fname, elty) in ((:onemklZherk, :ComplexF64),
268268
k = size(A, trans == 'N' ? 2 : 1)
269269
lda = max(1,stride(A,2))
270270
ldc = max(1,stride(C,2))
271-
queue = global_queue(context(A), device(A))
271+
queue = global_queue(context(A), device())
272272
$fname(sycl_queue(queue), uplo, trans, n, k, alpha, A, lda, beta, C, ldc)
273273
C
274274
end
@@ -305,7 +305,7 @@ for (fname, elty) in ((:onemklZher2k,:ComplexF64),
305305
lda = max(1,stride(A,2))
306306
ldb = max(1,stride(B,2))
307307
ldc = max(1,stride(C,2))
308-
queue = global_queue(context(A), device(A))
308+
queue = global_queue(context(A), device())
309309
$fname(sycl_queue(queue), uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc)
310310
C
311311
end
@@ -336,7 +336,7 @@ for (fname, elty) in ((:onemklSgemv, :Float32),
336336
x::oneStridedArray{$elty},
337337
beta::Number,
338338
y::oneStridedArray{$elty})
339-
queue = global_queue(context(x), device(x))
339+
queue = global_queue(context(x), device())
340340
# handle trans
341341
m,n = size(a)
342342
# check dimensions
@@ -380,7 +380,7 @@ for (fname, elty) in ((:onemklChemv,:ComplexF32),
380380
lda = max(1,stride(A,2))
381381
incx = stride(x,1)
382382
incy = stride(y,1)
383-
queue = global_queue(context(x), device(x))
383+
queue = global_queue(context(x), device())
384384
$fname(sycl_queue(queue), uplo, n, alpha, A, lda, x, incx, beta, y, incy)
385385
y
386386
end
@@ -414,7 +414,7 @@ for (fname, elty) in ((:onemklChbmv,:ComplexF32),
414414
lda = max(1,stride(A,2))
415415
incx = stride(x,1)
416416
incy = stride(y,1)
417-
queue = global_queue(context(x), device(x))
417+
queue = global_queue(context(x), device())
418418
$fname(sycl_queue(queue), uplo, n, k, alpha, A, lda, x, incx, beta, y, incy)
419419
y
420420
end
@@ -443,7 +443,7 @@ for (fname, elty) in ((:onemklCher,:ComplexF32),
443443
length(x) == n || throw(DimensionMismatch("Length of vector must be the same as the matrix dimensions"))
444444
incx = stride(x,1)
445445
lda = max(1,stride(A,2))
446-
queue = global_queue(context(x), device(x))
446+
queue = global_queue(context(x), device())
447447
$fname(sycl_queue(queue), uplo, n, alpha, x, incx, A, lda)
448448
A
449449
end
@@ -466,7 +466,7 @@ for (fname, elty) in ((:onemklCher2,:ComplexF32),
466466
incx = stride(x,1)
467467
incy = stride(y,1)
468468
lda = max(1,stride(A,2))
469-
queue = global_queue(context(x), device(x))
469+
queue = global_queue(context(x), device())
470470
$fname(sycl_queue(queue), uplo, n, alpha, x, incx, y, incy, A, lda)
471471
A
472472
end
@@ -486,7 +486,7 @@ for (fname, elty) in
486486
alpha::Number,
487487
x::oneStridedArray{$elty},
488488
y::oneStridedArray{$elty})
489-
queue = global_queue(context(x), device(x))
489+
queue = global_queue(context(x), device())
490490
alpha = $elty(alpha)
491491
$fname(sycl_queue(queue), n, alpha, x, stride(x,1), y, stride(y,1))
492492
y
@@ -506,7 +506,7 @@ for (fname, elty) in
506506
x::oneStridedArray{$elty},
507507
beta::Number,
508508
y::oneStridedArray{$elty})
509-
queue = global_queue(context(x), device(x))
509+
queue = global_queue(context(x), device())
510510
alpha = $elty(alpha)
511511
beta = $elty(beta)
512512
$fname(sycl_queue(queue), n, alpha, x, stride(x,1), beta, y, stride(y,1))
@@ -528,7 +528,7 @@ for (fname, elty, cty, sty, supty) in ((:onemklSrot,:Float32,:Float32,:Float32,:
528528
y::oneStridedArray{$elty},
529529
c::Real,
530530
s::$supty)
531-
queue = global_queue(context(x), device(x))
531+
queue = global_queue(context(x), device())
532532
c = $cty(c)
533533
s = $sty(s)
534534
$fname(sycl_queue(queue), n, x, stride(x, 1), y, stride(y, 1), c, s)
@@ -560,7 +560,7 @@ for (fname, elty) in
560560
function scal!(n::Integer,
561561
alpha::$elty,
562562
x::oneStridedArray{$elty})
563-
queue = global_queue(context(x), device(x))
563+
queue = global_queue(context(x), device())
564564
$fname(sycl_queue(queue), n, alpha, x, stride(x,1))
565565
x
566566
end
@@ -586,7 +586,7 @@ for (fname, elty, ret_type) in
586586
(:onemklZnrm2, :ComplexF64,:Float64))
587587
@eval begin
588588
function nrm2(n::Integer, x::oneStridedArray{$elty})
589-
queue = global_queue(context(x), device(x))
589+
queue = global_queue(context(x), device())
590590
result = oneArray{$ret_type}([0]);
591591
$fname(sycl_queue(queue), n, x, stride(x,1), result)
592592
res = Array(result)
@@ -616,7 +616,7 @@ for (jname, fname, elty) in
616616
function $jname(n::Integer,
617617
x::oneStridedArray{$elty},
618618
y::oneStridedArray{$elty})
619-
queue = global_queue(context(x), device(x))
619+
queue = global_queue(context(x), device())
620620
result = oneArray{$elty}([0]);
621621
$fname(sycl_queue(queue), n, x, stride(x,1), y, stride(y,1), result)
622622
res = Array(result)
@@ -649,7 +649,7 @@ for (fname, elty) in ((:onemklSsbmv, :Float32),
649649
if !(1<=(1+k)<=n) throw(DimensionMismatch("Incorrect number of bands")) end
650650
if m < 1+k throw(DimensionMismatch("Array A has fewer than 1+k rows")) end
651651
if n != length(x) || n != length(y) throw(DimensionMismatch("")) end
652-
queue = global_queue(context(x), device(x))
652+
queue = global_queue(context(x), device())
653653
lda = max(1, stride(a,2))
654654
incx = stride(x,1)
655655
incy = stride(y,1)
@@ -676,7 +676,7 @@ for (fname, elty, celty) in ((:onemklCSscal, :Float32, :ComplexF32),
676676
function scal!(n::Integer,
677677
alpha::$elty,
678678
x::oneStridedArray{$celty})
679-
queue = global_queue(context(x), device(x))
679+
queue = global_queue(context(x), device())
680680
$fname(sycl_queue(queue), n, alpha, x, stride(x,1))
681681
end
682682
end
@@ -696,7 +696,7 @@ for (fname, elty) in ((:onemklSger, :Float32),
696696
m,n = size(a)
697697
m == length(x) || throw(DimensionMismatch(""))
698698
n == length(y) || throw(DimensionMismatch(""))
699-
queue = global_queue(context(x), device(x))
699+
queue = global_queue(context(x), device())
700700
$fname(sycl_queue(queue), m, n, alpha, x, stride(x,1), y, stride(y,1), a, max(1,stride(a,2)))
701701
a
702702
end
@@ -714,7 +714,7 @@ for (fname, elty) in ((:onemklSspr, :Float32),
714714
n = round(Int, (sqrt(8*length(A))-1)/2)
715715
length(x) == n || throw(DimensionMismatch("Length of vector must be the same as the matrix dimensions"))
716716
incx = stride(x,1)
717-
queue = global_queue(context(x), device(x))
717+
queue = global_queue(context(x), device())
718718
$fname(sycl_queue(queue), uplo, n, alpha, x, incx, A)
719719
A
720720
end
@@ -738,7 +738,7 @@ for (fname, elty) in ((:onemklSsymv,:Float32),
738738
lda = max(1,stride(A,2))
739739
incx = stride(x,1)
740740
incy = stride(y,1)
741-
queue = global_queue(context(x), device(x))
741+
queue = global_queue(context(x), device())
742742
$fname(sycl_queue(queue), uplo, n, alpha, A, lda, x, incx, beta, y, incy)
743743
y
744744
end
@@ -764,7 +764,7 @@ for (fname, elty) in ((:onemklSsyr,:Float32),
764764
length(x) == n || throw(DimensionMismatch("Length of vector must be the same as the matrix dimensions"))
765765
incx = stride(x,1)
766766
lda = max(1,stride(A,2))
767-
queue = global_queue(context(x), device(x))
767+
queue = global_queue(context(x), device())
768768
$fname(sycl_queue(queue), uplo, n, alpha, x, incx, A, lda)
769769
A
770770
end
@@ -786,7 +786,7 @@ for (fname, elty) in
786786
function copy!(n::Integer,
787787
x::oneStridedArray{$elty},
788788
y::oneStridedArray{$elty})
789-
queue = global_queue(context(x), device(x))
789+
queue = global_queue(context(x), device())
790790
$fname(sycl_queue(queue), n, x, stride(x, 1), y, stride(y, 1))
791791
y
792792
end
@@ -807,7 +807,7 @@ for (fname, elty, ret_type) in
807807
function asum(n::Integer,
808808
x::oneStridedArray{$elty})
809809
result = oneArray{$ret_type}([0])
810-
queue = global_queue(context(x), device(x))
810+
queue = global_queue(context(x), device())
811811
$fname(sycl_queue(queue), n, x, stride(x, 1), result)
812812
res = Array(result)
813813
return res[1]
@@ -824,7 +824,7 @@ for (fname, elty) in
824824
@eval begin
825825
function iamax(x::oneStridedArray{$elty})
826826
n = length(x)
827-
queue = global_queue(context(x), device(x))
827+
queue = global_queue(context(x), device())
828828
result = oneArray{Int64}([0]);
829829
$fname(sycl_queue(queue), n, x, stride(x, 1), result, 'O')
830830
return Array(result)[1]
@@ -842,7 +842,7 @@ for (fname, elty) in
842842
function iamin(x::StridedArray{$elty})
843843
n = length(x)
844844
result = oneArray{Int64}([0]);
845-
queue = global_queue(context(x), device(x))
845+
queue = global_queue(context(x), device())
846846
$fname(sycl_queue(queue),n, x, stride(x, 1), result, 'O')
847847
return Array(result)[1]
848848
end
@@ -859,7 +859,7 @@ for (fname, elty) in ((:onemklSswap,:Float32),
859859
x::oneStridedArray{$elty},
860860
y::oneStridedArray{$elty})
861861
# Assuming both memory allocated on same device & context
862-
queue = global_queue(context(x), device(x))
862+
queue = global_queue(context(x), device())
863863
$fname(sycl_queue(queue), n, x, stride(x, 1), y, stride(y, 1))
864864
x, y
865865
end
@@ -885,7 +885,7 @@ for (fname, elty) in ((:onemklSgbmv, :Float32),
885885
n = size(a,2)
886886
length(x) == (trans == 'N' ? n : m) && length(y) ==
887887
(trans == 'N' ? m : n) || throw(DimensionMismatch(""))
888-
queue = global_queue(context(x), device(x))
888+
queue = global_queue(context(x), device())
889889
lda = max(1, stride(a,2))
890890
incx = stride(x,1)
891891
incy = stride(y,1)
@@ -903,7 +903,7 @@ function gbmv(trans::Char,
903903
x::oneStridedArray{T}) where T
904904
n = size(a,2)
905905
leny = trans == 'N' ? m : n
906-
queue = global_queue(context(x), device(x))
906+
queue = global_queue(context(x), device())
907907
gbmv!(trans, m, kl, ku, alpha, a, x, zero(T), similar(x, leny))
908908
end
909909
function gbmv(trans::Char,
@@ -912,7 +912,7 @@ function gbmv(trans::Char,
912912
ku::Integer,
913913
a::oneStridedArray{T},
914914
x::oneStridedArray{T}) where T
915-
queue = global_queue(context(x), device(x))
915+
queue = global_queue(context(x), device())
916916
gbmv(trans, m, kl, ku, one(T), a, x)
917917
end
918918

@@ -932,7 +932,7 @@ for (fname, elty) in ((:onemklSspmv, :Float32),
932932
end
933933
incx = stride(x,1)
934934
incy = stride(y,1)
935-
queue = global_queue(context(x), device(x))
935+
queue = global_queue(context(x), device())
936936
$fname(sycl_queue(queue), uplo, n, alpha, A, x, incx, beta, y, incy)
937937
y
938938
end
@@ -966,7 +966,7 @@ for (fname, elty) in ((:onemklStbsv, :Float32),
966966
if n != length(x) throw(DimensionMismatch("")) end
967967
lda = max(1,stride(A,2))
968968
incx = stride(x,1)
969-
queue = global_queue(context(x), device(x))
969+
queue = global_queue(context(x), device())
970970
$fname(sycl_queue(queue), uplo, trans, diag, n, k, A, lda, x, incx)
971971
x
972972
end
@@ -996,7 +996,7 @@ for (fname, elty) in ((:onemklStbmv,:Float32),
996996
if n != length(x) throw(DimensionMismatch("")) end
997997
lda = max(1,stride(A,2))
998998
incx = stride(x,1)
999-
queue = global_queue(context(x), device(x))
999+
queue = global_queue(context(x), device())
10001000
$fname(sycl_queue(queue), uplo, trans, diag, n, k, A, lda, x, incx)
10011001
x
10021002
end
@@ -1029,7 +1029,7 @@ for (fname, elty) in ((:onemklStrmv, :Float32),
10291029
end
10301030
lda = max(1,stride(A,2))
10311031
incx = stride(x,1)
1032-
queue = global_queue(context(x), device(x))
1032+
queue = global_queue(context(x), device())
10331033
$fname(sycl_queue(queue), uplo, trans, diag, n, A, lda, x, incx)
10341034
x
10351035
end
@@ -1061,7 +1061,7 @@ for (fname, elty) in ((:onemklStrsv, :Float32),
10611061
end
10621062
lda = max(1,stride(A,2))
10631063
incx = stride(x,1)
1064-
queue = global_queue(context(x), device(x))
1064+
queue = global_queue(context(x), device())
10651065
$fname(sycl_queue(queue), uplo, trans, diag, n, A, lda, x, incx)
10661066
x
10671067
end
@@ -1096,7 +1096,7 @@ for (mmname, smname, elty) in
10961096
if nA != (side == 'L' ? m : n) throw(DimensionMismatch("trmm!")) end
10971097
lda = max(1,stride(A,2))
10981098
ldb = max(1,stride(B,2))
1099-
queue = global_queue(context(A), device(A))
1099+
queue = global_queue(context(A), device())
11001100
$mmname(sycl_queue(queue), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb)
11011101
B
11021102
end
@@ -1114,7 +1114,7 @@ for (mmname, smname, elty) in
11141114
if nA != (side == 'L' ? m : n) throw(DimensionMismatch("trsm!")) end
11151115
lda = max(1,stride(A,2))
11161116
ldb = max(1,stride(B,2))
1117-
queue = global_queue(context(A), device(A))
1117+
queue = global_queue(context(A), device())
11181118
$smname(sycl_queue(queue), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb)
11191119
B
11201120
end
@@ -1160,7 +1160,7 @@ for (fname, elty) in ((:onemklZhemm,:ComplexF64),
11601160
lda = max(1,stride(A,2))
11611161
ldb = max(1,stride(B,2))
11621162
ldc = max(1,stride(C,2))
1163-
queue = global_queue(context(A), device(A))
1163+
queue = global_queue(context(A), device())
11641164
$fname(sycl_queue(queue), side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc)
11651165
C
11661166
end
@@ -1202,9 +1202,9 @@ for (fname, elty) in
12021202
ldb = max(1,stride(B,2))
12031203
ldc = max(1,stride(C,2))
12041204

1205-
device(A) == device(B) == device(C) || error("Multi-device GEMM not supported")
1205+
device() == device(B) == device(C) || error("Multi-device GEMM not supported")
12061206
context(A) == context(B) == context(C) || error("Multi-context GEMM not supported")
1207-
queue = global_queue(context(A), device(A))
1207+
queue = global_queue(context(A), device())
12081208

12091209
alpha = $elty(alpha)
12101210
beta = $elty(beta)
@@ -1249,7 +1249,7 @@ for (fname, elty) in ((:onemklSdgmm, :Float32),
12491249
lda = max(1,stride(A,2))
12501250
incx = stride(X,1)
12511251
ldc = max(1,stride(C,2))
1252-
queue = global_queue(context(A), device(A))
1252+
queue = global_queue(context(A), device())
12531253
$fname(sycl_queue(queue), mode, m, n, A, lda, X, incx, C, ldc)
12541254
C
12551255
end
@@ -1292,7 +1292,7 @@ for (fname, elty) in
12921292
strideB = size(B, 3) == 1 ? 0 : stride(B, 3)
12931293
strideC = stride(C, 3)
12941294
batchCount = size(C, 3)
1295-
queue = global_queue(context(A), device(A))
1295+
queue = global_queue(context(A), device())
12961296
alpha = $elty(alpha)
12971297
beta = $elty(beta)
12981298
$fname(sycl_queue(queue), transA, transB, m, n, k, alpha, A, lda, strideA, B,

0 commit comments

Comments
 (0)