Skip to content

Commit 3c364f7

Browse files
authored
Ensure scalar_rule rules are Number typed (#62)
Currently, `@scalar_rule` generates `rrule` methods with no type constraints. This poses a problem for rules with different definitions that have not yet been implemented. For example, the rule for matrix exponential is quite different from that for scalar exponential, but in the absence of an `rrule(::typeof(exp), ::AbstractMatrix)` method, the incorrect fallback generated by `@scalar_rule` is used. To solve this, `@scalar_rule` now allows type constaints, e.g. `@scalar_rule(f(x::Complex), g(x))`, and it adds explicit `::Number` constraints to the generated methods if no such constraints are provided.
1 parent 6053597 commit 3c364f7

File tree

3 files changed

+35
-4
lines changed

3 files changed

+35
-4
lines changed

src/rules.jl

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,22 +410,29 @@ A convenience macro that generates simple scalar forward or reverse rules using
410410
the provided partial derivatives. Specifically, generates the corresponding
411411
methods for `frule` and `rrule`:
412412
413-
function ChainRules.frule(::typeof(f), x₁, x₂, ...)
413+
function ChainRules.frule(::typeof(f), x₁::Number, x₂::Number, ...)
414414
Ω = f(x₁, x₂, ...)
415415
\$(statement₁, statement₂, ...)
416416
return Ω, (Rule((Δx₁, Δx₂, ...) -> ∂f₁_∂x₁ * Δx₁ + ∂f₁_∂x₂ * Δx₂ + ...),
417417
Rule((Δx₁, Δx₂, ...) -> ∂f₂_∂x₁ * Δx₁ + ∂f₂_∂x₂ * Δx₂ + ...),
418418
...)
419419
end
420420
421-
function ChainRules.rrule(::typeof(f), x₁, x₂, ...)
421+
function ChainRules.rrule(::typeof(f), x₁::Number, x₂::Number, ...)
422422
Ω = f(x₁, x₂, ...)
423423
\$(statement₁, statement₂, ...)
424424
return Ω, (Rule((ΔΩ₁, ΔΩ₂, ...) -> ∂f₁_∂x₁ * ΔΩ₁ + ∂f₂_∂x₁ * ΔΩ₂ + ...),
425425
Rule((ΔΩ₁, ΔΩ₂, ...) -> ∂f₁_∂x₂ * ΔΩ₁ + ∂f₂_∂x₂ * ΔΩ₂ + ...),
426426
...)
427427
end
428428
429+
If no type constraints in `f(x₁, x₂, ...)` within the call to `@scalar_rule` are
430+
provided, each parameter in the resulting `frule`/`rrule` definition is given a
431+
type constraint of `Number`.
432+
Constraints may also be explicitly be provided to override the `Number` constraint,
433+
e.g. `f(x₁::Complex, x₂)`, which will constrain `x₁` to `Complex` and `x₂` to
434+
`Number`.
435+
429436
Note that the result of `f(x₁, x₂, ...)` is automatically bound to `Ω`. This
430437
allows the primal result to be conveniently referenced (as `Ω`) within the
431438
derivative/setup expressions.
@@ -458,7 +465,19 @@ macro scalar_rule(call, maybe_setup, partials...)
458465
partials = (maybe_setup, partials...)
459466
end
460467
@assert Meta.isexpr(call, :call)
461-
f, inputs = esc(call.args[1]), esc.(call.args[2:end])
468+
f = esc(call.args[1])
469+
# Annotate all arguments in the signature as scalars
470+
inputs = map(call.args[2:end]) do arg
471+
esc(Meta.isexpr(arg, :(::)) ? arg : Expr(:(::), arg, :Number))
472+
end
473+
# Remove annotations and escape names for the call
474+
for (i, arg) in enumerate(call.args)
475+
if Meta.isexpr(arg, :(::))
476+
call.args[i] = esc(first(arg.args))
477+
else
478+
call.args[i] = esc(arg)
479+
end
480+
end
462481
if all(Meta.isexpr(partial, :tuple) for partial in partials)
463482
forward_rules = Any[rule_from_partials(partial.args...) for partial in partials]
464483
reverse_rules = Any[]

src/rules/base.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
@scalar_rule(adjoint(x), Wirtinger(Zero(), One()))
3939
@scalar_rule(transpose(x), One())
4040
@scalar_rule(abs(x), sign(x))
41-
@scalar_rule(rem2pi(x, r), (One(), DNE()))
41+
@scalar_rule(rem2pi(x, r::RoundingMode), (One(), DNE()))
4242
@scalar_rule(+(x), One())
4343
@scalar_rule(-(x), -1)
4444
@scalar_rule(+(x, y), (One(), One()))

test/rules.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,25 @@
11
cool(x) = x + 1
22
cool(x, y) = x + y + 1
33

4+
_second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
5+
46
@testset "rules" begin
57
@testset "frule and rrule" begin
68
@test frule(cool, 1) === nothing
79
@test frule(cool, 1; iscool=true) === nothing
810
@test rrule(cool, 1) === nothing
911
@test rrule(cool, 1; iscool=true) === nothing
12+
1013
ChainRules.@scalar_rule(Main.cool(x), one(x))
14+
@test hasmethod(rrule, Tuple{typeof(cool),Number})
15+
ChainRules.@scalar_rule(Main.cool(x::String), "wow such dfdx")
16+
@test hasmethod(rrule, Tuple{typeof(cool),String})
17+
# Ensure those are the *only* methods that have been defined
18+
cool_methods = Set(m.sig for m in methods(rrule) if _second(m.sig) == typeof(cool))
19+
only_methods = Set([Tuple{typeof(rrule),typeof(cool),Number},
20+
Tuple{typeof(rrule),typeof(cool),String}])
21+
@test cool_methods == only_methods
22+
1123
frx, fr = frule(cool, 1)
1224
@test frx == 2
1325
@test fr(1) == 1

0 commit comments

Comments
 (0)