Skip to content

Commit e569283

Browse files
authored
Merge pull request #773 from MasonProtter/patch-1
Use handwritten rules for `zero` and `one`
2 parents e37c9a6 + cc93409 commit e569283

File tree

3 files changed

+33
-4
lines changed

3 files changed

+33
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.59.0"
3+
version = "1.59.1"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/rulesets/Base/base.jl

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,30 @@
22
# that also have FastMath versions.
33

44
@scalar_rule copysign(y, x) (ifelse(signbit(x)!=signbit(y), -one(y), +one(y)), NoTangent())
5-
6-
@scalar_rule one(x) ZeroTangent()
7-
@scalar_rule zero(x) ZeroTangent()
85
@scalar_rule transpose(x) true
96

7+
# `zero`
8+
9+
function frule((_, _), ::typeof(zero), x)
10+
return (zero(x), ZeroTangent())
11+
end
12+
13+
function rrule(::typeof(zero), x)
14+
zero_pullback(_) = (NoTangent(), ZeroTangent())
15+
return (zero(x), zero_pullback)
16+
end
17+
18+
# `one`
19+
20+
function frule((_, _), ::typeof(one), x)
21+
return (one(x), ZeroTangent())
22+
end
23+
24+
function rrule(::typeof(one), x)
25+
one_pullback(_) = (NoTangent(), ZeroTangent())
26+
return (one(x), one_pullback)
27+
end
28+
1029
# `adjoint`
1130

1231
frule((_, Δz), ::typeof(adjoint), z::Number) = (z', Δz')

test/rulesets/Base/base.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
11
@testset "base.jl" begin
2+
@testset "zero/one" begin
3+
for f in [zero, one]
4+
for x in [1.0, 1.0im, [10.0+im 11.0-im; 12.0+2im 13.0-3im]]
5+
test_frule(f, x)
6+
test_rrule(f, x)
7+
end
8+
end
9+
test_frule(zero, [1.0, 2.0, 3.0])
10+
test_rrule(zero, [1.0, 2.0, 3.0])
11+
end
212
@testset "copysign" begin
313
# don't go too close to zero as the numerics may jump over it yielding wrong results
414
@testset "at $y" for y in (-1.1, 0.1, 100.0)

0 commit comments

Comments
 (0)