@@ -7,8 +7,8 @@ Given a Linear Problem ``A x = b`` computes the sensitivities for ``A`` and ``b`
7
7
8
8
```math
9
9
\b egin{align}
10
- A^T \l ambda &= \p artial x \\
11
- \p artial A &= -\l ambda x^T \\
10
+ A' \l ambda &= \p artial x \\
11
+ \p artial A &= -\l ambda x' \\
12
12
\p artial b &= \l ambda
13
13
\e nd{align}
14
14
```
@@ -20,7 +20,7 @@ For more details, check [these notes](https://math.mit.edu/~stevenj/18.336/adjoi
20
20
Note that in most cases, it makes sense to use the same linear solver for the adjoint as the
21
21
forward solve (this is done by keeping the linsolve as `missing`). For example, if the
22
22
forward solve was performed via a Factorization, then we can reuse the factorization for the
23
- adjoint solve. However, for specific structured matrices if ``A^T `` is known to have a
23
+ adjoint solve. However, for specific structured matrices if ``A' `` is known to have a
24
24
specific structure distinct from ``A`` then passing in a `linsolve` will be more efficient.
25
25
"""
26
26
@kwdef struct LinearSolveAdjoint{L} < :
@@ -62,21 +62,21 @@ function CRC.rrule(::typeof(SciMLBase.solve), prob::LinearProblem,
62
62
elseif cache. cacheval isa Tuple && cache. cacheval[1 ] isa Factorization
63
63
first (cache. cacheval)' \ ∂u
64
64
elseif alg isa AbstractKrylovSubspaceMethod
65
- invprob = LinearProblem (transpose (cache. A), ∂u)
65
+ invprob = LinearProblem (adjoint (cache. A), ∂u)
66
66
solve (invprob, alg; cache. abstol, cache. reltol, cache. verbose). u
67
67
elseif alg isa DefaultLinearSolver
68
68
LinearSolve. defaultalg_adjoint_eval (cache, ∂u)
69
69
else
70
- invprob = LinearProblem (transpose (A_), ∂u) # We cached `A`
70
+ invprob = LinearProblem (adjoint (A_), ∂u) # We cached `A`
71
71
solve (invprob, alg; cache. abstol, cache. reltol, cache. verbose). u
72
72
end
73
73
else
74
- invprob = LinearProblem (transpose (A_), ∂u) # We cached `A`
74
+ invprob = LinearProblem (adjoint (A_), ∂u) # We cached `A`
75
75
λ = solve (
76
76
invprob, sensealg. linsolve; cache. abstol, cache. reltol, cache. verbose). u
77
77
end
78
78
79
- tu = transpose (sol. u)
79
+ tu = adjoint (sol. u)
80
80
∂A = BroadcastArray (@~ .- (λ .* tu))
81
81
∂b = λ
82
82
∂prob = LinearProblem (∂A, ∂b, ∂∅)
0 commit comments