Skip to content

Commit 6b67260

Browse files
committed
move @scalar_rule to its own file
1 parent ab27e9f commit 6b67260

File tree

3 files changed

+147
-150
lines changed

3 files changed

+147
-150
lines changed

src/ChainRulesCore.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ export extern, cast, store!, Wirtinger, Zero, One, Casted, DNE, Thunk, DNERule
88

99
include("differentials.jl")
1010
include("rules.jl")
11+
include("rule_definition_tools.jl")
1112
end # module

src/rule_definition_tools.jl

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# These are some macros (and supporting functions) to make it easier to define rules.
2+
3+
"""
4+
@scalar_rule(f(x₁, x₂, ...),
5+
@setup(statement₁, statement₂, ...),
6+
(∂f₁_∂x₁, ∂f₁_∂x₂, ...),
7+
(∂f₂_∂x₁, ∂f₂_∂x₂, ...),
8+
...)
9+
10+
A convenience macro that generates simple scalar forward or reverse rules using
11+
the provided partial derivatives. Specifically, generates the corresponding
12+
methods for `frule` and `rrule`:
13+
14+
function ChainRulesCore.frule(::typeof(f), x₁::Number, x₂::Number, ...)
15+
Ω = f(x₁, x₂, ...)
16+
\$(statement₁, statement₂, ...)
17+
return Ω, (Rule((Δx₁, Δx₂, ...) -> ∂f₁_∂x₁ * Δx₁ + ∂f₁_∂x₂ * Δx₂ + ...),
18+
Rule((Δx₁, Δx₂, ...) -> ∂f₂_∂x₁ * Δx₁ + ∂f₂_∂x₂ * Δx₂ + ...),
19+
...)
20+
end
21+
22+
function ChainRulesCore.rrule(::typeof(f), x₁::Number, x₂::Number, ...)
23+
Ω = f(x₁, x₂, ...)
24+
\$(statement₁, statement₂, ...)
25+
return Ω, (Rule((ΔΩ₁, ΔΩ₂, ...) -> ∂f₁_∂x₁ * ΔΩ₁ + ∂f₂_∂x₁ * ΔΩ₂ + ...),
26+
Rule((ΔΩ₁, ΔΩ₂, ...) -> ∂f₁_∂x₂ * ΔΩ₁ + ∂f₂_∂x₂ * ΔΩ₂ + ...),
27+
...)
28+
end
29+
30+
If no type constraints in `f(x₁, x₂, ...)` within the call to `@scalar_rule` are
31+
provided, each parameter in the resulting `frule`/`rrule` definition is given a
32+
type constraint of `Number`.
33+
Constraints may also be explicitly be provided to override the `Number` constraint,
34+
e.g. `f(x₁::Complex, x₂)`, which will constrain `x₁` to `Complex` and `x₂` to
35+
`Number`.
36+
37+
Note that the result of `f(x₁, x₂, ...)` is automatically bound to `Ω`. This
38+
allows the primal result to be conveniently referenced (as `Ω`) within the
39+
derivative/setup expressions.
40+
41+
Note that the `@setup` argument can be elided if no setup code is need. In other
42+
words:
43+
44+
@scalar_rule(f(x₁, x₂, ...),
45+
(∂f₁_∂x₁, ∂f₁_∂x₂, ...),
46+
(∂f₂_∂x₁, ∂f₂_∂x₂, ...),
47+
...)
48+
49+
is equivalent to:
50+
51+
@scalar_rule(f(x₁, x₂, ...),
52+
@setup(nothing),
53+
(∂f₁_∂x₁, ∂f₁_∂x₂, ...),
54+
(∂f₂_∂x₁, ∂f₂_∂x₂, ...),
55+
...)
56+
57+
For examples, see ChainRulesCore' `rules` directory.
58+
59+
See also: [`frule`](@ref), [`rrule`](@ref), [`AbstractRule`](@ref)
60+
"""
61+
macro scalar_rule(call, maybe_setup, partials...)
62+
if Meta.isexpr(maybe_setup, :macrocall) && maybe_setup.args[1] == Symbol("@setup")
63+
setup_stmts = map(esc, maybe_setup.args[3:end])
64+
else
65+
setup_stmts = (nothing,)
66+
partials = (maybe_setup, partials...)
67+
end
68+
@assert Meta.isexpr(call, :call)
69+
f = esc(call.args[1])
70+
# Annotate all arguments in the signature as scalars
71+
inputs = map(call.args[2:end]) do arg
72+
esc(Meta.isexpr(arg, :(::)) ? arg : Expr(:(::), arg, :Number))
73+
end
74+
# Remove annotations and escape names for the call
75+
for (i, arg) in enumerate(call.args)
76+
if Meta.isexpr(arg, :(::))
77+
call.args[i] = esc(first(arg.args))
78+
else
79+
call.args[i] = esc(arg)
80+
end
81+
end
82+
if all(Meta.isexpr(partial, :tuple) for partial in partials)
83+
forward_rules = Any[rule_from_partials(partial.args...) for partial in partials]
84+
reverse_rules = Any[]
85+
for i in 1:length(inputs)
86+
reverse_partials = [partial.args[i] for partial in partials]
87+
push!(reverse_rules, rule_from_partials(reverse_partials...))
88+
end
89+
else
90+
@assert length(inputs) == 1 && all(!Meta.isexpr(partial, :tuple) for partial in partials)
91+
forward_rules = Any[rule_from_partials(partial) for partial in partials]
92+
reverse_rules = Any[rule_from_partials(partials...)]
93+
end
94+
forward_rules = length(forward_rules) == 1 ? forward_rules[1] : Expr(:tuple, forward_rules...)
95+
reverse_rules = length(reverse_rules) == 1 ? reverse_rules[1] : Expr(:tuple, reverse_rules...)
96+
return quote
97+
function ChainRulesCore.frule(::typeof($f), $(inputs...))
98+
$(esc()) = $call
99+
$(setup_stmts...)
100+
return $(esc()), $forward_rules
101+
end
102+
function ChainRulesCore.rrule(::typeof($f), $(inputs...))
103+
$(esc()) = $call
104+
$(setup_stmts...)
105+
return $(esc()), $reverse_rules
106+
end
107+
end
108+
end
109+
110+
function rule_from_partials(∂s...)
111+
wirtinger_indices = findall(x -> Meta.isexpr(x, :call) && x.args[1] === :Wirtinger, ∂s)
112+
∂s = map(esc, ∂s)
113+
Δs = [Symbol(string(, i)) for i in 1:length(∂s)]
114+
Δs_tuple = Expr(:tuple, Δs...)
115+
if isempty(wirtinger_indices)
116+
∂_mul_Δs = [:(mul(@thunk($(∂s[i])), $(Δs[i]))) for i in 1:length(∂s)]
117+
return :(Rule($Δs_tuple -> add($(∂_mul_Δs...))))
118+
else
119+
∂_mul_Δs_primal = Any[]
120+
∂_mul_Δs_conjugate = Any[]
121+
∂_wirtinger_defs = Any[]
122+
for i in 1:length(∂s)
123+
if i in wirtinger_indices
124+
Δi = Δs[i]
125+
∂i = Symbol(string(:∂, i))
126+
push!(∂_wirtinger_defs, :($∂i = $(∂s[i])))
127+
∂f∂i_mul_Δ = :(mul(wirtinger_primal($∂i), wirtinger_primal($Δi)))
128+
∂f∂ī_mul_Δ̄ = :(mul(conj(wirtinger_conjugate($∂i)), wirtinger_conjugate($Δi)))
129+
∂f̄∂i_mul_Δ = :(mul(wirtinger_conjugate($∂i), wirtinger_primal($Δi)))
130+
∂f̄∂ī_mul_Δ̄ = :(mul(conj(wirtinger_primal($∂i)), wirtinger_conjugate($Δi)))
131+
push!(∂_mul_Δs_primal, :(add($∂f∂i_mul_Δ, $∂f∂ī_mul_Δ̄)))
132+
push!(∂_mul_Δs_conjugate, :(add($∂f̄∂i_mul_Δ, $∂f̄∂ī_mul_Δ̄)))
133+
else
134+
∂_mul_Δ = :(mul(@thunk($(∂s[i])), $(Δs[i])))
135+
push!(∂_mul_Δs_primal, ∂_mul_Δ)
136+
push!(∂_mul_Δs_conjugate, ∂_mul_Δ)
137+
end
138+
end
139+
primal_rule = :(Rule($Δs_tuple -> add($(∂_mul_Δs_primal...))))
140+
conjugate_rule = :(Rule($Δs_tuple -> add($(∂_mul_Δs_conjugate...))))
141+
return quote
142+
$(∂_wirtinger_defs...)
143+
WirtingerRule($primal_rule, $conjugate_rule)
144+
end
145+
end
146+
end

src/rules.jl

Lines changed: 0 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -350,153 +350,3 @@ true
350350
See also: [`frule`](@ref), [`AbstractRule`](@ref), [`@scalar_rule`](@ref)
351351
"""
352352
rrule(::Any, ::Vararg{Any}; kwargs...) = nothing
353-
354-
355-
#####
356-
##### macros
357-
#####
358-
359-
"""
360-
@scalar_rule(f(x₁, x₂, ...),
361-
@setup(statement₁, statement₂, ...),
362-
(∂f₁_∂x₁, ∂f₁_∂x₂, ...),
363-
(∂f₂_∂x₁, ∂f₂_∂x₂, ...),
364-
...)
365-
366-
A convenience macro that generates simple scalar forward or reverse rules using
367-
the provided partial derivatives. Specifically, generates the corresponding
368-
methods for `frule` and `rrule`:
369-
370-
function ChainRulesCore.frule(::typeof(f), x₁::Number, x₂::Number, ...)
371-
Ω = f(x₁, x₂, ...)
372-
\$(statement₁, statement₂, ...)
373-
return Ω, (Rule((Δx₁, Δx₂, ...) -> ∂f₁_∂x₁ * Δx₁ + ∂f₁_∂x₂ * Δx₂ + ...),
374-
Rule((Δx₁, Δx₂, ...) -> ∂f₂_∂x₁ * Δx₁ + ∂f₂_∂x₂ * Δx₂ + ...),
375-
...)
376-
end
377-
378-
function ChainRulesCore.rrule(::typeof(f), x₁::Number, x₂::Number, ...)
379-
Ω = f(x₁, x₂, ...)
380-
\$(statement₁, statement₂, ...)
381-
return Ω, (Rule((ΔΩ₁, ΔΩ₂, ...) -> ∂f₁_∂x₁ * ΔΩ₁ + ∂f₂_∂x₁ * ΔΩ₂ + ...),
382-
Rule((ΔΩ₁, ΔΩ₂, ...) -> ∂f₁_∂x₂ * ΔΩ₁ + ∂f₂_∂x₂ * ΔΩ₂ + ...),
383-
...)
384-
end
385-
386-
If no type constraints in `f(x₁, x₂, ...)` within the call to `@scalar_rule` are
387-
provided, each parameter in the resulting `frule`/`rrule` definition is given a
388-
type constraint of `Number`.
389-
Constraints may also be explicitly be provided to override the `Number` constraint,
390-
e.g. `f(x₁::Complex, x₂)`, which will constrain `x₁` to `Complex` and `x₂` to
391-
`Number`.
392-
393-
Note that the result of `f(x₁, x₂, ...)` is automatically bound to `Ω`. This
394-
allows the primal result to be conveniently referenced (as `Ω`) within the
395-
derivative/setup expressions.
396-
397-
Note that the `@setup` argument can be elided if no setup code is need. In other
398-
words:
399-
400-
@scalar_rule(f(x₁, x₂, ...),
401-
(∂f₁_∂x₁, ∂f₁_∂x₂, ...),
402-
(∂f₂_∂x₁, ∂f₂_∂x₂, ...),
403-
...)
404-
405-
is equivalent to:
406-
407-
@scalar_rule(f(x₁, x₂, ...),
408-
@setup(nothing),
409-
(∂f₁_∂x₁, ∂f₁_∂x₂, ...),
410-
(∂f₂_∂x₁, ∂f₂_∂x₂, ...),
411-
...)
412-
413-
For examples, see ChainRulesCore' `rules` directory.
414-
415-
See also: [`frule`](@ref), [`rrule`](@ref), [`AbstractRule`](@ref)
416-
"""
417-
macro scalar_rule(call, maybe_setup, partials...)
418-
if Meta.isexpr(maybe_setup, :macrocall) && maybe_setup.args[1] == Symbol("@setup")
419-
setup_stmts = map(esc, maybe_setup.args[3:end])
420-
else
421-
setup_stmts = (nothing,)
422-
partials = (maybe_setup, partials...)
423-
end
424-
@assert Meta.isexpr(call, :call)
425-
f = esc(call.args[1])
426-
# Annotate all arguments in the signature as scalars
427-
inputs = map(call.args[2:end]) do arg
428-
esc(Meta.isexpr(arg, :(::)) ? arg : Expr(:(::), arg, :Number))
429-
end
430-
# Remove annotations and escape names for the call
431-
for (i, arg) in enumerate(call.args)
432-
if Meta.isexpr(arg, :(::))
433-
call.args[i] = esc(first(arg.args))
434-
else
435-
call.args[i] = esc(arg)
436-
end
437-
end
438-
if all(Meta.isexpr(partial, :tuple) for partial in partials)
439-
forward_rules = Any[rule_from_partials(partial.args...) for partial in partials]
440-
reverse_rules = Any[]
441-
for i in 1:length(inputs)
442-
reverse_partials = [partial.args[i] for partial in partials]
443-
push!(reverse_rules, rule_from_partials(reverse_partials...))
444-
end
445-
else
446-
@assert length(inputs) == 1 && all(!Meta.isexpr(partial, :tuple) for partial in partials)
447-
forward_rules = Any[rule_from_partials(partial) for partial in partials]
448-
reverse_rules = Any[rule_from_partials(partials...)]
449-
end
450-
forward_rules = length(forward_rules) == 1 ? forward_rules[1] : Expr(:tuple, forward_rules...)
451-
reverse_rules = length(reverse_rules) == 1 ? reverse_rules[1] : Expr(:tuple, reverse_rules...)
452-
return quote
453-
function ChainRulesCore.frule(::typeof($f), $(inputs...))
454-
$(esc()) = $call
455-
$(setup_stmts...)
456-
return $(esc()), $forward_rules
457-
end
458-
function ChainRulesCore.rrule(::typeof($f), $(inputs...))
459-
$(esc()) = $call
460-
$(setup_stmts...)
461-
return $(esc()), $reverse_rules
462-
end
463-
end
464-
end
465-
466-
function rule_from_partials(∂s...)
467-
wirtinger_indices = findall(x -> Meta.isexpr(x, :call) && x.args[1] === :Wirtinger, ∂s)
468-
∂s = map(esc, ∂s)
469-
Δs = [Symbol(string(, i)) for i in 1:length(∂s)]
470-
Δs_tuple = Expr(:tuple, Δs...)
471-
if isempty(wirtinger_indices)
472-
∂_mul_Δs = [:(mul(@thunk($(∂s[i])), $(Δs[i]))) for i in 1:length(∂s)]
473-
return :(Rule($Δs_tuple -> add($(∂_mul_Δs...))))
474-
else
475-
∂_mul_Δs_primal = Any[]
476-
∂_mul_Δs_conjugate = Any[]
477-
∂_wirtinger_defs = Any[]
478-
for i in 1:length(∂s)
479-
if i in wirtinger_indices
480-
Δi = Δs[i]
481-
∂i = Symbol(string(:∂, i))
482-
push!(∂_wirtinger_defs, :($∂i = $(∂s[i])))
483-
∂f∂i_mul_Δ = :(mul(wirtinger_primal($∂i), wirtinger_primal($Δi)))
484-
∂f∂ī_mul_Δ̄ = :(mul(conj(wirtinger_conjugate($∂i)), wirtinger_conjugate($Δi)))
485-
∂f̄∂i_mul_Δ = :(mul(wirtinger_conjugate($∂i), wirtinger_primal($Δi)))
486-
∂f̄∂ī_mul_Δ̄ = :(mul(conj(wirtinger_primal($∂i)), wirtinger_conjugate($Δi)))
487-
push!(∂_mul_Δs_primal, :(add($∂f∂i_mul_Δ, $∂f∂ī_mul_Δ̄)))
488-
push!(∂_mul_Δs_conjugate, :(add($∂f̄∂i_mul_Δ, $∂f̄∂ī_mul_Δ̄)))
489-
else
490-
∂_mul_Δ = :(mul(@thunk($(∂s[i])), $(Δs[i])))
491-
push!(∂_mul_Δs_primal, ∂_mul_Δ)
492-
push!(∂_mul_Δs_conjugate, ∂_mul_Δ)
493-
end
494-
end
495-
primal_rule = :(Rule($Δs_tuple -> add($(∂_mul_Δs_primal...))))
496-
conjugate_rule = :(Rule($Δs_tuple -> add($(∂_mul_Δs_conjugate...))))
497-
return quote
498-
$(∂_wirtinger_defs...)
499-
WirtingerRule($primal_rule, $conjugate_rule)
500-
end
501-
end
502-
end

0 commit comments

Comments
 (0)