@@ -167,38 +167,44 @@ let
167
167
# literal_pow is in base.jl
168
168
function frule ((_, Δx, Δp), :: typeof (^ ), x:: Number , p:: Number )
169
169
y = x ^ p
170
- dx = _pow_grad_x (x, p, float (y))
170
+ _dx = _pow_grad_x (x, p, float (y))
171
171
# When x < 0 && isinteger(p), could decide p is non-differentiable, isolated
172
172
# points, but chose to match what the rrule with ProjectTo gives, real(log(...)):
173
- dp = Δp isa AbstractZero ? Δp : _pow_grad_p (x, p, float (y))
174
- return y, muladd (dp , Δp, dx * Δx)
173
+ _dp = Δp isa AbstractZero ? Δp : _pow_grad_p (x, p, float (y))
174
+ return y, muladd (_dp , Δp, _dx * Δx)
175
175
end
176
176
177
177
function rrule (:: typeof (^ ), x:: Number , p:: Number )
178
178
y = x^ p
179
- project_x, project_p = ProjectTo (x), ProjectTo (p)
179
+ project_x = ProjectTo (x)
180
+ project_p = ProjectTo (p)
180
181
@inline function power_pullback (dy)
181
- dx = project_x (conj (_pow_grad_x (x,p,float (y))) * dy)
182
- dp = @thunk project_p (conj (_pow_grad_p (x,p,float (y))) * dy)
183
- return (NoTangent (), dx, dp)
182
+ _dx = _pow_grad_x (x, p, float (y))
183
+ _dy = _pow_grad_p (x, p, float (y))
184
+ return (
185
+ NoTangent (),
186
+ project_x (conj (_dx) * dy),
187
+ @thunk project_p (conj (_dy) * dy)
188
+ )
184
189
end
185
190
return y, power_pullback
186
191
end
187
192
193
+ # # `rem`
188
194
@scalar_rule (
189
195
rem (x, y),
190
196
@setup ((u, nan) = promote (x / y, NaN16 ), isint = isinteger (x / y)),
191
197
(ifelse (isint, nan, one (u)), ifelse (isint, nan, - trunc (u))),
192
198
)
199
+ # # `min`, `max`
193
200
@scalar_rule max (x, y) @setup (gt = x > y) (gt, ! gt)
194
201
@scalar_rule min (x, y) @setup (gt = x > y) (! gt, gt)
195
202
196
203
# Unary functions
197
204
@scalar_rule + x true
198
205
@scalar_rule - x - 1
199
206
200
- # `sign`
201
-
207
+ # # `sign`
202
208
function frule ((_, Δx), :: typeof (sign), x)
203
209
n = ifelse (iszero (x), one (real (x)), abs (x))
204
210
Ω = x isa Real ? sign (x) : x / n
263
269
# Thes functions need to be defined outside the eval() block.
264
270
# The special cases they aim to hit are in POWERGRADS in tests.
265
271
_pow_grad_x (x, p, y) = (p * y / x)
266
- # function _pow_grad_x(x::Real, p::Real, y)
267
- # return ifelse(!iszero(x) | (p<0), (p * y / x),
268
- # ifelse(isone(p), one(y),
269
- # ifelse((0<p) | (p<1), oftype(y, Inf), zero(y) )))
270
- # end
271
272
function _pow_grad_x (x:: Real , p:: Real , y)
272
273
return if ! iszero (x) || p < 0
273
274
p * y / x
@@ -281,10 +282,6 @@ function _pow_grad_x(x::Real, p::Real, y)
281
282
end
282
283
283
284
_pow_grad_p (x, p, y) = y * log (complex (x))
284
- # function _pow_grad_p(x::Real, p::Real, y)
285
- # return ifelse(!iszero(x), y * real(log(complex(x))),
286
- # ifelse(p>0, zero(y), oftype(y, NaN) ))
287
- # end
288
285
function _pow_grad_p (x:: Real , p:: Real , y)
289
286
return if ! iszero (x)
290
287
y * real (log (complex (x)))
0 commit comments