Skip to content

Commit a13f333

Browse files
Merge pull request #210 from vpuri3/vec
allow functionoperator (batch = false) to accept/return vec'd arrays
2 parents f169b36 + 734205b commit a13f333

File tree

4 files changed

+173
-100
lines changed

4 files changed

+173
-100
lines changed

docs/src/tutorials/fftw.md

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,22 @@ x = range(start=-L/2, stop=L/2-dx, length=n) |> Array
1919
u = @. sin(5x)cos(7x);
2020
du = @. 5cos(5x)cos(7x) - 7sin(5x)sin(7x);
2121
22-
k = rfftfreq(n, 2π*n/L) |> Array
23-
m = length(k)
24-
transform = plan_rfft(x)
22+
k = rfftfreq(n, 2π*n/L) |> Array
23+
m = length(k)
24+
P = plan_rfft(x)
25+
26+
F = FunctionOperator(fwd, x, im*k;
27+
T=ComplexF64,
2528
26-
T = FunctionOperator((du,u,p,t) -> mul!(du, transform, u), x, im*k;
27-
isinplace=true,
28-
T=ComplexF64,
29+
op_adjoint = bwd,
30+
op_inverse = bwd,
31+
op_adjoint_inverse = fwd,
2932
30-
op_adjoint = (du,u,p,t) -> ldiv!(du, transform, u),
31-
op_inverse = (du,u,p,t) -> ldiv!(du, transform, u),
32-
op_adjoint_inverse = (du,u,p,t) -> ldiv!(du, transform, u),
33-
)
33+
islinear=true,
34+
)
3435
3536
ik = im * DiagonalOperator(k)
36-
Dx = T \ ik * T
37+
Dx = F \ ik * F
3738
3839
Dx = cache_operator(Dx, x)
3940
@@ -79,18 +80,17 @@ Now we are ready to define our wrapper for the FFT object. To `FunctionOperator`
7980
pass the in-place forward application of the transform,
8081
`(du,u,p,t) -> mul!(du, transform, u)`, its inverse application,
8182
`(du,u,p,t) -> ldiv!(du, transform, u)`, as well as input and output prototype vectors.
82-
We also set the flag `isinplace` to `true` to signal that we intend to use the operator
83-
in a non-allocating way, and pass in the element-type and size of the operator.
8483

8584
```
86-
T = FunctionOperator((du,u,p,t) -> mul!(du, transform, u), x, im*k;
87-
isinplace=true,
88-
T=ComplexF64,
89-
90-
op_adjoint = (du,u,p,t) -> ldiv!(du, transform, u),
91-
op_inverse = (du,u,p,t) -> ldiv!(du, transform, u),
92-
op_adjoint_inverse = (du,u,p,t) -> ldiv!(du, transform, u),
93-
)
85+
F = FunctionOperator(fwd, x, im*k;
86+
T=ComplexF64,
87+
88+
op_adjoint = bwd,
89+
op_inverse = bwd,
90+
op_adjoint_inverse = fwd,
91+
92+
islinear=true,
93+
)
9494
```
9595

9696
After wrapping the FFT with `FunctionOperator`, we are ready to compose it with other
@@ -100,7 +100,7 @@ both in-place, and out-of-place by comparing its output to the analytical deriva
100100

101101
```
102102
ik = im * DiagonalOperator(k)
103-
Dx = T \ ik * T
103+
Dx = F \ ik * F
104104
105105
@show ≈(Dx * u, du; atol=1e-8)
106106
@show ≈(mul!(copy(u), Dx, u), du; atol=1e-8)

src/func.jl

Lines changed: 122 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ function FunctionOperator(op,
173173
msg = """`FunctionOperator` constructed with `batch = true` only
174174
accepts `AbstractVecOrMat` types with
175175
`size(L, 2) == size(u, 1)`."""
176-
ArgumentError(msg) |> throw
176+
throw(ArgumentError(msg))
177177
end
178178

179179
if input isa AbstractMatrix
@@ -184,7 +184,7 @@ function FunctionOperator(op,
184184
array, $(typeof(input)), has size $(size(input)), whereas
185185
output array, $(typeof(output)), has size
186186
$(size(output))."""
187-
ArgumentError(msg) |> throw
187+
throw(ArgumentError(msg))
188188
end
189189
end
190190
end
@@ -340,14 +340,14 @@ function _cache_operator(L::FunctionOperator, u::AbstractArray)
340340
if !isa(u, AbstractVecOrMat)
341341
msg = """$L constructed with `batch = true` only accepts
342342
`AbstractVecOrMat` types with `size(L, 2) == size(u, 1)`."""
343-
ArgumentError(msg) |> throw
343+
throw(ArgumentError(msg))
344344
end
345345

346346
if size(L, 2) != size(u, 1)
347347
msg = """Second dimension of $L of size $(size(L))
348348
is not consistent with first dimension of input array `u`
349349
of size $(size(u))."""
350-
DimensionMismatch(msg) |> throw
350+
throw(DimensionMismatch(msg))
351351
end
352352

353353
M = size(L, 1)
@@ -486,7 +486,7 @@ function Base.resize!(L::FunctionOperator, n::Integer)
486486
if length(L.traits.sizes[1]) != 1
487487
msg = """`Base.resize!` is only supported by $L whose input/output
488488
arrays are `AbstractVector`s."""
489-
MethodError(msg) |> throw
489+
throw(MethodError(msg))
490490
end
491491

492492
for op in getops(L)
@@ -534,131 +534,187 @@ has_ldiv(L::FunctionOperator{iip}) where{iip} = !(L.op_inverse isa Nothing)
534534
has_ldiv!(L::FunctionOperator{iip}) where{iip} = iip & !(L.op_inverse isa Nothing)
535535

536536
function _sizecheck(L::FunctionOperator, u, v)
537-
537+
sizes = L.traits.sizes
538538
if L.traits.batch
539539
if !isnothing(u)
540540
if !isa(u, AbstractVecOrMat)
541541
msg = """$L constructed with `batch = true` only
542542
accept input arrays that are `AbstractVecOrMat`s with
543-
`size(L, 2) == size(u, 1)`."""
544-
ArgumentError(msg) |> throw
543+
`size(L, 2) == size(u, 1)`. Recieved $(typeof(u))."""
544+
throw(ArgumentError(msg))
545545
end
546546

547-
if size(u) != L.traits.sizes[1]
548-
msg = """$L expects input arrays of size $(L.traits.sizes[1]).
549-
Recievd array of size $(size(u))."""
550-
DimensionMismatch(msg) |> throw
547+
if size(L, 2) != size(u, 1)
548+
msg = """$L accepts input `AbstractVecOrMat`s of size
549+
($(size(L, 2)), K). Recievd array of size $(size(u))."""
550+
throw(DimensionMismatch(msg))
551551
end
552-
end
552+
end # u
553553

554554
if !isnothing(v)
555-
if size(v) != L.traits.sizes[2]
556-
msg = """$L expects input arrays of size $(L.traits.sizes[1]).
557-
Recievd array of size $(size(v))."""
558-
DimensionMismatch(msg) |> throw
555+
if !isa(v, AbstractVecOrMat)
556+
msg = """$L constructed with `batch = true` only
557+
returns output arrays that are `AbstractVecOrMat`s with
558+
`size(L, 1) == size(v, 1)`. Recieved $(typeof(v))."""
559+
throw(ArgumentError(msg))
559560
end
560-
end
561+
562+
if size(L, 1) != size(v, 1)
563+
msg = """$L accepts output `AbstractVecOrMat`s of size
564+
($(size(L, 1)), K). Recievd array of size $(size(v))."""
565+
throw(DimensionMismatch(msg))
566+
end
567+
end # v
568+
569+
if !isnothing(u) & !isnothing(v)
570+
if size(u, 2) != size(v, 2)
571+
msg = """input array $u, and output array, $v, must have the
572+
same batch size (i.e. length of second dimension). Got
573+
$(size(u)), $(size(v)). If you encounter this error during
574+
an in-place evaluation (`LinearAlgebra.mul!`, `ldiv!`),
575+
ensure that the operator $L has been cached with an input
576+
array of the correct size. Do so by calling
577+
`L = cache_operator(L, u)`."""
578+
throw(DimensionMismatch(msg))
579+
end
580+
end # u, v
581+
561582
else # !batch
583+
562584
if !isnothing(u)
563-
if size(u) != L.traits.sizes[1]
564-
msg = """$L expects input arrays of size $(L.traits.sizes[1]).
565-
Recievd array of size $(size(u))."""
566-
DimensionMismatch(msg) |> throw
585+
if size(u) (sizes[1], tuple(size(L, 2)),)
586+
msg = """$L recievd input array of size $(size(u)), but only
587+
accepts input arrays of size $(sizes[1]), or vectors like
588+
`vec(u)` of size $(tuple(prod(sizes[1])))."""
589+
throw(DimensionMismatch(msg))
567590
end
568-
end
591+
end # u
569592

570593
if !isnothing(v)
571-
if size(v) != L.traits.sizes[2]
572-
msg = """$L expects input arrays of size $(L.traits.sizes[1]).
573-
Recievd array of size $(size(v))."""
574-
DimensionMismatch(msg) |> throw
594+
if size(v) (sizes[2], tuple(size(L, 1)),)
595+
msg = """$L recievd output array of size $(size(v)), but only
596+
accepts output arrays of size $(sizes[2]), or vectors like
597+
`vec(u)` of size $(tuple(prod(sizes[2])))"""
598+
throw(DimensionMismatch(msg))
575599
end
576-
end
600+
end # v
577601
end # batch
578602

579603
return
580604
end
581605

582-
# operator application
583-
function Base.:*(L::FunctionOperator{iip,true}, u::AbstractArray) where{iip}
584-
_sizecheck(L, u, nothing)
606+
function _unvec(L::FunctionOperator, u, v)
607+
if L.traits.batch
608+
return u, v, false
609+
else
610+
sizes = L.traits.sizes
585611

586-
L.op(u, L.p, L.t; L.traits.kwargs...)
587-
end
612+
# no need to vec since expected input/output are AbstractVectors
613+
if length(sizes[1]) == 1
614+
return u, v, false
615+
end
588616

589-
function Base.:\(L::FunctionOperator{iip,true}, u::AbstractArray) where{iip}
590-
_sizecheck(L, nothing, u)
617+
vec_u = isnothing(u) ? false : size(u) != sizes[1]
618+
vec_v = isnothing(v) ? false : size(v) != sizes[2]
591619

592-
L.op_inverse(u, L.p, L.t; L.traits.kwargs...)
593-
end
620+
if !isnothing(u) & !isnothing(v)
621+
if (vec_u & !vec_v) | (!vec_u & vec_v)
622+
msg = """Input / output to $L can either be of sizes
623+
$(sizes[1]) / $(sizes[2]), or
624+
$(tuple(prod(sizes[1]))) / $(tuple(prod(sizes[2]))). Got
625+
$(size(u)), $(size(v))."""
626+
throw(DimensionMismatch(msg))
627+
end
628+
end
594629

595-
# fallback *, \ for FunctionOperator with no OOP method
630+
U = vec_u ? reshape(u, sizes[1]) : u
631+
V = vec_v ? reshape(v, sizes[2]) : v
632+
vec_output = vec_u | vec_v
596633

597-
function Base.:*(L::FunctionOperator{true,false}, u::AbstractArray)
598-
_, co = L.cache
599-
v = zero(co)
634+
return U, V, vec_output
635+
end
636+
end
600637

601-
_sizecheck(L, u, v)
638+
# operator application
639+
function Base.:*(L::FunctionOperator{iip,true}, u::AbstractArray) where{iip}
640+
_sizecheck(L, u, nothing)
641+
U, _, vec_output = _unvec(L, u, nothing)
602642

603-
L.op(v, u, L.p, L.t; L.traits.kwargs...)
643+
V = L.op(U, L.p, L.t; L.traits.kwargs...)
644+
645+
vec_output ? vec(V) : V
604646
end
605647

606-
function Base.:\(L::FunctionOperator{true,false}, u::AbstractArray)
607-
ci, _ = L.cache
608-
v = zero(ci)
648+
function Base.:\(L::FunctionOperator{iip,true}, v::AbstractArray) where{iip}
649+
_sizecheck(L, nothing, v)
650+
_, V, vec_output = _unvec(L, nothing, v)
609651

610-
_sizecheck(L, v, u)
652+
U = L.op_inverse(V, L.p, L.t; L.traits.kwargs...)
611653

612-
L.op_inverse(v, u, L.p, L.t; L.traits.kwargs...)
654+
vec_output ? vec(U) : U
613655
end
614656

615657
function LinearAlgebra.mul!(v::AbstractArray, L::FunctionOperator{true}, u::AbstractArray)
616-
617658
_sizecheck(L, u, v)
659+
U, V, vec_output = _unvec(L, u, v)
660+
661+
L.op(V, U, L.p, L.t; L.traits.kwargs...)
618662

619-
L.op(v, u, L.p, L.t; L.traits.kwargs...)
663+
vec_output ? vec(V) : V
620664
end
621665

622-
function LinearAlgebra.mul!(v::AbstractArray, L::FunctionOperator{false}, u::AbstractArray, args...)
623-
@error "LinearAlgebra.mul! not defined for out-of-place FunctionOperators"
666+
function LinearAlgebra.mul!(::AbstractArray, L::FunctionOperator{false}, ::AbstractArray, args...)
667+
@error "LinearAlgebra.mul! not defined for out-of-place operator $L"
624668
end
625669

626670
function LinearAlgebra.mul!(v::AbstractArray, L::FunctionOperator{true, oop, false}, u::AbstractArray, α, β) where{oop}
627-
_, co = L.cache
671+
_, Co = L.cache
628672

629673
_sizecheck(L, u, v)
674+
U, V, _ = _unvec(L, u, v)
630675

631-
copy!(co, v)
632-
mul!(v, L, u)
633-
axpby!(β, co, α, v)
676+
copy!(Co, V)
677+
L.op(V, U, L.p, L.t; L.traits.kwargs...) # mul!(V, L, U)
678+
axpby!(β, Co, α, V)
679+
680+
v
634681
end
635682

636683
function LinearAlgebra.mul!(v::AbstractArray, L::FunctionOperator{true, oop, true}, u::AbstractArray, α, β) where{oop}
637-
638684
_sizecheck(L, u, v)
685+
U, V, _ = _unvec(L, u, v)
639686

640-
L.op(v, u, L.p, L.t, α, β; L.traits.kwargs...)
687+
L.op(V, U, L.p, L.t, α, β; L.traits.kwargs...)
688+
689+
v
641690
end
642691

643-
function LinearAlgebra.ldiv!(v::AbstractArray, L::FunctionOperator{true}, u::AbstractArray)
692+
function LinearAlgebra.ldiv!(u::AbstractArray, L::FunctionOperator{true}, v::AbstractArray)
693+
_sizecheck(L, u, v)
694+
U, V, _ = _unvec(L, u, v)
644695

645-
_sizecheck(L, v, u)
696+
L.op_inverse(U, V, L.p, L.t; L.traits.kwargs...)
646697

647-
L.op_inverse(v, u, L.p, L.t; L.traits.kwargs...)
698+
u
648699
end
649700

650701
function LinearAlgebra.ldiv!(L::FunctionOperator{true}, u::AbstractArray)
651-
ci, _ = L.cache
702+
V, _ = L.cache
703+
704+
_sizecheck(L, u, V)
705+
U, _, _ = _unvec(L, u, nothing)
706+
707+
copy!(V, U)
708+
L.op_inverse(U, V, L.p, L.t; L.traits.kwargs...) # ldiv!(U, L, V)
652709

653-
copy!(ci, u)
654-
ldiv!(u, L, ci)
710+
u
655711
end
656712

657713
function LinearAlgebra.ldiv!(v::AbstractArray, L::FunctionOperator{false}, u::AbstractArray)
658-
@error "LinearAlgebra.ldiv! not defined for out-of-place FunctionOperators"
714+
@error "LinearAlgebra.ldiv! not defined for out-of-place $L"
659715
end
660716

661717
function LinearAlgebra.ldiv!(L::FunctionOperator{false}, u::AbstractArray)
662-
@error "LinearAlgebra.ldiv! not defined for out-of-place FunctionOperators"
718+
@error "LinearAlgebra.ldiv! not defined for out-of-place $L"
663719
end
664720
#

test/func.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,25 @@ NK = N * K
4747
L = FunctionOperator(f, u, v; kw...)
4848
L = cache_operator(L, u)
4949

50+
# test with ND-arrays
5051
@test _mul(A, u) L(u, p, t) L * u mul!(zero(v), L, u)
5152
@test α * _mul(A, u)+ β * v mul!(copy(v), L, u, α, β)
5253

5354
if sz_in == sz_out
5455
@test _div(A, v) L \ v ldiv!(zero(u), L, v) ldiv!(L, copy(v))
5556
end
56-
end
57+
58+
# test with vec(Array)
59+
@test vec(_mul(A, u)) L(vec(u), p, t) L * vec(u) mul!(vec(zero(v)), L, vec(u))
60+
@test vec* _mul(A, u)+ β * v) mul!(vec(copy(v)), L, vec(u), α, β)
61+
62+
if sz_in == sz_out
63+
@test vec(_div(A, v)) L \ vec(v) ldiv!(vec(zero(u)), L, vec(v)) ldiv!(L, vec(copy(v)))
64+
end
65+
66+
@test_throws DimensionMismatch mul!(vec(v), L, u)
67+
@test_throws DimensionMismatch mul!(v, L, vec(u))
68+
end # for
5769

5870
end
5971

0 commit comments

Comments
 (0)