Skip to content

Commit 9b2fe0e

Browse files
Merge pull request #582 from cmrace/adjoint_complex
Update adjoint of Linear Solve for complex matrices
2 parents 137aa4e + 2663bbb commit 9b2fe0e

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

src/adjoint.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ Given a Linear Problem ``A x = b`` computes the sensitivities for ``A`` and ``b`
77
88
```math
99
\begin{align}
10-
A^T \lambda &= \partial x \\
11-
\partial A &= -\lambda x^T \\
10+
A' \lambda &= \partial x \\
11+
\partial A &= -\lambda x' \\
1212
\partial b &= \lambda
1313
\end{align}
1414
```
@@ -20,7 +20,7 @@ For more details, check [these notes](https://math.mit.edu/~stevenj/18.336/adjoi
2020
Note that in most cases, it makes sense to use the same linear solver for the adjoint as the
2121
forward solve (this is done by keeping the linsolve as `missing`). For example, if the
2222
forward solve was performed via a Factorization, then we can reuse the factorization for the
23-
adjoint solve. However, for specific structured matrices if ``A^T`` is known to have a
23+
adjoint solve. However, for specific structured matrices if ``A'`` is known to have a
2424
specific structure distinct from ``A`` then passing in a `linsolve` will be more efficient.
2525
"""
2626
@kwdef struct LinearSolveAdjoint{L} <:
@@ -62,21 +62,21 @@ function CRC.rrule(::typeof(SciMLBase.solve), prob::LinearProblem,
6262
elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization
6363
first(cache.cacheval)' \ ∂u
6464
elseif alg isa AbstractKrylovSubspaceMethod
65-
invprob = LinearProblem(transpose(cache.A), ∂u)
65+
invprob = LinearProblem(adjoint(cache.A), ∂u)
6666
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
6767
elseif alg isa DefaultLinearSolver
6868
LinearSolve.defaultalg_adjoint_eval(cache, ∂u)
6969
else
70-
invprob = LinearProblem(transpose(A_), ∂u) # We cached `A`
70+
invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A`
7171
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
7272
end
7373
else
74-
invprob = LinearProblem(transpose(A_), ∂u) # We cached `A`
74+
invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A`
7575
λ = solve(
7676
invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u
7777
end
7878

79-
tu = transpose(sol.u)
79+
tu = adjoint(sol.u)
8080
∂A = BroadcastArray(@~ .-.* tu))
8181
∂b = λ
8282
∂prob = LinearProblem(∂A, ∂b, ∂∅)

test/adjoint.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x), copy(b1))
4444
@test dA dA2
4545
@test db1 db12
4646

47+
# Test complex numbers
48+
A = rand(n, n) + 1im*rand(n, n);
49+
b1 = rand(n) + 1im*rand(n);
50+
4751
function f3(A, b1, b2; alg = KrylovJL_GMRES())
4852
prob = LinearProblem(A, b1)
4953
sol1 = solve(prob, alg)
@@ -66,6 +70,9 @@ db22 = FiniteDiff.finite_difference_gradient(
6670
@test db1 db12
6771
@test db2 db22
6872

73+
A = rand(n, n);
74+
b1 = rand(n);
75+
6976
function f4(A, b1, b2; alg = LUFactorization())
7077
prob = LinearProblem(A, b1)
7178
sol1 = solve(prob, alg; sensealg = LinearSolveAdjoint(; linsolve = KrylovJL_LSMR()))

0 commit comments

Comments
 (0)