diff --git a/src/adjoint.jl b/src/adjoint.jl index 4781602fb..02e4b068b 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -28,6 +28,14 @@ specific structure distinct from ``A`` then passing in a `linsolve` will be more linsolve::L = missing end +function CRC.rrule(T::typeof(SciMLBase.solve), prob::LinearProblem, alg::Nothing, args...; kwargs...) + @show "here?" + assump = OperatorAssumptions(issquare(prob.A)) + alg = defaultalg(prob.A, prob.b, assump) + @show alg + CRC.rrule(T, prob, alg, args...; kwargs...) +end + function CRC.rrule(::typeof(SciMLBase.solve), prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, args...; alias_A = default_alias_A( alg, prob.A, prob.b), kwargs...) diff --git a/src/default.jl b/src/default.jl index f5772ade4..7b642425d 100644 --- a/src/default.jl +++ b/src/default.jl @@ -364,12 +364,20 @@ end @generated function defaultalg_adjoint_eval(cache::LinearCache, dy) ex = :() for alg in first.(EnumX.symbol_map(DefaultAlgorithmChoice.T)) - newex = if alg in Symbol.((DefaultAlgorithmChoice.MKLLUFactorization, - DefaultAlgorithmChoice.AppleAccelerateLUFactorization, - DefaultAlgorithmChoice.RFLUFactorization)) + newex = if alg == Symbol(DefaultAlgorithmChoice.RFLUFactorization) quote getproperty(cache.cacheval, $(Meta.quot(alg)))[1]' \ dy end + elseif alg == Symbol(DefaultAlgorithmChoice.MKLLUFactorization) + quote + A = getproperty(cache.cacheval, $(Meta.quot(alg)))[1] + getrs!('T', A.factors, A.ipiv, dy) + end + elseif alg == Symbol(DefaultAlgorithmChoice.AppleAccelerateLUFactorization) + quote + A = getproperty(cache.cacheval, $(Meta.quot(alg)))[1] + aa_getrs!('T', A.factors, A.ipiv, dy) + end elseif alg in Symbol.((DefaultAlgorithmChoice.LUFactorization, DefaultAlgorithmChoice.QRFactorization, DefaultAlgorithmChoice.KLUFactorization,