Skip to content

Commit f36036c

Browse files
authored
Support and as elementwise- and tensor-product operators (#35150)
While we have broadcasting and `a*b'`, sometimes you need to pass an operator as an argument to a function. Since we already have `dot` or `⋅` for the inner product, these elementwise and tensor products fill out the space of possibilities.
1 parent 9507225 commit f36036c

File tree

5 files changed

+206
-3
lines changed

5 files changed

+206
-3
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ Standard library changes
169169
* The BLAS submodule now supports the level-2 BLAS subroutine `spmv!` ([#34320]).
170170
* The BLAS submodule now supports the level-1 BLAS subroutine `rot!` ([#35124]).
171171
* New generic `rotate!(x, y, c, s)` and `reflect!(x, y, c, s)` functions ([#35124]).
172+
* `hadamard` or `` (`\odotTAB`) can be used as an elementwise multiplication operator,
173+
and `tensor` or `` (`\otimesTAB`) as the tensor product operator ([#35150]).
172174

173175
#### Markdown
174176

stdlib/LinearAlgebra/docs/src/index.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,8 @@ LinearAlgebra.PosDefException
321321
LinearAlgebra.ZeroPivotException
322322
LinearAlgebra.dot
323323
LinearAlgebra.cross
324+
LinearAlgebra.hadamard
325+
LinearAlgebra.tensor
324326
LinearAlgebra.factorize
325327
LinearAlgebra.Diagonal
326328
LinearAlgebra.Bidiagonal
@@ -474,6 +476,8 @@ LinearAlgebra.lmul!
474476
LinearAlgebra.rmul!
475477
LinearAlgebra.ldiv!
476478
LinearAlgebra.rdiv!
479+
LinearAlgebra.hadamard!
480+
LinearAlgebra.tensor!
477481
```
478482

479483
## BLAS functions

stdlib/LinearAlgebra/src/LinearAlgebra.jl

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,143 @@ const ⋅ = dot
388388
const × = cross
389389
export , ×
390390

391+
"""
392+
hadamard(a, b)
393+
a ⊙ b
394+
395+
For arrays `a` and `b`, perform elementwise multiplication.
396+
`a` and `b` must have identical `axes`.
397+
398+
`⊙` can be passed as an operator to higher-order functions.
399+
400+
```jldoctest
401+
julia> a = [2, 3]; b = [5, 7];
402+
403+
julia> a ⊙ b
404+
2-element Array{$Int,1}:
405+
10
406+
21
407+
408+
julia> a ⊙ [5]
409+
ERROR: DimensionMismatch("Axes of `A` and `B` must match, got (Base.OneTo(2),) and (Base.OneTo(1),)")
410+
[...]
411+
```
412+
413+
!!! compat "Julia 1.5"
414+
This function requires at least Julia 1.5. In Julia 1.0-1.4 it is available from
415+
the `Compat` package.
416+
"""
417+
function hadamard(A::AbstractArray, B::AbstractArray)
418+
@noinline throw_dmm(axA, axB) = throw(DimensionMismatch("Axes of `A` and `B` must match, got $axA and $axB"))
419+
420+
axA, axB = axes(A), axes(B)
421+
axA == axB || throw_dmm(axA, axB)
422+
return map(*, A, B)
423+
end
424+
const = hadamard
425+
426+
"""
427+
hadamard!(dest, A, B)
428+
429+
Similar to `hadamard(A, B)` (which can also be written `A ⊙ B`), but stores its results in
430+
the pre-allocated array `dest`.
431+
432+
!!! compat "Julia 1.5"
433+
This function requires at least Julia 1.5. In Julia 1.0-1.4 it is available from
434+
the `Compat` package.
435+
"""
436+
function hadamard!(dest::AbstractArray, A::AbstractArray, B::AbstractArray)
437+
@noinline function throw_dmm(axA, axB, axdest)
438+
throw(DimensionMismatch("`axes(dest) = $axdest` must be equal to `axes(A) = $axA` and `axes(B) = $axB`"))
439+
end
440+
441+
axA, axB, axdest = axes(A), axes(B), axes(dest)
442+
((axdest == axA) & (axdest == axB)) || throw_dmm(axA, axB, axdest)
443+
@simd for I in eachindex(dest, A, B)
444+
@inbounds dest[I] = A[I] * B[I]
445+
end
446+
return dest
447+
end
448+
449+
export , hadamard, hadamard!
450+
451+
"""
452+
tensor(A, B)
453+
A ⊗ B
454+
455+
Compute the tensor product of `A` and `B`.
456+
If `C = A ⊗ B`, then `C[i1, ..., im, j1, ..., jn] = A[i1, ... im] * B[j1, ..., jn]`.
457+
458+
```jldoctest
459+
julia> a = [2, 3]; b = [5, 7, 11];
460+
461+
julia> a ⊗ b
462+
2×3 Array{$Int,2}:
463+
10 14 22
464+
15 21 33
465+
```
466+
467+
See also: [`kron`](@ref).
468+
469+
!!! compat "Julia 1.5"
470+
This function requires at least Julia 1.5. In Julia 1.0-1.4 it is available from
471+
the `Compat` package.
472+
"""
473+
tensor(A::AbstractArray, B::AbstractArray) = [a*b for a in A, b in B]
474+
const = tensor
475+
476+
const CovectorLike{T} = Union{Adjoint{T,<:AbstractVector},Transpose{T,<:AbstractVector}}
477+
function tensor(u::AbstractArray, v::CovectorLike)
478+
# If `v` is thought of as a covector, you might want this to be two-dimensional,
479+
# but thought of as a matrix it should be three-dimensional.
480+
# The safest is to avoid supporting it at all. See discussion in #35150.
481+
error("`tensor` is not defined for co-vectors, perhaps you meant `*`?")
482+
end
483+
function tensor(u::CovectorLike, v::AbstractArray)
484+
error("`tensor` is not defined for co-vectors, perhaps you meant `*`?")
485+
end
486+
function tensor(u::CovectorLike, v::CovectorLike)
487+
error("`tensor` is not defined for co-vectors, perhaps you meant `*`?")
488+
end
489+
490+
"""
491+
tensor!(dest, A, B)
492+
493+
Similar to `tensor(A, B)` (which can also be written `A ⊗ B`), but stores its results in
494+
the pre-allocated array `dest`.
495+
496+
!!! compat "Julia 1.5"
497+
This function requires at least Julia 1.5. In Julia 1.0-1.4 it is available from
498+
the `Compat` package.
499+
"""
500+
function tensor!(dest::AbstractArray, A::AbstractArray, B::AbstractArray)
501+
@noinline function throw_dmm(axA, axB, axdest)
502+
throw(DimensionMismatch("`axes(dest) = $axdest` must concatenate `axes(A) = $axA` and `axes(B) = $axB`"))
503+
end
504+
505+
axA, axB, axdest = axes(A), axes(B), axes(dest)
506+
axes(dest) == (axA..., axB...) || throw_dmm(axA, axB, axdest)
507+
if IndexStyle(dest) === IndexCartesian()
508+
for IB in CartesianIndices(axB)
509+
@inbounds b = B[IB]
510+
@simd for IA in CartesianIndices(axA)
511+
@inbounds dest[IA,IB] = A[IA]*b
512+
end
513+
end
514+
else
515+
i = firstindex(dest)
516+
@inbounds for b in B
517+
@simd for a in A
518+
dest[i] = a*b
519+
i += 1
520+
end
521+
end
522+
end
523+
return dest
524+
end
525+
526+
export , tensor, tensor!
527+
391528
"""
392529
LinearAlgebra.peakflops(n::Integer=2000; parallel::Bool=false)
393530

stdlib/LinearAlgebra/src/dense.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,9 +341,9 @@ end
341341
342342
Kronecker tensor product of two vectors or two matrices.
343343
344-
For vectors v and w, the Kronecker product is related to the outer product by
345-
`kron(v,w) == vec(w*transpose(v))` or
346-
`w*transpose(v) == reshape(kron(v,w), (length(w), length(v)))`.
344+
For vectors v and w, the Kronecker product is related to the tensor product [`tensor`](@ref), or `⊗`, by
345+
`kron(v,w) == vec(w ⊗ v)` or
346+
`w ⊗ v == reshape(kron(v,w), (length(w), length(v)))`.
347347
Note how the ordering of `v` and `w` differs on the left and right
348348
of these expressions (due to column-major storage).
349349

stdlib/LinearAlgebra/test/addmul.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,66 @@ for cmat in mattypes,
131131
push!(testdata, (cmat{celt}, amat{aelt}, bmat{belt}))
132132
end
133133

134+
@testset "Alternative multiplication operators" begin
135+
for T in (Int, Float32, Float64, BigFloat)
136+
a = [T[1, 2], T[-3, 7]]
137+
b = [T[5, 11], T[-13, 17]]
138+
@test map(, a, b) == map(dot, a, b) == [27, 158]
139+
@test map(, a, b) == map(hadamard, a, b) == [a[1].*b[1], a[2].*b[2]]
140+
@test map(, a, b) == map(tensor, a, b) == [a[1]*transpose(b[1]), a[2]*transpose(b[2])]
141+
@test hadamard!(fill(typemax(Int), 2), T[1, 2], T[-3, 7]) == [-3, 14]
142+
@test tensor!(fill(typemax(Int), 2, 2), T[1, 2], T[-3, 7]) == [-3 7; -6 14]
143+
end
144+
145+
@test_throws DimensionMismatch [1,2] [3]
146+
@test_throws DimensionMismatch hadamard!([0, 0, 0], [1,2], [-3,7])
147+
@test_throws DimensionMismatch hadamard!([0, 0], [1,2], [-3])
148+
@test_throws DimensionMismatch hadamard!([0, 0], [1], [-3,7])
149+
@test_throws DimensionMismatch tensor!(Matrix{Int}(undef, 2, 2), [1], [-3,7])
150+
@test_throws DimensionMismatch tensor!(Matrix{Int}(undef, 2, 2), [1,2], [-3])
151+
152+
u, v = [2+2im, 3+5im], [1-3im, 7+3im]
153+
@test u v == conj(u[1])*v[1] + conj(u[2])*v[2]
154+
@test u v == [u[1]*v[1], u[2]*v[2]]
155+
@test u v == [u[1]*v[1] u[1]*v[2]; u[2]*v[1] u[2]*v[2]]
156+
@test hadamard(u, v) == u v
157+
@test tensor(u, v) == u v
158+
dest = similar(u)
159+
@test hadamard!(dest, u, v) == u v
160+
dest = Matrix{Complex{Int}}(undef, 2, 2)
161+
@test tensor!(dest, u, v) == u v
162+
163+
for (A, B, b) in (([1 2; 3 4], [5 6; 7 8], [5,6]),
164+
([1+0.8im 2+0.7im; 3+0.6im 4+0.5im],
165+
[5+0.4im 6+0.3im; 7+0.2im 8+0.1im],
166+
[5+0.6im,6+0.3im]))
167+
@test A b == cat(A*b[1], A*b[2]; dims=3)
168+
@test A B == cat(cat(A*B[1,1], A*B[2,1]; dims=3),
169+
cat(A*B[1,2], A*B[2,2]; dims=3); dims=4)
170+
end
171+
172+
A, B = reshape(1:27, 3, 3, 3), reshape(1:4, 2, 2)
173+
@test A B == [a*b for a in A, b in B]
174+
175+
# Adjoint/transpose is a dual vector, not an AbstractMatrix
176+
v = [1,2]
177+
@test_throws ErrorException v v'
178+
@test_throws ErrorException v transpose(v)
179+
@test_throws ErrorException v' v
180+
@test_throws ErrorException transpose(v) v
181+
@test_throws ErrorException v' v'
182+
@test_throws ErrorException transpose(v) transpose(v)
183+
@test_throws ErrorException v' transpose(v)
184+
@test_throws ErrorException transpose(v) v'
185+
@test_throws ErrorException A v'
186+
@test_throws ErrorException A transpose(v)
187+
188+
# Docs comparison to `kron`
189+
v, w = [1,2,3], [5,7]
190+
@test kron(v,w) == vec(w v)
191+
@test w v == reshape(kron(v,w), (length(w), length(v)))
192+
end
193+
134194
@testset "mul!(::$TC, ::$TA, ::$TB, α, β)" for (TC, TA, TB) in testdata
135195
if needsquare(TA)
136196
na1 = na2 = rand(sizecandidates)

0 commit comments

Comments
 (0)