Skip to content

Commit 3321fc8

Browse files
authored
Bump CUTENSOR to v1.7. (#1960)
[skip julia] [skip cuda] [skip downstream] [skip benchmarks] [skip docs]
1 parent 1ef29ec commit 3321fc8

File tree

7 files changed

+59
-32
lines changed

7 files changed

+59
-32
lines changed

.buildkite/pipeline.yml

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ steps:
2626
# use the CUDA installation from the CI image
2727
using CUDA
2828
CUDA.set_runtime_version!("local")'
29-
if: build.message !~ /\[skip tests\]/
29+
if: build.message !~ /\[skip tests\]/ &&
30+
build.message !~ /\[skip julia\]/
3031
timeout_in_minutes: 120
3132
matrix:
3233
setup:
@@ -63,7 +64,9 @@ steps:
6364
agents:
6465
queue: "juliagpu"
6566
cuda: "*"
66-
if: build.message !~ /\[skip tests\]/ && !build.pull_request.draft
67+
if: build.message !~ /\[skip tests\]/ &&
68+
build.message !~ /\[skip cuda\]/ &&
69+
!build.pull_request.draft
6770
timeout_in_minutes: 120
6871
matrix:
6972
setup:
@@ -99,6 +102,9 @@ steps:
99102
- with:
100103
cuda: "12.0"
101104
package: "cuDNN"
105+
- with:
106+
cuda: "12.0"
107+
package: "cuTENSOR"
102108
plugins:
103109
- JuliaCI/julia#v1:
104110
version: "1.8"
@@ -110,7 +116,9 @@ steps:
110116
agents:
111117
queue: "juliagpu"
112118
cuda: "*"
113-
if: build.message !~ /\[skip tests\]/ && !build.pull_request.draft
119+
if: build.message !~ /\[skip tests\]/ &&
120+
build.message !~ /\[skip subpackages\]/ &&
121+
!build.pull_request.draft
114122
timeout_in_minutes: 120
115123
commands: |
116124
julia -e '
@@ -161,7 +169,9 @@ steps:
161169
agents:
162170
queue: "juliagpu"
163171
cuda: "*"
164-
if: build.message !~ /\[skip tests\]/ && !build.pull_request.draft
172+
if: build.message !~ /\[skip tests\]/ &&
173+
build.message !~ /\[skip downstream\]/ &&
174+
!build.pull_request.draft
165175
timeout_in_minutes: 60
166176

167177
- group: ":eyes: Special"
@@ -212,7 +222,8 @@ steps:
212222
agents:
213223
queue: "juliagpu"
214224
cuda: "*"
215-
if: build.message !~ /\[skip docs\]/ && !build.pull_request.draft
225+
if: build.message !~ /\[skip docs\]/ &&
226+
!build.pull_request.draft
216227
timeout_in_minutes: 30
217228

218229
# XXX: fails often, and is very slow

lib/cutensor/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "cuTENSOR"
22
uuid = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
33
authors = ["Tim Besard <tim.besard@gmail.com>"]
4-
version = "1.0.4"
4+
version = "1.1.0"
55

66
[deps]
77
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
@@ -10,7 +10,7 @@ CUTENSOR_jll = "35b6c64b-1ee1-5834-92a3-3f624899209a"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111

1212
[compat]
13-
CUTENSOR_jll = "~1.6"
13+
CUTENSOR_jll = "~1.7"
1414
CUDA = "~4.3"
1515
CEnum = "0.2, 0.3, 0.4"
1616
julia = "1.6"

lib/cutensor/src/cuTENSOR.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,28 +31,26 @@ include("wrappers.jl")
3131
include("interfaces.jl")
3232

3333
# cache for created, but unused handles
34-
const idle_handles = HandleCache{CuContext,Base.RefValue{cutensorHandle_t}}()
34+
const idle_handles = HandleCache{CuContext,Ptr{cutensorHandle_t}}()
3535

3636
function handle()
3737
cuda = CUDA.active_state()
3838

3939
# every task maintains library state per device
40-
LibraryState = @NamedTuple{handle::Base.RefValue{cutensorHandle_t}}
40+
LibraryState = @NamedTuple{handle::Ptr{cutensorHandle_t}}
4141
states = get!(task_local_storage(), :cuTENSOR) do
4242
Dict{CuContext,LibraryState}()
4343
end::Dict{CuContext,LibraryState}
4444

4545
# get library state
4646
@noinline function new_state(cuda)
4747
new_handle = pop!(idle_handles, cuda.context) do
48-
handle = Ref{cutensorHandle_t}()
49-
cutensorInit(handle)
50-
handle
48+
cutensorCreate()
5149
end
5250

5351
finalizer(current_task()) do task
5452
push!(idle_handles, cuda.context, new_handle) do
55-
# cuTENSOR doesn't need to actively destroy its handle
53+
cutensorDestroy(new_handle)
5654
end
5755
end
5856

lib/cutensor/src/libcutensor.jl

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using CEnum
22

3-
# cuTENSOR uses CUDA runtime objects, which are compatible with our driver usage
3+
# CUTENSOR uses CUDA runtime objects, which are compatible with our driver usage
44
const cudaStream_t = CUstream
55

66
# outlined functionality to avoid GC frame allocation
@@ -14,7 +14,7 @@ end
1414

1515
function check(f, errs...)
1616
res = retry_reclaim(in((CUTENSOR_STATUS_ALLOC_FAILED, errs...))) do
17-
f()
17+
return f()
1818
end
1919

2020
if res != CUTENSOR_STATUS_SUCCESS
@@ -164,9 +164,14 @@ end
164164
# typedef void ( * cutensorLoggerCallback_t ) ( int32_t logLevel , const char * functionName , const char * message )
165165
const cutensorLoggerCallback_t = Ptr{Cvoid}
166166

167-
@checked function cutensorInit(handle)
167+
@checked function cutensorCreate(handle)
168168
initialize_context()
169-
@ccall libcutensor.cutensorInit(handle::Ptr{cutensorHandle_t})::cutensorStatus_t
169+
@ccall libcutensor.cutensorCreate(handle::Ptr{Ptr{cutensorHandle_t}})::cutensorStatus_t
170+
end
171+
172+
@checked function cutensorDestroy(handle)
173+
initialize_context()
174+
@ccall libcutensor.cutensorDestroy(handle::Ptr{cutensorHandle_t})::cutensorStatus_t
170175
end
171176

172177
@checked function cutensorHandleDetachPlanCachelines(handle)
@@ -392,12 +397,12 @@ function cutensorGetErrorString(error)
392397
@ccall libcutensor.cutensorGetErrorString(error::cutensorStatus_t)::Cstring
393398
end
394399

395-
# no prototype is found for this function at cutensor.h:745:8, please use with caution
400+
# no prototype is found for this function at cutensor.h:794:8, please use with caution
396401
function cutensorGetVersion()
397402
@ccall libcutensor.cutensorGetVersion()::Csize_t
398403
end
399404

400-
# no prototype is found for this function at cutensor.h:751:8, please use with caution
405+
# no prototype is found for this function at cutensor.h:800:8, please use with caution
401406
function cutensorGetCudartVersion()
402407
@ccall libcutensor.cutensorGetCudartVersion()::Csize_t
403408
end
@@ -427,7 +432,7 @@ end
427432
@ccall libcutensor.cutensorLoggerSetMask(mask::Int32)::cutensorStatus_t
428433
end
429434

430-
# no prototype is found for this function at cutensor.h:799:18, please use with caution
435+
# no prototype is found for this function at cutensor.h:848:18, please use with caution
431436
@checked function cutensorLoggerForceDisable()
432437
initialize_context()
433438
@ccall libcutensor.cutensorLoggerForceDisable()::cutensorStatus_t
@@ -460,3 +465,8 @@ end
460465
typeCompute::cutensorComputeType_t,
461466
workspaceSize::Ptr{UInt64})::cutensorStatus_t
462467
end
468+
469+
@checked function cutensorInit(handle)
470+
initialize_context()
471+
@ccall libcutensor.cutensorInit(handle::Ptr{cutensorHandle_t})::cutensorStatus_t
472+
end

lib/cutensor/src/wrappers.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@ function cuda_version()
1616
VersionNumber(major, minor, patch)
1717
end
1818

19+
function cutensorCreate()
20+
handle_ref = Ref{Ptr{cutensorHandle_t}}()
21+
check(CUTENSOR_STATUS_NOT_INITIALIZED) do
22+
unsafe_cutensorCreate(handle_ref)
23+
end
24+
handle_ref[]
25+
end
26+
1927

2028
abstract type CuTensorPlan end
2129

lib/cutensornet/Project.toml

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1+
authors = ["Katharine Hyatt <kshyatt@gmail.com>"]
12
name = "cuTensorNet"
23
uuid = "448d79b3-4b49-4e06-a5ea-00c62c0dc3db"
3-
authors = ["Katharine Hyatt <kshyatt@gmail.com>"]
44
version = "1.0.4"
55

6+
[compat]
7+
CEnum = "0.2, 0.3, 0.4"
8+
CUDA = "~4.3"
9+
cuQuantum_jll = "~22.11"
10+
cuTENSOR = "~1.0, ~1.1"
11+
julia = "1.6"
12+
613
[deps]
714
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
815
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
9-
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
1016
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1117
cuQuantum_jll = "b75408ef-6fdf-5d74-b65a-7df000ad18e6"
12-
13-
[compat]
14-
cuQuantum_jll = "~22.11"
15-
cuTENSOR = "~1.0"
16-
CUDA = "~4.3"
17-
CEnum = "0.2, 0.3, 0.4"
18-
julia = "1.6"
18+
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"

res/wrap/Manifest.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645"
4646
version = "8.9.2+0"
4747

4848
[[CUTENSOR_jll]]
49-
deps = ["Artifacts", "CUDA_Runtime_jll", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"]
50-
git-tree-sha1 = "cd25ea92f25cc47e92e366f2a7a337eae897cd68"
49+
deps = ["Artifacts", "CUDA_Runtime_jll", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
50+
git-tree-sha1 = "d2c1a1de70a8716acaeee873def6608d854212d5"
5151
uuid = "35b6c64b-1ee1-5834-92a3-3f624899209a"
52-
version = "1.6.1+1"
52+
version = "1.7.0+0"
5353

5454
[[Clang]]
5555
deps = ["CEnum", "Clang_jll", "Downloads", "Pkg", "TOML"]

0 commit comments

Comments
 (0)