Skip to content

Commit 7d74f90

Browse files
maleadtChrisPsa
andauthored
Update to CUTENSOR 2.0 (#2178)
Co-authored-by: Christos Psarras <cpsarras@nvidia.com>
1 parent e718b0d commit 7d74f90

21 files changed

+1257
-1134
lines changed

.buildkite/pipeline.yml

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,7 @@ steps:
139139
withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0) do
140140
Pkg.instantiate()
141141
142-
pkgs = [PackageSpec(path=joinpath(pwd(), "lib", lowercase("{{matrix.package}}")))]
143-
if "{{matrix.package}}" == "cuTensorNet"
144-
# cuTensorNet depends on a development version of cuTENSOR
145-
push!(pkgs, PackageSpec(path=joinpath(pwd(), "lib", "cutensor")))
146-
end
147-
Pkg.develop(pkgs)
142+
Pkg.develop(path=joinpath(pwd(), "lib", lowercase("{{matrix.package}}")))
148143
149144
write("LocalPreferences.toml", "[CUDA_Runtime_jll]\nversion = \"{{matrix.cuda}}\"")
150145
end

lib/cutensor/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "cuTENSOR"
22
uuid = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
33
authors = ["Tim Besard <tim.besard@gmail.com>"]
4-
version = "1.2.1"
4+
version = "2.0"
55

66
[deps]
77
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
@@ -14,6 +14,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1414
CEnum = "0.2, 0.3, 0.4"
1515
CUDA = "~5.1"
1616
CUDA_Runtime_Discovery = "0.2"
17-
CUTENSOR_jll = "~1.7"
17+
CUTENSOR_jll = "~2.0"
1818
julia = "1.6"
1919
LinearAlgebra = "1"

lib/cutensor/src/cuTENSOR.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,21 @@ include("libcutensor.jl")
2727

2828
# low-level wrappers
2929
include("error.jl")
30-
include("tensor.jl")
31-
include("wrappers.jl")
30+
include("utils.jl")
31+
include("types.jl")
32+
include("operations.jl")
3233

3334
# high-level integrations
3435
include("interfaces.jl")
3536

3637
# cache for created, but unused handles
37-
const idle_handles = HandleCache{CuContext,Ptr{cutensorHandle_t}}()
38+
const idle_handles = HandleCache{CuContext,cutensorHandle_t}()
3839

3940
function handle()
4041
cuda = CUDA.active_state()
4142

4243
# every task maintains library state per device
43-
LibraryState = @NamedTuple{handle::Ptr{cutensorHandle_t}}
44+
LibraryState = @NamedTuple{handle::cutensorHandle_t}
4445
states = get!(task_local_storage(), :cuTENSOR) do
4546
Dict{CuContext,LibraryState}()
4647
end::Dict{CuContext,LibraryState}

lib/cutensor/src/error.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ function description(err::CUTENSORError)
4141
"insufficient workspace memory for this operation"
4242
elseif err.code == CUTENSOR_STATUS_INSUFFICIENT_DRIVER
4343
"insufficient driver version"
44+
elseif err.code == CUTENSOR_STATUS_IO_ERROR
45+
"file not found"
4446
else
4547
"no description for this error"
4648
end

lib/cutensor/src/interfaces.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@ function Base.:(+)(A::CuTensor, B::CuTensor)
66
α = convert(eltype(A), 1.0)
77
γ = convert(eltype(B), 1.0)
88
C = similar(B)
9-
elementwiseBinary!(α, A, CUTENSOR_OP_IDENTITY, γ, B, CUTENSOR_OP_IDENTITY, C, CUTENSOR_OP_ADD)
9+
elementwise_binary!(α, A.data, A.inds, CUTENSOR_OP_IDENTITY, γ, B.data, B.inds, CUTENSOR_OP_IDENTITY, C.data, C.inds, CUTENSOR_OP_ADD)
10+
C
1011
end
1112

1213
function Base.:(-)(A::CuTensor, B::CuTensor)
1314
α = convert(eltype(A), 1.0)
1415
γ = convert(eltype(B), -1.0)
1516
C = similar(B)
16-
elementwiseBinary!(α, A, CUTENSOR_OP_IDENTITY, γ, B, CUTENSOR_OP_IDENTITY, C, CUTENSOR_OP_ADD)
17+
elementwise_binary!(α, A.data, A.inds, CUTENSOR_OP_IDENTITY, γ, B.data, B.inds, CUTENSOR_OP_IDENTITY, C.data, C.inds, CUTENSOR_OP_ADD)
18+
C
1719
end
1820

1921
function Base.:(*)(A::CuTensor, B::CuTensor)
@@ -33,8 +35,15 @@ end
3335

3436
using LinearAlgebra
3537

36-
LinearAlgebra.axpy!(a, X::CuTensor, Y::CuTensor) = elementwiseBinary!(a, X, CUTENSOR_OP_IDENTITY, one(eltype(Y)), Y, CUTENSOR_OP_IDENTITY, Y, CUTENSOR_OP_ADD)
37-
LinearAlgebra.axpby!(a, X::CuTensor, b, Y::CuTensor) = elementwiseBinary!(a, X, CUTENSOR_OP_IDENTITY, b, Y, CUTENSOR_OP_IDENTITY, Y, CUTENSOR_OP_ADD)
38+
function LinearAlgebra.axpy!(a, X::CuTensor, Y::CuTensor)
39+
elementwise_binary!(a, X.data, X.inds, CUTENSOR_OP_IDENTITY, one(eltype(Y)), Y.data, Y.inds, CUTENSOR_OP_IDENTITY, Y.data, Y.inds, CUTENSOR_OP_ADD)
40+
return Y
41+
end
42+
43+
function LinearAlgebra.axpby!(a, X::CuTensor, b, Y::CuTensor)
44+
elementwise_binary!(a, X.data, X.inds, CUTENSOR_OP_IDENTITY, b, Y.data, Y.inds, CUTENSOR_OP_IDENTITY, Y.data, Y.inds, CUTENSOR_OP_ADD)
45+
return Y
46+
end
3847

3948
function LinearAlgebra.mul!(C::CuTensor, A::CuTensor, B::CuTensor)
4049
contraction!(one(eltype(C)), A.data, A.inds, CUTENSOR_OP_IDENTITY, B.data, B.inds, CUTENSOR_OP_IDENTITY, zero(eltype(C)), C.data, C.inds, CUTENSOR_OP_IDENTITY, CUTENSOR_OP_IDENTITY)

0 commit comments

Comments
 (0)