Skip to content

Commit 98c407b

Browse files
authored
Speed up permutedims. (#387)
1 parent 35f25ab commit 98c407b

File tree

5 files changed

+165
-19
lines changed

5 files changed

+165
-19
lines changed

Manifest.toml

Lines changed: 123 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,69 @@
22

33
[[Adapt]]
44
deps = ["LinearAlgebra"]
5-
git-tree-sha1 = "f1b523983a58802c4695851926203b36e28f09db"
5+
git-tree-sha1 = "af92965fb30777147966f58acb05da51c5616b5f"
66
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
7-
version = "3.3.0"
7+
version = "3.3.3"
8+
9+
[[ArgTools]]
10+
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
11+
12+
[[Artifacts]]
13+
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
14+
15+
[[Base64]]
16+
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
17+
18+
[[CEnum]]
19+
git-tree-sha1 = "215a9aa4a1f23fbd05b92769fdd62559488d70e9"
20+
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
21+
version = "0.4.1"
22+
23+
[[Dates]]
24+
deps = ["Printf"]
25+
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
26+
27+
[[Downloads]]
28+
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
29+
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
30+
31+
[[InteractiveUtils]]
32+
deps = ["Markdown"]
33+
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
34+
35+
[[JLLWrappers]]
36+
deps = ["Preferences"]
37+
git-tree-sha1 = "22df5b96feef82434b07327e2d3c770a9b21e023"
38+
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
39+
version = "1.4.0"
40+
41+
[[LLVM]]
42+
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"]
43+
git-tree-sha1 = "f8dcd7adfda0dddaf944e62476d823164cccc217"
44+
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
45+
version = "4.7.1"
46+
47+
[[LLVMExtra_jll]]
48+
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
49+
git-tree-sha1 = "62115afed394c016c2d3096c5b85c407b48be96b"
50+
uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab"
51+
version = "0.0.13+1"
52+
53+
[[LibCURL]]
54+
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
55+
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
56+
57+
[[LibCURL_jll]]
58+
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
59+
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
60+
61+
[[LibGit2]]
62+
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
63+
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
64+
65+
[[LibSSH2_jll]]
66+
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
67+
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
868

969
[[Libdl]]
1070
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
@@ -13,17 +73,54 @@ uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1373
deps = ["Libdl"]
1474
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1575

76+
[[Logging]]
77+
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
78+
79+
[[Markdown]]
80+
deps = ["Base64"]
81+
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
82+
83+
[[MbedTLS_jll]]
84+
deps = ["Artifacts", "Libdl"]
85+
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
86+
87+
[[MozillaCACerts_jll]]
88+
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
89+
90+
[[NetworkOptions]]
91+
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
92+
93+
[[Pkg]]
94+
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
95+
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
96+
97+
[[Preferences]]
98+
deps = ["TOML"]
99+
git-tree-sha1 = "2cf929d64681236a2e074ffafb8d568733d2e6af"
100+
uuid = "21216c6a-2e73-6563-6e65-726566657250"
101+
version = "1.2.3"
102+
16103
[[Printf]]
17104
deps = ["Unicode"]
18105
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
19106

107+
[[REPL]]
108+
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
109+
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
110+
20111
[[Random]]
21112
deps = ["Serialization"]
22113
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
23114

115+
[[SHA]]
116+
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
117+
24118
[[Serialization]]
25119
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
26120

121+
[[Sockets]]
122+
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
123+
27124
[[SparseArrays]]
28125
deps = ["LinearAlgebra", "Random"]
29126
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -32,5 +129,29 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
32129
deps = ["LinearAlgebra", "SparseArrays"]
33130
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
34131

132+
[[TOML]]
133+
deps = ["Dates"]
134+
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
135+
136+
[[Tar]]
137+
deps = ["ArgTools", "SHA"]
138+
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
139+
140+
[[UUIDs]]
141+
deps = ["Random", "SHA"]
142+
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
143+
35144
[[Unicode]]
36145
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
146+
147+
[[Zlib_jll]]
148+
deps = ["Libdl"]
149+
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
150+
151+
[[nghttp2_jll]]
152+
deps = ["Artifacts", "Libdl"]
153+
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
154+
155+
[[p7zip_jll]]
156+
deps = ["Artifacts", "Libdl"]
157+
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ version = "8.1.3"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
7+
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
78
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
89
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
910
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -12,4 +13,5 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1213

1314
[compat]
1415
Adapt = "2.0, 3.0"
16+
LLVM = "3.9, 4"
1517
julia = "1.6"

src/GPUArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using LinearAlgebra.BLAS
99
using Base.Cartesian
1010

1111
using Adapt
12+
using LLVM.Interop
1213

1314
# device functionality
1415
include("device/execution.jl")

src/host/linalg.jl

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -197,33 +197,42 @@ LinearAlgebra.lmul!(a::Number, B::AbstractGPUArray) = generic_lmul!(a, B)
197197
LinearAlgebra.permutedims!(dest::AbstractGPUArray, src::AbstractGPUArray, perm) =
198198
permutedims!(dest, src, Tuple(perm))
199199

200+
@inline @generated function permute_linearindex(size::NTuple{N,T}, l::T, strides::NTuple{N,T}) where {N,T}
201+
quote
202+
l -= one(T)
203+
res = one(T)
204+
Base.Cartesian.@nexprs $(N-1) i->begin
205+
assume(size[i] > 0)
206+
@inbounds l, s = divrem(l, size[i])
207+
@inbounds res += s * strides[i]
208+
end
209+
return @inbounds res + strides[N] * l
210+
end
211+
end
200212
function LinearAlgebra.permutedims!(dest::AbstractGPUArray, src::AbstractGPUArray,
201213
perm::NTuple{N}) where N
202-
Base.checkdims_perm(dest, src, perm)
214+
length(dest) <= typemax(UInt32) ? _permutedims!(UInt32, dest, src, perm) : _permutedims!(UInt64, dest, src, perm)
215+
end
203216

204-
# get the new strides of destination tensor
217+
function _permutedims!(::Type{IT}, dest::AbstractGPUArray, src::AbstractGPUArray,
218+
perm::NTuple{N}) where {IT,N}
219+
@assert length(src) <= typemax(IT)
220+
Base.checkdims_perm(dest, src, perm)
205221
dest_strides = ntuple(k->k==1 ? 1 : prod(i->size(dest, i), 1:k-1), N)
206-
dest_strides_perm = ntuple(i->dest_strides[findfirst(==(i), perm)], N)
207-
208-
function permutedims_kernel(ctx, dest, src, dest_strides_perm)
209-
# find the cartesian index in source tensor
210-
LI = @linearidx src
211-
I = @inbounds CartesianIndices(src)[LI]
212-
213-
# the corresponding linear index in the destination tensor
214-
dest_index = map_index(I.I, dest_strides_perm)
222+
dest_strides_perm = ntuple(i->IT(dest_strides[findfirst(==(i), perm)]), N)
223+
size_src = IT.(size(src))
224+
function permutedims_kernel(ctx, dest, src, size_src, dest_strides_perm)
225+
SLI = @linearidx dest
226+
assume(0 < SLI <= typemax(IT))
227+
LI = IT(SLI)
228+
dest_index = permute_linearindex(size_src, LI, dest_strides_perm)
215229
@inbounds dest[dest_index] = src[LI]
216230
return
217231
end
218-
gpu_call(permutedims_kernel, dest, src, dest_strides_perm)
232+
gpu_call(permutedims_kernel, vec(dest), vec(src), size_src, dest_strides_perm)
219233
return dest
220234
end
221235

222-
# get linear index from cartesian indices and strides.
223-
@inline @generated function map_index(I::NTuple{N}, dest_strides::NTuple{N,T}) where {N,T}
224-
Expr(:call, :+, one(T), [:(@inbounds (I[$i]-1) * dest_strides[$i]) for i in 1:N]...)
225-
end
226-
227236
## norm
228237

229238
function LinearAlgebra.norm(v::AbstractGPUArray{T}, p::Real=2) where {T}

test/testsuite/linalg.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,22 @@
1515
@test compare(x -> permutedims(x, (2, 1, 3)), AT, rand(Float32, 4, 5, 6))
1616
@test compare(x -> permutedims(x, (3, 1, 2)), AT, rand(Float32, 4, 5, 6))
1717
@test compare(x -> permutedims(x, [2,1,4,3]), AT, randn(ComplexF32,3,4,5,1))
18+
# test UInt64 version to make sure it works properly when array length is larger than typemax of UInt32.
19+
AT <: GPUArrays.AbstractGPUArray && @test let
20+
x = randn(ComplexF32,3,4,5,1)
21+
y = permutedims(x, (2,1,4,3))
22+
Array(GPUArrays._permutedims!(UInt64, AT(zero(y)), AT(x), (2,1,4,3))) y
23+
end
1824
# high dimensional tensor
1925
@static if VERSION >= v"1.7"
2026
@test compare(x -> permutedims(x, 18:-1:1), AT, rand(Float32, 4, [2 for _ = 2:18]...))
27+
# test the Uint64 type version for large array permutedims
28+
AT <: GPUArrays.AbstractGPUArray && @test let
29+
x = rand(Float32, 4, [2 for _ = 2:18]...)
30+
pm = (18:-1:1...,)
31+
y = permutedims(x, pm)
32+
Array(GPUArrays._permutedims!(UInt64, AT(zero(y)), AT(x), pm)) y
33+
end
2134
end
2235
end
2336

0 commit comments

Comments
 (0)