Skip to content

Commit 8cae9ed

Browse files
committed
debug
1 parent 7b018de commit 8cae9ed

File tree

2 files changed

+28
-13
lines changed

2 files changed

+28
-13
lines changed

src/solver/highlevel.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,9 @@ for (fname, elty) in (
226226
(mA == nA == mB == nB == mC == nC) || throw(DimensionMismatch("The first two dimensions of A, B and C must match"))
227227
(nblocksA == nblocksB - 1 == nblocksC) || throw(DimensionMismatch("Inconsistency for the last dimension of A, B and C"))
228228

229-
lda = max(1, stride(A, 2))
230-
ldb = max(1, stride(B, 2))
231-
ldc = max(1, stride(C, 2))
229+
lda = max(1, stride(A, 3))
230+
ldb = max(1, stride(B, 3))
231+
ldc = max(1, stride(C, 3))
232232

233233
devinfo = ROCArray{Cint}(undef, 1)
234234
$fname(rocBLAS.handle(), mB, nblocksB, A, lda, B, ldb, C, ldc, devinfo)
@@ -257,10 +257,10 @@ for (fname, elty) in (
257257
(mX == mA) || throw(DimensionMismatch("The first dimension of X is inconsistent with first two dimensions of A, B and C"))
258258
(nblocksA == nblocksB - 1 == nblocksX - 1 == nblocksC) || throw(DimensionMismatch("Inconsistency for the number of blocks in A, B, C and X"))
259259

260-
lda = max(1, stride(A, 2))
261-
ldb = max(1, stride(B, 2))
262-
ldc = max(1, stride(C, 2))
263-
ldx = max(1, stride(X, 2))
260+
lda = max(1, stride(A, 3))
261+
ldb = max(1, stride(B, 3))
262+
ldc = max(1, stride(C, 3))
263+
ldx = max(1, stride(X, 3))
264264

265265
$fname(rocBLAS.handle(), mB, nblocksB, nrhs, A, lda, B, ldb, C, ldc, X, ldx)
266266
X

test/rocarray/solver.jl

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -204,13 +204,15 @@ end
204204
@testset "elty = $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64]
205205
nblocks = 5
206206
n = 16
207+
nrhs = 4
207208
p = n * nblocks
208209
A = rand(elty, n, n, nblocks-1)
209210
B = rand(elty, n, n, nblocks)
210211
C = rand(elty, n, n, nblocks-1)
211-
RHS = rand(elty, p, 4)
212+
R = rand(elty, n, nblocks, nrhs)
212213

213214
M = zeros(elty, p, p)
215+
RHS = zeros(elty, p, nrhs)
214216
for k in 1:nblocks
215217
offset = (k-1)*n
216218
for i = 1:n
@@ -221,14 +223,16 @@ end
221223
M[offset+i,offset+n+j] = C[i,j,k]
222224
end
223225
end
226+
for j = 1:nrhs
227+
RHS[offset+i,j] = R[i,k,j]
228+
end
224229
end
225230
end
226-
X = M \ RHS
227231

228232
d_A = ROCArray(A)
229233
d_B = ROCArray(B)
230234
d_C = ROCArray(C)
231-
d_X = ROCArray(RHS)
235+
d_R = ROCArray(R)
232236
rocSOLVER.geblttrf!(d_A, d_B, d_C)
233237

234238
L = zeros(elty, p, p)
@@ -251,10 +255,21 @@ end
251255
end
252256
end
253257
end
254-
@test L * U M
258+
N = L * U
259+
@test N M
255260

256-
rocSOLVER.geblttrs!(d_A, d_B, d_C, d_X)
257-
@test X collect(d_X)
261+
X = N \ RHS
262+
Y = similar(R)
263+
for k in 1:nblocks
264+
for i = 1:n
265+
for j = 1:nrhs
266+
l = (k-1)*n + i
267+
Y[i, k, j] = X[l,j]
268+
end
269+
end
270+
end
271+
rocSOLVER.geblttrs!(d_A, d_B, d_C, d_R)
272+
@test Y collect(d_R)
258273
end
259274
end
260275

0 commit comments

Comments
 (0)