Skip to content

Commit 0df040f

Browse files
authored
simplify zero and one rules
1 parent 275b93b commit 0df040f

File tree

1 file changed

+8
-16
lines changed

1 file changed

+8
-16
lines changed

src/rulesets/Base/base.jl

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,32 +6,24 @@
66

77
# `zero`
88

9-
function frule((_, Δ1), ::typeof(zero), x)
10-
var"∂f/∂x" = ZeroTangent()
11-
return (zero(x), Δ1 * var"∂f/∂x")
9+
function frule((_, _), ::typeof(zero), x)
10+
return (zero(x), ZeroTangent())
1211
end
1312

1413
function rrule(::typeof(zero), x)
15-
Ω = zero(x)
16-
proj_x = ProjectTo(x)
17-
var"∂f/∂x" = ZeroTangent()
18-
pullback(Δ1) = (NoTangent(), proj_x(conj(var"∂f/∂x") * Δ1))
19-
return (Ω, pullback)
14+
zero_pullback(_) = (NoTangent(), ZeroTangent())
15+
return (zero(x), zero_pullback)
2016
end
2117

2218
# `one`
2319

24-
function frule((_, Δ1), ::typeof(one), x)
25-
var"∂f/∂x" = ZeroTangent()
26-
return (one(x), Δ1 * var"∂f/∂x")
20+
function frule((_, _), ::typeof(one), x)
21+
return (one(x), ZeroTangent())
2722
end
2823

2924
function rrule(::typeof(one), x)
30-
Ω = one(x)
31-
proj_x = ProjectTo(x)
32-
var"∂f/∂x" = ZeroTangent()
33-
pullback(Δ1) = (NoTangent(), proj_x(conj(var"∂f/∂x") * Δ1))
34-
return (Ω, pullback)
25+
one_pullback(_) = (NoTangent(), ZeroTangent())
26+
return (one(x), one_pullback)
3527
end
3628

3729
# `adjoint`

0 commit comments

Comments
 (0)