Skip to content

Commit 4aa2c8b

Browse files
add rules for constant returning zero and one function (JuliaDiff#78)
* add rules for constant returning zero and one function * Update test/rulesets/Base/base.jl Co-Authored-By: Nick Robinson <npr251@gmail.com> * Update test/rulesets/Base/base.jl Co-Authored-By: Nick Robinson <npr251@gmail.com>
1 parent 7f8af88 commit 4aa2c8b

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

src/rulesets/Base/base.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
@scalar_rule(one(x), Zero())
2+
@scalar_rule(zero(x), Zero())
13
@scalar_rule(abs2(x), Wirtinger(x', x))
24
@scalar_rule(log(x), inv(x))
35
@scalar_rule(log10(x), inv(x) / log(oftype(x, 10)))

test/rulesets/Base/base.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,5 +123,22 @@ end
123123
rrule_test(identity, randn(rng), (randn(rng), randn(rng)))
124124
rrule_test(identity, randn(rng, 4), (randn(rng, 4), randn(rng, 4)))
125125
end
126+
127+
@testset "Constants" begin
128+
function test_constant(f, x, expected)
129+
y, rule = frule(f, x)
130+
@test y == expected
131+
@test extern(rule(1)) == 0.0
132+
133+
y, rule = rrule(f, x)
134+
@test y == expected
135+
@test extern(rule(1)) == 0.0
136+
end
137+
test_constant(one, 5, 1)
138+
test_constant(one, -4.1, 1)
139+
140+
test_constant(zero, 5, 0)
141+
test_constant(zero, -4.1, 0)
142+
end
126143
end
127144
# TODO: Non-trig stuff

0 commit comments

Comments
 (0)