Skip to content

Commit 9972f36

Browse files
authored
Use handwritten rules for zero and one
1 parent e37c9a6 commit 9972f36

File tree

1 file changed

+30
-3
lines changed

1 file changed

+30
-3
lines changed

src/rulesets/Base/base.jl

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,38 @@
22
# that also have FastMath versions.
33

44
@scalar_rule copysign(y, x) (ifelse(signbit(x)!=signbit(y), -one(y), +one(y)), NoTangent())
5-
6-
@scalar_rule one(x) ZeroTangent()
7-
@scalar_rule zero(x) ZeroTangent()
85
@scalar_rule transpose(x) true
96

7+
# `zero`
8+
9+
function frule((_, Δ1), ::typeof(zero), x)
10+
var"∂f/∂x" = ZeroTangent()
11+
(zero(x), Δ1 * var"∂f/∂x")
12+
end
13+
14+
function rrule(::typeof(zero), x)
15+
Ω = zero(x)
16+
proj_x = ProjectTo(x)
17+
var"∂f/∂x" = ZeroTangent()
18+
pullback(Δ1) = (NoTangent(), proj_x(conj(var"∂f/∂x") * Δ1))
19+
(Ω, pullback)
20+
end
21+
22+
# `one`
23+
24+
function frule((_, Δ1), ::typeof(one), x)
25+
var"∂f/∂x" = ZeroTangent()
26+
(one(x), Δ1 * var"∂f/∂x")
27+
end
28+
29+
function rrule(::typeof(one), x)
30+
Ω = one(x)
31+
proj_x = ProjectTo(x)
32+
var"∂f/∂x" = ZeroTangent()
33+
pullback(Δ1) = (NoTangent(), proj_x(conj(var"∂f/∂x") * Δ1))
34+
(Ω, pullback)
35+
end
36+
1037
# `adjoint`
1138

1239
frule((_, Δz), ::typeof(adjoint), z::Number) = (z', Δz')

0 commit comments

Comments
 (0)