Skip to content

Commit b208630

Browse files
Merge pull request #41 from vpuri3/tensor
Tensor Product Operator
2 parents 6bdba5e + 264d845 commit b208630

File tree

6 files changed

+320
-29
lines changed

6 files changed

+320
-29
lines changed

src/SciMLOperators.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ using LinearAlgebra
66
import StaticArrays
77
import SparseArrays
88
import ArrayInterfaceCore
9+
import Base: ReshapedArray
910
import Lazy: @forward
1011
import Setfield: @set!
1112

1213
# overload
13-
import Base: size, +, -, *, /, \, adjoint, , inv, one, convert, Matrix, iszero, ==
14+
import Base: +, -, *, /, \, , ==
15+
import Base: conj, one, iszero, inv, adjoint, transpose, size, convert, Matrix
1416
import LinearAlgebra: mul!, ldiv!, lmul!, rmul!, factorize, exp, Diagonal
1517
import SparseArrays: sparse
1618

@@ -34,6 +36,7 @@ $(TYPEDEF)
3436
"""
3537
abstract type AbstractMatrixFreeOperator{T} <: AbstractSciMLOperator{T} end
3638

39+
include("utils.jl")
3740
include("interface.jl")
3841
include("basic.jl")
3942
include("sciml.jl")
@@ -42,7 +45,8 @@ export ScalarOperator,
4245
MatrixOperator,
4346
DiagonalOperator,
4447
AffineOperator,
45-
FunctionOperator
48+
FunctionOperator,
49+
TensorProductOperator
4650

4751
export update_coefficients!,
4852
update_coefficients,

src/basic.jl

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ Base.convert(::Type{AbstractMatrix}, ::IdentityOperator{N}) where{N} = Diagonal(
1818
# traits
1919
Base.size(::IdentityOperator{N}) where{N} = (N, N)
2020
Base.adjoint(A::IdentityOperator) = A
21+
Base.transpose(A::IdentityOperator) = A
2122
LinearAlgebra.opnorm(::IdentityOperator{N}, p::Real=2) where{N} = true
2223
for pred in (
23-
:isreal, :issymmetric, :ishermitian, :isposdef,
24+
:issymmetric, :ishermitian, :isposdef,
2425
)
2526
@eval LinearAlgebra.$pred(::IdentityOperator) = true
2627
end
@@ -98,9 +99,10 @@ Base.convert(::Type{AbstractMatrix}, ::NullOperator{N}) where{N} = Diagonal(zero
9899
# traits
99100
Base.size(::NullOperator{N}) where{N} = (N, N)
100101
Base.adjoint(A::NullOperator) = A
102+
Base.transpose(A::NullOperator) = A
101103
LinearAlgebra.opnorm(::NullOperator{N}, p::Real=2) where{N} = false
102104
for pred in (
103-
:isreal, :issymmetric, :ishermitian,
105+
:issymmetric, :ishermitian,
104106
)
105107
@eval LinearAlgebra.$pred(::NullOperator) = true
106108
end
@@ -184,6 +186,7 @@ function Base.adjoint(α::ScalarOperator) # TODO - test
184186
update_func = (oldval,u,p,t) -> α.update_func(oldval',u,p,t)'
185187
ScalarOperator(val; update_func=update_func)
186188
end
189+
Base.transpose::ScalarOperator) = α
187190

188191
getops::ScalarOperator) =.val)
189192
islinear(L::ScalarOperator) = true
@@ -286,7 +289,12 @@ SparseArrays.sparse(L::ScaledOperator) = L.λ * sparse(L.L)
286289

287290
# traits
288291
Base.size(L::ScaledOperator) = size(L.L)
289-
Base.adjoint(L::ScaledOperator) = ScaledOperator(L.λ', L.op')
292+
for op in (
293+
:adjoint,
294+
:transpose,
295+
)
296+
@eval Base.$op(L::ScaledOperator) = ScaledOperator($op(L.λ), $op(L.op))
297+
end
290298
LinearAlgebra.opnorm(L::ScaledOperator, p::Real=2) = abs(L.λ) * opnorm(L.L, p)
291299

292300
getops(L::ScaledOperator) = (L.λ, L.L)
@@ -404,7 +412,12 @@ SparseArrays.sparse(L::AddedOperator) = sum(_sparse, L.ops)
404412

405413
# traits
406414
Base.size(L::AddedOperator) = size(first(L.ops))
407-
Base.adjoint(L::AddedOperator) = AddedOperator(adjoint.(L.ops)...)
415+
for op in (
416+
:adjoint,
417+
:transpose,
418+
)
419+
@eval Base.$op(L::AddedOperator) = AddedOperator($op.(L.ops)...)
420+
end
408421

409422
getops(L::AddedOperator) = L.ops
410423
Base.iszero(L::AddedOperator) = all(iszero, getops(L))
@@ -486,7 +499,12 @@ SparseArrays.sparse(L::ComposedOperator) = prod(_sparse, L.ops)
486499

487500
# traits
488501
Base.size(L::ComposedOperator) = (size(first(L.ops), 1), size(last(L.ops),2))
489-
Base.adjoint(L::ComposedOperator) = ComposedOperator(adjoint.(reverse(L.ops)))
502+
for op in (
503+
:adjoint,
504+
:transpose,
505+
)
506+
@eval Base.$op(L::ComposedOperator) = ComposedOperator($op.(reverse(L.ops))...)
507+
end
490508
LinearAlgebra.opnorm(L::ComposedOperator) = prod(opnorm, L.ops)
491509

492510
getops(L::ComposedOperator) = L.ops
@@ -495,7 +513,7 @@ Base.iszero(L::ComposedOperator) = all(iszero, getops(L))
495513
has_adjoint(L::ComposedOperator) = all(has_adjoint, L.ops)
496514
has_mul!(L::ComposedOperator) = all(has_mul!, L.ops)
497515
has_ldiv(L::ComposedOperator) = all(has_ldiv, L.ops)
498-
has_ldiv!(L::ComposedOperator) = all(has_mul!, L.ops)
516+
has_ldiv!(L::ComposedOperator) = all(has_ldiv!, L.ops)
499517

500518
factorize(L::ComposedOperator) = prod(factorize, reverse(L.ops))
501519
for fact in (
@@ -593,6 +611,11 @@ end
593611
AbstractAdjointedVector = Adjoint{ <:Number, <:AbstractVector}
594612
AbstractTransposedVector = Transpose{<:Number, <:AbstractVector}
595613

614+
has_adjoint(::AdjointedOperator) = true
615+
616+
Base.transpose(L::AdjointedOperator) = conj(L.L)
617+
Base.adjoint(L::TransposedOperator) = conj(L.L)
618+
596619
for (op, LType, VType) in (
597620
(:adjoint, :AdjointedOperator, :AbstractAdjointedVector ),
598621
(:transpose, :TransposedOperator, :AbstractTransposedVector),
@@ -606,12 +629,10 @@ for (op, LType, VType) in (
606629
@eval Base.size(L::$LType) = size(L.L) |> reverse
607630
@eval Base.$op(L::$LType) = L.L
608631

609-
@eval has_adjoint(L::$LType) = true
610632
@eval getops(L::$LType) = (L.L,)
611633

612634
@eval @forward $LType.L (
613635
# LinearAlgebra
614-
LinearAlgebra.isreal,
615636
LinearAlgebra.issymmetric,
616637
LinearAlgebra.ishermitian,
617638
LinearAlgebra.isposdef,
@@ -690,7 +711,6 @@ has_ldiv!(L::InvertedOperator) = has_mul!(L.L)
690711

691712
@forward InvertedOperator.L (
692713
# LinearAlgebra
693-
LinearAlgebra.isreal,
694714
LinearAlgebra.issymmetric,
695715
LinearAlgebra.ishermitian,
696716
LinearAlgebra.isposdef,

src/interface.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,11 @@ issquare(A...) = @. (&)(issquare(A)...)
161161
# default linear operator traits
162162
###
163163

164+
Base.isreal(L::AbstractSciMLOperator{T}) where{T} = T <: Real
165+
function Base.conj(L::AbstractSciMLOperator)
166+
isreal(L) && return L
167+
convert(AbstractMatrix, L) |> conj
168+
end
164169
function Base.:(==)(L1::AbstractSciMLOperator, L2::AbstractSciMLOperator)
165170
size(L1) != size(L2) && return false
166171
convert(AbstractMatrix, L1) == convert(AbstractMatrix, L1)
@@ -172,7 +177,6 @@ expmv(L::AbstractSciMLLinearOperator,u,p,t) = exp(L,t)*u
172177
expmv!(v,L::AbstractSciMLLinearOperator,u,p,t) = mul!(v,exp(L,t),u)
173178

174179
Base.Matrix(L::AbstractSciMLLinearOperator) = Matrix(convert(AbstractMatrix, L))
175-
Base.adjoint(A::AbstractSciMLLinearOperator) = Adjoint(A) # TODO write lazy adjoint operator interface here
176180

177181
Base.@propagate_inbounds function Base.getindex(L::AbstractSciMLLinearOperator, I::Vararg{Any,N}) where {N}
178182
convert(AbstractMatrix, L)[I...]
@@ -184,7 +188,6 @@ end
184188
LinearAlgebra.exp(L::AbstractSciMLLinearOperator) = exp(Matrix(L))
185189
LinearAlgebra.opnorm(L::AbstractSciMLLinearOperator, p::Real=2) = opnorm(convert(AbstractMatrix,L), p)
186190
for pred in (
187-
:isreal,
188191
:issymmetric,
189192
:ishermitian,
190193
:isposdef,

0 commit comments

Comments
 (0)