Skip to content

Commit 886805b

Browse files
authored
Merge pull request #207 from vpuri3/cache
add method `cache_internals(::FunctionOp, ::AbstractArray)`
2 parents 4ec2740 + cbbbe92 commit 886805b

File tree

2 files changed

+63
-6
lines changed

2 files changed

+63
-6
lines changed

src/func.jl

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,18 @@ function _cache_self(L::FunctionOperator, u::AbstractArray)
384384
@set! L.cache = (_u, _v)
385385
end
386386

387+
# fix method amg bw AbstractArray, AbstractVecOrMat
388+
cache_internals(L::FunctionOperator, u::AbstractArray) = _cache_internals(L, u)
389+
cache_internals(L::FunctionOperator, u::AbstractVecOrMat) = _cache_internals(L, u)
390+
391+
function _cache_internals(L::FunctionOperator, u::AbstractArray)
392+
393+
@set! L.op = cache_operator(L.op, u)
394+
@set! L.op_adjoint = cache_operator(L.op_adjoint, u)
395+
@set! L.op_inverse = cache_operator(L.op_inverse, u)
396+
@set! L.op_adjoint_inverse = cache_operator(L.op_adjoint_inverse, u)
397+
end
398+
387399
function Base.show(io::IO, L::FunctionOperator)
388400
M, N = size(L)
389401
print(io, "FunctionOperator($M × $N)")
@@ -542,7 +554,7 @@ function _sizecheck(L::FunctionOperator, u, v)
542554
if !isnothing(v)
543555
if size(v) != L.traits.sizes[2]
544556
msg = """$L expects input arrays of size $(L.traits.sizes[1]).
545-
Recievd array of size $(size(u))."""
557+
Recievd array of size $(size(v))."""
546558
DimensionMismatch(msg) |> throw
547559
end
548560
end
@@ -558,7 +570,7 @@ function _sizecheck(L::FunctionOperator, u, v)
558570
if !isnothing(v)
559571
if size(v) != L.traits.sizes[2]
560572
msg = """$L expects input arrays of size $(L.traits.sizes[1]).
561-
Recievd array of size $(size(u))."""
573+
Recievd array of size $(size(v))."""
562574
DimensionMismatch(msg) |> throw
563575
end
564576
end
@@ -618,8 +630,7 @@ function LinearAlgebra.mul!(v::AbstractArray, L::FunctionOperator{true, oop, fal
618630

619631
copy!(co, v)
620632
mul!(v, L, u)
621-
lmul!(α, v)
622-
axpy!(β, co, v)
633+
axpby!(β, co, α, v)
623634
end
624635

625636
function LinearAlgebra.mul!(v::AbstractArray, L::FunctionOperator{true, oop, true}, u::AbstractArray, α, β) where{oop}
@@ -639,8 +650,6 @@ end
639650
function LinearAlgebra.ldiv!(L::FunctionOperator{true}, u::AbstractArray)
640651
ci, _ = L.cache
641652

642-
_sizecheck(L, nothing, u)
643-
644653
copy!(ci, u)
645654
ldiv!(u, L, ci)
646655
end

test/func.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,54 @@ N = 8
99
K = 12
1010
NK = N * K
1111

12+
@testset "(Unbatched) FunctionOperator ND array" begin
13+
N1, N2, N3 = 3, 4, 5
14+
M1, M2, M3 = 4, 5, 6
15+
16+
p = nothing
17+
t = 0.0
18+
α = rand()
19+
β = rand()
20+
21+
for (sz_in, sz_out) in (
22+
((N1, N2, N3), (N1, N2, N3)), # equal size
23+
((N1, N2, N3), (M1, M2, M3)), # different size
24+
)
25+
N = prod(sz_in)
26+
M = prod(sz_out)
27+
28+
A = rand(M, N)
29+
u = rand(sz_in... )
30+
v = rand(sz_out...)
31+
32+
_mul(A, u) = reshape(A * vec(u), sz_out)
33+
f(u, p, t) = _mul(A, u)
34+
f(du, u, p, t) = (mul!( vec(du), A, vec(u)); du)
35+
36+
kw = (;) # FunctionOp kwargs
37+
38+
if sz_in == sz_out
39+
F = lu(A)
40+
_div(A, v) = reshape(A \ vec(v), sz_in)
41+
fi(u, p, t) = _div(A, u)
42+
fi(du, u, p, t) = (ldiv!(vec(du), F, vec(u)); du)
43+
44+
kw = (; op_inverse = fi)
45+
end
46+
47+
L = FunctionOperator(f, u, v; kw...)
48+
L = cache_operator(L, u)
49+
50+
@test _mul(A, u) L(u, p, t) L * u mul!(zero(v), L, u)
51+
@test α * _mul(A, u)+ β * v mul!(copy(v), L, u, α, β)
52+
53+
if sz_in == sz_out
54+
@test _div(A, v) L \ v ldiv!(zero(u), L, v) ldiv!(L, copy(v))
55+
end
56+
end
57+
58+
end
59+
1260
@testset "(Unbatched) FunctionOperator" begin
1361
u = rand(N, K)
1462
p = nothing

0 commit comments

Comments
 (0)