Skip to content

Commit 31f8be1

Browse files
authored
Promote hcat vcat and fix complex outer product
* Capture complex outer product Fixes #156 * Fix promotion in hcat and vcat Fixes #155 * Fix outer product
1 parent 190aa71 commit 31f8be1

File tree

4 files changed

+29
-3
lines changed

4 files changed

+29
-3
lines changed

src/linalg.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ end
103103

104104
return quote
105105
@_inline_meta
106-
@inbounds return similar_type(a, Size($Snew))(tuple($(exprs...)))
106+
@inbounds return similar_type(a, promote_type(eltype(a), eltype(b)), Size($Snew))(tuple($(exprs...)))
107107
end
108108
end
109109
# TODO make these more efficient
@@ -129,7 +129,7 @@ end
129129

130130
return quote
131131
@_inline_meta
132-
@inbounds return similar_type(a, Size($Snew))(tuple($(exprs...)))
132+
@inbounds return similar_type(a, promote_type(eltype(a), eltype(b)), Size($Snew))(tuple($(exprs...)))
133133
end
134134
end
135135
# TODO make these more efficient
@@ -297,8 +297,8 @@ end
297297
end
298298
end
299299

300-
# TODO same for `RowVector`?
301300
@inline Size(::Union{RowVector{T, SA}, Type{RowVector{T, SA}}}) where {T, SA <: StaticArray} = Size(1, Size(SA)[1])
301+
@inline Size(::Union{RowVector{T, CA}, Type{RowVector{T, CA}}} where CA <: ConjVector{<:Any, SA}) where {T, SA <: StaticArray} = Size(1, Size(SA)[1])
302302
@inline Size(::Union{Symmetric{T,SA}, Type{Symmetric{T,SA}}}) where {T,SA<:StaticArray} = Size(SA)
303303
@inline Size(::Union{Hermitian{T,SA}, Type{Hermitian{T,SA}}}) where {T,SA<:StaticArray} = Size(SA)
304304

src/matrix_multiply.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,13 @@ promote_matprod{T1,T2}(::Type{T1}, ::Type{T2}) = typeof(zero(T1)*zero(T2) + zero
3636
@inline *(A::StaticMatrix, B::StaticMatrix) = _A_mul_B(Size(A), Size(B), A, B)
3737
@inline *(A::StaticVector, B::StaticMatrix) = *(reshape(A, Size(Size(A)[1], 1)), B)
3838
@inline *(A::StaticVector, B::RowVector{<:Any, <:StaticVector}) = _A_mul_B(Size(A), Size(B), A, B)
39+
@inline *(A::StaticVector, B::RowVector{<:Any, <:ConjVector{<:Any, <:StaticVector}}) = _A_mul_B(Size(A), Size(B), A, B)
3940

4041
@inline A_mul_B!(dest::StaticVecOrMat, A::StaticMatrix, B::StaticVector) = _A_mul_B!(Size(dest), dest, Size(A), Size(B), A, B)
4142
@inline A_mul_B!(dest::StaticVecOrMat, A::StaticMatrix, B::StaticMatrix) = _A_mul_B!(Size(dest), dest, Size(A), Size(B), A, B)
4243
@inline A_mul_B!(dest::StaticVecOrMat, A::StaticVector, B::StaticMatrix) = A_mul_B!(dest, reshape(A, Size(Size(A)[1], 1)), B)
4344
@inline A_mul_B!(dest::StaticVecOrMat, A::StaticVector, B::RowVector{<:Any, <:StaticVector}) = _A_mul_B!(Size(dest), dest, Size(A), Size(B), A, B)
45+
@inline A_mul_B!(dest::StaticVecOrMat, A::StaticVector, B::RowVector{<:Any, <:ConjVector{<:Any, <:StaticVector}}) = _A_mul_B!(Size(dest), dest, Size(A), Size(B), A, B)
4446

4547
#@inline *{TA<:Base.LinAlg.BlasFloat,Tb}(A::StaticMatrix{TA}, b::StaticVector{Tb})
4648

@@ -93,6 +95,18 @@ end
9395
end
9496
end
9597

98+
# complex outer product
99+
@generated function _A_mul_B(::Size{sa}, ::Size{sb}, a::StaticVector{<: Any, Ta}, b::RowVector{Tb, <:ConjVector{<:Any, <:StaticVector}}) where {sa, sb, Ta, Tb}
100+
newsize = (sa[1], sb[2])
101+
exprs = [:(a[$i]*b[$j]) for i = 1:sa[1], j = 1:sb[2]]
102+
103+
return quote
104+
@_inline_meta
105+
T = promote_op(*, Ta, Tb)
106+
@inbounds return similar_type(b, T, Size($newsize))(tuple($(exprs...)))
107+
end
108+
end
109+
96110
@generated function _A_mul_B(Sa::Size{sa}, Sb::Size{sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticMatrix{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
97111
# Heuristic choice for amount of codegen
98112
if sa[1]*sa[2]*sb[2] <= 8*8*8

test/linalg.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@
100100

101101
@test @inferred(vcat(SVector(1),SVector(2),SVector(3),SVector(4))) === SVector(1,2,3,4)
102102
@test @inferred(hcat(SVector(1),SVector(2),SVector(3),SVector(4))) === SMatrix{1,4}(1,2,3,4)
103+
104+
vcat(SVector(1.0f0), SVector(1.0)) === SVector(1.0, 1.0)
105+
hcat(SVector(1.0f0), SVector(1.0)) === SMatrix{1,2}(1.0, 1.0)
103106
end
104107

105108
@testset "normalization" begin

test/matrix_multiply.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@
4444
m = @SMatrix [1 2 3 4]
4545
v = @SVector [1, 2]
4646
@test @inferred(v*m) === @SMatrix [1 2 3 4; 2 4 6 8]
47+
48+
# Outer product
49+
v2 = SVector(1, 2)
50+
v3 = SVector(3, 4)
51+
@test v2 * v3' === @SMatrix [3 4; 6 8]
52+
53+
v4 = SVector(1+0im, 2+0im)
54+
v5 = SVector(3+0im, 4+0im)
55+
@test v4 * v5' === @SMatrix [3+0im 4+0im; 6+0im 8+0im]
4756
end
4857

4958
@testset "Matrix-matrix" begin

0 commit comments

Comments
 (0)