@@ -179,21 +179,48 @@ end
179
179
@scalar_rule floor (x) zero (x)
180
180
@scalar_rule ceil (x) zero (x)
181
181
182
- # note: rules for ^ are defined in the fastmath_able.jl
183
- function frule ((_, _, Δx, _), :: typeof (Base. literal_pow), :: typeof (^ ), x:: Real , pv:: Val{p} ) where p
184
- y = Base. literal_pow (^ , x, pv)
185
- return y, ifelse (iszero (x), zero (y), p * y / x * Δx)
182
+ # `literal_pow`
183
+ # Note that rules for `^` are defined in the fastmath_able.jl
184
+
185
+ function frule ((_, _, Δx, _), :: typeof (Base. literal_pow), :: typeof (^ ), x:: Real , :: Val{p} ) where p
186
+ yox = Base. literal_pow (^ , x, Val (p- 1 ))
187
+ if p < 0 && iseven (p)
188
+ # When p<0 and x==0, using yox * x for the primal gives NaN instead of +-Inf
189
+ y = ifelse (iszero (x), oftype (yox, Inf ), yox * x)
190
+ elseif p < 0
191
+ y = ifelse (iszero (x), copysign (oftype (yox, Inf ), x), yox * x)
192
+ else
193
+ y = yox * x
194
+ end
195
+ return y, p * yox * Δx
186
196
end
187
197
frule ((_, _, Δx, _), :: typeof (Base. literal_pow), :: typeof (^ ), x:: Real , :: Val{1} ) = x^ 1 , Δx
198
+ frule ((_, _, Δx, _), :: typeof (Base. literal_pow), :: typeof (^ ), x:: Real , :: Val{0} ) = x^ 0 , zero (Δx)
188
199
189
- function rrule (:: typeof (Base. literal_pow), :: typeof (^ ), x:: Real , pv:: Val{p} ) where p
190
- y = Base. literal_pow (^ , x, pv)
191
- function literal_pow_pullback (dy)
192
- return NoTangent (), NoTangent (), ifelse (iszero (x), zero (y), p * y / x * dy), NoTangent ()
200
+ function rrule (:: typeof (Base. literal_pow), :: typeof (^ ), x:: Real , :: Val{p} ) where p
201
+ yox = Base. literal_pow (^ , x, Val (p- 1 ))
202
+ project = ProjectTo (x)
203
+ @inline function literal_pow_pullback (dy)
204
+ return NoTangent (), NoTangent (), project (p * yox * dy), NoTangent ()
205
+ end
206
+ if p < 0 && iseven (p)
207
+ # When p<0 and x==0, using yox * x for the primal gives NaN instead of +-Inf
208
+ y = ifelse (iszero (x), oftype (yox, Inf ), yox * x)
209
+ elseif p < 0
210
+ y = ifelse (iszero (x), copysign (oftype (yox, Inf ), x), yox * x)
211
+ else
212
+ y = yox * x
193
213
end
194
214
return y, literal_pow_pullback
195
215
end
196
- function rrule (:: typeof (Base. literal_pow), :: typeof (^ ), x:: Real , pv:: Val{1} )
197
- literal_pow_one_pullback (dy) = NoTangent (), NoTangent (), dy, NoTangent ()
216
+ function rrule (:: typeof (Base. literal_pow), :: typeof (^ ), x:: Real , :: Val{1} )
217
+ project = ProjectTo (x)
218
+ literal_pow_one_pullback (dy) = NoTangent (), NoTangent (), project (dy), NoTangent ()
198
219
return x^ 1 , literal_pow_one_pullback
199
220
end
221
+ function rrule (:: typeof (Base. literal_pow), :: typeof (^ ), x:: Real , :: Val{0} )
222
+ # Since 0^0 == 1 == 0.001^0, this gradient should not be NaN at x==0
223
+ project = ProjectTo (x)
224
+ literal_pow_zero_pullback (dy) = NoTangent (), NoTangent (), project (zero (dy)), NoTangent ()
225
+ return x^ 0 , literal_pow_zero_pullback
226
+ end
0 commit comments