Skip to content

Commit 1a57e9f

Browse files
authored
Add header for cublasLt. (#2324)
Also significantly shorten the API listings by using a template mechanism.
1 parent 0f5ec3a commit 1a57e9f

File tree

12 files changed

+2259
-5212
lines changed

12 files changed

+2259
-5212
lines changed

lib/cublas/CUBLAS.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ using LLVM.Interop: assume
2121
using CEnum: @cenum
2222

2323

24+
const cudaDataType_t = cudaDataType
25+
2426
# core library
2527
include("libcublas.jl")
2628
include("libcublas_deprecated.jl")

lib/cublas/libcublas.jl

Lines changed: 790 additions & 26 deletions
Large diffs are not rendered by default.

lib/cusparse/libcusparse.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1065,7 +1065,7 @@ end
10651065
bsrSortedColInd::CuPtr{Cint},
10661066
blockSize::Cint,
10671067
info::bsrsm2Info_t,
1068-
pBufferSize::CuPtr{Csize_t})::cusparseStatus_t
1068+
pBufferSize::Ptr{Csize_t})::cusparseStatus_t
10691069
end
10701070

10711071
@checked function cusparseDbsrsm2_bufferSizeExt(handle, dirA, transA, transB, mb, n, nnzb,

lib/cutensor/Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
name = "cuTENSOR"
22
uuid = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
33
authors = ["Tim Besard <tim.besard@gmail.com>"]
4-
version = "2.1"
4+
version = "2.1.0"
55

66
[deps]
77
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
88
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
99
CUDA_Runtime_Discovery = "1af6417a-86b4-443c-805f-a4643ffb695f"
1010
CUTENSOR_jll = "35b6c64b-1ee1-5834-92a3-3f624899209a"
1111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
12+
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1213

1314
[compat]
1415
CEnum = "0.2, 0.3, 0.4, 0.5"
1516
CUDA = "~5.3"
1617
CUDA_Runtime_Discovery = "0.2"
1718
CUTENSOR_jll = "~2.0"
18-
julia = "1.8"
1919
LinearAlgebra = "1"
20+
julia = "1.8"

lib/cutensor/src/cuTENSOR.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ using CUDA: retry_reclaim, initialize_context, isdebug
77

88
using CEnum: @cenum
99

10+
using Printf: @printf
11+
1012
if CUDA.local_toolkit
1113
using CUDA_Runtime_Discovery
1214
else

lib/cutensor/src/libcutensor.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -535,8 +535,9 @@ for desc in [:CUTENSOR_COMPUTE_DESC_16F,
535535
:CUTENSOR_COMPUTE_DESC_64F]
536536
@eval begin
537537
function $desc()
538-
ptr = Ptr{cutensorComputeDescriptor_t}(cglobal(($(QuoteNode(desc)), libcutensor)))
539-
unsafe_load(ptr)
538+
ptr = Ptr{cutensorComputeDescriptor_t}(cglobal(($(QuoteNode(desc)),
539+
libcutensor)))
540+
return unsafe_load(ptr)
540541
end
541542
end
542543
end

lib/cutensor/src/types.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ mutable struct CuTensorPlan
170170
end
171171
end
172172

173+
Base.show(io::IO, plan::CuTensorPlan) = @printf(io, "CuTensorPlan(%p)", plan.handle)
174+
173175
Base.unsafe_convert(::Type{cutensorPlan_t}, plan::CuTensorPlan) = plan.handle
174176

175177
Base.:(==)(a::CuTensorPlan, b::CuTensorPlan) = a.handle == b.handle
@@ -220,6 +222,8 @@ mutable struct CuTensorDescriptor
220222
end
221223
end
222224

225+
Base.show(io::IO, desc::CuTensorDescriptor) = @printf(io, "CuTensorDescriptor(%p)", desc.handle)
226+
223227
Base.unsafe_convert(::Type{cutensorTensorDescriptor_t}, obj::CuTensorDescriptor) = obj.handle
224228

225229
function unsafe_destroy!(obj::CuTensorDescriptor)

res/wrap/Manifest.toml

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,57 +17,52 @@ version = "0.5.0"
1717

1818
[[CSTParser]]
1919
deps = ["Tokenize"]
20-
git-tree-sha1 = "b1d309487c04e92253b55c1f803b1d6f0e136920"
20+
git-tree-sha1 = "b544d62417a99d091c569b95109bc9d8c223e9e3"
2121
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
22-
version = "3.4.1"
22+
version = "3.4.2"
2323

2424
[[CUDA_Driver_jll]]
2525
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"]
26-
git-tree-sha1 = "d01bfc999768f0a31ed36f5d22a76161fc63079c"
26+
git-tree-sha1 = "dc172b558adbf17952001e15cf0d6364e6d78c2f"
2727
uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc"
28-
version = "0.7.0+1"
28+
version = "0.8.1+0"
2929

3030
[[CUDA_Runtime_jll]]
3131
deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
32-
git-tree-sha1 = "8e25c009d2bf16c2c31a70a6e9e8939f7325cc84"
32+
git-tree-sha1 = "4ca7d6d92075906c2ce871ea8bba971fff20d00c"
3333
uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
34-
version = "0.11.1+0"
34+
version = "0.12.1+0"
3535

3636
[[CUDA_SDK_jll]]
3737
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
38-
git-tree-sha1 = "a188221ab66c7608f9af75f93fe2bd6d0a14b4b5"
38+
git-tree-sha1 = "752d5810571b8639702ada3b098cf550babb1967"
3939
uuid = "6cbf2f2e-7e60-5632-ac76-dca2274e0be0"
40-
version = "12.3.2+0"
40+
version = "12.4.1+0"
4141

4242
[[CUDNN_jll]]
4343
deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
44-
git-tree-sha1 = "b188220f9bd361db61a72326e987d57c0750ac5e"
44+
git-tree-sha1 = "cbf7d75f8c58b147bdf6acea2e5bc96cececa6d4"
4545
uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645"
46-
version = "9.0.0+0"
46+
version = "9.0.0+1"
4747

4848
[[CUTENSOR_jll]]
4949
deps = ["Artifacts", "CUDA_Runtime_jll", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
50-
git-tree-sha1 = "aff33b54e97432cba542bf646c405a6b03cc29ac"
50+
git-tree-sha1 = "2ad02c8180d94cca10336fc5646a7e24ab4aa268"
5151
uuid = "35b6c64b-1ee1-5834-92a3-3f624899209a"
52-
version = "2.0.0+0"
52+
version = "2.0.1+0"
5353

5454
[[Clang]]
5555
deps = ["CEnum", "Clang_jll", "Downloads", "Pkg", "TOML"]
56-
git-tree-sha1 = "846054622cb22aa63b5d51b5d84ec04b42d4d587"
56+
git-tree-sha1 = "be935fd478265159ffdb1a949489a5f91319fb95"
5757
uuid = "40e3b903-d033-50b4-a0cc-940c62c95e31"
58-
version = "0.17.8"
58+
version = "0.18.1"
5959

6060
[[Clang_jll]]
6161
deps = ["Artifacts", "JLLWrappers", "Libdl", "TOML", "Zlib_jll", "libLLVM_jll"]
6262
git-tree-sha1 = "de2204d98741f57e7ddb9a6a738db74ba8a608cb"
6363
uuid = "0ee61d77-7f21-5576-8119-9fcc46b10100"
6464
version = "15.0.7+10"
6565

66-
[[Combinatorics]]
67-
git-tree-sha1 = "08c8b6831dc00bfea825826be0bc8336fc369860"
68-
uuid = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
69-
version = "1.0.2"
70-
7166
[[CommonMark]]
7267
deps = ["Crayons", "JSON", "PrecompileTools", "URIs"]
7368
git-tree-sha1 = "532c4185d3c9037c0237546d817858b23cf9e071"
@@ -76,9 +71,9 @@ version = "0.8.12"
7671

7772
[[Compat]]
7873
deps = ["TOML", "UUIDs"]
79-
git-tree-sha1 = "75bd5b6fc5089df449b5d35fa501c846c9b6549b"
74+
git-tree-sha1 = "c955881e3c981181362ae4088b35995446298b80"
8075
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
81-
version = "4.12.0"
76+
version = "4.14.0"
8277

8378
[Compat.extensions]
8479
CompatLinearAlgebraExt = "LinearAlgebra"
@@ -90,7 +85,7 @@ version = "4.12.0"
9085
[[CompilerSupportLibraries_jll]]
9186
deps = ["Artifacts", "Libdl"]
9287
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
93-
version = "1.0.5+1"
88+
version = "1.1.0+0"
9489

9590
[[Crayons]]
9691
git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15"
@@ -99,9 +94,9 @@ version = "4.1.1"
9994

10095
[[DataStructures]]
10196
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
102-
git-tree-sha1 = "ac67408d9ddf207de5cfa9a97e114352430f01ed"
97+
git-tree-sha1 = "97d79461925cdb635ee32116978fc735b9463a39"
10398
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
104-
version = "0.18.16"
99+
version = "0.18.19"
105100

106101
[[Dates]]
107102
deps = ["Printf"]
@@ -137,10 +132,10 @@ uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
137132
version = "0.21.4"
138133

139134
[[JuliaFormatter]]
140-
deps = ["CSTParser", "Combinatorics", "CommonMark", "DataStructures", "Glob", "Pkg", "PrecompileTools", "Tokenize"]
141-
git-tree-sha1 = "bf3bdb6d310b8106fa13f69eb9cd9c6a53b82b5b"
135+
deps = ["CSTParser", "CommonMark", "DataStructures", "Glob", "Pkg", "PrecompileTools", "Tokenize"]
136+
git-tree-sha1 = "1c4880cb70a5c6c87ea36deccc3d7f9e7969c18c"
142137
uuid = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
143-
version = "1.0.47"
138+
version = "1.0.56"
144139

145140
[[LazyArtifacts]]
146141
deps = ["Artifacts", "Pkg"]
@@ -238,15 +233,15 @@ version = "1.10.0"
238233

239234
[[PrecompileTools]]
240235
deps = ["Preferences"]
241-
git-tree-sha1 = "03b4c25b43cb84cee5c90aa9b5ea0a78fd848d2f"
236+
git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f"
242237
uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
243-
version = "1.2.0"
238+
version = "1.2.1"
244239

245240
[[Preferences]]
246241
deps = ["TOML"]
247-
git-tree-sha1 = "00805cd429dcb4870060ff49ef443486c262e38e"
242+
git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6"
248243
uuid = "21216c6a-2e73-6563-6e65-726566657250"
249-
version = "1.4.1"
244+
version = "1.4.3"
250245

251246
[[Printf]]
252247
deps = ["Unicode"]
@@ -299,9 +294,9 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
299294

300295
[[XML2_jll]]
301296
deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Zlib_jll"]
302-
git-tree-sha1 = "801cbe47eae69adc50f36c3caec4758d2650741b"
297+
git-tree-sha1 = "532e22cf7be8462035d092ff21fada7527e2c488"
303298
uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a"
304-
version = "2.12.2+0"
299+
version = "2.12.6+0"
305300

306301
[[XSLT_jll]]
307302
deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "Pkg", "XML2_jll", "Zlib_jll"]

0 commit comments

Comments
 (0)