From 9c1cd1893f59f924874a73d7a35b4c65412b9f83 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Thu, 8 Oct 2020 09:33:21 +0200 Subject: [PATCH 1/2] add simple 3-arg and 4-arg * methods --- src/matrix_multiply.jl | 9 ++++++--- test/matrix_multiply.jl | 3 +++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/matrix_multiply.jl b/src/matrix_multiply.jl index 080a5a04..c3b2d217 100644 --- a/src/matrix_multiply.jl +++ b/src/matrix_multiply.jl @@ -15,11 +15,15 @@ import LinearAlgebra: BlasFloat, matprod, mul! @inline *(A::StaticArray{Tuple{N,1},<:Any,2}, B::Adjoint{<:Any,<:StaticVector}) where {N} = vec(A) * B @inline *(A::StaticArray{Tuple{N,1},<:Any,2}, B::Transpose{<:Any,<:StaticVector}) where {N} = vec(A) * B +# Avoid LinearAlgebra._quad_matmul's order calculation on equal sizes +@inline *(A::StaticMatrix{N,N}, B::StaticMatrix{N,N}, C::StaticMatrix{N,N}) where {N} = (A*B)*C +@inline *(A::StaticMatrix{N,N}, B::StaticMatrix{N,N}, C::StaticMatrix{N,N}, D::StaticMatrix{N,N}) where {N} = ((A*B)*C)*D + """ mul_result_structure(a::Type, b::Type) Get a structure wrapper that should be applied to the result of multiplication of matrices -of given types (a*b). +of given types (a*b). """ function mul_result_structure(a, b) return identity @@ -114,7 +118,6 @@ end b::Union{Transpose{Tb, <:StaticVector}, Adjoint{Tb, <:StaticVector}}) where {sa, sb, Ta, Tb} newsize = (sa[1], sb[2]) exprs = [:(a[$i]*b[$j]) for i = 1:sa[1], j = 1:sb[2]] - return quote @_inline_meta T = promote_op(*, Ta, Tb) @@ -209,7 +212,7 @@ end while m < M mu = min(M, m + M_r) mrange = m+1:mu - + atemps_init = [:($(atemps[k1]) = a[$k1]) for k1 = mrange] exprs_init = [:($(tmps[k1,k2]) = $(atemps[k1]) * b[$(1 + (k2-1) * sb[1])]) for k1 = mrange, k2 = nrange] atemps_loop_init = [:($(atemps[k1]) = a[$(k1-sa[1]) + $(sa[1])*j]) for k1 = mrange] diff --git a/test/matrix_multiply.jl b/test/matrix_multiply.jl index 0f94bd7e..439289aa 100644 --- a/test/matrix_multiply.jl +++ b/test/matrix_multiply.jl @@ -173,6 +173,9 @@ mul_wrappers = [ @test m*transpose(n) === @SMatrix [8 14; 18 32] @test transpose(m)*transpose(n) === @SMatrix [11 19; 16 28] + @test @inferred(m*n*m) === @SMatrix [49 72; 109 160] + @test @inferred(m*n*m*n) === @SMatrix [386 507; 858 1127] + # check different sizes because there are multiple implementations for matrices of different sizes for (mm, nn) in [ (m, n), From fbea54435111caa67d06416e697b54bd1ba1692b Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 8 Jun 2021 09:19:17 -0400 Subject: [PATCH 2/2] widen to StaticMatMulLike --- src/matrix_multiply.jl | 4 ++-- test/matrix_multiply.jl | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/matrix_multiply.jl b/src/matrix_multiply.jl index c3b2d217..f60b1719 100644 --- a/src/matrix_multiply.jl +++ b/src/matrix_multiply.jl @@ -16,8 +16,8 @@ import LinearAlgebra: BlasFloat, matprod, mul! @inline *(A::StaticArray{Tuple{N,1},<:Any,2}, B::Transpose{<:Any,<:StaticVector}) where {N} = vec(A) * B # Avoid LinearAlgebra._quad_matmul's order calculation on equal sizes -@inline *(A::StaticMatrix{N,N}, B::StaticMatrix{N,N}, C::StaticMatrix{N,N}) where {N} = (A*B)*C -@inline *(A::StaticMatrix{N,N}, B::StaticMatrix{N,N}, C::StaticMatrix{N,N}, D::StaticMatrix{N,N}) where {N} = ((A*B)*C)*D +@inline *(A::StaticMatMulLike{N,N}, B::StaticMatMulLike{N,N}, C::StaticMatMulLike{N,N}) where {N} = (A*B)*C +@inline *(A::StaticMatMulLike{N,N}, B::StaticMatMulLike{N,N}, C::StaticMatMulLike{N,N}, D::StaticMatMulLike{N,N}) where {N} = ((A*B)*C)*D """ mul_result_structure(a::Type, b::Type) diff --git a/test/matrix_multiply.jl b/test/matrix_multiply.jl index 439289aa..c80586f4 100644 --- a/test/matrix_multiply.jl +++ b/test/matrix_multiply.jl @@ -173,8 +173,11 @@ mul_wrappers = [ @test m*transpose(n) === @SMatrix [8 14; 18 32] @test transpose(m)*transpose(n) === @SMatrix [11 19; 16 28] + # 3- and 4-arg * @test @inferred(m*n*m) === @SMatrix [49 72; 109 160] @test @inferred(m*n*m*n) === @SMatrix [386 507; 858 1127] + @test @inferred(m*n'*UpperTriangular(m)) === @SMatrix [8 72; 18 164] + @test @inferred(Diagonal(m)*n*m'*transpose(n)) === @SMatrix [70 122; 496 864] # check different sizes because there are multiple implementations for matrices of different sizes for (mm, nn) in [