Skip to content

Commit 46e5e9c

Browse files
committed
unrelated test fixes in passing
1 parent 879ff8b commit 46e5e9c

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

test/rulesets/Base/base.jl

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,16 @@
9898
@testset "($x) * ($y)" for
9999
x in test_points, y in test_points
100100

101-
# ensure all complex if any complex for FiniteDifferences
102-
x, y = Base.promote(x, y)
101+
# all complex if any complex, was a limitation of FiniteDifferences?
102+
xx, yy = Base.promote(x, y)
103+
test_frule(*, xx, yy)
104+
test_rrule(*, xx, yy)
103105

106+
# explicitly allow mixed types
104107
test_frule(*, x, y)
105108
test_rrule(*, x, y)
109+
rrule(*, x, y)[2](1)[2] isa typeof(x)
110+
rrule(*, x, y)[2](1)[3] isa typeof(y)
106111
end
107112
end
108113

@@ -136,14 +141,17 @@
136141
test_rrule(identity, Tuple(randn(T, 3)))
137142
end
138143

139-
@testset "Constants" for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im, 0.0+200im)
144+
@testset "one(::Number), zero(::Number)" for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im, 0.0+200im)
140145
test_scalar(one, x)
141146
test_scalar(zero, x)
147+
148+
rrule(one, x)[2](1) === (NoTangent(), zero(x))
149+
rrule(zero, x)[2](1) === (NoTangent(), zero(x))
142150
end
143151

144152
@testset "muladd(x::$T, y::$T, z::$T)" for T in (Float64, ComplexF64)
145-
test_frule(muladd, 10randn(), randn(), randn())
146-
test_rrule(muladd, 10randn(), randn(), randn())
153+
test_frule(muladd, 10randn(T), randn(T), randn(T))
154+
test_rrule(muladd, 10randn(T), randn(T), randn(T))
147155
end
148156

149157
@testset "fma" begin
@@ -163,6 +171,13 @@
163171
# to right
164172
test_frule(clamp, 4., 2., 3.)
165173
test_rrule(clamp, 4., 2., 3.)
174+
175+
# nonzero gradient at the boundaries
176+
@test frule((0,1,0,0), clamp, 2, 2, 3) == (2, 1)
177+
@test rrule(clamp, 2.0, 2, 3)[2](1)[2] == 1.0
178+
179+
@test frule((0,1,0,0), clamp, 3, 2, 3) == (3, 1)
180+
@test rrule(clamp, 3, 2, 3)[2](1)[2] == 1.0
166181
end
167182

168183
@testset "rounding" begin

0 commit comments

Comments
 (0)