@@ -180,26 +180,35 @@ LinearAlgebra.lmul!(a::Number, B::AbstractGPUArray) = generic_lmul!(a, B)
180
180
181
181
182
182
# # permutedims
183
+ LinearAlgebra. permutedims! (dest:: AbstractGPUArray , src:: AbstractGPUArray , perm) =
184
+ permutedims! (dest, src, Tuple (perm))
183
185
184
186
function LinearAlgebra. permutedims! (dest:: AbstractGPUArray , src:: AbstractGPUArray ,
185
- perm:: NTuple )
187
+ perm:: NTuple{N} ) where N
186
188
Base. checkdims_perm (dest, src, perm)
187
- function permutedims_kernel (ctx, dest, src, :: Val{perm} ) where {perm}
188
- I = @cartesianidx src
189
- @inbounds begin
190
- J = CartesianIndex (map (i-> I[i], perm))
191
- dest[J] = src[I]
192
- end
189
+
190
+ # get the new strides of destination tensor
191
+ dest_strides = ntuple (k-> k== 1 ? 1 : prod (i-> size (dest, i), 1 : k- 1 ), N)
192
+ dest_strides_perm = ntuple (i-> dest_strides[findfirst (== (i), perm)], N)
193
+
194
+ function permutedims_kernel (ctx, dest, src, dest_strides_perm)
195
+ # find the cartesian index in source tensor
196
+ LI = @linearidx src
197
+ I = @inbounds CartesianIndices (src)[LI]
198
+
199
+ # the corresponding linear index in the destination tensor
200
+ dest_index = map_index (I. I, dest_strides_perm)
201
+ @inbounds dest[dest_index] = src[LI]
193
202
return
194
203
end
195
- gpu_call (permutedims_kernel, dest, src, Val (perm) )
204
+ gpu_call (permutedims_kernel, dest, src, dest_strides_perm )
196
205
return dest
197
206
end
198
207
199
- # TODO : implementation without the memory copy
200
- LinearAlgebra . permutedims! (dest :: AbstractGPUArray , src :: AbstractGPUArray , perm) =
201
- permutedims! (dest, src, Tuple (perm) )
202
-
208
+ # get linear index from cartesian indices and strides.
209
+ @inline @generated function map_index (I :: NTuple{N} , dest_strides :: NTuple{N,T} ) where {N,T}
210
+ Expr ( :call , : + , one (T), [:( @inbounds (I[ $ i] - 1 ) * dest_strides[ $ i]) for i in 1 : N] . .. )
211
+ end
203
212
204
213
# # norm
205
214
0 commit comments