|
90 | 90 | ##### `Adjoint`
|
91 | 91 | #####
|
92 | 92 |
|
93 |
| -# ✖️✖️✖️TODO: Deal with complex-valued arrays as well |
94 |
| -function rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Real}) |
95 |
| - function Adjoint_pullback(ȳ) |
96 |
| - return (NO_FIELDS, adjoint(ȳ)) |
97 |
| - end |
| 93 | +function rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Number}) |
| 94 | + Adjoint_pullback(ȳ::Composite) = (NO_FIELDS, ȳ.parent) |
| 95 | + Adjoint_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, adjoint(ȳ)) |
98 | 96 | return Adjoint(A), Adjoint_pullback
|
99 | 97 | end
|
100 | 98 |
|
101 |
| -function rrule(::Type{<:Adjoint}, A::AbstractVector{<:Real}) |
102 |
| - function Adjoint_pullback(ȳ) |
103 |
| - return (NO_FIELDS, vec(adjoint(ȳ))) |
104 |
| - end |
| 99 | +function rrule(::Type{<:Adjoint}, A::AbstractVector{<:Number}) |
| 100 | + Adjoint_pullback(ȳ::Composite) = (NO_FIELDS, vec(ȳ.parent)) |
| 101 | + Adjoint_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(adjoint(ȳ))) |
105 | 102 | return Adjoint(A), Adjoint_pullback
|
106 | 103 | end
|
107 | 104 |
|
108 |
| -function rrule(::typeof(adjoint), A::AbstractMatrix{<:Real}) |
109 |
| - function adjoint_pullback(ȳ) |
110 |
| - return (NO_FIELDS, adjoint(ȳ)) |
111 |
| - end |
| 105 | +function rrule(::typeof(adjoint), A::AbstractMatrix{<:Number}) |
| 106 | + adjoint_pullback(ȳ::Composite) = (NO_FIELDS, ȳ.parent) |
| 107 | + adjoint_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, adjoint(ȳ)) |
112 | 108 | return adjoint(A), adjoint_pullback
|
113 | 109 | end
|
114 | 110 |
|
115 |
| -function rrule(::typeof(adjoint), A::AbstractVector{<:Real}) |
116 |
| - function adjoint_pullback(ȳ) |
117 |
| - return (NO_FIELDS, vec(adjoint(ȳ))) |
118 |
| - end |
| 111 | +function rrule(::typeof(adjoint), A::AbstractVector{<:Number}) |
| 112 | + adjoint_pullback(ȳ::Composite) = (NO_FIELDS, vec(ȳ.parent)) |
| 113 | + adjoint_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(adjoint(ȳ))) |
119 | 114 | return adjoint(A), adjoint_pullback
|
120 | 115 | end
|
121 | 116 |
|
122 | 117 | #####
|
123 | 118 | ##### `Transpose`
|
124 | 119 | #####
|
125 | 120 |
|
126 |
| -function rrule(::Type{<:Transpose}, A::AbstractMatrix) |
127 |
| - function Transpose_pullback(ȳ) |
128 |
| - return (NO_FIELDS, transpose(ȳ)) |
129 |
| - end |
| 121 | +function rrule(::Type{<:Transpose}, A::AbstractMatrix{<:Number}) |
| 122 | + Transpose_pullback(ȳ::Composite) = (NO_FIELDS, ȳ.parent) |
| 123 | + Transpose_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, Transpose(ȳ)) |
130 | 124 | return Transpose(A), Transpose_pullback
|
131 | 125 | end
|
132 | 126 |
|
133 |
| -function rrule(::Type{<:Transpose}, A::AbstractVector) |
134 |
| - function Transpose_pullback(ȳ) |
135 |
| - return (NO_FIELDS, vec(transpose(ȳ))) |
136 |
| - end |
| 127 | +function rrule(::Type{<:Transpose}, A::AbstractVector{<:Number}) |
| 128 | + Transpose_pullback(ȳ::Composite) = (NO_FIELDS, vec(ȳ.parent)) |
| 129 | + Transpose_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(Transpose(ȳ))) |
137 | 130 | return Transpose(A), Transpose_pullback
|
138 | 131 | end
|
139 | 132 |
|
140 |
| -function rrule(::typeof(transpose), A::AbstractMatrix) |
141 |
| - function transpose_pullback(ȳ) |
142 |
| - return (NO_FIELDS, transpose(ȳ)) |
143 |
| - end |
| 133 | +function rrule(::typeof(transpose), A::AbstractMatrix{<:Number}) |
| 134 | + transpose_pullback(ȳ::Composite) = (NO_FIELDS, ȳ.parent) |
| 135 | + transpose_pullback(ȳ::AbstractVecOrMat) = (NO_FIELDS, transpose(ȳ)) |
144 | 136 | return transpose(A), transpose_pullback
|
145 | 137 | end
|
146 | 138 |
|
147 |
| -function rrule(::typeof(transpose), A::AbstractVector) |
148 |
| - function transpose_pullback(ȳ) |
149 |
| - return (NO_FIELDS, vec(transpose(ȳ))) |
150 |
| - end |
| 139 | +function rrule(::typeof(transpose), A::AbstractVector{<:Number}) |
| 140 | + transpose_pullback(ȳ::Composite) = (NO_FIELDS, vec(ȳ.parent)) |
| 141 | + transpose_pullback(ȳ::AbstractMatrix) = (NO_FIELDS, vec(transpose(ȳ))) |
151 | 142 | return transpose(A), transpose_pullback
|
152 | 143 | end
|
153 | 144 |
|
|
0 commit comments