52
52
# exponents
53
53
@scalar_rule cbrt (x) inv (3 * Ω ^ 2 )
54
54
@scalar_rule inv (x) - (Ω ^ 2 )
55
- @scalar_rule sqrt (x) inv (2 Ω)
55
+ @scalar_rule sqrt (x) inv (2 Ω) # gradient +Inf at x==0
56
56
@scalar_rule exp (x) Ω
57
57
@scalar_rule exp10 (x) Ω * log (oftype (x, 10 ))
58
58
@scalar_rule exp2 (x) Ω * log (oftype (x, 2 ))
137
137
138
138
# Binary functions
139
139
140
- # `hypot`
141
-
140
+ # # `hypot`
142
141
function frule (
143
142
(_, Δx, Δy),
144
143
:: typeof (hypot),
@@ -163,17 +162,52 @@ let
163
162
@scalar_rule x + y (true , true )
164
163
@scalar_rule x - y (true , - 1 )
165
164
@scalar_rule x / y (one (x) / y, - (Ω / y))
166
- # log(complex(x)) is required so it gives correct complex answer for x<0
167
- @scalar_rule (x ^ y, (
168
- ifelse (iszero (x), ifelse (isone (y), one (Ω), zero (Ω)), y * Ω / x),
169
- Ω * log (complex (x)),
170
- ))
171
- # x^y for x < 0 errors when y is not an integer, but then derivative wrt y
172
- # is undefined, so we adopt subgradient convention and set derivative to 0.
173
- @scalar_rule (x:: Real ^ y:: Real , (
174
- ifelse (iszero (x), ifelse (isone (y), one (Ω), zero (Ω)), y * Ω / x),
175
- Ω * log (oftype (Ω, ifelse (x ≤ 0 , one (x), x))),
176
- ))
165
+
166
+ # # power
167
+ # literal_pow is in base.jl
168
+ function frule ((_, Δx, Δp), :: typeof (^ ), x:: Number , p:: Number )
169
+ yox = x ^ (p- 1 )
170
+ y = yox * x
171
+ thelog = if Δp isa AbstractZero
172
+ # Then don't waste time computing log
173
+ NoTangent ()
174
+ elseif x isa Real && p isa Real
175
+ # For positive x we'd like a real answer, including any Δp.
176
+ # For negative x, this is a DomainError unless isinteger(p)...
177
+ # could decide that implues that p is non-differentiable:
178
+ # log(ifelse(x<0, one(x), x))
179
+
180
+ # or we could match what the rrule with ProjectTo gives:
181
+ real (log (complex (x)))
182
+ #=
183
+
184
+ julia> frule((0,0,1), ^, -4, 3.0), unthunk.(rrule(^, -4, 3.0)[2](1))
185
+ ((-64.0, 0.0), (NoTangent(), 48.0, -88.722839111673))
186
+
187
+ julia> frule((0,0,1), ^, 4, 3.0), unthunk.(rrule(^, 4, 3.0)[2](1))
188
+ ((64.0, 88.722839111673), (NoTangent(), 48.0, 88.722839111673))
189
+ =#
190
+ else
191
+ # This promotion handles e.g. real x & complex p
192
+ log (oftype (y, x))
193
+ end
194
+ return y, muladd (y * thelog, Δp, p * yox * Δx)
195
+ end
196
+ function rrule (:: typeof (^ ), x:: Number , p:: Number )
197
+ yox = x ^ (p- 1 )
198
+ project_x, project_p = ProjectTo (x), ProjectTo (p)
199
+ @inline function power_pullback (dy)
200
+ dx = project_x (conj (p * yox) * dy)
201
+ dp = @thunk if x isa Real && p isa Real
202
+ project_p (conj (yox * x * log (complex (x))) * dy)
203
+ else
204
+ project_p (conj (yox * x * log (oftype (yox, x))) * dy)
205
+ end
206
+ return (NoTangent (), dx, dp)
207
+ end
208
+ return yox * x, power_pullback
209
+ end
210
+
177
211
@scalar_rule (
178
212
rem (x, y),
179
213
@setup ((u, nan) = promote (x / y, NaN16 ), isint = isinteger (x / y)),
235
269
non_transformed_definitions = intersect (fastable_ast. args, fast_ast. args)
236
270
filter! (expr-> ! (expr isa LineNumberNode), non_transformed_definitions)
237
271
if ! isempty (non_transformed_definitions)
238
- error (
239
- " Non-FastMath compatible rules defined in fastmath_able.jl. \n Definitions:\n " *
240
- join (non_transformed_definitions, " \n " )
272
+ @warn (
273
+ " Non-FastMath compatible rules defined in fastmath_able.jl." , # \n Definitions:\n" *
274
+ # join(non_transformed_definitions, "\n")
275
+ non_transformed_definitions
241
276
)
242
277
end
243
278
0 commit comments