@@ -5,7 +5,7 @@ using LinearAlgebra: Transpose, Adjoint,
5
5
Hermitian, Symmetric,
6
6
LowerTriangular, UnitLowerTriangular,
7
7
UpperTriangular, UnitUpperTriangular,
8
- MulAddMul, wrap
8
+ UpperOrLowerTriangular, MulAddMul, wrap
9
9
10
10
#
11
11
# BLAS 1
@@ -163,12 +163,50 @@ function LinearAlgebra.generic_matmatmul!(C::oneStridedMatrix, tA, tB, A::oneStr
163
163
GPUArrays. generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), alpha, beta)
164
164
end
165
165
166
+ const AdjOrTransOroneMatrix{T} = Union{oneStridedMatrix{T}, AdjOrTrans{<: T ,<: oneStridedMatrix }}
167
+
168
+ function LinearAlgebra. generic_trimatmul! (
169
+ C:: oneStridedMatrix{T} , uplocA, isunitcA,
170
+ tfunA:: Function , A:: oneStridedMatrix{T} ,
171
+ triB:: UpperOrLowerTriangular{T, <: AdjOrTransOroneMatrix{T}} ,
172
+ ) where {T<: onemklFloat }
173
+ uplocB = LinearAlgebra. uplo_char (triB)
174
+ isunitcB = LinearAlgebra. isunit_char (triB)
175
+ B = parent (triB)
176
+ tfunB = LinearAlgebra. wrapperop (B)
177
+ transa = tfunA === identity ? ' N' : tfunA === transpose ? ' T' : ' C'
178
+ transb = tfunB === identity ? ' N' : tfunB === transpose ? ' T' : ' C'
179
+ if uplocA == ' L' && tfunA === identity && tfunB === identity && uplocB == ' U' && isunitcB == ' N' # lower * upper
180
+ triu! (B)
181
+ trmm! (' L' , uplocA, transa, isunitcA, one (T), A, B, C)
182
+ elseif uplocA == ' U' && tfunA === identity && tfunB === identity && uplocB == ' L' && isunitcB == ' N' # upper * lower
183
+ tril! (B)
184
+ trmm! (' L' , uplocA, transa, isunitcA, one (T), A, B, C)
185
+ elseif uplocA == ' U' && tfunA === identity && tfunB != = identity && uplocB == ' U' && isunitcA == ' N'
186
+ # operation is reversed to avoid executing the tranpose
187
+ triu! (A)
188
+ trmm! (' R' , uplocB, transb, isunitcB, one (T), parent (B), A, C)
189
+ elseif uplocA == ' L' && tfunA != = identity && tfunB === identity && uplocB == ' L' && isunitcB == ' N'
190
+ tril! (B)
191
+ trmm! (' L' , uplocA, transa, isunitcA, one (T), A, B, C)
192
+ elseif uplocA == ' U' && tfunA != = identity && tfunB === identity && uplocB == ' U' && isunitcB == ' N'
193
+ triu! (B)
194
+ trmm! (' L' , uplocA, transa, isunitcA, one (T), A, B, C)
195
+ elseif uplocA == ' L' && tfunA === identity && tfunB != = identity && uplocB == ' L' && isunitcA == ' N'
196
+ tril! (A)
197
+ trmm! (' R' , uplocB, transb, isunitcB, one (T), parent (B), A, C)
198
+ else
199
+ throw (" mixed triangular-triangular multiplication" ) # TODO : rethink
200
+ end
201
+ return C
202
+ end
203
+
166
204
# triangular
167
205
LinearAlgebra. generic_trimatmul! (C:: oneStridedMatrix{T} , uploc, isunitc, tfun:: Function , A:: oneStridedMatrix{T} , B:: oneStridedMatrix{T} ) where {T<: onemklFloat } =
168
- trmm! (' L' , uploc, tfun === identity ? ' N' : tfun === transpose ? ' T' : ' C' , isunitc, one (T), A, C === B ? C : copyto! (C, B) )
206
+ trmm! (' L' , uploc, tfun === identity ? ' N' : tfun === transpose ? ' T' : ' C' , isunitc, one (T), A, B, C )
169
207
LinearAlgebra. generic_mattrimul! (C:: oneStridedMatrix{T} , uploc, isunitc, tfun:: Function , A:: oneStridedMatrix{T} , B:: oneStridedMatrix{T} ) where {T<: onemklFloat } =
170
- trmm! (' R' , uploc, tfun === identity ? ' N' : tfun === transpose ? ' T' : ' C' , isunitc, one (T), B, C === A ? C : copyto! (C, A) )
171
- LinearAlgebra. generic_trimatdiv! (C:: oneStridedMatrix{T} , uploc, isunitc, tfun:: Function , A:: oneStridedMatrix{T} , B:: oneStridedMatrix {T} ) where {T<: onemklFloat } =
172
- trsm! (' L' , uploc, tfun === identity ? ' N' : tfun === transpose ? ' T' : ' C' , isunitc, one (T), A, C === B ? C : copyto! (C, B) )
173
- LinearAlgebra. generic_mattridiv! (C:: oneStridedMatrix{T} , uploc, isunitc, tfun:: Function , A:: oneStridedMatrix {T} , B:: oneStridedMatrix{T} ) where {T<: onemklFloat } =
174
- trsm! (' R' , uploc, tfun === identity ? ' N' : tfun === transpose ? ' T' : ' C' , isunitc, one (T), B, C === A ? C : copyto! (C, A) )
208
+ trmm! (' R' , uploc, tfun === identity ? ' N' : tfun === transpose ? ' T' : ' C' , isunitc, one (T), B, A, C )
209
+ LinearAlgebra. generic_trimatdiv! (C:: oneStridedMatrix{T} , uploc, isunitc, tfun:: Function , A:: oneStridedMatrix{T} , B:: AbstractMatrix {T} ) where {T<: onemklFloat } =
210
+ trsm! (' L' , uploc, tfun === identity ? ' N' : tfun === transpose ? ' T' : ' C' , isunitc, one (T), A, B, C )
211
+ LinearAlgebra. generic_mattridiv! (C:: oneStridedMatrix{T} , uploc, isunitc, tfun:: Function , A:: AbstractMatrix {T} , B:: oneStridedMatrix{T} ) where {T<: onemklFloat } =
212
+ trsm! (' R' , uploc, tfun === identity ? ' N' : tfun === transpose ? ' T' : ' C' , isunitc, one (T), B, A, C )
0 commit comments