Skip to content

Commit 05ba598

Browse files
Merge branch 'master' into qr_views
2 parents d9ca07a + 96a4d0c commit 05ba598

File tree

3 files changed

+46
-13
lines changed

3 files changed

+46
-13
lines changed

Manifest.toml

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
77
version = "1.2.1"
88

99
[[Adapt]]
10-
deps = ["LinearAlgebra"]
11-
git-tree-sha1 = "0310e08cb19f5da31d08341c6120c047598f5b9c"
10+
deps = ["LinearAlgebra", "Requires"]
11+
git-tree-sha1 = "cc37d689f599e8df4f464b2fa3870ff7db7492ef"
1212
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
13-
version = "3.5.0"
13+
version = "3.6.1"
1414

1515
[[ArgTools]]
1616
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
@@ -58,9 +58,9 @@ version = "1.15.7"
5858

5959
[[ChangesOfVariables]]
6060
deps = ["ChainRulesCore", "LinearAlgebra", "Test"]
61-
git-tree-sha1 = "844b061c104c408b24537482469400af6075aae4"
61+
git-tree-sha1 = "485193efd2176b88e6622a39a246f8c5b600e74e"
6262
uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
63-
version = "0.1.5"
63+
version = "0.1.6"
6464

6565
[[Compat]]
6666
deps = ["Dates", "LinearAlgebra", "UUIDs"]
@@ -124,9 +124,9 @@ uuid = "3587e190-3f89-42d0-90ee-14403ec27112"
124124
version = "0.1.8"
125125

126126
[[IrrationalConstants]]
127-
git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151"
127+
git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2"
128128
uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
129-
version = "0.1.1"
129+
version = "0.2.2"
130130

131131
[[JLLWrappers]]
132132
deps = ["Preferences"]
@@ -142,9 +142,9 @@ version = "4.16.0"
142142

143143
[[LLVMExtra_jll]]
144144
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"]
145-
git-tree-sha1 = "771bfe376249626d3ca12bcd58ba243d3f961576"
145+
git-tree-sha1 = "7718cf44439c676bc0ec66a87099f41015a522d6"
146146
uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab"
147-
version = "0.0.16+0"
147+
version = "0.0.16+2"
148148

149149
[[LazyArtifacts]]
150150
deps = ["Artifacts", "Pkg"]
@@ -175,9 +175,9 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
175175

176176
[[LogExpFunctions]]
177177
deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"]
178-
git-tree-sha1 = "680e733c3a0a9cea9e935c8c2184aea6a63fa0b5"
178+
git-tree-sha1 = "0a1b7c2863e44523180fdb3146534e265a91870b"
179179
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
180-
version = "0.3.21"
180+
version = "0.3.23"
181181

182182
[[Logging]]
183183
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
@@ -266,9 +266,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
266266

267267
[[SpecialFunctions]]
268268
deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"]
269-
git-tree-sha1 = "d75bda01f8c31ebb72df80a46c88b25d1c79c56d"
269+
git-tree-sha1 = "ef28127915f4229c971eb43f3fc075dd3fe91880"
270270
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
271-
version = "2.1.7"
271+
version = "2.2.0"
272272

273273
[[Statistics]]
274274
deps = ["LinearAlgebra", "SparseArrays"]

lib/cublas/linalg.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,3 +591,15 @@ end
591591
error("only supports BLAS type, got $T")
592592
end
593593
end
594+
595+
op_wrappers = ((identity, T -> 'N', identity),
596+
(T -> :(Transpose{T, <:$T}), T -> 'T', A -> :(parent($A))),
597+
(T -> :(Adjoint{T, <:$T}), T -> T <: Real ? 'T' : 'C', A -> :(parent($A))))
598+
599+
for op in (:(+), :(-))
600+
for (wrapa, transa, unwrapa) in op_wrappers, (wrapb, transb, unwrapb) in op_wrappers
601+
TypeA = wrapa(:(CuMatrix{T}))
602+
TypeB = wrapb(:(CuMatrix{T}))
603+
@eval Base.$op(A::$TypeA, B::$TypeB) where {T <: CublasFloat} = CUBLAS.geam($transa(T), $transb(T), one(T), $(unwrapa(:A)), $(op)(one(T)), $(unwrapb(:B)))
604+
end
605+
end

test/cublas.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,27 @@ end
11981198
h_C = Array(d_C)
11991199
@test D h_C
12001200
end
1201+
@testset "CuMatrix -- A ± B -- $elty" begin
1202+
for opa in (identity, transpose, adjoint)
1203+
for opb in (identity, transpose, adjoint)
1204+
n = 10
1205+
m = 20
1206+
geam_A = opa == identity ? rand(elty, n, m) : rand(elty, m, n)
1207+
geam_B = opb == identity ? rand(elty, n, m) : rand(elty, m, n)
1208+
1209+
geam_dA = CuMatrix{elty}(geam_A)
1210+
geam_dB = CuMatrix{elty}(geam_B)
1211+
1212+
geam_C = opa(geam_A) + opb(geam_B)
1213+
geam_dC = opa(geam_dA) + opb(geam_dB)
1214+
@test geam_C collect(geam_dC)
1215+
1216+
geam_C = opa(geam_A) - opb(geam_B)
1217+
geam_dC = opa(geam_dA) - opb(geam_dB)
1218+
@test geam_C collect(geam_dC)
1219+
end
1220+
end
1221+
end
12011222
A = rand(elty,m,k)
12021223
d_A = CuArray(A)
12031224
@testset "syrkx!" begin

0 commit comments

Comments
 (0)