Skip to content

Commit 57bf31a

Browse files
GiggleLiumaleadt
andauthored
Implement permutedims without specializing on the permutation (#383)
Co-authored-by: Tim Besard <tim.besard@gmail.com>
1 parent 53a4a52 commit 57bf31a

File tree

3 files changed

+27
-14
lines changed

3 files changed

+27
-14
lines changed

.buildkite/pipeline.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ steps:
22
- label: "CUDA.jl"
33
plugins:
44
- JuliaCI/julia#v1:
5-
version: 1.6
5+
version: 1.7
66
- JuliaCI/julia-coverage#v1:
77
codecov: true
88
command: |
@@ -24,7 +24,7 @@ steps:
2424
- label: "oneAPI.jl"
2525
plugins:
2626
- JuliaCI/julia#v1:
27-
version: 1.6
27+
version: 1.7
2828
- JuliaCI/julia-coverage#v1:
2929
codecov: true
3030
command: |

src/host/linalg.jl

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -180,26 +180,35 @@ LinearAlgebra.lmul!(a::Number, B::AbstractGPUArray) = generic_lmul!(a, B)
180180

181181

182182
## permutedims
183+
LinearAlgebra.permutedims!(dest::AbstractGPUArray, src::AbstractGPUArray, perm) =
184+
permutedims!(dest, src, Tuple(perm))
183185

184186
function LinearAlgebra.permutedims!(dest::AbstractGPUArray, src::AbstractGPUArray,
185-
perm::NTuple)
187+
perm::NTuple{N}) where N
186188
Base.checkdims_perm(dest, src, perm)
187-
function permutedims_kernel(ctx, dest, src, ::Val{perm}) where {perm}
188-
I = @cartesianidx src
189-
@inbounds begin
190-
J = CartesianIndex(map(i->I[i], perm))
191-
dest[J] = src[I]
192-
end
189+
190+
# get the new strides of destination tensor
191+
dest_strides = ntuple(k->k==1 ? 1 : prod(i->size(dest, i), 1:k-1), N)
192+
dest_strides_perm = ntuple(i->dest_strides[findfirst(==(i), perm)], N)
193+
194+
function permutedims_kernel(ctx, dest, src, dest_strides_perm)
195+
# find the cartesian index in source tensor
196+
LI = @linearidx src
197+
I = @inbounds CartesianIndices(src)[LI]
198+
199+
# the corresponding linear index in the destination tensor
200+
dest_index = map_index(I.I, dest_strides_perm)
201+
@inbounds dest[dest_index] = src[LI]
193202
return
194203
end
195-
gpu_call(permutedims_kernel, dest, src, Val(perm))
204+
gpu_call(permutedims_kernel, dest, src, dest_strides_perm)
196205
return dest
197206
end
198207

199-
# TODO: implementation without the memory copy
200-
LinearAlgebra.permutedims!(dest::AbstractGPUArray, src::AbstractGPUArray, perm) =
201-
permutedims!(dest, src, Tuple(perm))
202-
208+
# get linear index from cartesian indices and strides.
209+
@inline @generated function map_index(I::NTuple{N}, dest_strides::NTuple{N,T}) where {N,T}
210+
Expr(:call, :+, one(T), [:(@inbounds (I[$i]-1) * dest_strides[$i]) for i in 1:N]...)
211+
end
203212

204213
## norm
205214

test/testsuite/linalg.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
@test compare(x -> permutedims(x, (2, 1, 3)), AT, rand(Float32, 4, 5, 6))
1616
@test compare(x -> permutedims(x, (3, 1, 2)), AT, rand(Float32, 4, 5, 6))
1717
@test compare(x -> permutedims(x, [2,1,4,3]), AT, randn(ComplexF32,3,4,5,1))
18+
# high dimensional tensor
19+
@static if VERSION >= v"1.7"
20+
@test compare(x -> permutedims(x, 18:-1:1), AT, rand(Float32, 4, [2 for _ = 2:18]...))
21+
end
1822
end
1923

2024
@testset "symmetric" begin

0 commit comments

Comments
 (0)