@@ -173,7 +173,7 @@ function FunctionOperator(op,
173
173
msg = """ `FunctionOperator` constructed with `batch = true` only
174
174
accepts `AbstractVecOrMat` types with
175
175
`size(L, 2) == size(u, 1)`."""
176
- ArgumentError (msg) |> throw
176
+ throw ( ArgumentError (msg))
177
177
end
178
178
179
179
if input isa AbstractMatrix
@@ -184,7 +184,7 @@ function FunctionOperator(op,
184
184
array, $(typeof (input)) , has size $(size (input)) , whereas
185
185
output array, $(typeof (output)) , has size
186
186
$(size (output)) ."""
187
- ArgumentError (msg) |> throw
187
+ throw ( ArgumentError (msg))
188
188
end
189
189
end
190
190
end
@@ -340,14 +340,14 @@ function _cache_operator(L::FunctionOperator, u::AbstractArray)
340
340
if ! isa (u, AbstractVecOrMat)
341
341
msg = """ $L constructed with `batch = true` only accepts
342
342
`AbstractVecOrMat` types with `size(L, 2) == size(u, 1)`."""
343
- ArgumentError (msg) |> throw
343
+ throw ( ArgumentError (msg))
344
344
end
345
345
346
346
if size (L, 2 ) != size (u, 1 )
347
347
msg = """ Second dimension of $L of size $(size (L))
348
348
is not consistent with first dimension of input array `u`
349
349
of size $(size (u)) ."""
350
- DimensionMismatch (msg) |> throw
350
+ throw ( DimensionMismatch (msg))
351
351
end
352
352
353
353
M = size (L, 1 )
@@ -486,7 +486,7 @@ function Base.resize!(L::FunctionOperator, n::Integer)
486
486
if length (L. traits. sizes[1 ]) != 1
487
487
msg = """ `Base.resize!` is only supported by $L whose input/output
488
488
arrays are `AbstractVector`s."""
489
- MethodError (msg) |> throw
489
+ throw ( MethodError (msg))
490
490
end
491
491
492
492
for op in getops (L)
@@ -534,131 +534,187 @@ has_ldiv(L::FunctionOperator{iip}) where{iip} = !(L.op_inverse isa Nothing)
534
534
has_ldiv! (L:: FunctionOperator{iip} ) where {iip} = iip & ! (L. op_inverse isa Nothing)
535
535
536
536
function _sizecheck (L:: FunctionOperator , u, v)
537
-
537
+ sizes = L . traits . sizes
538
538
if L. traits. batch
539
539
if ! isnothing (u)
540
540
if ! isa (u, AbstractVecOrMat)
541
541
msg = """ $L constructed with `batch = true` only
542
542
accept input arrays that are `AbstractVecOrMat`s with
543
- `size(L, 2) == size(u, 1)`."""
544
- ArgumentError (msg) |> throw
543
+ `size(L, 2) == size(u, 1)`. Recieved $( typeof (u)) . """
544
+ throw ( ArgumentError (msg))
545
545
end
546
546
547
- if size (u ) != L . traits . sizes[ 1 ]
548
- msg = """ $L expects input arrays of size $(L . traits . sizes[ 1 ]) .
549
- Recievd array of size $(size (u)) ."""
550
- DimensionMismatch (msg) |> throw
547
+ if size (L, 2 ) != size (u, 1 )
548
+ msg = """ $L accepts input `AbstractVecOrMat`s of size
549
+ ( $( size (L, 2 )) , K). Recievd array of size $(size (u)) ."""
550
+ throw ( DimensionMismatch (msg))
551
551
end
552
- end
552
+ end # u
553
553
554
554
if ! isnothing (v)
555
- if size (v) != L. traits. sizes[2 ]
556
- msg = """ $L expects input arrays of size $(L. traits. sizes[1 ]) .
557
- Recievd array of size $(size (v)) ."""
558
- DimensionMismatch (msg) |> throw
555
+ if ! isa (v, AbstractVecOrMat)
556
+ msg = """ $L constructed with `batch = true` only
557
+ returns output arrays that are `AbstractVecOrMat`s with
558
+ `size(L, 1) == size(v, 1)`. Recieved $(typeof (v)) ."""
559
+ throw (ArgumentError (msg))
559
560
end
560
- end
561
+
562
+ if size (L, 1 ) != size (v, 1 )
563
+ msg = """ $L accepts output `AbstractVecOrMat`s of size
564
+ ($(size (L, 1 )) , K). Recievd array of size $(size (v)) ."""
565
+ throw (DimensionMismatch (msg))
566
+ end
567
+ end # v
568
+
569
+ if ! isnothing (u) & ! isnothing (v)
570
+ if size (u, 2 ) != size (v, 2 )
571
+ msg = """ input array $u , and output array, $v , must have the
572
+ same batch size (i.e. length of second dimension). Got
573
+ $(size (u)) , $(size (v)) . If you encounter this error during
574
+ an in-place evaluation (`LinearAlgebra.mul!`, `ldiv!`),
575
+ ensure that the operator $L has been cached with an input
576
+ array of the correct size. Do so by calling
577
+ `L = cache_operator(L, u)`."""
578
+ throw (DimensionMismatch (msg))
579
+ end
580
+ end # u, v
581
+
561
582
else # !batch
583
+
562
584
if ! isnothing (u)
563
- if size (u) != L. traits. sizes[1 ]
564
- msg = """ $L expects input arrays of size $(L. traits. sizes[1 ]) .
565
- Recievd array of size $(size (u)) ."""
566
- DimensionMismatch (msg) |> throw
585
+ if size (u) ∉ (sizes[1 ], tuple (size (L, 2 )),)
586
+ msg = """ $L recievd input array of size $(size (u)) , but only
587
+ accepts input arrays of size $(sizes[1 ]) , or vectors like
588
+ `vec(u)` of size $(tuple (prod (sizes[1 ]))) ."""
589
+ throw (DimensionMismatch (msg))
567
590
end
568
- end
591
+ end # u
569
592
570
593
if ! isnothing (v)
571
- if size (v) != L. traits. sizes[2 ]
572
- msg = """ $L expects input arrays of size $(L. traits. sizes[1 ]) .
573
- Recievd array of size $(size (v)) ."""
574
- DimensionMismatch (msg) |> throw
594
+ if size (v) ∉ (sizes[2 ], tuple (size (L, 1 )),)
595
+ msg = """ $L recievd output array of size $(size (v)) , but only
596
+ accepts output arrays of size $(sizes[2 ]) , or vectors like
597
+ `vec(u)` of size $(tuple (prod (sizes[2 ]))) """
598
+ throw (DimensionMismatch (msg))
575
599
end
576
- end
600
+ end # v
577
601
end # batch
578
602
579
603
return
580
604
end
581
605
582
- # operator application
583
- function Base.:* (L:: FunctionOperator{iip,true} , u:: AbstractArray ) where {iip}
584
- _sizecheck (L, u, nothing )
606
+ function _unvec (L:: FunctionOperator , u, v)
607
+ if L. traits. batch
608
+ return u, v, false
609
+ else
610
+ sizes = L. traits. sizes
585
611
586
- L. op (u, L. p, L. t; L. traits. kwargs... )
587
- end
612
+ # no need to vec since expected input/output are AbstractVectors
613
+ if length (sizes[1 ]) == 1
614
+ return u, v, false
615
+ end
588
616
589
- function Base.: \ (L :: FunctionOperator{iip,true} , u :: AbstractArray ) where {iip}
590
- _sizecheck (L, nothing , u)
617
+ vec_u = isnothing (u) ? false : size (u) != sizes[ 1 ]
618
+ vec_v = isnothing (v) ? false : size (v) != sizes[ 2 ]
591
619
592
- L. op_inverse (u, L. p, L. t; L. traits. kwargs... )
593
- end
620
+ if ! isnothing (u) & ! isnothing (v)
621
+ if (vec_u & ! vec_v) | (! vec_u & vec_v)
622
+ msg = """ Input / output to $L can either be of sizes
623
+ $(sizes[1 ]) / $(sizes[2 ]) , or
624
+ $(tuple (prod (sizes[1 ]))) / $(tuple (prod (sizes[2 ]))) . Got
625
+ $(size (u)) , $(size (v)) ."""
626
+ throw (DimensionMismatch (msg))
627
+ end
628
+ end
594
629
595
- # fallback *, \ for FunctionOperator with no OOP method
630
+ U = vec_u ? reshape (u, sizes[1 ]) : u
631
+ V = vec_v ? reshape (v, sizes[2 ]) : v
632
+ vec_output = vec_u | vec_v
596
633
597
- function Base.: * (L :: FunctionOperator{true,false} , u :: AbstractArray )
598
- _, co = L . cache
599
- v = zero (co)
634
+ return U, V, vec_output
635
+ end
636
+ end
600
637
601
- _sizecheck (L, u, v)
638
+ # operator application
639
+ function Base.:* (L:: FunctionOperator{iip,true} , u:: AbstractArray ) where {iip}
640
+ _sizecheck (L, u, nothing )
641
+ U, _, vec_output = _unvec (L, u, nothing )
602
642
603
- L. op (v, u, L. p, L. t; L. traits. kwargs... )
643
+ V = L. op (U, L. p, L. t; L. traits. kwargs... )
644
+
645
+ vec_output ? vec (V) : V
604
646
end
605
647
606
- function Base.:\ (L:: FunctionOperator{true,false } , u :: AbstractArray )
607
- ci, _ = L . cache
608
- v = zero (ci )
648
+ function Base.:\ (L:: FunctionOperator{iip,true } , v :: AbstractArray ) where {iip}
649
+ _sizecheck (L, nothing , v)
650
+ _, V, vec_output = _unvec (L, nothing , v )
609
651
610
- _sizecheck (L, v, u )
652
+ U = L . op_inverse (V, L . p, L . t; L . traits . kwargs ... )
611
653
612
- L . op_inverse (v, u, L . p, L . t; L . traits . kwargs ... )
654
+ vec_output ? vec (U) : U
613
655
end
614
656
615
657
function LinearAlgebra. mul! (v:: AbstractArray , L:: FunctionOperator{true} , u:: AbstractArray )
616
-
617
658
_sizecheck (L, u, v)
659
+ U, V, vec_output = _unvec (L, u, v)
660
+
661
+ L. op (V, U, L. p, L. t; L. traits. kwargs... )
618
662
619
- L . op (v, u, L . p, L . t; L . traits . kwargs ... )
663
+ vec_output ? vec (V) : V
620
664
end
621
665
622
- function LinearAlgebra. mul! (v :: AbstractArray , L:: FunctionOperator{false} , u :: AbstractArray , args... )
623
- @error " LinearAlgebra.mul! not defined for out-of-place FunctionOperators "
666
+ function LinearAlgebra. mul! (:: AbstractArray , L:: FunctionOperator{false} , :: AbstractArray , args... )
667
+ @error " LinearAlgebra.mul! not defined for out-of-place operator $L "
624
668
end
625
669
626
670
function LinearAlgebra. mul! (v:: AbstractArray , L:: FunctionOperator{true, oop, false} , u:: AbstractArray , α, β) where {oop}
627
- _, co = L. cache
671
+ _, Co = L. cache
628
672
629
673
_sizecheck (L, u, v)
674
+ U, V, _ = _unvec (L, u, v)
630
675
631
- copy! (co, v)
632
- mul! (v, L, u)
633
- axpby! (β, co, α, v)
676
+ copy! (Co, V)
677
+ L. op (V, U, L. p, L. t; L. traits. kwargs... ) # mul!(V, L, U)
678
+ axpby! (β, Co, α, V)
679
+
680
+ v
634
681
end
635
682
636
683
function LinearAlgebra. mul! (v:: AbstractArray , L:: FunctionOperator{true, oop, true} , u:: AbstractArray , α, β) where {oop}
637
-
638
684
_sizecheck (L, u, v)
685
+ U, V, _ = _unvec (L, u, v)
639
686
640
- L. op (v, u, L. p, L. t, α, β; L. traits. kwargs... )
687
+ L. op (V, U, L. p, L. t, α, β; L. traits. kwargs... )
688
+
689
+ v
641
690
end
642
691
643
- function LinearAlgebra. ldiv! (v:: AbstractArray , L:: FunctionOperator{true} , u:: AbstractArray )
692
+ function LinearAlgebra. ldiv! (u:: AbstractArray , L:: FunctionOperator{true} , v:: AbstractArray )
693
+ _sizecheck (L, u, v)
694
+ U, V, _ = _unvec (L, u, v)
644
695
645
- _sizecheck (L, v, u )
696
+ L . op_inverse (U, V, L . p, L . t; L . traits . kwargs ... )
646
697
647
- L . op_inverse (v, u, L . p, L . t; L . traits . kwargs ... )
698
+ u
648
699
end
649
700
650
701
function LinearAlgebra. ldiv! (L:: FunctionOperator{true} , u:: AbstractArray )
651
- ci, _ = L. cache
702
+ V, _ = L. cache
703
+
704
+ _sizecheck (L, u, V)
705
+ U, _, _ = _unvec (L, u, nothing )
706
+
707
+ copy! (V, U)
708
+ L. op_inverse (U, V, L. p, L. t; L. traits. kwargs... ) # ldiv!(U, L, V)
652
709
653
- copy! (ci, u)
654
- ldiv! (u, L, ci)
710
+ u
655
711
end
656
712
657
713
function LinearAlgebra. ldiv! (v:: AbstractArray , L:: FunctionOperator{false} , u:: AbstractArray )
658
- @error " LinearAlgebra.ldiv! not defined for out-of-place FunctionOperators "
714
+ @error " LinearAlgebra.ldiv! not defined for out-of-place $L "
659
715
end
660
716
661
717
function LinearAlgebra. ldiv! (L:: FunctionOperator{false} , u:: AbstractArray )
662
- @error " LinearAlgebra.ldiv! not defined for out-of-place FunctionOperators "
718
+ @error " LinearAlgebra.ldiv! not defined for out-of-place $L "
663
719
end
664
720
#
0 commit comments